Py学习  »  机器学习算法

ICLR 2024 | DeepZero:首个零阶优化深度学习框架发布!

极市平台 • 2 月前 • 92 次点击  
↑ 点击蓝字 关注极市平台
作者丨机器之心
来源丨机器之心
编辑丨极市平台

极市导读

 

如何扩展零阶优化使其可以训练深度学习模型? >>加入极市CV技术交流群,走在计算机视觉的最前沿

今天介绍一篇密歇根州立大学 (Michigan State University) 和劳伦斯・利弗莫尔国家实验室(Lawrence Livermore National Laboratory)的一篇关于零阶优化深度学习框架的文章 “DeepZero: Scaling up Zeroth-Order Optimization for Deep Model Training”,本文被 ICLR 2024 接收,代码已开源。

论文地址:https://arxiv.org/abs/2310.02025

项目地址:https://www.optml-group.com/posts/deepzero_iclr24

1. 背景

零阶(Zeroth-Order, ZO)优化已成为解决机器学习(Machine Learning)问题的热门技术,特别是在一阶(First-Order, FO)信息难以或无法获得的情况下:

  • 物理学和化学等学科:机器学习模型可能与复杂的模拟器或实验相互作用,其中底层系统是不可求导的。

  • 黑盒学习场景:当深度学习(Deep Learning)模型与第三方 API 集成时,如针对黑盒深度学习模型的对抗性攻击和防御,以及语言模型服务的黑盒提示学习。

  • 硬件限制:用于计算一阶梯度的原理性反向传播(backpropagation)机制在硬件系统上实现深度学习模型时可能不受支持。

然而,目前零阶优化的可扩展性仍然是一个未解决的问题:其使用主要限于相对较小规模的机器学习问题,如样本级的对抗性攻击生成。随着问题维度的增加,传统零阶方法的准确性和效率会下降。这是因为基于零阶有限差分的梯度估计是一阶梯度的有偏估算,且在高维空间中偏差更加明显。这些挑战激发了本文讨论的核心问题:如何扩展零阶优化使其可以训练深度学习模型?

2. 零阶梯度估算:RGE 还是 CGE?

零阶优化器仅通过提交输入和接收相应的函数值与目标函数进行交互。主要有两种梯度估算方法:坐标梯度估算(Coordinate Gradient Estimation, CGE)和随机梯度估算(Random Gradient Estimation, RGE),如下所示:

其中 表示对优化变量 (例如, 神经网络的模型参数)的一阶梯度的估算。

在 (RGE) 中, 表示随机扰动向量, 例如, 从标准高斯分布中抽取; 是扰动大小 (又称平滑参数) ; 是用于获得有限差分的随机方向数。

在 (CGE) 中, 表示标准基向量, 提供了 在对应坐标的偏导数的有限差分估计。

与 CGE 相比,RGE 具有可以减少函数评估次数的灵活性。尽管查询效率高,但 RGE 在从头开始训练深度模型时是否能提供令人满意的准确性仍不确定。为此,我们进行了调查,其中我们使用 RGE 和 CGE 对不同大小的小型卷积神经网络(CNN)在 CIFAR-10 上进行了训练。如下图所示,CGE 可以实现与一阶优化训练相当的测试精度,并显著优于 RGE,同时也比 RGE 具有更高的时间效率。

基于 CGE 在准确性和计算效率方面相对于 RGE 的优势,我们选择 CGE 作为首选的零阶梯度估计器。然而,CGE 的查询复杂性仍然是一个瓶颈,因为它随模型大小增加而扩大。

3. 零阶深度学习框架:DeepZero

据我们所知,之前的工作没有展示出 ZO 优化在训练深度神经网络(DNN)时不会显著降低性能的有效性。为了克服这一障碍,我们开发了 DeepZero,一种原理性零阶优化深度学习框架,可以将零阶优化扩展到从头开始的神经网络训练。

a) 零阶模型修剪(ZO-GraSP):一个随机初始化的密集神经网络往往包含一个高质量的稀疏子网络。然而,大多数有效的修剪方法都包含模型训练作为中间步骤。因此,它们不适合通过零阶优化找到稀疏性。为了解决上述挑战,我们受到了无需训练的修剪方法的启发,称为初始化修剪。在这类方法中,梯度信号保留(GraSP)被选用,它是一种通过随机初始化网络的梯度流识别神经网络的稀疏性先验的方法。

b) 稀疏梯度:为了保留训练密集模型的准确性优势,在 CGE 中我们结合了梯度稀疏性而不是权重稀疏性。这确保了我们在权重空间中训练一个密集模型,而不是训练一个稀疏模型。具体而言,我们利用 ZO-GraSP 确定可以捕获 DNN 可压缩性的逐层修剪比率(Layer-wise Pruning Ratios, LPRs),然后零阶优化可以通过不断迭代更新部分模型参数权重来训练密集模型,其中稀疏梯度比率由 LPRs 确定。

c) 特征重用:由于 CGE 逐元素扰动每个参数,它可以重用紧接扰动层之前的特征,并执行剩余的前向传播操作,而不是从输入层开始。从经验上看,带有特征重用的 CGE 在训练时间上可以实现 2 倍以上的减少。

d) 前传并行化:CGE 支持模型训练的并行化。这种解耦特性使得通过分布式机器扩展前向传播成为可能,从而显著提高零阶训练速度。

4. 实验分析

a) 图像分类

在 CIFAR-10 数据集上,我们将 DeepZero 训练的 ResNet-20 与两种通过一阶优化训练的变体进行比较:

(1)通过一阶优化训练获得的密集 ResNet-20

(2)通过一阶优化训练通过 FO-GraSP 获得的稀疏 ResNet-20

如下图所示,尽管在 80% 至 99% 的稀疏区间中,与(1)相比,使用 DeepZero 训练的模型仍存在准确度差距。这突出了 ZO 优化用于深度模型训练的挑战,其中高稀疏度的实现是被期望的。值得注意的是,在 90% 至 99% 的稀疏区间中,DeepZero 优于(2),展示了 DeepZero 中梯度稀疏性相对于权重稀疏性的优越性

b) 黑箱防御

当模型的所有者不愿意与防御者共享模型细节时,会出现黑盒防御问题。这对于使用一阶优化训练直接增强白盒模型的现有鲁棒性增强算法构成了挑战。为了克服这一挑战,ZO-AE-DS 被提出,在白盒去噪平滑(Denoised Smoothing, DS)防御操作和黑盒图像分类器之间引入了自动编码器(AutoEncoder, AE),以解决 ZO 训练的维度挑战。ZO-AE-DS 的缺点是难以扩展到高分辨率数据集(例如,ImageNet),因为使用 AE 会损害输入到黑盒图像分类器的图像的保真度,并导致较差的防御性能。 相比之下,DeepZero 可以直接学习与黑盒分类器集成的防御操作,无需自动编码器。如下表所示,就认证准确率(Certified Accuracy, CA)而言 DeepZero 在所有输入扰动半径上始终优于 ZO-AE-DS。

c) 与仿真耦合的深度学习

数值方法在提供物理信息模拟方面不可或缺,但它们自身存在挑战:离散化不可避免地产生数值误差。通过与迭代偏微分方程(Partial Differential Equation, PDE)求解器的循环交互训练纠正神经网络的可行性,被称为” 求解器环路”(Solver-in-the-Loop, SOL)。虽然现有工作专注于使用或开发可微模拟器进行模型训练,我们通过利用 DeepZero 扩展了 SOL,使其能够与不可微或黑盒模拟器一起使用。 下表比较了 ZO-SOL(通过 DeepZero 实现)与三种不同的可微方法的测试误差纠正性能:

(1) SRC(低保真模拟无误差纠正);

(2) NON(非交互式训练,使用预生成的低和高保真模拟数据在模拟循环外进行);

(3) FO-SOL(给定可微模拟器时,用于 SOL 的一阶训练)。

每个测试模拟的误差计算为与高保真模拟相比的纠正模拟的平均绝对误(MAE)。结果表明,通过 DeepZero 实现的 ZO-SOL 在只有基于查询的模拟器访问权限的情况下依然优于 SRC 和 NON,并缩小了与 FO-SOL 的性能差距。与 NON 相比,ZO-SOL 的表现突显了在有黑盒模拟器集成时的 ZO-SOL 前景。

5. 总结与讨论

这篇论文介绍了一个深度网络训练中零阶优化深度学习框架 (DeepZero)。具体来说,DeepZero 将坐标梯度估计、零阶模型修剪带来的梯度稀疏性、特征重用以及前传并行化整合到统一的训练流程中。利用这些创新,DeepZero 在包括图像分类任务和各种实际黑箱深度学习场景中表现出了效率和有效性。此外,还探索了 DeepZero 在其他领域的适用性,如涉及不可微物理实体的应用,以及在计算图和反向传播的计算不被支持的设备上进行训练。

作者介绍

张益萌,密歇根州⽴⼤学 OPTML 实验室, 计算机博士在读, 研究兴趣⽅向包括 Generative AI,  Multi-Modality,  Computer Vision,  Safe AI,  Efficient AI。

公众号后台回复“数据集”获取100+深度学习各方向资源整理

极市干货

技术专栏:多模态大模型超详细解读专栏搞懂Tranformer系列ICCV2023论文解读极市直播
极视角动态欢迎高校师生申报极视角2023年教育部产学合作协同育人项目新视野+智慧脑,「无人机+AI」成为道路智能巡检好帮手!
技术综述:四万字详解Neural ODE:用神经网络去刻画非离散的状态变化transformer的细节到底是怎么样的?Transformer 连环18问!

点击阅读原文进入CV社区

收获更多技术干货

Python社区是高质量的Python/Django开发社区
本文地址:http://www.python88.com/topic/166774
 
92 次点击