Py学习  »  机器学习算法

如何用深度学习来做检索:度量学习中关于排序损失函数的综述(1)

机器学习研究组订阅号 • 3 年前 • 264 次点击  

作者:Ahmed Taha

编译:ronghuaiyang

导读

一篇关于度量学习损失函数的综述,这是第一部分,对比损失和三元组损失。

检索网络对于搜索和索引是必不可少的。深度学习利用各种排名损失来学习一个对象的嵌入 —— 来自同一类的对象的嵌入比来自不同类的对象的嵌入更接近。本文比较了各种著名的排名损失的公式和应用。

深度学习的检索正式的说法为度量学习(ML)。在这个学习范式中,神经网络学习一个嵌 入—— 比如一个128维的向量。这样的嵌入量化了不同对象之间的相似性,如下图所示。学习后的嵌入可以进行搜索、最近邻检索、索引等。

用排序损失训练的深度网络,使搜索和索引成为可能

这个综述比较了各种损失的公式和应用。综述分为两部分。第一部分对对比损失和三元组损失进行了对比。第二部分将介绍N-pairs损失和Angular损失。

对比损失

最古老,最简单的排序损失。这种损失使相似点和不同点之间的欧氏距离分别达到最小和最大。相似的点和不同的点被分成正样本对和负样本对。下图给出了它的公式,使用了一对点的嵌入(x_i,x_j)。当(x_i,x_j)嵌入属于同一个类时,y=0。在这种情况下,第一项使欧几里得距离D(x_i,x_j)最小,而第二项是无效的,即等于零。当嵌入项(x_i,x_j)属于不同类别时,y=1,第二项使点之间的距离最大,而第一项为零。第二项中的max(0,m-D)确保不同的嵌入间隔一定的距离,即有限的距离。在训练过程中,这一margin确保了神经网络的梯度忽略大量的远(容易)的负样本对,而利用稀缺的近(难)的负样本对。

对比损失

尽管它很受欢迎,但在大多数检索任务(通常用作基线)中,这种对比性损失的表现很不起眼。大多数高级损失需要一个三元组(x_i,x_j,x_k),其中(x_i,x_j)属于同一类,(x_i,x_k)属于不同类。这种三元组样本在无监督学习中很难获得。因此,尽管对比损失在检索方面的表现不佳,但在无监督学习和自我监督学习文献中仍普遍使用。

三元组损失

最常见的排序损失是三元组损失。它解决了对比损失的一个重要限制。如果两个点是不同的,对比损失将两个点推向相反的方向。如果其中一个点已经位于集群的中心,那么这个解决方案就不是最优的。三元组损失使用三元组而不是样本对来解决这个限制。三元组(x_i,x_j,x_k)通常被称为(锚,正样本,负样本),即(a,p,n)。三元组损失将锚和正样本拉在一起,同时将锚和负样本推离彼此。

三元组损失

与对比损失类似,三元组损失也用到了margin。max和margin m确保不同的点在距离>m的时候不会产生损失。在人脸识别、行人重识别和特征嵌入等检索应用中,三元组损失通常优于对比损失。然而,对比损失在无监督学习中仍然占主导地位。因为很难从未标记的数据中抽取有意义的三元组。三元组损失对噪声数据很敏感,因此随机负采样会影响其性能。

三元组损失的性能很大程度上依赖于三元组采样策略。因此,存在大量的三元组损失的变体。这些变体采用相同的三元组损失函数,但是具有不同的三元组抽样策略。在原始的三元组损失中,从训练数据集中随机抽取三元组样本。随机抽样的收敛速度很慢。因此,大量的论文已经研究了困难样本挖掘来寻找有用的三元组和加速收敛。接下来的段落比较了两种著名的困难样本挖掘策略:困难采样策略和半困难采样策略。

在这两种策略中,每个训练小批包含K*P个随机抽样的训练样本,每个样本来自K个类,每个类有P个样本。例如,如果训练批的大小是B=32和P=4,那么批将包含来自K=8个不同类的样本,每个类P=4个实例。现在,每个锚都有(P-1=3)可能的正样本实例和(K-1)*P=28个可能的负样本实例。

在困难采样中,只使用最远的正样本和最近的负样本。在下一个图中,n_3是锚a最近的负样本。因此,假设p是最远的正样本,损失将使用三元组(a,p,n_3)计算。这种策略的收敛速度更快,因为在训练过程中,它利用了最难的样本。然而,训练超参数(例如,学习率和批大小)需要仔细调整,以避免模式坍塌。当所有的嵌入都相同,即f(x)=0时,就会发生模式坍塌。

三元组损失元组(锚,正样本,负样本)和margin m,hard,semi-hard和easy的负样本分别用红色、青色和橙色突出显示

为了避免困难样本的训练不稳定性,半困难将每个锚点与每个正样本点配对。在下一个图中,锚点(a)将与所有五个正样本配对。对于每一个正样本,将选择一个负样本,使其离正样本较远,但在禁止范围m内。因此,对(a, p_2)将利用橙色边框内的红色负样本。这种抽样策略对模型崩溃具有更强的鲁棒性,但收敛速度比困难样本挖掘策略慢。

困难样本采样通过选择最远的正样本和最近的负样本(a, p1, n)来提升嵌入的能力。半困难样本采样选择(a, p2, n)并避免任何n位于a和p之间的元组(a, p, n)。

三元组损失的采样策略在最近的文献中得到了大量的研究。需要一篇专门的文章来涵盖所有提出的变体。前面提到的两种策略都是Tensorflow库所支持的。大多数深度学习框架都提供了对比损失和三元组损失的api。后面会讲解另外两个公认的好的排序损失:N-pairs和Angular。请继续关注!


END

英文原文:https://ahmdtaha.medium.com/retrieval-with-deep-learning-a-ranking-loss-survey-part-1-8e88a6f8e091


想要了解更多资讯,请扫描下方二维码,关注机器学习研究会

                                          


转自:AI公园

Python社区是高质量的Python/Django开发社区
本文地址:http://www.python88.com/topic/105993
 
264 次点击