社区所有版块导航
Python
python开源   Django   Python   DjangoApp   pycharm  
DATA
docker   Elasticsearch  
aigc
aigc   chatgpt  
WEB开发
linux   MongoDB   Redis   DATABASE   NGINX   其他Web框架   web工具   zookeeper   tornado   NoSql   Bootstrap   js   peewee   Git   bottle   IE   MQ   Jquery  
机器学习
机器学习算法  
Python88.com
反馈   公告   社区推广  
产品
短视频  
印度
印度  
Py学习  »  机器学习算法

深度学习TabNet能否超越GBDT?

机器学习AI算法工程 • 3 年前 • 489 次点击  


向AI转型的程序员都关注了这个号👇👇👇

机器学习AI算法工程   公众号:datayx


随着深度神经网络的不断发展,DNN在图像、文本和语音等类型的数据上都有了广泛的应用,然而对于同样非常常见的一种数据——表格数据,DNN却似乎并没有取得像它在其他领域那么大的成功。参加过Kaggle等数据挖掘竞赛的同学应该都知道,对于采用表格数据的任务,基本都是决策树模型的主场,像XGBoost和LightGBM这类提升(Boosting)树模型已经成为了现在数据挖掘比赛中的标配。相比于DNN,这类树模型好处主要有:

  • 模型的决策流形(decision manifolds)是可以看成是超平面边界的,对于表格数据的效果很好

  • 可以根据决策树追溯推断过程,可解释性较好

  • 训练起来更快

而对于DNN,它的优势在于:

  • 类似于图像和文本,可以对表格数据进行编码(encode),从而得到一个能够表征表格数据的方法,这种表征学习(representation learning)可以用在很多地方

  • 可以减少对于特征工程(feature engineering)的依赖(相信打过比赛的同学都知道这有多重要)

  • 可以通过online learning的方式来更新模型,而树模型只能用整个数据集重新训练

然而对于传统的DNN,一味地堆叠网络层很容易导致模型过参数化(overparametrized),导致DNN在表格数据集上表现并不尽如人意。因此,如果能够设计这样一种DNN,它既吸收了树模型的长处,又继承了DNN的优点,那么这样的模型无疑是针对于表格数据的一大利器,而这次介绍的论文就巧妙地设计出了这样的模型——TabNet,它在保留DNN的end-to-end和representation learning特点的基础上,还拥有了树模型的可解释性和稀疏特征选择的优点,这使得它在具备DNN优点的同时,在表格数据任务上也可以和目前主流的树模型相媲美,接下来我们就开始具体介绍TabNet。

用DNN构造决策树

既然想要让DNN具有树模型的优点,那么我们首先需要解决的一个问题就是:如何构建一个与树模型具有相似决策流形的神经网络?下图是一个决策树流形的简单示例。






模型架构

为了理解起来比较容易,上面的那个神经网络构造得比较简单,作为一个加性模型它只有两步,Mask层是人为设置好的,特征计算用的也是一个简单的FC层,而接下来介绍的TabNet就对这些地方做了改进,它的基本结构如下所示。








  • 特征选择:Attentive transformer层可以根据上一个step的结果得到当前step的Mask矩阵,并尽量使得Mask矩阵是稀疏且不重复的。值得注意的一点是,不同样本的Mask向量可以不同,也就是说TabNet可以让不同的样本选择不同的特征(instance-wise),而这个特点是树模型所不具备的,对于XGBoost这类加性模型,一个step就是一棵树,而这棵决策树用到的特征是在所有样本上挑选出来的(例如通过计算信息增益),它没有办法做到instance-wise。

  • 特征计算:Feature transformer层实现了对于当前step步所选取特征的计算处理。还是类比于决策树,对于给定的一些特征,一棵决策树构造的是单个特征的大小关系的组合,也就是上面提到的决策流形,而之前那个简单神经网络就是通过一个FC层来模仿这个决策流形,但FC层只是构造了一组简单的线性关系,并没有考虑更加复杂的情况,因此TabNet通过更复杂的Feature transformer层来进行特征计算,个人感觉它的决策流形不一定和决策树的相似,在一些特征组合上它可能比决策树做得更好。

自监督学习

前面提到了DNN的一个好处就是可以进行表征学习,而TabNet就应用了自监督学习的方法,通过encoder-decoder框架来获得表格数据的representation,从而也有助于分类和回归任务,如下图所示:



简单来说,我们认为同一样本的不同特征之间是有关联的,因此自监督学习就是先人为mask掉一些feature,然后通过encoder-decoder模型来对mask掉的feature进行预测。我们认为通过这样的方式训练出来的encoder模型,可以有效地将表征样本的feature(可以理解为对数据进行了编码或压缩),这时再将encoder模型由于回归或分类任务,就能够事半功倍。自监督学习时的encoder模型就是上图中的模型,decoder模型如下所示:

这里的encoded representation就是encoder中没有经过FC层的加和向量,将它作为decoder的输入,decoder同样利用了Feature transformer层,只不过这次的目的是将representation向量重构为feature,然后类似地经过若干个step的加和,得到最后的重构feature。



实验

为了证明TabNet确实具有上文中提到的种种优点,这篇文章在不同的数据集上进行了各种类型的实验,这里只介绍一部分,其它实验以及具体实验细节可以看论文原文,写得也很详细。



这个模型的代码:

tensorflow版本的代码

https://github.com/google-research/google-research/tree/master/tabnet


  • pytorch版本的代码

  • https://github.com/microsoft/qlib/blob/main/qlib/contrib/model/pytorch_tabnet.py


  1. Instance-wise feature selection



2. 真实数据集

Forest Cover Type:这个数据集是一个分类任务——根据cartographic变量来对森林覆盖类型进行分类,实验的baseline采用了如XGBoost等目前主流的树模型、可以自动构造高阶特征的AutoInt、以及AutoML Tables这种用了神经网络结构搜索 (Neural Architecture Search)的强力模型(node hours的数量反映了模型的复杂性),对比结果如下:


Higgs Boson:这是一个物理领域的数据集,任务是将产生希格斯玻色子的信号与背景信号分辨开来,由于这个数据集很大,因此DNN比树模型的表现更好,下面是对比结果,其中Sparse evolutionary MLP应用了目前最好的evolutionary sparsification算法,能够有效减小原始MLP模型的大小,不过可以看出,和它大小相近的TabNet-S的性能也只是稍弱一点,这说明轻量级的TabNet表现依旧很好。


3. 可解释性





4. 自监督学习

前面已经提到了,自监督学习可以提高模型的小样本学习能力,还能加快模型的收敛速度。为了验证这一点,这里我们采用Higgs Boson数据集,其中用全部样本来做自监督学习(pre-training),而只用部分样本做监督学习(fine-tuning),该方法与直接全样本监督学习的对比结果如下所示:

从结果中可以看出,通过自监督学习进行预训练之后,模型的收敛速度明显更快,小样本学习的结果也变得更好。

总结

这篇论文提出的TabNet是一种针对于表格数据的神经网络,它通过类似于加性模型的顺序注意力机制(sequential attention mechanism)实现了instance-wise的特征选择,还通过encoder-decoder框架实现了自监督学习,从而将树模型的可解释性与DNN的表征能力很好地结合到了一起,相信这种兼具两者优点的模型将会成为数据挖掘竞赛中的一大利器,也对未来的研究提供了一个很好的思路。

参考资料

[1] TabNet: Attentive Interpretable Tabular Learning

https://arxiv.org/abs/1908.07442



机器学习算法AI大数据技术

 搜索公众号添加: datanlp

长按图片,识别二维码




阅读过本文的人还看了以下文章:


TensorFlow 2.0深度学习案例实战


基于40万表格数据集TableBank,用MaskRCNN做表格检测


《基于深度学习的自然语言处理》中/英PDF


Deep Learning 中文版初版-周志华团队


【全套视频课】最全的目标检测算法系列讲解,通俗易懂!


《美团机器学习实践》_美团算法团队.pdf


《深度学习入门:基于Python的理论与实现》高清中文PDF+源码


《深度学习:基于Keras的Python实践》PDF和代码


特征提取与图像处理(第二版).pdf


python就业班学习视频,从入门到实战项目


2019最新《PyTorch自然语言处理》英、中文版PDF+源码


《21个项目玩转深度学习:基于TensorFlow的实践详解》完整版PDF+附书代码


《深度学习之pytorch》pdf+附书源码


PyTorch深度学习快速实战入门《pytorch-handbook》


【下载】豆瓣评分8.1,《机器学习实战:基于Scikit-Learn和TensorFlow》


《Python数据分析与挖掘实战》PDF+完整源码


汽车行业完整知识图谱项目实战视频(全23课)


李沐大神开源《动手学深度学习》,加州伯克利深度学习(2019春)教材


笔记、代码清晰易懂!李航《统计学习方法》最新资源全套!


《神经网络与深度学习》最新2018版中英PDF+源码


将机器学习模型部署为REST API


FashionAI服装属性标签图像识别Top1-5方案分享


重要开源!CNN-RNN-CTC 实现手写汉字识别


yolo3 检测出图像中的不规则汉字


同样是机器学习算法工程师,你的面试为什么过不了?


前海征信大数据算法:风险概率预测


【Keras】完整实现‘交通标志’分类、‘票据’分类两个项目,让你掌握深度学习图像分类


VGG16迁移学习,实现医学图像识别分类工程项目


特征工程(一)


特征工程(二) :文本数据的展开、过滤和分块


特征工程(三):特征缩放,从词袋到 TF-IDF


特征工程(四): 类别特征


特征工程(五): PCA 降维


特征工程(六): 非线性特征提取和模型堆叠


特征工程(七):图像特征提取和深度学习


如何利用全新的决策树集成级联结构gcForest做特征工程并打分?


Machine Learning Yearning 中文翻译稿


蚂蚁金服2018秋招-算法工程师(共四面)通过


全球AI挑战-场景分类的比赛源码(多模型融合)


斯坦福CS230官方指南:CNN、RNN及使用技巧速查(打印收藏)


python+flask搭建CNN在线识别手写中文网站


中科院Kaggle全球文本匹配竞赛华人第1名团队-深度学习与特征工程



不断更新资源

深度学习、机器学习、数据分析、python

 搜索公众号添加: datayx  



Python社区是高质量的Python/Django开发社区
本文地址:http://www.python88.com/topic/134167
 
489 次点击