社区所有版块导航
Python
python开源   Django   Python   DjangoApp   pycharm  
DATA
docker   Elasticsearch  
aigc
aigc   chatgpt  
WEB开发
linux   MongoDB   Redis   DATABASE   NGINX   其他Web框架   web工具   zookeeper   tornado   NoSql   Bootstrap   js   peewee   Git   bottle   IE   MQ   Jquery  
机器学习
机器学习算法  
Python88.com
反馈   公告   社区推广  
产品
短视频  
印度
印度  
Py学习  »  机器学习算法

深度学习模型的多Loss调参技巧

小白学视觉 • 2 年前 • 208 次点击  

点击上方小白学视觉”,选择加"星标"或“置顶

重磅干货,第一时间送达

编辑: kaggle竞赛宝典

在多目标多任务训练的网络中,如果最终的loss为有时为多个loss的加权和,例如 loss = a*loss_x+b*loss_y+c*loss_y+... ,这个问题在微信视频号推荐比赛里也存在。任务需要对视频号的某个视频的收藏、点击头像、转发、点赞、评论、查看评论等进行多任务建模,也就产生了多个loss。

    这里介绍在这次实践过程中测试过的几个方法。

1.GradNorm

    GradNorm:ω(t+1)=ω(t)+λβ(t),该方法主要在对各损失函数权重的梯度进行处理,利用梯度更新公式动态更新权重ω。


2.Multi-Task Learning as Multi-Objective Optimization

    在处理多个loss时,引入Pareto用一次训练的方式将问题转化为求取Pareto最优解。有兴趣的可以看看原文:https://arxiv.org/pdf/1810.04650.pdf

3.Multi-task likelihoods

    最简单的多任务Loss的线性加权:

    对于分类任务,经常通过softmax函数产生概率向量中抽取样本来构造多任务的最大似然函数。

class MultiLossLayer(nn.Module):
    """
        计算自适应损失权重
        implementation of "Multi-Task Learning Using Uncertainty to Weigh Losses for Scene Geometry and Semantics"
    """

    def __init__(self, num_loss):
        super(MultiLossLayer, self).__init__()
        self.sigmas_dota = nn.Parameter(nn.init.uniform_(torch.empty(num_loss), a=0.2, b=1.0), requires_grad=True)

    def get_loss(self, loss_set):
        factor = torch.div(1.0, torch.mul(2.0, self.sigmas_dota))
        loss_part = torch.sum(torch.mul(factor, loss_set))
        regular_part = torch.sum(torch.log(self.sigmas_dota))
        loss = loss_part + regular_part
        return loss
4.玄学调参

    上面说了太多方法调参,来点手动的经验吧。最简单的方法如下:

  • 例如 loss = a*loss_x+b*loss_y+c*loss_y ,可以在a+b+c=1前提下,固定a,b,调整c,分别在2x、4x、6x等倍数去做尝试,最后相加为1;

  • 权重缩放,固定其中一个为1,利用power(m,n)去调整尝试;

  • Weight Uncertainty 利用 Gaussian approximation 方式直接修改loss ,并同时以梯度传播的方式来更新里面的两个参数。

参考资料
下载1:OpenCV-Contrib扩展模块中文版教程
在「小白学视觉」公众号后台回复:扩展模块中文教程即可下载全网第一份OpenCV扩展模块教程中文版,涵盖扩展模块安装、SFM算法、立体视觉、目标跟踪、生物视觉、超分辨率处理等二十多章内容。

下载2:Python视觉实战项目52讲
小白学视觉公众号后台回复:Python视觉实战项目 即可下载包括图像分割、口罩检测、车道线检测、车辆计数、添加眼线、车牌识别、字符识别、情绪检测、文本内容提取、面部识别等31个视觉实战项目,助力快速学校计算机视觉。

下载3:OpenCV实战项目20讲
小白学视觉公众号后台回复:OpenCV实战项目20讲即可下载含有20个基于OpenCV实现20个实战项目,实现OpenCV学习进阶。

交流群


欢迎加入公众号读者群一起和同行交流,目前有SLAM、三维视觉、传感器自动驾驶、计算摄影、检测、分割、识别、医学影像、GAN算法竞赛等微信群(以后会逐渐细分),请扫描下面微信号加群,备注:”昵称+学校/公司+研究方向“,例如:”张三 + 上海交大 + 视觉SLAM“。请按照格式备注,否则不予通过。添加成功后会根据研究方向邀请进入相关微信群。请勿在群内发送广告,否则会请出群,谢谢理解~


Python社区是高质量的Python/Django开发社区
本文地址:http://www.python88.com/topic/156089
 
208 次点击