Py学习  »  Git

ICML 2022 | 基于Logit归一化的置信度校准方法

PaperWeekly • 1 年前 • 187 次点击  

©作者 | 牟宇滔
单位 | 北京邮电大学
研究方向 | 自然语言理解

神经网络经常出现过度自信问题(overconfidence),表现为对 in-distribution 和 out-of-distribution 的数据都产生比较高的概率置信度,这是 OOD 检测的一个最基础的概念。本文提出一种 Logit Normalization 方法,在训练过程中将 Logit 的范数限定为一个常数,对传统的交叉熵损失进行修正,来缓解这种 overconfidence 问题。


论文标题:

Mitigating Neural Network Overconfidence with Logit Normalization

收录会议:

ICML 2022

论文链接:

https://arxiv.org/abs/2205.09310




研究动机


之前 OOD 检测的研究主要聚焦于设计一种比 maximum softmax probability (MSP) 更好的指标来度量 OOD 不确定性。但是很少研究关注神经网络过度自信的原因,以及如何缓解神经网络的过度自信。


作者认为这才是 OOD 检测的本质问题。作者首先做了一个分析,看神经网络训练过程中,Logit 范数的变化。可以发现即使大多数训练示例被分类到正确的标签,softmax 交叉熵损失也可以继续增加 Logit 向量的大小。因此,训练期间不断增长的幅度会导致过度自信问题。


▲ 训练过程中,IND和OOD的Logit Norms都在不断增大,这导致一个较大的置信度分数,不利于区分IND和OOD


为了缓解上述问题,直接的想法就是在训练过程中将 Logit 范数限定为一个常数,同时保持 Logit 向量方向不变(本文提出的 LogitNorm 方法)。




方法


2.1 分析为什么softmax交叉熵会影响overconfidence


假设神经网络的 pre-softmax 输出为 f(向量),不失一般性可以将这个向量分解成范数*单位向量形式。



有以下推论:


1. Logit 向量的每一个维度的元素同时扩大 s 倍,不会影响 softmax 分类结果;



2. 但是对 Logit 向量的每一个维度的元素同时扩大 s 倍,会影响 softmax 置信度分数,使得置信度分数变高。



换句话说,Logit 向量的大小增大,将造成更大的 softmax 置信度分数,但是不影响分类结果接下来分析对交叉熵训练目标的影响,如果训练目标采用下面这个交叉熵损失,那么训练过程中损失值不断减小,会使得 Logit 范数不断增大,由上述推论(2)可知会得到更高的 softmax 置信度分数。



2.2 提出方法


为了解决上述提到的 softmax 交叉熵鼓励网络产生范数较大的 Logit,导致过度自信,不利于区分 IND 和 OOD 的问题。作者的 idea 是将范数大小的影响和网络优化过程进行解耦,换句话说,就是在训练过程中保持 Logit 范数为一个常数值。



在现代神经网络的背景下执行约束优化并非易事,作者也提到简单用拉格朗日乘数法可能在这种深度神经网络上效果不好(具体证明比较复杂感兴趣可以看原文)。为了解决这个问题,我们将上述带约束的目标转换为可替代的端到端可训练的损失函数。 


具体地,作者在计算 softmax 交叉熵之前做了一个 Logit Normalization 操作,鼓励 Logit 向量的方向与它的 one-hot 标签一致,但是不优化 Logit 的大小(限定为一个常数)。特别地,理想情况希望 Logit 向量优化为一个常数大小的单位向量。LogitNorm 交叉熵的数学形式如下:



上述式子可等价为:


▲ 温度系数用来调控Logit大小


这样一来网络优化的其实是一个单位向量。这能让模型得到相对保守的预测。


▲ 可见传统交叉熵做IND预训练,得到得到概率值都比较高,而LogitNorm可以得到相对平滑的概率分布


▲ 棕色是OOD类别,可见LogitNorm可以得到更多有意义的信息用来区分IND和OOD

外, 作者还对这个 LogitNorm 交叉熵的下界进行了分析:



可以看出温度系数越大,损失函数下界也随之升高。较高的损失函数下界不利于优化。实验部分对温度系数进行了讨论。




实验


3.1 主实验


这里采用最基本的 MSP 分数做 OOD 检测,可以看到 LogitNorm 在不同数据集上提升比较明显。



下图进一步展示了 IND 和 OOD 数据的 softmax 置信度分数分布,可以看出传统交叉熵会导致大多数 OOD 样本被分配一个较高的置信度分数,而 LogitNorm 可以更好地区分 IND 和 OOD。



3.2 比较LogitNorm对不同OOD检测方法的提升



3.3 LogitNorm对不同网络结构适配性





总结


本文提出一个 LogitNorm 交叉熵损失,是对传统交叉熵的改进,主要解决神经网络过拟合和模型矫正问题。虽然实验都是在 OOD 检测任务上做的,但是这个方法应该是具有比较强的通用性的,适用于一些需要知识迁移的任务。本文通过理论推导结合实验分析的方式,逐步引出方法,这个行文思路值得借鉴(最近看了不少这个类型的工作,我比较喜欢这种风格)。




更多阅读



#投 稿 通 道#

 让你的文字被更多人看到 



如何才能让更多的优质内容以更短路径到达读者群体,缩短读者寻找优质内容的成本呢?答案就是:你不认识的人。


总有一些你不认识的人,知道你想知道的东西。PaperWeekly 或许可以成为一座桥梁,促使不同背景、不同方向的学者和学术灵感相互碰撞,迸发出更多的可能性。 


PaperWeekly 鼓励高校实验室或个人,在我们的平台上分享各类优质内容,可以是最新论文解读,也可以是学术热点剖析科研心得竞赛经验讲解等。我们的目的只有一个,让知识真正流动起来。


📝 稿件基本要求:

• 文章确系个人原创作品,未曾在公开渠道发表,如为其他平台已发表或待发表的文章,请明确标注 

• 稿件建议以 markdown 格式撰写,文中配图以附件形式发送,要求图片清晰,无版权问题

• PaperWeekly 尊重原作者署名权,并将为每篇被采纳的原创首发稿件,提供业内具有竞争力稿酬,具体依据文章阅读量和文章质量阶梯制结算


📬 投稿通道:

• 投稿邮箱:hr@paperweekly.site 

• 来稿请备注即时联系方式(微信),以便我们在稿件选用的第一时间联系作者

• 您也可以直接添加小编微信(pwbot02)快速投稿,备注:姓名-投稿


△长按添加PaperWeekly小编




🔍


现在,在「知乎」也能找到我们了

进入知乎首页搜索「PaperWeekly」

点击「关注」订阅我们的专栏吧


·

Python社区是高质量的Python/Django开发社区
本文地址:http://www.python88.com/topic/138765
 
187 次点击