社区所有版块导航
Python
python开源   Django   Python   DjangoApp   pycharm  
DATA
docker   Elasticsearch  
aigc
aigc   chatgpt  
WEB开发
linux   MongoDB   Redis   DATABASE   NGINX   其他Web框架   web工具   zookeeper   tornado   NoSql   Bootstrap   js   peewee   Git   bottle   IE   MQ   Jquery  
机器学习
机器学习算法  
Python88.com
反馈   公告   社区推广  
产品
短视频  
印度
印度  
Py学习  »  机器学习算法

【深度学习】谷歌用「钞」能力放大招:扩展到220亿参数的巨大视觉 Transformer

机器学习初学者 • 2 年前 • 406 次点击  

作者丨科技猛兽    编辑丨极市平台

导读

 

本文提出了迄今为止最大的密集视觉 ViT 模型 ViT- 22B,具有220亿参数。并发现超大 ViT 病态训练的不稳定性,这种不稳定性组织了模型尺度的进一步扩展。作者通过仔细设计模型,以较高的效率实现模型并行训练。 

本文目录

52 扩展到220亿参数的巨大视觉 Transformer
(来自谷歌,含 ViT 作者)
52 ViT-22B 论文解读
52.1 背景和动机
52.2 三句话概括 ViT-22B 模型的架构
52.3 ViT-22B 的实现方法:异步并行线性操作计算
52.4 数据集和超参
52.5 图像分类迁移性能
52.6 密集预测性能
52.7 ViT-22B 与人类感知的一致性

Transformer 是 Google 的团队在 2017 年提出的一种 NLP 经典模型,现在比较火热的 Bert 也是基于 Transformer。Transformer 模型使用了 Self-Attention 机制,不采用 RNN 的顺序结构,使得模型可以并行化训练,而且能够拥有全局信息。

本文提出了迄今为止最大的密集视觉 ViT 模型 ViT- 22B,具有220亿参数。并发现超大 ViT 病态训练的不稳定性,这种不稳定性组织了模型尺度的进一步扩展。作者通过仔细设计模型,以较高的效率实现模型并行训练。

52 扩展到220亿参数的巨大视觉 Transformer

论文名称:Scaling Vision Transformers to 22 Billion Parameters

论文地址:

https://arxiv.org/pdf/2302.05442.pdf

52.1 背景和动机

与自然语言处理类似,视觉预训练大模型也提高了在各种视觉任务的性能。这不仅依赖于更大的数据集,更大的可扩展的视觉架构,也同样依赖于新的训练策略。

但尽管如此,视觉模型从规模和效果上而言,远远落后于语言模型。迄今为止最大的密集视觉模型只有 4B 参数 ViT[1],而入门级的语言模型通常包含超过 10B [2][3]个参数以上,最大的密集语言模型有 540B 个参数。稀疏模型也展示出同样的趋势,其中语言模型的参数超过了一万亿[4],但最大的稀疏视觉模型只有约 15B[5]。

本文提出了迄今为止最大的密集视觉 ViT 模型 ViT-22B,并发现超大 ViT 病态训练的不稳定性,这种不稳定性组织了模型尺度的进一步扩展。作者通过仔细设计模型,以较高的效率实现模型并行训练。ViT-22B 的质量通过分类和下游任务实验进行评估,在这些任务中它达到或提高了当前的最先进水平。

通过多模态训练一个 text tower 来匹配视觉特征,ViT- 22B 在 ImageNet 上实现了 85.9% 的 zero-shot 精度。此外,该模型是一个很好的老师——用作蒸馏目标,作者训练了一个 ViT- B 学生模型,其在 ImageNet 上达到了88.6% 的 SOTA 精度。

除了分布的改进、ViT-22 的可靠性、不确定性估计和公平性都取得了提升。更重要的是,ViT-22 的特征更好地与人类的感知保持一致,实现了之前未见过的 87% 的形状偏差 (shape bias)。

52.2 三句话概括 ViT-22B 模型的架构

ViT-22B 是一种基于 Transformer 的模型,其架构的设计类似于原始 ViT,但包含以下3个修改,以提高效率和大规模训练的稳定性。如下图1所示,用三句话概括分别是:

  • 并行设计
  • Query 和 Key 的归一化
  • 省略 bias
图1:ViT-22B 模型的架构

第1句话,把 ViT 中的 Self-Attention 层换成了一个 Self-Attention 层加一个 MLP 层的并行设计,公式如下:

注意这里作者不是简单地将两个模块相加,而是使用了一个并行化技巧,即:用于 Self-attention 中的 Query,Key,Value 计算的矩阵乘法和 MLP 的第1个线性层被融合到一个单独的操作中;用于 Self-attention 中的输出投影和 MLP 的第2个线性层也被融合到一个单独的操作中。这种方法最初是由 PaLM[6] 提出的,该技术在不降低性能的情况下将最大模型的训练速度提高了 15%。

第2句话,Self-Attention 中 Query 和 Key 的计算过程添加了归一化。在以往 ViT 扩展的工作中,作者观察到了在训练了几千步之后,training loss 发散了的情况,这毫无疑问使得我们无法训练大型 ViT 模型,尤其是在 8B 参数左右的模型中观察到这种不稳定性。

出现这种现象的原因是:注意力矩阵的值异常 (近似于 one-hot,有的地方注意力很大,其他位置几乎为零) 引起的,这导致注意力权重的熵接近于零。为了解决这个问题,作者采用[7]的方法,将 LayerNorm 应用于 Self-Attention 中 Query 和 Key 的计算过程。具体可以写成:

式中, 是 query/key 的维度, 是 layer normalization, 分别是 Query 和 Key 的权重矩阵。对 8B 参数模型的影响如图2所示,其中归一化防止了注意力矩阵的值不受控的异常而导致的训练发散。

图2:Self-Attention 中 Query 和 Key 的计算过程添加归一化对 8B 参数模型的影响

第3句话,省略 bias。遵循 PaLM 的做法,从 QKV 投影中去除 bias,并使用没有 bias 项和 centering 的 Layer Normalization[8]。这在不影响训练质量的前提下提升了训练速度。但是,与 PaLM 不同的是,作者对所有 MLP 层使用了 bias 项,因为观察到质量得到了改善,且训练速度没有下降。

ViT-22B 使用 14×14 大小的 Patch,输入图像分辨率为 224×224。类似于原始 ViT,ViT-22B 也采用了可学习的位置编码。在对高分辨率图像 (不同数量的 Patch) 进行微调期间,也对预训练的位置编码执行二维插值。ViT-22B和 ViT-G,ViT-e 的超参数对比如下图3所示。

图3:ViT-22B和 ViT-G,ViT-e 的超参数对比

ViT-22B 的 model card 如下图所示。

图4:ViT-22B 的 model card

52.3 ViT-22B 的实现方法:异步并行线性操作计算

ViT-22B 基于 JAX 框架和 FLAX,Scenic 库,它同时利用了模型和数据的并行性。作者使用了 jax.xmap 这个 API,其对所有中间体的分片 (例如权重和激活) 以及芯片间通信提供了明确的控制。

这个高效的实现如下图5和6所示。该怎么理解这个过程呢?

图5:矩阵 A 在不同设备之间按行切分
图6:矩阵 A 在不同设备之间按列切分

为了说明这个过程, 这一段我们介绍下它的原理。考虑计算 这个矩阵乘法运算。

比如说我们有 台设备, 设 分别是 的第 个 Block , 那么 分别放在第 台设备上, 如上图5和图6所示。设 。那么现在等于说我们有 个矩阵 的分块。要把这么多块均匀地放在 台 devices 上面, 就 有两种放法。为了方便说明, 这里我们假设 吧。那么现在矩阵 就被分成了 的分块。

  • 第1种: 的每一行的块放在同一台 device 上, 即: 放在同一台上面, 对应图5。

  • 第2种: 的每一列的块放在同一台 device 上, 即: 放在同一台上面, 对应图6。

对于第1种情况, 计算 需要 的通信, 因为 的维度是 , 所以就需要 float 的通信。

对于第2种情况, 计算 需要 中间值的通信, 因为 中间值的维度是 , 所以就 需要 float 的通信。

时, 采用哪一种计算方式都可以。当 时, 比如 MLP 的输出层有 , 那么这个时候采用第2种情况的计算方法。

注意, 以上通信的过程与计算过程是异步的。即, 当计算一层时, 设备可以开始通信下一层的权重, 从而最大限度地减少通信开销。

作者将芯片组织成大小为 逻辑网格,其中 是数据轴的大小, 是模型轴的大小。然后, 对于 组中的每个组, 个设备获得相同 Batch 的图像, 每个设备只保留 的激活值, 并负责计算输出

52.4 数据集和超参

ViT-22B 在 JFT的一个版本上训练,训练集包含大约 4B 图像,并采用 Sigmoid 交叉熵损失以多标签分类方式使用所有分配的标签。

ViT-22B 把图片分为 14×14 大小的块,使用 65k 的 Batch Size 训练 177k steps (3 Epochs),初始学习率,reciprocal square-root learning rate schedule,以及 10k 的线性学习率 warmup,和 30k 的的线性学习率 cooldown。上游预训练的权重衰减 head 设为3.0,body 设为 0.03。

52.5 图像分类迁移性能

Linear Probing 实验结果

作者在 ImageNet 上使用了10个周期的动量 SGD,分辨率为 224px,使用 mild random cropping 和 horizontal flipping 作为唯一的数据增强策略,而没有进一步的正则化措施。

如下图7所示是 ViT-22B 的 Linear Probing 实验结果,虽然收益不高,但在这个尺度上仍有显著的改善。而且图7还说明,,像 ViT-22B 这样的大模型的 Linear Probing 可以接近或超过具有小模型的高分辨率完全微调的性能,而 Linear Probing 通常成本更低。

图7:ImageNet 上 ViT-22B 的 Linear Probing 实验结果

作者进一步在细粒度分类数据集 iNaturalist 2017 上测试线性可分性。iNaturalist 2017 有5,089个细粒度类别,属于13个大类。与 ImageNet 不同,不同类别的图像数量是不平衡的。概念的长尾分布对分类更具挑战性。作者将 ViT- 22B 与其他 ViT 变体进行比较,并测试了 224px 和 384px 的输入分辨率。结果如图8所示,可以观察到 ViT- 22B 明显优于其他变体,特别是在标准的 224px 输入分辨率下,这表明 ViT-22B 中大量的参数对于从图像中提取详细信息是有用的。

图8:iNaturalist 2017 上 ViT-22B 的 Linear Probing 实验结果

Out-of-distribution 实验结果

定义: 在分类任务中,给定测试图片,若模型在训练阶段模型见过或类似的图片,则能正确分类;但如果与训练集完全不相关,也会被强制判定为训练集类别中的一种,这种情况是不合理的。OOD 算法希望能判断的分布状况是否与训练集一致,若一致,则称为 in-distribution (ID),否则称为 out-of-distribution (OOD)。
使用场景举例: 在 MNIST 上训练的一个分类模型,然后,输入一张“马”的图片,会被归类为数字 0~9,这是错误的。此时,MNIST 数据集就是 in-distribution,相对于 ID 而言,“马”是 out-of-distribution。

作者构建了从 JFT 到 ImageNet 的标签映射,以及从 ImageNet 到不同分布外数据集的标签映射,即 ObjectNet ,ImageNet-v2,ImageNet-R 和 ImageNet-A。ImageNet-R 和 ImageNet-A 使用相同的 ImageNet 的200个标签子空间,而 ObjectNet 有313个类别,其中只考虑与 ImageNet 标签空间重叠的113个类别。那么这样以后,JFT 上面训练出的模型的预测结果就可以转化到其他数据集上面了,这就为模型的 Out-of-distribution 能力提供了一种验证手段。

Out-of-distribution 的评估一般是首先在一个较大的数据集上面 (比如 ImageNet) 做预训练,然后在 ImageNet-R ,ImageNet-A 等数据集上直接评估其性能。实验结果如下图9所示。作者做了两种实验,其一 (图9上半部分) 是首先将 ViT-22B 模型在 JFT 数据集上做预训练,然后在 ObjectNet ,ImageNet-v2,ImageNet-R 和 ImageNet-A 数据集上评估性能。其二 (图9下半部分) 是首先将 ViT-22B 模型在 JFT 数据集上做预训练,然后在 ImageNet 数据集上微调,最后在 ObjectNet ,ImageNet-v2,ImageNet-R 和 ImageNet-A 数据集上评估性能。

图9:OOD 实验结果。标注 "ema" 的模型使用 Polyak 平均进行微调

把图9中 ObjectNet 的实验结果画出来就是图10.

图10:OOD 中 ObjectNet 实验结果

从图9和10中可以得出结论:把模型做大可以增加 Out-of-distribution 的性能。这适用于只看过 JFT 图像的 ViT-22B 模型,以及在 ImageNet 上做了微调过之后的模型。在这两种情况下,ViT-22B 在更大的模型上都延续了 OOD 性能更好的趋势。即使 ImageNet 的性能饱和,但从图10中也可以看到 ObjectNet 上的精度从 ViT-e/14 到 ViT-22B 的显著提升。

52.6 密集预测性能

密集预测任务的迁移性能也是评价一个 Backbone 模型的关键因素。作者通过语义分割和单目估计任务评价 ViT-22B 捕获的几何和空间信息的质量。

语义分割任务使用 ViT-22B 作为骨干模型,UperNet 作为分类头,在 ADE20K, Pascal Context 和 Pascal VOC 三个数据集上进行评测。如下图11所示,作者将 ViT-22B 与 DeiT-III ,ViT-G 进行了比较,且只使用了一部分的训练数据。作者使用线性解码器和端到端微调。从图11中可以观察到,当只使用少量的数据时,ViT-22B 骨干更好。例如,当只对1200张图像 (即1/16) 的 ADE20K 训练数据进行微调时,ViT-22B 达到了 44.7 mIoU 的性能,比 DeiT-III Large 提高了8.6 mIoU,比 ViT-G 提高了 2.3 mIoU。当数据量较大时,ViT-G 和 ViT-22B 的性能趋于一致。

图11:ADE20K Fewshot 语义分割实验结果

单目估计任务使用 ViT-22B 作为骨干模型,Dense Prediction Transformer (DPT) 作为单目估计头,或者仅仅使用一个简单的线性解码器作为估计头。实验在 Waymo Open real-world driving 数据集上进行评测。如下图12所示,上半部分 (DPT解码器) 的结果可以观察到,与不同的主干相比,使用 ViT-22B 骨干网络得到了最好的性能。通过将 ViT-22B 骨干与 ViT-e 进行比较,发现扩展架构可以提高性能。使用线性解码器,可再次观察到使用 ViT-22B 骨干模型可获得最佳性能。DPT 和线性解码器之间的差距表明,虽然在 ViT 特征中保留了足够的几何信息,但只有一部分可被普通的线性解码器利用到。

图12:Waymo Open dataset 单目估计实验结果

在研究规模的影响时,除了下游任务性能之外,还有一些重要的方面需要考虑。比如一个主干网络的与人类感知的一致性。

52.7 ViT-22B 与人类感知的一致性

ViT-22B 分类决策与人类分类决策的一致性如何?通过这篇论文 "Partial success in closing the gap between human and machine vision" 的方法,作者评估了在 ImageNet 上以不同分辨率 (224,384,560) 微调的三个 ViT-22B 模型。在所有指标中,ViT-22B-224 具有最高的 OOD 稳健性 (图13 (a)), ViT-22B -384 与人类分类精度最接近(图13(b)), ViT-22B-560 具有最大的错误一致性 (即大多数类似人类的错误模式,图13(d))。ViT-22B 模型在视觉模型中有最高的 shape bias 记录:而大多数模型都有很强的 texture bias (20-30% 的 shape bias + 70-80% 的 texture bias),人类的 shape bias 为 96% + 4%,而 ViT-22B -384 达到了前所未见的 87% shape bias + 13% texture bias。总体而言,ViT-22B显著改善了人类视觉物体识别的对齐。

图13:ViT-22B 与人类感知的一致性

除此之外,作者还对 ViT-22B 的公平性,鲁棒性,可靠性和校准。进行了相关实验,详细细节读者可以参考原始论文。作者发现,随着模型尺寸的增加,会出现有利的特性。

总结

本文提出了 ViT-22B,目前最大的视觉 Transformer 模型,有220亿个参数。作者证明,通过对原始架构进行三点修改,可以实现出色的硬件利用率训练稳定性,从而产生一个在几个基准上实现 SOTA 的模型。作者的评估进一步表明,与现有模型相比,ViT-22B 在形状和纹理偏差方面更符合人类,并在公平性和稳健性方面具有优势。

参考:

基于深度模型Out of Distribution(OOD)基础技术路线研究:https://blog.csdn.net/Aqrose_666/article/details/124592372

参考

  1. ^PaLI: A Jointly-Scaled Multilingual Language-Image Model
  2. ^Exploring the Limits of Transfer Learning with a Unified Text-to-Text Transformer
  3. ^Unifying language learning paradigms
  4. ^Switch Transformers: Scaling to Trillion Parameter Models with Simple and Efficient Sparsity
  5. ^Scaling Vision with Sparse Mixture of Experts
  6. ^PaLM: Scaling language modeling with pathways
  7. ^Intriguing Properties of Transformer Training Instabilities
  8. ^Root Mean Square Layer Normalization

往期精彩回顾




Python社区是高质量的Python/Django开发社区
本文地址:http://www.python88.com/topic/152355
 
406 次点击