文章首页全文地址:https://pmc.ncbi.nlm.nih.gov/articles/PMC11711098/
代码地址:https://priorlabs.ai/tabpfn-nature/
数据地址:https://zenodo.org/records/13981285
思维导图一种专为表格设计的架构
Transformer架构目前是灵活深度学习和基础模型中备受青睐的架构。Transformer模型处理序列数据,并使用所谓的注意力机制将序列项之间的信息进行整合,这使得它们能够有效地捕捉长距离依赖关系,并学习数据中的复杂关系。尽管基于Transformer的模型可以应用于表格数据,但TabPFN解决了它们固有的两个关键局限性。
- 首先,由于Transformer是为序列设计的,它们将输入数据视为单个序列,而没有利用表格结构。
- 其次,机器学习模型通常以“拟合 - 预测”模式使用,即模型在训练集上进行一次拟合,然后重复用于多个测试数据集。然而,基于Transformer的上下文学习(ICL)算法在一次前向传播中接收训练数据和测试数据,从而一次性执行训练和预测。因此,当重复使用一个已拟合的模型时,它必须重新对训练集进行计算。
为了更好地利用表格结构,受参考文献22、28的启发,我们提出一种架构,为表格中的每个单元格赋予单独的表示。我们的架构如图1b所示,采用双向注意力机制,每个单元格先关注其所在行中的其他特征(即其所在样本),然后关注其所在列中的相同特征(即所有其他样本)。这种设计使得该架构对样本和特征的顺序均具有不变性,并且与训练时遇到的表格相比,在样本数量和特征数量方面,都能更高效地训练并外推至更大的表格。
为了避免在“拟合-预测”模式下针对每个测试样本都在训练集上重复计算,我们的模型可以将对训练样本和测试样本的推理过程分开。这使我们能够在训练集上仅执行一次上下文学习(ICL),保存所得状态,并将其重复用于多个测试集的推理。在拥有10000个训练样本和10个特征的数据集上,我们优化后的训练状态缓存使得推理速度提升了 。在CPU上提速约300倍(从32秒降至0.1秒),在GPU上提速6倍。当特征数量增加10倍(达到100个)时,CPU上的提速幅度增至800倍,GPU上则提速30倍。这些测量仅聚焦于核心推理过程,不包括“推理细节”部分详述的预处理和集成步骤。GPU上提速幅度较低,是由于其大规模并行架构未得到充分利用。
我们通过以半精度计算层归一化、使用Flash Attention(快速注意力机制)、激活检查点以及按顺序计算状态,进一步优化该架构的内存和计算需求。我们的优化措施将内存需求降低至原来的四分之一,使得每个单元格的内存占用不到1000字节。这使得在单个H100 GPU上就能对多达5000万个单元格的数据集(例如500万行×10个特征)进行预测。
基于因果模型的合成数据
TabPFN的性能依赖于生成合适的合成训练数据集,这些数据集需捕捉现实世界表格数据的特征与挑战。为生成此类数据集,我们开发了一种基于结构因果模型(SCMs)的方法 。结构因果模型为表示数据背后的因果关系和生成过程提供了一个正式框架。通过使用合成数据,而非大量公开的表格数据,我们避免了基础模型常见的问题,例如隐私与版权侵犯、测试数据污染训练数据,以及数据可用性有限等问题。
如图2所示,我们的生成流程首先对诸如数据集规模、特征数量和难度等级等高级超参数进行采样,以此把控每个合成数据集的整体特性。在这些超参数的指引下,我们构建一个有向无环图,用于明确数据集背后的因果结构。
为了生成数据集中的每个样本,我们将随机生成的噪声(称为初始化数据)通过因果图的根节点进行传播。这些初始化数据是从随机正态分布或均匀分布中采样得到的,样本之间具有不同程度的非独立性,详见“初始化数据采样”部分。当这些数据在计算图的边中传递时,我们会应用一系列多样的计算映射:带有线性或非线性激活函数(如Sigmoid、ReLU(修正线性单元)、取模、正弦函数)的小型神经网络、用于生成分类特征的离散化机制,以及对基于规则的局部依赖关系进行编码的决策树结构。在每条边上,我们添加高斯噪声,为生成的数据引入不确定性。我们会保存每个节点处的中间数据表示,以便后续提取。具体细节见“计算边映射”部分。
在遍历因果图之后,我们在采样得到的特征节点和目标节点处提取中间表示,从而生成一个由特征值和相关目标值组成的样本。
图2通过将各种数据挑战和复杂性融入合成数据集,我们创建了一个训练平台,使TabPFN能够制定应对现实世界数据集中类似问题的策略。例如,以表格数据中常见的缺失值问题为例。在合成数据生成过程中,让TabPFN接触具有不同缺失值模式和比例的合成数据集,该模型就能学习到处理缺失值的有效方法,并将其推广应用到现实世界的数据集中。我们应用后处理技术,进一步增强数据的真实性,并考验所学预测算法的稳健性。这包括使用库马拉斯瓦米分布进行扭曲变换、引入复杂的非线性失真以及模拟离散特征的量化处理。具体细节见“后处理”部分。
通过这一生成过程,我们在每次模型训练时创建了一个庞大的语料库,其中包含约1亿个合成数据集,每个数据集都有独特的因果结构、特征类型和功能特性。
定性分析
我们首先分析TabPFN在简单问题上的表现,以此建立直观认识并厘清各种数据集特征的影响。由于回归问题更容易可视化,我们在定性分析中聚焦于这类问题。在图3a中,我们将TabPFN与一系列不同的标准预测器进行比较,所有方法均采用默认设置。
线性(岭)回归自然只能对线性函数进行建模,这使得其预测简单且具有可解释性,但在许多简单函数问题上会遭遇严重失败。多层感知器(MLPs) 在具有高度非平滑模式的数据集上表现较差。这在阶梯函数的情况中尤为明显。相比之下,TabPFN无需额外调整就能对平滑或非平滑的任意一种函数类型进行建模。尽管TabPFN是一种神经网络,但它对阶梯函数也能实现较好的近似。以CatBoost为代表的基于树的方法,只能拟合分段常数函数。虽然这会导致近似误差和不直观的预测,但可避免严重失败。
与所有基线模型相比,TabPFN的主要优势在于其内在的无额外成本的不确定性建模能力。 传统回归方法输出单个实值预测,而TabPFN返回一个目标分布,从而捕捉预测的不确定性。 TabPFN的这些不确定性建模能力不仅限于简单分布,还能处理复杂的多模态分布。图3b通过对双缝实验中不同狭缝距离和宽度下到达探测器屏幕的光的密度进行建模,展示了这一点 。在这个经典实验中,光子穿过两条狭缝,由于光的波状干涉行为而产生多模态强度图案。TabPFN仅需一次前向传播就能预测这些复杂图案,仅需1.2秒。相比之下,像CatBoost这样的传统方法需要在不同分位数上训练多个分位数模型,并从这些预测中重建分布。即使针对此任务专门调整了CatBoost,与TabPFN相比,其预测结果仍明显较差,见图3b。使用默认设置时,CatBoost需要169.3秒,且结果更差。从定性角度看,我们观察到TabPFN与CatBoost相比,TabPFN在预测极低密度时更为准确,且产生的伪影更少。
图3 | TabPFN和一组基线模型在简单函数上的表现。在所有图表中,我们用橙色表示真实值,蓝色表示模型预测值。a,每一列代表一个不同的简单函数,每个函数都有一个特征(沿x轴)和一个目标值(沿y轴)。TabPFN能够对多种a. 不同的函数进行建模,包括含有噪声的函数。b. TabPFN无需额外调整就能对输出的分布进行建模,这一点通过在观测1000个光子的位置后,预测双缝实验中的光强模式得以例证。这张图展示了TabPFN和一组基线模型在简单函数上的表现对比,具体解读如下:
图a部分
每一列代表一个不同的简单函数,这些函数都只有一个特征(沿x轴)和一个目标值(沿y轴)。图中展示了6种不同的函数:
- Homoscedastic noise:具有同方差噪声的函数。
- Heteroscedastic noise:具有异方差噪声的函数。
每一行代表不同的模型预测结果:
- True Function:真实函数值,用橙色点表示。
- TabPFN:TabPFN模型的预测结果,蓝色曲线。
- CatBoost:CatBoost模型的预测结果,蓝色点。
从图中可以看出,TabPFN能够较好地拟合各种复杂函数,包括含有噪声的函数。
图b部分
展示了TabPFN和CatBoost在双缝实验中预测光强模式的能力。图中对比了真实函数(True function)、TabPFN的预测结果以及CatBoost(quantile)的预测结果。
- True function:真实的光强分布,用橙色表示。
- TabPFN:TabPFN模型预测的光强分布,用蓝色表示。
- **CatBoost (quantile)**:CatBoost模型预测的光强分布,用蓝色表示。
图中展示了不同狭缝宽度(Slit width)和狭缝间距(Slit separation)下的光强分布预测结果。TabPFN能够直接对输出的分布进行建模,而CatBoost则需要通过分位数模型进行重建。结果显示,TabPFN在预测光强模式方面表现更优。
定量分析
我们在两个数据集集合上对TabPFN进行了定量评估:AutoML基准测试36和OpenML - CTR2337。这些基准测试包含了各种来自现实世界的表格型数据集,它们是根据复杂性、相关性和领域多样性精心挑选的。从这些基准测试中,我们使用了29个分类数据集和28个回归数据集,这些数据集最多有10,000个样本、500个特征和10个类别。我们还进一步评估了参考文献14、15中的其他基准测试套件,以及来自表格型游乐场系列(Tabular Playground Series)的五个Kaggle竞赛数据集。
我们将TabPFN与最先进的基线模型进行了比较,包括基于树的方法(随机森林、XGBoost(XGB)、CatBoost、LightGBM8)、线性模型、支持向量机(SVM)和多层感知机(MLP)。
评估指标方面,分类任务采用受试者工作特征曲线下面积(ROC AUC,即“一对其余” 模式下的曲线下面积)与准确率;回归任务采用决定系数()与负均方根误差(negative RMSE)。各数据集的分数均进行了归一化处理,1.0代表相对于所有基线模型的最佳性能,0.0则代表最差性能。
对于每个数据集和方法,我们使用不同的随机种子和训练 - 测试划分(90%用于训练,10%用于测试)进行10次重复实验。我们采用随机搜索结合五折交叉验证来调整超参数,时间预算范围从30秒到4小时不等。所有方法均使用8个CPU核心进行评估,TabPFN还额外使用了一块消费级GPU(RTX 2080 Ti;其他方法无法从中受益,详见扩展数据图2d)。TabPFN预先使用8块NVIDIA RTX 2080 GPU进行了为期2周的训练,这样在处理所有新数据集时,只需一次前向传播就能实现上下文学习(ICL)。这些适度的计算需求使得学术实验室也能够开展类似的研究。具体细节,请参考“详细评估协议”部分。
与最先进基线模型的比较
图4a展示了TabPFN与经过调优及默认配置的XGBoost、CatBoost和随机森林相比,具有强大的默认性能。在分类任务中,在默认设置下,TabPFN在归一化的ROC AUC指标上比表现最佳的默认基线模型CatBoost高出0.187(分别为0.939和0.752),在调优设置下高出0.13(分别为0.952和0.822)。在回归任务中,在默认设置下,TabPFN在归一化的RMSE指标上比CatBoost高出0.051(分别为0.923和0.872),在调优设置下高出0.093(分别为0.968和0.875)。在图4b中,我们展示了每个数据集的比较情况。尽管在某些数据集上CatBoost的表现优于TabPFN,但在大多数数据集上TabPFN更胜一筹。
图4c展示了随着在超参数搜索上花费更多时间,TabPFN和基线模型的性能如何提升。TabPFN的默认设置在分类任务中平均耗时2.8秒,在回归任务中平均耗时4.8秒,其性能超过了所有基线模型,即便对基线模型进行长达4小时的调优——TabPFN在分类和回归任务中的加速比分别达到5140倍和3000倍。我们在扩展数据表1和2中展示了更多指标的比较情况。
如扩展数据图2所示,与我们的主要基准测试类似,在参考文献14、15的基准测试中,TabPFN的表现也大幅超过所有基线模型。参考文献14的基准测试尤其值得关注,因为此前研究发现基于树的方法在该基准测试中表现出色。此外,我们在扩展数据表6中展示了,在最新完成的表格型游乐场系列中所有五个训练样本少于10000的Kaggle竞赛中,默认的TabPFN表现优于默认的CatBoost。
图4 | TabPFN在我们测试基准上的对比,这些基准包含多达10,000个样本和500个特征的数据集。在使用所有基线进行汇总之前,每个数据集的性能都进行了归一化处理;区间代表95%置信区间。Wilcoxon P指的是双侧Wilcoxon符号秩检验的P值54。a,TabPFN的默认版本及调优版本与我们的基线模型的平均性能。所有方法分别针对受试者工作特征曲线下面积(ROC AUC)或均方根误差(RMSE)进行了调优,因此降低了次要指标的代表性。LGBM即LightGBM;MLP即多层感知器;SVM即支持向量机。 RF代表随机森林(Random Forest);CB代表CatBoost;XGB代表XGBoost;Lin在分类任务中代表逻辑回归,在回归任务中代表岭回归。右侧的图表展示了对所考虑的最强基线模型的放大分析。b,TabPFN与其最强基线模型CatBoost的逐数据集比较。每个点是一个数据集上的平均得分。c,所考虑方法的超参数调优的影响。x轴表示使用该算法进行拟合和预测所需的平均时间。
评估多样的数据属性
在图5a和5b中,我们展示了TabPFN对于传统上基于神经网络的方法难以处理的数据集特征的鲁棒性。
图5a对TabPFN在各种数据集类型上的性能进行了分析。首先,我们添加无信息特征(从原始数据集中随机打乱顺序的特征)以及异常值(以2%的概率将每个单元格与0到异常值因子之间的随机数相乘)。结果表明,TabPFN对无信息特征和异常值具有很强的鲁棒性,而这对于神经网络通常是难点,多层感知器(MLP)基线模型的表现就体现了这一点。其次,尽管减少样本或特征会损害所有方法的性能,但即便样本量减半,TabPFN的表现仍与使用全部样本的次优方法相当。
在图5b中,我们将测试数据集划分为多个子组,并对每个子组进行分析。我们根据数据集中分类特征的存在情况、缺失值、样本数量以及特征数量来创建子组。样本数量和特征数量子组的划分方式是,使每个组包含三分之一的数据集。我们可以看到,相较于其他方法,这些特征都不会对TabPFN的性能产生显著影响。不过,我们要指出,这些结果不能作为TabPFN在超过此处所考虑的10000个样本和500个特征规模时仍能良好扩展的证据。我们在扩展数据图1中展示了另外四项消融实验。
图5 | 跨数据集的稳健性以及与调优集成模型的性能比较。a,修改后数据集的比较。我们可以看到,与基线模型相比,TabPFN对这些修改并不更脆弱。我们还发现,仅提供一半的训练样本时,TabPFN就能重现CatBoost(默认设置)的准确率。在此,我们对每个数据集的分数进行归一化处理(在一个实验的所有修改中共享一种归一化方式),以避免负异常值。b,我们根据数据特征拆分测试数据集,并分析每个子组的性能。c,分类性能。左图,TabPFN(PHE)相对于AutoGluon的胜率(排除一次平局情况);右图,每种方法在调优过程中的ROC AUC分数随时间的变化,第一个标记代表非集成方法的默认配置。d,回归性能,呈现方式与c相同,但使用RMSE指标。区间表示95%置信区间,Wilcoxon P指的是双侧Wilcoxon符号秩检验的P值54 。与调优集成方法的比较
我们将TabPFN的性能与AutoGluon 1.0(参考文献40)进行比较,AutoGluon 1.0将包括我们的基线模型在内的各种机器学习模型组合成一个堆叠集成模型,对其超参数进行调优,然后使用事后集成(PHE)生成最终预测结果。因此,与单个基线模型相比,它代表了一类不同的方法。
为了评估TabPFN是否也能通过调优集成方法得到改进,我们引入了TabPFN(PHE)。TabPFN(PHE)仅自动将TabPFN模型与PHE相结合,并使用我们搜索空间中的随机组合来调整其超参数。我们在“TabPFN(PHE)”部分详细介绍这种方法。
图5c - d比较了TabPFN、TabPFN(PHE)、AutoGluon和CatBoost的性能。对于TabPFN(PHE)和AutoGluon,我们从300秒的最小调优预算开始,因为否则AutoGluon无法可靠地返回结果。在仅2.8秒内,TabPFN(默认设置)在分类任务中的表现就超过了AutoGluon,即使AutoGluon有长达4小时的调优时间,加速比达到5140倍。TabPFN(PHE)进一步提升了性能,平均归一化ROC AUC得分为0.971,相比之下,TabPFN(默认设置)为0.939,AutoGluon为0.914。
对于回归任务,超参数调优更为重要。在此,TabPFN(PHE)在300秒的最小调优预算后,性能超过了AutoGluon(允许调优4小时),加速比达到48倍。
具备可解释性的基础模型
除了强大的预测性能,TabPFN还展现出基础模型的关键能力,如数据生成、密度估计、学习可复用嵌入以及微调。我们通过在德国信用数据集(该数据集包含信用风险信息)和基于表格表示对手写数字进行分类的mfeat - factors数据集上开展概念验证实验,来展示这些能力。
如图6a所示,TabPFN可以估计数值特征的概率密度函数以及分类特征的概率质量函数。通过计算样本密度,能够进行异常检测,以识别诸如欺诈、设备故障、医疗紧急情况或低质量数据等问题。
如图6b所示,TabPFN还能够合成模拟现实世界数据集特征的新表格数据样本。这使得数据增强或隐私保护数据共享等应用成为可能。
TabPFN的架构可产生有意义的特征表示,这些表示可重用于诸如数据插补与聚类下游任务。 我们从图6c的mfeat - factors数据集中提取并可视化学习到的嵌入,结果显示,与前两个主成分上的原始数据相比,类别区分度有所提高。
此外,我们展示了TabPFN通过在相关数据集上进行微调来提升性能的能力。与基于树的方法不同,TabPFN的神经网络架构允许针对特定数据集类别进行微调。我们使用正弦曲线数据集开展概念验证实验,在微调数据和测试数据之间设置不同的偏移量。图6d展示了一个微调结果示例。我们通过50次运行进行分析(扩展数据图4),结果表明,即使微调任务和测试任务的标签差异显著,TabPFN仍能成功迁移知识,且随着分布变得更加相似,其性能会有所提升。例如,这可以让我们针对一系列医学研究数据集进行微调,以获得一个用于医学诊断任务的改进通用模型。详情请参考“基础模型能力”部分。
最后,我们开发了一种方法,能够轻松解读TabPFN的预测结果。在高风险领域部署模型时,可解释性对于建立信任和问责机制至关重要。我们支持通过SHAP(Shapley加性解释)来计算特征重要性,这是一种基于博弈论的预测解释方法。SHAP值代表每个特征对模型输出的贡献。扩展数据图3比较了逻辑回归、CatBoost和TabPFN的特征重要性及影响。TabPFN在学习简单且可解释的特征关系的同时,实现了高精度。相比之下,逻辑回归具有可解释性,但准确性较低;而CatBoost虽然准确,但由于复杂且不光滑的决策边界,在定性层面的可解释性较差。
图6 | TabPFN作为表格型基础模型的应用展示。 a、b,在德国信用数据集上,我们进行数据密度估计(a)以及生成新的合成样本(b)。c,我们展示在手写数字数据集(mfeat - factors)上,所学习到的嵌入是每个样本的有效表示,不同类别形成不同的簇。d,我们展示针对特定任务集对TabPFN进行微调。在包含各种正弦曲线的数据集上进行微调(上图)后,我们发现该模型在另一个正弦曲线数据集上能做出更准确的预测。结论
TabPFN代表了表格数据建模领域的一项重大变革,它借助上下文学习(ICL)自主发现了一种高效算法,在样本量多达10,000且特征数量达500的数据集上,性能超越了传统的人工设计方法。这种向基于合成数据训练的基础模型的转变,为各领域的表格数据分析开辟了新的可能性。
未来潜在的研究方向包括扩展至更大规模的数据集、处理数据漂移问题 、探究在相关表格任务中的微调能力 ,以及理解我们所采用方法的理论基础。未来的工作还可以探索创建专门的先验知识,以处理诸如时
间序列、多模态数据等数据类型,或者处理诸如心电图(ECG)、神经影像数据及遗传数据等特殊模态。随着表格数据建模领域的不断发展,我们相信像TabPFN这样的基础模型将在助力研究人员方面发挥关键作用。为了推动TabPFN的广泛应用,我们在“用户指南”部分讨论了如何有效地使用它。
如果您对真实世界研究/临床因果估计方法/生信分析/影像组学人工智能算法感兴趣可以通过下方的微信加我的交流群
助教微信-程老师
助教微信-金老师欢迎关注我的视频号-每周定期直播免费文献分享会
扫一扫,添加我的视频号欢迎关注我的小红书

欢迎关注我的B站账号-
公开课及文献分享视频会更新至此
我的B站