点击上方“图灵人工智能”,选择“星标”公众号
您想知道的人工智能干货,第一时间送达
版权声明
转自CreateAMind,仅用于学术分享,如有侵权留言删除Data reconstruction from machine learning models via inverse estimation and Bayesian inference
通过逆向估计与贝叶斯推断从机器学习模型中重构数据https://www.nature.com/articles/s41598-025-96215-z#citeas
本研究探讨了通过逆估计(inverse estimation)和贝叶斯推断从机器学习模型中进行数据重建的任务,旨在仅基于已训练的模型恢复原始数据集。我们提出了一种新颖的理论框架,用于研究影响数据重建质量的各种因素。具体而言,我们推导出若干表达式,通过考察关键变量关于独立变量的偏导数的并发行为,量化这些变量的变化如何影响真实后验分布与估计后验分布之间的差异。这种基于导数的方法建立了变量之间的理论关联,表明重建数据的保真度主要由两个因素决定:(1)所假设先验的准确性,以及(2)机器学习模型本身的准确性。在多个基准数据集和机器学习算法上的实验结果验证了这些理论预测,进一步证实了我们理论框架的有效性与鲁棒性。在实际应用中,我们的数据重建方法能够生成合成模型,这些模型可高度复现原始模型的性能。本研究推动了在机器学习背景下对数据重建和模型内省(model introspection)的理论理解与实践技术的发展。
从机器学习模型中恢复原始数据集的能力,在机器学习研究与应用的多个领域具有深远意义。在诸如联邦学习(federated learning)和迁移学习(transfer learning)等新兴范式中,由于原始数据受限或分布于去中心化的设备上,仅通过模型重建训练数据的能力显得尤为宝贵。数据重建有助于严格的模型验证和误差分析,使研究人员能够识别出仅通过传统性能指标难以察觉的偏差、异常或模型失效问题。此外,数据重建涉及推断与真实数据相似的样本,这为日益发展的合成数据生成领域做出了贡献。
从机器学习模型中重建数据的问题可被视为一个“逆问题”(inverse problem),这一概念在系统生物学等领域已被广泛应用。在系统生物学中,研究人员常面临类似挑战,即从可观测现象中推断生物模型中不可观测的微观或介观参数。例如,Hartoyo 等人(2019年)和 Hartoyo 等人(2020年)应用贝叶斯推断研究脑电图(EEG)中两个重要但尚未被充分理解的特征:α振荡和α阻断现象。通过对神经群体模型进行逆估计,他们揭示了每种现象背后的生理驱动机制,尽管不同个体的脑电特征存在显著差异,但仍发现了跨个体的共性规律。我们的工作将这一范式扩展到机器学习模型中。在此背景下,与传统的“正向问题”(即训练样本决定模型对目标类别的预测)不同,逆问题是从一个目标类别出发,重构出能使模型产生该目标类别预测的训练样本。
问题表述
本研究聚焦于求解机器学习模型中的逆问题,具体目标是在给定一个已训练模型的前提下,恢复其原始训练数据。我们的研究局限于对采用z-score标准化方法预处理过的数据集所训练的机器学习模型进行分析。z-score标准化是一种标准的数据预处理方法,尤其适用于数据集中各特征具有不同量纲或尺度的情况。通过z-score标准化,各特征被转换到统一的尺度上,这有助于提升机器学习模型的性能和稳定性,特别是对于依赖距离计算或基于梯度优化的模型而言尤为重要。
我们采用贝叶斯推断框架来应对从机器学习模型中进行数据重建的挑战。如定义7所述,贝叶斯推断为估计后验分布 q(Θ | c) 提供了一种原则性方法,其中 Θ 代表原始数据集的参数(或特征),c 表示观测到的目标类别。通过贝叶斯法则将先验分布 q(Θ) 与似然函数 p(c|Θ) 相结合,我们得到后验分布 q(Θ | c),该分布表示在观测到目标类别 c 后,我们对参数的信念更新。
研究概述与亮点
本研究探讨了一个新颖且尚未被深入研究的问题——从机器学习模型中进行数据重建。据我们所知,这一问题此前尚未被系统性地研究过。这一有趣而复杂的问题为机器学习中的逆问题研究开辟了新的方向。
本文提出了一种新颖的理论框架,用于模拟和逼近逆估计过程(详见“理论框架”一节)。该框架采用基于导数的方法,建立了关键变量之间的理论关联,特别是:重建数据的保真度、所假设先验的准确性,以及底层机器学习模型的准确性。通过分析这些关系的偏导数,该方法从理论上表明,先验假设和模型预测中的不准确性会共同放大数据重建中的偏差。通过对这些关系进行严谨的数学推导,我们为机器学习中逆问题的研究开辟了新的研究路径,提供了可推广至更广泛场景的宝贵理论预期。
尽管我们的理论模型和推导是理想化的,可能无法完全捕捉实际应用中的所有细节,但它们指导了“实验方法”一节中描述的实证方法,并在多个基准数据集和机器学习算法上验证了理论预测。如“实验结果”一节所示,实验结果与理论预测高度一致,进一步增强了我们理论框架的鲁棒性。
在实际应用层面,该方法还能够基于推断出的合成数据训练生成合成模型。实验结果表明,这些合成模型在相同数据集上的测试性能与原始模型高度接近。这种理论与实践并重的研究方式,不仅加深了我们对逆估计过程的理论理解,也为数据重建和模型内省提供了实用的技术手段。
贝叶斯推断设置
在我们的研究设定中,真实先验和真实后验是可以通过经验获取的,因为它们天然存在于数据集中。某一特定参数 θi 的边缘真实先验 p(θi) 通过在整个数据集中计算该参数的经验分布得到,反映了该参数在不依赖于特定输出情况下的整体变异性。相比之下,边缘真实后验 p(θi | c) 是在给定特定类别 c 的条件下推导得出的。真实后验的推导方法是:先提取数据集中属于类别 c 的子集,然后计算该子集中参数 θi 的分布。图1展示了心脏病数据集(Heart Disease dataset)中部分参数的边缘化真实先验和后验分布。这使我们能够获得“真实情况”(ground truth),从而可以将我们所假设的先验和估计的后验与数据集中观察到的实际分布进行比较,进而深入理解我们贝叶斯推断方法的性能与鲁棒性。
贝叶斯推断的核心在于使用先验分布 q(Θ),它编码了我们在观测任何数据之前对参数的信念。我们对每个参数 θi 采用边缘均匀先验分布 q(θi) = U(θi | −3, 3),如定义1所述,并在图2中以心脏病数据集的部分参数为例,将其与真实分布进行了对比。选择均匀先验的做法借鉴了先前研究工作(Hartoyo 等人,2019, 2020)在贝叶斯推断框架中的方法,该方法已被证明在从复杂生理数据中提取有意义信息方面非常有效。这一选择也与早期贝叶斯学派先驱(如贝叶斯本人和拉普拉斯)所推荐的历史性做法一致,尤其是在对参数分布缺乏先验知识的情况下。区间 [−3, 3] 的选择基于数据的z-score标准化:在均值为0、标准差为1的正态分布假设下,±3个标准差范围涵盖了约99.7%的数据。这体现了“三西格玛法则”(Three-sigma rule),即几乎所有数据值都落在均值的三个标准差范围内,超出该范围的事件极不可能发生。因此,我们将该区间称为参数的合理范围(plausible range)。值得注意的是,实际的先验分布形状并不需要符合某种特定的参数形式,这一点在图1和图2中已明显体现——这些图中所示的边缘先验分布形状差异显著,包括单峰、双峰甚至四峰分布。
与机器学习模型相关的似然函数 p(c | Θ) 描述了观测到的目标类别 c 如何依赖于参数 Θ。我们使用 scikit-learn 环境中分类器类的 predict_proba
方法来表示与模型相关的经验似然函数,该方法计算在给定输入参数 Θ 的条件下,目标类别 c 上的概率分布。
实验结果
在本节中,我们展示了一系列实证研究结果,用以验证前文理论框架部分所提出的理论预期。这些结果来源于在多个基准数据集和多种机器学习算法上进行的全面实验,实验设计详见“实验方法”一节。
图3展示了针对心脏病数据集上训练的深度神经网络模型,通过逆估计分析得到的部分参数的边缘估计后验分布与边缘真实后验分布的对比。该分析突出了估计后验分布与真实后验分布在多大程度上保持一致或出现偏差。图中附带的KL散度(KLD)值为这种偏差提供了数值度量。
图4展示了在心脏病数据集上训练的深度神经网络模型进行逆估计分析时的几个关键相关性。各子图分别显示了:(a) 边缘真实先验与假设先验之间的KL散度,与边缘真实后验和估计后验之间平均KL散度的相关性;(b) 基础模型的预测误差与边缘真实后验和估计后验之间平均KL散度的相关性;(c) 基础模型与合成模型的训练准确率之间的相关性;以及 (d) 两者在测试集上的准确率之间的相关性。有趣的是,在本案例中,合成模型在相同测试集上的表现通常优于基础模型,这一点从子图(d)中数据点在对角线周围的分布情况可以看出。通常而言,相关系数超过0.5被视为强相关,超过0.7则被视为非常强的相关。根据这一通用标准,我们可以观察到,在针对心脏病数据集训练的深度神经网络模型的逆估计分析中,所有相关性均表现出非常强的关联。其他数据集和机器学习算法对应的类似图表见附录A。
在图4所示相关性分析的基础上,表1进一步扩展了这一分析,展示了在不同数据集和机器学习算法下关键指标之间的相关性。这些相关性始终表现出强或非常强的关联。所有实验结果均可完全复现,完整代码和数据集的链接见“数据可用性”一节。
基于表1的分析结果,图5以可视化方式总结了我们分析中识别出的四个关键关系的皮尔逊相关系数(ρ)分布情况。箱线图显示,所有相关系数的中位数均高于0.7,再次证实了这些关系的强度。尽管存在少数离群值,但它们极为罕见,并未影响整体结果的一致性,反而突显了所考察的不同评估指标之间高度一致的关联模式。
讨论
我们的实验结果验证了本研究的理论基础。具体而言,实验中观察到的相关性支持了以下理论预期:
真实先验与假设先验之间的KL散度(KLD)与真实后验和估计后验之间平均KLD的相关性:实验中观察到这两组KLD值之间存在强相关性,这与推论1的理论预期一致。该推论强调了一种预期关系:先验信念的偏差(通过先验的KLD衡量)会反映在后验分布的差异上(通过后验的KLD衡量)。
模型预测误差与真实后验和估计后验之间平均KLD的相关性:我们的实验结果揭示了这两者之间存在强相关性。这一发现与推论2的理论预测相符,即后验估计的保真度与机器学习模型的准确性密切相关。
基于以上两点,我们可以推断:如果先验假设和模型预测中的误差均被最小化至零(或接近零),那么真实后验与估计后验之间的KL散度也将趋近于零。这一理论推断支持了定义9——即当所假设的先验与真实先验完全一致,且似然函数也准确时,估计的后验将与真实后验完全重合。
为了实证验证这一理论预期,我们使用MCMC采样方法进行了一项实验:在心脏病数据集上,将真实先验分布作为假设先验进行后验估计。结果如图6所示,与理论预测一致。与图3中使用均匀先验估计后验的结果相比,当前结果明显改善:估计的后验在视觉上与真实后验非常接近,甚至在模态(如单峰、双峰等)上也高度一致。这种视觉上的一致性进一步由后验侧极低的KL散度值所证实,表明估计后验与真实后验之间的差异极小。这些结果在“真实情况”(ground truth)条件下强化了我们理论框架的鲁棒性:准确的先验假设和完美的模型性能能够实现精确的后验估计。
理论预期与实验结果之间的一致性,对机器学习模型中的逆估计具有重要意义。所观察到的相关性为这些系统的设计与评估提供了可操作的洞见,特别是在需要精确后验估计的场景中尤为重要。
- 真实与假设先验的KLD与真实和估计后验的KLD之间的强相关性,凸显了准确先验假设的重要性。这一发现强调了持续验证和优化先验的必要性,因为提高先验的准确性可直接带来更可靠的数据重建。
尽管本研究出于简化和数学可处理性的考虑采用了均匀先验,但这并不意味着我们主张均匀先验在所有场景下都是最优选择。相反,使用均匀先验是为了建立一个基准,用以展示本框架的能力,并为改进先验选择提供洞见。在实际应用中,先验的选择应与数据特征相匹配。对于结构化数据(如网络或时间序列),可设计包含依赖关系的先验,例如基于图的先验或自回归先验;对于非结构化数据,基于领域知识的先验或从数据中经验推导出的先验通常能更好地匹配真实潜在分布。为了确保先验在数据重建中的准确性和有效性,其优化应充分考虑这些特性。
当先验假设与真实分布存在显著偏差时,可采用先进技术来减轻不准确性的影响。例如,分层先验(hierarchical priors)引入可随数据调整的超参数;经验贝叶斯方法(empirical Bayes)则直接从观测数据中估计先验。这些方法提供了灵活性和鲁棒性,减少了对预设假设的依赖。类似地,当先验知识有限时,非信息性或弱信息性先验(non-informative or weakly informative priors)非常有价值,因为它们能最小化对后验的不当影响。先验正则化(prior regularization)和神经先验(neural priors)等技术进一步扩展了建模复杂分布的工具集。这些方法突显了选择既能保持灵活性又能保证鲁棒性的先验的重要性,为在多样化场景中实现更可靠的数据重建铺平了道路。
- 模型预测误差与真实和估计后验之间KLD的强相关性表明,模型准确性对可靠的数据重建至关重要。提高模型性能不仅有助于提升正向问题中的预测准确率,也有助于保障逆问题中重建数据的完整性。
在我们的框架中,模型误差对数据重建的影响本质上与其“以训练为导向”的特性相关。如问题设定所述,重建任务的目标是:给定一个已训练的模型,恢复其原始训练数据。该框架直接从训练集中提取真实先验和真实后验,因此训练数据成为重建过程的“真实情况”基准。因此,重建质量取决于模型的训练准确率(或训练误差),其中欠拟合(偏差)是主要关注点。欠拟合意味着模型未能充分捕捉训练数据中的模式,从而限制了其忠实重建数据的能力。
然而,过拟合(方差)也是一个重要问题,尤其是在我们将基于重建数据训练的合成模型在测试集上评估其泛化能力时。过拟合发生在模型对训练数据拟合过度,包括噪声或无关模式,这会削弱其对未见数据的泛化能力,并在重建的分布中引入不一致或伪影。
平衡偏差与方差对模型性能至关重要,也直接影响数据重建的质量。为应对欠拟合,可采用增加模型容量(如增加参数、层数或深度)或优化训练过程(如调整学习率、使用先进优化技术)等策略,以帮助模型更好地捕捉数据模式。另一方面,过拟合可通过正则化技术(如Dropout或权重衰减)来缓解,这些方法在高维或小样本数据集中尤为有效。这些技术通过防止模型拟合训练数据中的噪声来促进泛化。此外,数据增强技术也有助于抑制过拟合,特别是在数据集较小或多样性不足的情况下。通过人工增加训练数据的规模和变异性,数据增强可降低模型对特定模式的记忆倾向,从而提升其泛化能力。平衡偏差与方差的有效策略还包括集成方法,如Bagging(如随机森林)和Boosting(如XGBoost),它们通过结合多个模型的预测来提升泛化性能。这些方法可同时降低方差与偏差,从而提升整体模型表现。交叉验证也有助于优化偏差与方差的平衡,为评估模型在未见数据上的表现提供手段。通过上述策略优化偏差与方差的平衡,不仅对提升预测准确性至关重要,也对提高数据重建质量具有关键作用。
我们的框架通过分析先验和模型预测中的不准确性如何影响后验分布与重建数据,实现了对不确定性传播的刻画。这一分析通过系统性实验实现:一方面,让先验不确定性基于数据集的边缘真实先验自然变化(而假设的边缘先验保持为均匀分布);另一方面,系统性地改变模型准确率以控制模型不确定性。后验KL散度作为稳健指标,量化了这些传播的不确定性对重建质量的影响。结果表明,输入中的不确定性会通过框架逐层传播,并直接影响重建数据的质量。
尽管后验KL散度(posterior KLD)为不确定性传播提供了清晰且可解释的度量,传统的不确定性量化技术仍可作为该方法的有益补充。例如,可信区间(credible intervals)能够直观地表示不确定性,而贝叶斯模型平均(Bayesian model averaging)则可通过整合多个模型的预测结果,降低对单一模型的依赖。敏感性分析(sensitivity analysis)可进一步量化先验不确定性和模型不确定性各自的相对贡献,从而指导有针对性的改进。这些扩展性方法虽不在本研究的当前范围内,但在未来工作中有望增强该框架的鲁棒性与适用性。
基础模型与合成模型在相同原始数据集上的准确率之间表现出强相关性,这凸显了我们方法在生成能够高度模拟基础模型性能的合成模型方面的有效性。更重要的是,合成模型的准确率不仅与基础模型高度相关,而且在绝对数值上也具有可比性。这种绝对可比性在心脏病数据集的图4d相关性图中清晰可见,也在附录A中其他数据集的对应图表中得到体现。在多个数据集上均表现出的强相关性和相近的性能指标,进一步验证了我们逆估计方法在从机器学习模型中进行数据重建方面的实际有效性,表明我们基于理论的方法能够可靠地恢复并保留原始数据的内在特征。
理论框架
分析方法与研究路径
本框架旨在通过关注关键变量在某一特定独立变量变化下的行为,揭示它们之间的相关性。我们提出:当一个共同的独立变量发生变化而其他所有因素保持不变时,若两个因变量表现出同步变化,则意味着它们之间存在相关性。具体而言,若这两个因变量对该共同独立变量的偏导数同为正或同为负,则认为它们是相关的。这些偏导数符号的一致性可作为关键变量之间潜在函数关系的代理指标,可靠地反映各因素对重建数据保真度的方向性影响。
这一方法在高维或非线性系统中尤为有价值,因为在这些系统中,直接推导关键变量之间关系的闭式表达式在计算上不可行,甚至在解析上不可能实现。通过引入一个共同的独立变量作为共享代理,我们能够在统一的框架下近似这些相关性,确保各关键变量之间的可比性,从而在无需显式函数形式的前提下,获得数学上严谨且计算上可行的洞见。
我们承认,在变量交互表现出强烈非单调行为或反馈回路的情境下,非线性依赖关系可能非常重要,此时单调模式不足以捕捉系统的复杂性。然而,当简单模式足以描述所关注的关系时,它们因其更高的可解释性和更易分析的特性而更受青睐。在本研究中,基于导数的方法使我们能够提取出此类单调或线性关系,这些关系足以捕捉最关键的趋向性,从而提供有意义的洞见。此外,这种分析方法与我们的实证评估保持一致——在实证中我们使用皮尔逊相关系数,该指标本身也假设变量间存在线性关系。理论与实证方法之间的一致性确保了本研究的分析在整体上保持连贯且具有实际意义。
我们的方法也符合既有的数学实践传统,即在不建模系统全部复杂性的前提下提取关键特征。例如,在稳定性分析中,特征值实部的符号足以判断扰动随时间是衰减还是增长;在基于二次型的优化问题中,特征值符号所决定的正定性即可保证极值的存在,而无需对函数曲面进行详细建模。类似地,在本框架中,偏导数的符号作为潜在函数关系的代理,使我们能够在不完全建模非线性依赖复杂性的情况下,近似相关性和方向性趋势。
为建立这些相关性,本框架引入了若干关键定义与定理,作为分析的基础,并最终导出两个主要推论。所有定理的完整证明以及支持性引理均见附录B。两个关键推论揭示了以下重要关系:
- 真实先验与假设先验之间的KL散度(KLD)与重建数据偏差之间的相关性,如推论1所述;
- 模型预测误差与重建数据偏差之间的相关性,如推论2所示。
我们的数学框架使用了理论上的先验、似然和后验分布,这些分布旨在定义或近似本文其他部分实验工作中所使用的经验对应量。在整个论文中,当我们比较这两种方法时,会明确将其称为“理论”或“经验”的先验、似然和后验,以清晰区分两者。
逆估计问题本质上是多变量的。然而,在本理论框架中,我们主要关注单变量情形,原因有二:第一,该问题的实证研究通常也采用单变量方法,即数据常被逐个变量地分析和可视化。即使是KL散度的经验评估,通常也是针对边缘分布进行计算的,因为准确估计高维分布存在巨大挑战。因此,为与实证方法保持一致,本框架也主要聚焦于单变量分布。第二,本框架旨在通过分析关键变量对单一独立参数的偏导数行为来揭示它们之间的相关性。由于偏导数本质上是针对单变量变化定义的,这种单变量视角使分析更具可操作性,有助于我们在逆估计的背景下推导出更清晰的理论预期。
关键定义与定理
在本小节中,我们建立支撑本研究理论框架的基础定义与定理。这些定义与定理为后续小节中将探讨的关系提供了必要的形式化结构。
我们框架的一个方面涉及使用均匀假设的先验,这确保了估计后验的分析可以简化为单变量贝叶斯分析,而不引入不准确性,正如定理1所建立的。另一个方面涉及对真实先验和真实后验的分析,其中参数之间独立性的假设,如定义11中所规定的,引入了一种权衡。正如定理2所形式化的,这一假设允许真实后验的分析简化为每个参数的单变量贝叶斯分析。这种简化的不准确性源于独立性假设与分布中实际存在的依赖性之间的偏差程度。因此,结果的准确性取决于独立性假设在多大程度上近似真实先验和似然。当参数表现出弱或无依赖性时,这种简化非常准确;相反,随着参数依赖性的增强,近似引入了更大的不准确性。然而,这种在简单性和准确性之间的折衷是本框架中采用的常见方法。类似的独立性假设也在诸如朴素贝叶斯等广泛使用的算法中做出,这些算法在实践中表现良好。这表明所提出的简化方法虽然不完美,但在许多情况下是一个理论上合理且分析上可处理的方法,提供了有意义的见解。通过减小先验差异来降低数据恢复偏差
研究重点与假设:在本小节中,我们探究两个关键因变量之间是否存在相关性:所假设的先验分布与真实先验分布之间的差异,以及恢复的数据分布与真实数据分布之间的偏差。我们试图回答的核心问题是:在不同条件下,所假设先验的准确性变化是否会影响数据恢复的准确性。这些不同的数据集条件通过真实先验边际分布的不同标准差来表示,这些标准差作为我们数据集中实际存在的多样性和变异性的代理指标。通过分析这些变量之间的关系,我们旨在从理论上揭示:更准确的先验假设是否会导致更优的数据恢复结果。实证方法
数据集和机器学习算法
我们使用广泛认可的基准数据集评估我们的方法,这些数据集用于二元和多变量分类:Sirtung 小分子数据集、视网膜神经节细胞刺激类型数据集、贫血类型数据集(具有五个类别标签)、脑卒中预测数据集、心脏病数据集、威斯康星乳腺癌数据集,以及基于笔迹识别的手写数字数据集(具有三个和五个类别标签)。对于机器学习部分,我们采用各种算法,包括深度神经网络、决策树分类器、支持向量机(SVM)、朴素贝叶斯、逻辑回归、随机森林分类器、投票分类器和梯度提升机来训练和评估这些数据集上的模型。
基础模型的构建
对于特定的数据集和特定的机器学习算法,我们系统地构建一组具有不同准确度水平的基础模型。每个基础模型都是使用应用于原始数据集或其系统失真版本的机器学习算法训练的。
数据集的失真是通过在训练集中以一组定义的增量移动特征值来实现的,同时在所有模型中保持相同的测试集。这个过程允许我们生成多个版本的训练数据,每个版本导致具有不同预测准确度级别的模型。通过应用诸如 [0, -0.2, -0.4, -0.6, -0.8, -1, 0.2, 0.4, 0.6, 0.8, 1] 这样的移位,我们为每种算法创建了11个不同的模型。原始的、未失真的数据集作为基线,确保所有模型的真实先验和真实后验保持一致以供参考。
采样方法
算法
计算复杂度
Metropolis 算法在采用随机游走提议机制时的计算复杂度受到目标测度的维度和规则性的影响,而在我们的情况下,目标测度即为所推断的后验分布。该复杂度以多项式形式增长,为 O(d²κ⁺¹)⁵⁷,其中 d 表示问题的维度,κ 反映后验分布的规则性。多项式时间复杂度在大多数实际场景中通常被认为是高效的。这种效率使得基于 Metropolis 算法的 MCMC 采样成为我们逆向估计框架中一种可行的方法。
合成模型的构建
在本研究中,我们引入了“合成模型”的概念。合成模型是指使用与对应基础模型完全相同的机器学习算法和超参数,在合成数据集上训练得到的机器学习模型。这些合成数据集是通过我们的重构过程从原始训练数据集中推导出的重建版本,作为训练合成模型的输入。随后,合成模型在与基础模型相同的原始测试集上进行评估,从而可以直接比较合成模型与基础模型的性能。
实验设置
我们的实验设置旨在系统性地探索与逆向估计和机器学习模型性能相关的各种指标之间的关系。具体步骤如下:
我们首先选取 8 个广泛认可的基准数据集,如“数据集与机器学习算法”小节中所述。对于每个数据集,我们提取其边缘真实先验分布和后验分布,作为真实基准(ground truth)。
每个数据集分别与两种机器学习算法配对:深度神经网络,以及从“数据集与机器学习算法”小节中提到的其他算法中选择的一种算法。这共形成 16 个数据集-算法组合。
对于每个数据集-算法组合,我们使用“基础模型构建”小节中所述的系统性失真技术,构建 11 个具有不同准确率水平的基础模型。这一过程生成一组多样化的模型,每个模型具有不同的预测准确率。
对每个基础模型,我们执行逆向估计,通过估计每个类别条件下的后验数据分布来重构数据集。该过程采用“采样方法”小节中所述的 MCMC 采样方法实现。所得到的重构数据集被称为合成数据集。
随后,每个合成数据集用于训练一个合成模型,方法如“合成模型构建”小节所述。
因此,整个实验共包含 8 个数据集 × 2 种算法 × 11 种准确率 = 176 个逆向估计实例,同时生成 176 个基础模型、176 个合成数据集和 176 个合成模型。
在本实验设置中,评估以下指标:
每个数据集中每个参数的边缘真实先验与假设先验之间的 KLD(KL 散度)。
每个数据集中每个参数的边缘真实后验与估计后验之间的 KLD,对应于每个逆向估计实例。我们将同一参数在不同类别间的 KLD 进行平均。
在特定逆向估计实例中,同一数据集内不同参数的先验 KLD(第 1 项)与后验 KLD(第 2 项)之间的相关性(如图 4a 中心脏病数据集所示)。
每个基础模型的训练准确率和测试准确率。
基础模型的准确率(第 4 项)与后验 KLD(第 2 项)在不同逆向估计实例之间的相关性(如图 4b 中心脏病数据集所示)。
每个合成模型的训练准确率和测试准确率。
合成模型的准确率(第 6 项)与基础模型的准确率(第 4 项)在不同逆向估计实例之间的相关性(如图 4c 和 4d 中心脏病数据集所示)。
该实验设计提供了一个实证框架,用于分析先验假设、模型准确率与数据重构质量之间的相互作用,从而在真实世界数据集上有效检验我们理论框架的预测。此外,它还评估了合成模型的性能,以检验这些模型在多大程度上能够复现基础模型的行为。


文章精选: