社区所有版块导航
Python
python开源   Django   Python   DjangoApp   pycharm  
DATA
docker   Elasticsearch  
aigc
aigc   chatgpt  
WEB开发
linux   MongoDB   Redis   DATABASE   NGINX   其他Web框架   web工具   zookeeper   tornado   NoSql   Bootstrap   js   peewee   Git   bottle   IE   MQ   Jquery  
机器学习
机器学习算法  
Python88.com
反馈   公告   社区推广  
产品
短视频  
印度
印度  
Py学习  »  机器学习算法

半监督深度学习训练和实现小Tricks

小白学视觉 • 3 年前 • 339 次点击  

点击上方小白学视觉”,选择加"星标"或“置顶

重磅干货,第一时间送达

来源:本文授权转载知乎作者糯米稻谷,https://zhuanlan.zhihu.com/p/100252944。未经作者许可,不得二次转载。 

临近毕业,忽然想实现一下一些比较经典的半监督深度学习算法。恰逢一周前有个师妹问我有没有关于半监督深度学习训练技巧的文章,我一想,好像没有,便想自己写一篇。

Talk is cheap, show you my code:

https://github.com/iBelieveCJM/Tricks-of-Semi-supervisedDeepLeanring-Pytorchgithub.com

实现了几个比较经典的半监督深度学习算法,包括 PseudoLabel 2013[1], PI model[2], Tempens[2], MeanTeacher[3], ICT[4] 和 MixMatch[5]。代码刚整理完,可能还有些小问题,请多见谅~(没想到整理代码比写代码还累...orz)

言归正传,本文旨在分享一些我在训练和实现半监督深度学习模型时所发现的一些小Tricks,并没有什么太多理论,只是一些经验,也没有太多实验去支持(毕竟也就写了一周,repo上的一些代码我自己都没完整跑过...orz),大家就看看,希望能对大家有所启发。

注:下面给出的所有模型的实验结果,几乎都在同一实验参数下得出,以方便公平比较,并不是按照论文给出的最佳参数进行训练的。若想复现论文结果,请自行调参。

1. 默认优化参数

第一个我想分享的是,默认优化参数,如下:

  • 优化器:SGD +lr =0.1+momentum =0.9;或 Adam +lr =1e-3
  • 学习率策略:Cosine 学习率退火;或 学习率 warmup
  • Batch中有标签和无标签数据比例:1:1,即在一个batch中,无标签样本量和有标签样本量相等

没有太多的惊喜成分,就是常用的优化配置,但在初始构建一个半监督深度学习模型的时候,也够用了。想要更多的提升,也可以从这个默认的配置进行寻参。

我通常直接使用 SGD+Cosine退火,使用Cosine退火而不是多步衰减是因为懒得调参,效果也相差不大。

当你的半监督模型需要稀疏地更新样本的时候,例如CCL[6],可以尝试 Adam 优化器,或许会有更好的效果。

很多情况下,例如样本量比较少(250个有标签样本的Cifar10实验时),我都建议尝试学习率warmup策略,可能会有奇效(e.g. PI model)。学习率warmup策略会使得模型稳定收敛到不错的性能。

2.两种数据输入方式

讲这个内容前,请容我介绍一些半监督深度学习代价的基本形式:

其中, 为有监督代价函数,通常为交叉熵代价。 为无监督代价函数,也是不同半监督深度学习模型的区别所在。

由于 是无监督代价函数,可以作用于有标签数据和无标签数据,因此无监督代价函数的输入出现了两种不同实现方法。在一个batch中,令 为输入 的数据。

  • 方式v1

中有标签数据和无标签数据按固定比例混合输入。即 ,其中 为batch中的有标签数据和无标签数据。

  • 方式v2:

,即 为整个数据集的随机采样,因此不能保证有标签数据和无标签数据的比例,甚至会出现没有有标签数据的情况。

从目前CIFAR-10实验情况来看,方式v2更好(并不保证orz):

Test error of CIFAR-10 with 4000 labeled samples

3.使用 epoch pseudo labels

Pseudo-Label 2013 提出,使用网络的预测最大概率的类别,作为无标签数据的伪标签。具体解释参考我前面的文章《半监督深度学习小结》,https://zhuanlan.zhihu.com/p/33196506)。

此前我理解为,使用当前的预测类别,作为当前无标签数据的伪标签,我称为 iteration pseudo labels,但效果不太理想。

后来,Tempens[2]源码的实现,是使用上一个 epoch的预测类别,作为当前epoch中无标签数据的伪标签,即 epoch pseudo labels。

CIFAR-10的实验结果对比如下:

Test error of CIFAR-10 with 4000 labeled samples

可以看到,使用 epoch pseudo labels 有非常明显的提升。因此建议大家实现基于伪标签的半监督深度模型时使用 epoch pseudo labels 而不是 iteration pseudo labels。

4. 使用 MixUp loss

最近,使用 MixUp[7]+consistency regularization 成为了新的 SOTA 方法,代表有 ICT和MixMatch。半监督中使用的 Mixup 数据增强形式如下:

即通过线性插值的方式产生新的数据。但是如何利用插值数据有4种方式,令模型为

其中,CE为交叉熵代价函数,MSE为均方误差代价函数。

CIFAR-10的实验结果对比如下:

Test error of CIFAR-10 with 4000 labeled samples

其中,ICTv1表示使用第一种数据输入方式,ICTv2表示使用第二种数据输入方式。

从实验结果上看到,L3代价项的效果是最好的,而且对代价进行插值的 L3 和 L4 明显比 L1 和L2 的效果好。

半监督深度学习从入门到放弃之路,我终于快走到结束了,希望此文能对想入门的同学有所帮助。

参考
1.Pseudo-Label : The Simple and Efficient Semi-Supervised Learning Method for Deep Neural Networks
2.Temporal Ensembling for Semi-Supervised Learning https://arxiv.org/pdf/1610.02242v1.pdf
3.Mean teachers are better role models https://arxiv.org/pdf/1703.01780.pdf
4.Interpolation Consistency Training for Semi-supervised Learning https://arxiv.org/pdf/1903.03825.pdf
5.MixMatch: A Holistic Approach to Semi-Supervised Learning https://arxiv.org/pdf/1905.02249v2.pdf
6.Certainty-Driven Consistency Loss for Semi-supervised Learning
7.MixUp: Beyond Empirical Risk Minimization https://arxiv.org/pdf/1710.09412.pdf

-END-

好消息!

小白学视觉知识星球

开始面向外开放啦👇👇👇




下载1:OpenCV-Contrib扩展模块中文版教程
在「小白学视觉」公众号后台回复:扩展模块中文教程即可下载全网第一份OpenCV扩展模块教程中文版,涵盖扩展模块安装、SFM算法、立体视觉、目标跟踪、生物视觉、超分辨率处理等二十多章内容。

下载2:Python视觉实战项目52讲
小白学视觉公众号后台回复:Python视觉实战项目即可下载包括图像分割、口罩检测、车道线检测、车辆计数、添加眼线、车牌识别、字符识别、情绪检测、文本内容提取、面部识别等31个视觉实战项目,助力快速学校计算机视觉。

下载3:OpenCV实战项目20讲
小白学视觉公众号后台回复:OpenCV实战项目20讲即可下载含有20个基于OpenCV实现20个实战项目,实现OpenCV学习进阶。

交流群


欢迎加入公众号读者群一起和同行交流,目前有SLAM、三维视觉、传感器自动驾驶、计算摄影、检测、分割、识别、医学影像、GAN算法竞赛等微信群(以后会逐渐细分),请扫描下面微信号加群,备注:”昵称+学校/公司+研究方向“,例如:”张三 + 上海交大 + 视觉SLAM“。请按照格式备注,否则不予通过。添加成功后会根据研究方向邀请进入相关微信群。请勿在群内发送广告,否则会请出群,谢谢理解~


Python社区是高质量的Python/Django开发社区
本文地址:http://www.python88.com/topic/146348
 
339 次点击