2025年4月,飞桨框架迎来重大更新,发布3.0正式版,为开发者带来了一系列新的功能体验升级。其中神经网络编译器 CINN(Compiler Infrastructure for Neural Networks)作为飞桨框架3.0版本中重要的新特性之一,为开发者在深度学习模型性能优化上提供了“低成本高回报”的新选择。我们在 PaddleX 开发套件里选取了超过60个模型进行测试,使用 CINN 编译器后超60%模型有显著性能提升,提升范围集中在10%~40%之间,重点模型相比 PyTorch 开启编译优化后的版本也有一定性能优势。特别是在科学计算场景,使用飞桨CINN编译器性能提升效果更加显著,在 Modulus 系列模型上,飞桨使用CINN编译器相比 Pytorch 求解速度平均提升115%。本文我们将分享 CINN 编译器开发落地过程中遇到的一些技术问题,以及我们对应的思考和解决方案,希望能对读者有所启发,也欢迎感兴趣的读者与我们进行沟通探讨,共同进步。深度学习编译器并不算一个大家所熟知的领域,所以我们先对深度学习编译器的概念及其性能优化的原理进行简单介绍。
什么是深度学习编译器?深度学习编译器是一种专门为深度学习模型优化和部署而设计的工具,用来提高模型的计算效率、降低内存占用、加速训练推理过程。其功能是将高层次的深度学习模型转换为低层次的、高效的、底层硬件可执行的代码,简单来说,就是帮用户自动生成高效的硬件计算 Kernel,这里我们用一个具体的例子来理解:以常见的 RMSNorm 计算为例,其计算公式为:
class RMSNorm(paddle.nn.Layer):
def __init__(self):
super().__init__()
self.variance_epsilon = 1e-6
self.size = 768
self.weight = paddle.create_parameter(
shape=[self.size], dtype=paddle.get_default_dtype(),
default_initializer=nn.initializer.Constant(1.0))
def forward(self, x):
variance = x.pow(2).mean(-1, keepdim=True)
x = paddle.rsqrt(variance + self.variance_epsilon) * x
return x * self.weight
这段代码在运行时会分别执行 mean、pow、rsqrt、+、*等对应的计算 Kernel,由于存在大量的访存读写操作, 导致性能较差。如果使用深度学习编译器进行优化,编译器就可以将这些计算操作合并,自动生成一个融合的计算 Kernel(如CUDA Kernel),省去原先大量的访存读写,性能得到显著提升(在 A100 GPU 环境中,上述子图使用编译器优化可取得 3 倍左右的性能提升)。当然,这个融合的计算 Kernel 也可以由开发人员来实现,只不过这项工作有一定的门槛和投入成本:需要开发者熟悉硬件优化开发(如CUDA),还要花费时间调试、验证代码的性能和精度。相比之下,使用深度学习编译器时,开发者仅需要在上述代码中添加一行代码即可。以上述代码为例,我们仅需在 forward 函数上添加@paddle.jit.to_static(backend="CINN")装饰器代码,运行时将自动使用 CINN 编译器进行优化。
class RMSNorm(paddle.nn.Layer):
def __init__(self):
super().__init__()
self.variance_epsilon = 1e-6
self.size = 768
self.weight = paddle.create_parameter(
shape=[self.size], dtype=paddle.get_default_dtype(),
default_initializer=nn.initializer.Constant(1.0))
@paddle.jit.to_static(backend="CINN")
def forward(self, x):
variance = x.pow(2).mean(-1, keepdim=True)
x = paddle.rsqrt(variance + self.variance_epsilon) * x
return x * self.weight
注:nightly 版本中已将 CINN 置为默认 backend, 使用的装饰器代码可简化为@paddle.jit.to_static深度学习编译器能够通过一系列优化策略,自动化在硬件上完成计算优化,以达到充分利用硬件资源,提升模型性能的目的。其中的优化策略主要可以分为以下两个层面:- 图级优化(Graph-level Optimization):通过分析整个神经网络的计算图,进行算子融合(Operator Fusion)、常量折叠、死代码消除等优化,避免不必要的计算和数据拷贝,减少中间结果的访存开销。
- 算子级优化(Operator-level Optimization):针对特定硬件自动搜索最优的算子实现参数,如循环分块大小、并行策略等,并根据硬件特性和指令(如向量化访存、Tensor Core 等)生成高效的底层代码。
https://www.paddlepaddle.org.cn/documentation/docs/zh/guides/paddle_v3_features/cinn_cn.html
如何保证编译器的正确性
模型结果的正确性是深度学习框架的基石,如何保证编译器生成的 Kernel 代码正确,且精度误差满足模型要求,是深度学习编译器开发中的关键问题。以下是我们在 CINN 编译器开发过程中,针对此问题的一些思考过程和解决方案。神经网络模型中的算子可分为计算密集型(如卷积、矩阵乘法)和访存密集型(如元素级操作、数据重排)两大类。深度学习编译器虽可优化两类算子,但合并优化会显著增加问题复杂度,导致状态空间过大而难以求解。鉴于此,我们有必要对问题的规模加以缩减,以在可控范围内实现性能提升。具体而言,计算密集型算子已有高性能计算库(如BLAS、cuDNN)支持,编译器在此基础上提升性能的难度较大。而访存密集型算子在高性能库中的优化相对不足,且在深度学习模型中占比较高,对整体性能影响显著。因此,我们决定优先优化访存密集型算子,并进一步聚焦于其中的归约(Reduce)操作。- 元素级操作:对张量每个元素执行相同数学运算(如加法、乘法、指数、对数),因需逐个处理大量数据,内存访问频繁,是典型的访存密集型操作。
- 数据重排和索引操作:包括 reshape、transpose、slice 等,用于调整数据形状、布局及局部拷贝。虽然不涉及复杂数学计算,但数据搬移和内存访问模式的不规则性可能导致性能下降。
- 归约操作(Reduce):对张量的一个或多个维度上的元素进行累积计算(如求和、最大值、最小值)。这类操作常见且对模型性能影响显著,但优化难度较高,常涉及内存数据同步问题。
- 其他操作:例如填充、排序、散列等,虽然不如前三类常见,但在特定应用中可能至关重要。
鉴于归约操作(Reduce)的性能调优和精度保证难度较大,我们将工作重点集中于此,
通过限定问题范围,追求专而精,避免大而全,在系统层面降低了引入问题的概率。在此目标下,我们设计了以 Reduce 为核心的算子融合及 Kernel 优化体系,在大量模型上取得了显著性能提升,同时有效保障了正确性。当然,限定问题范围并不意味着万事大吉。在具体方案设计和落地过程中,仍会遇到诸多问题和困难。但通过提前聚焦和限定范围,我们已有效避免了大量潜在问题。接下来,我们将分享一个实际开发中遇到的精度问题及其解决过程。在保障计算结果数值稳定性方面,我们投入了大量的精力开展设计工作,涵盖减少处理流程中的随机性因素、优化 Reduce 计算策略等多个维度。在本节内容中,我们将聚焦于 Norm 类计算过程中遇到的一个典型精度问题展开详细介绍。在模型精度验证环节,我们留意到,诸如 MobileNet 等采用了 BatchNorm 层的模型,在经过编译器优化处理之后,出现了 loss 值为 nan 的异常情况。经过深入排查与定位,最终确定问题根源在于 BatchNorm 层中的方差计算存在错误。在深度学习领域,方差计算是归一化(Normalization)算子的基础,涉及 LayerNorm、BatchNorm 等算子,其数值稳定性直接影响模型训练效果。传统方差计算公式为:μ = E(X),Var(X) = E[(X - μ)²]由该公式可以得出一种朴素的 Two-Pass 算法,即先遍历一次输入得到均值E(X),再遍历一次输入得到方差。该方法具有极高的精度,然而由于其需要遍历两次输入,具有更大的访存成本。对于 LayerNorm、BatchNorm 等归一化维度较大的算子,Two-Pass 算法可能至多增加50%的用时,严重影响性能。为了解决 Two-Pass 算法访存两次的问题,我们可以对方差公式进行变换,得到如下形式:由此可以在一次遍历中同时计算E(X)和E(X²),节约了一趟访存成本;这就是所谓的One-Pass算法。One-Pass 算法虽然性能较高,但实际计算中会由于大数相减产生较大的精度损失,甚至因为精度波动过大,出现 E(X²) - E(X)² 结果为负数的情况,导致后续计算开平方得到nan值。大数相减导致的精度损失,本质上源于计算机浮点表示法“有效数字”有限的特性。当两个相近的大数相减时,其差值的有效数字位数相对于大数本身也会很小,因此,差值计算的结果精度远低于大数本身。另外,计算 E(X²) 和 E(X) 的过程都涉及归约操作,归约操作本身也有不可避免的误差;当 E(X²) 和 E(X) 的归约误差恰好往不同方向偏移,达到一定阈值时,就可能出现 E(X²) < E(X)² 这样违反数学理论的情况。为了解决上述问题,Welford 算法提供了一种只需单趟遍历且保持数值稳定性的方差计算方法。其基本递推公式如下:Mₖ = Mₖ₋₁ + (xₖ - Mₖ₋₁) / kSₖ = Sₖ₋₁ + (xₖ - Mₖ₋₁) * (xₖ - Mₖ)Welford 算法能保持数据稳定的关键在于,其通过维护累计均值Mₖ,在每一步减去差值,避免大数相减,确保中间变量数值稳定。下图展示了在 float 数据类型下,One-Pass 与 Welford 的方差计算的误差量,以 double 类型的 Two-Pass 为金标准。输入数据为带有一定整体偏移量(bias)的正态分布数据。我们在测试中发现,One-Pass 的误差主要和数据集的 bias 有关,而与数据集的大小关系不大。Welfold 算法可以很好地保持方差计算的精度,在任何情况下误差都小于1e-6,确保了数值稳定性。One-Pass与Welford的方差计算误差统计Welford 算法从定义上可以被视为一个归约运算,我们在实现时也希望复用 Reduce 算子的处理流程。然而,Welford 算法的复杂性也对实现带来了挑战:Welford 算法需考虑串行与并行归约:Welford 算法的基本递推公式只给出了串行归约的定义。然而,对于 GPU 并行计算的场景,必须同时考虑串行归约与并行归约,才能实现高性能的计算。因此,需要针对并行计算改造递推公式,实现并行版本的归约操作。针对此问题,我们根据 GPU 进行归约计算的特点,将 Welford 算法的计算过程分为3步:- 第一步:串行归约,由每个线程归约一部分数据,根据 Welford 基本递推公式计算即可。
- 第二步:并行归约,将各个线程的结果汇总起来,需要使用专用的并行归约公式,如下所示:
假设两个线程已经分别归约了 m 和 n 个元素,分别得到了 Mₘ、Sₘ 和 Mₙ、Sₙ,则中间状态合并的公式为:Mₘ₊ₙ = (Mₘ * m + Mₙ * n) / (m + n)Sₘ₊ₙ = Sₘ + Sₙ + (Mₘ - Mₙ)² * m * n / (m + n)- 第三步:输出结果,即计算方差的值,根据公式 Sₙ / (N - 1) 计算即可。
在实现 Welford 算法后,我们发现其性能与 One-Pass 有较大的差距。经过分析发现,这主要是因为递推公式中的除法在硬件上具有较高的执行开销。上述递推公式中有2处除法(在GPU上,除法是通过先对除数求倒数,再乘以被除数来实现,因此这里主要关注除数),它们分别为:其中,k从1开始递增,直到处理完当前线程负责的所有数据;m和n的大小则无固定范围,因为每个线程归约的元素个数不确定。考虑到串行归约在整个计算过程中占大部分用时,我们重点对k的倒数计算进行优化。我们尝试了如下三种方法:- 对 1 / k 进行打表:即在全局内存上分配一个数组 rcp,其中 rcp[k] = 1 / k。该方法的优点是无需进行运算,效率极高;缺点是数组大小需要随 k 而增长,在不知道数据量的情况下,难以事先分配一个足够大的数组。
- 每个Warp共享计算结果:即每32个线程首先分别计算 [1 / k, 1 / (k + 1), ..., 1 / (k + 31)] 的值,然后通过 Shared Memory 进行共享,在接下来的32次归约中都不需要再重新计算。该方法的优点是实现了动态处理任意大小的 k 值,缺点是增加了 Shared Memory 占用和同步原语开销。
- 使用近似倒数指令:CUDA 提供了一条 rcp.approx.f32 指令,比一般的除法快很多;由于k是正整数,经过验证,该指令在k<内的相对误差不超过1e-7。该方法的优点是没有额外的存储和同步开销,缺点有一定精度损失;考虑到方差计算的结果一般用于归一化算子,会被加上一个1e-5的eps再开根号,1e-7级别的误差的影响实际上非常有限。
我们以 shape=[128, 256, 56, 56] 的 BatchNorm 为例,对比了几种算法与除法优化的用时,如下图所示:经过评估,我们选择了近似倒数指令的优化,因为其实现难度与额外开销均较小,且精度误差在可接受范围内。最终,我们实现了精度误差小于1e-6且性能损失相对 One-Pass 不超过1%的 Welford 算法。深度学习编译器生成的 Kernel 代码之所以高效,是由于添加了大量的性能优化策略,如果缺少这些优化策略,那生成的 Kernel 可能会“慢的要死”。下面我们以两个具体的性能优化策略作为例子,来展现编译器优化策略的实现过程,这些优化点在手工开发高性能 Kernel 时同样具有参考价值。在 GPU 编程中,大规模数据归约(Reduce)操作的并行化是一大挑战。GPU 虽天生支持多线程并行,但归约操作本质上是串行的,导致算法与硬件结构存在冲突。一个简单方案是将数据分块,每个线程处理一块,并通过同步指令汇总。然而,GPU 的同步能力有限,传统方法仅支持Block 内(最多1024线程)同步,无法跨 Block 同步。即便 Block 间可通过全局内存传递数据,但由于无法确定 Block 完成状态,数据可靠性存疑。因此,为确保同步正确,传统方法只能用一个 Block 进行归约,仅利用 GPU 计算能力的一小部分(通常只用到数十到上百个流处理器单元SM中的一个),极大限制了归约操作的并行性。为解决并行度问题,层次化(Hierarchical)归约方法应用而生。该方法需调用两次 Reduce Kernel。第一次将数据分成多个部分(通常为SM数量的2到8倍),分配更多 Block 进行归约,充分发挥 GPU 的并行能力。第一次归约后,剩余数据量通常较小(KB级别),此时再调用一个 Reduce Kernel进行归约,由于数据量小,对并行度要求降低。层次化归约本质上是利用了 GPU 在两次 Kernel 调用实例之间强制同步的能力,实现了多个 Block 的同步,从而大幅提升并行度。下图展示了传统的归约和层次化归约的区别。然而,层次化归约虽然适用于手写 Kernel,但对编译器场景并不友好,原因主要有三点:- 算子融合处理:编译器进行算子融合时,Reduce 算子前后通常融合多个算子。若对 Reduce 算子分层,前后融合算子可能被分割,存在数据依赖时会给编译器融合策略带来巨大挑战。
- 动态维度适配:动态维度场景下,编译器编译时无法预知数据维度,难以确定是否应用层次化归约。对小规模数据,层次化归约可能适得其反,错误使用会导致性能下降。
- 性能提升受限:从性能角度看,调用两个 Kernel 会增加硬件调度开销。而编译器优势在于算子融合能力,若拆分为两个 Kernel,则无法充分展现其性能优势。
受算子库启发,我们发现可以通过原子计数器(Atomic Counter)将层次化归约的两个 Kernel 合并为一个,在不增加 Kernel 数量的情况下,既能利用层次化归约的高并行度,又保持编译器的简洁性。这种方法基本上遵循层次化归约的步骤,但无需通过两个 Kernel 同步。具体为:设定初始值为0的原子计数器,每个 Block 先对其负责的数据归约,并将结果写入全局内存的中间结果区域,然后计数器执行加1操作。由于原子操作的特性,计数器加1时会获取计数器的当前值,当某 Block 发现计数器的值等于 Block 总数时,表明它是最后一个完成归约的 Block。此时所有 Block 结果已写入全局内存,该 Block 可以完成最终的归约操作。
我们将这种通过原子计数器进行全局归约的方法称为“Grid Reduce”,因其归约层次高于 Block 级,在整个 Grid 级归约。Grid Reduce在架构兼容性和实现灵活性方面具有显著优势,能够在对CINN系统进行最小改动的同时满足性能需求,具体体现在以下几个方面:- Grid Reduce 无需拆分融合算子,可视为 CINN 后端独立优化行为,无需修改现有融合策略,避免策略变动影响架构稳定性。
- Grid Reduce 与传统 Block Reduce 相似,仅在 Block Reduce 基础上增加一层,而非重构算法,可在现有处理流程中扩展,可以在很大程度上复用 Block Reduce 逻辑。
- Grid Reduce 能灵活处理动态维度,作为 Block Reduce 后的额外层次,可根据需要启用或关闭,处理小规模数据时关闭即可避免对性能的影响。
Grid Reduce对Reduce相关融合算子的性能提升显著,下图展示了一些典型算子的性能提升情况。可以发现,Grid Reduce在全归约场景和Reduce维较大、非Reduce维较小的场景有较大的性能提升。Grid Reduce 与 传统 Reduce 性能对比虽然Grid Reduce给CINN带来了显著的性能提升,但使用一段时间后,我们发现现有方案的一些局限性:- 融合操作受限:基于原子计数器的方案引入强约束,仅最后一个完成归约的 Block 能执行收尾操作,无法将数据广播到所有 Block。因其他 Block 无法得知归约结束时间,也就无法获取最终结果,导致当前方案仅支持 Reduce 后接 Elementwise 融合,不支持 Reduce 后接 Broadcast 融合,限制了适用范围。
- 存在额外开销:原子计数器在每次执行 Kernel 前,需通过 cudaMemset 调用重置为 0,而 cudaMemset 本质是 Kernel 调用,这实际上仍涉及两次 Kernel 调用,增加了开销。
- 配置调优困难:Grid Reduce 在运行时配置上有问题。通过分配多个 Block 实现并行归约,增加 Block 数虽能提高并行性,但也会增加同步成本和显存使用,无法无限增加。而且,Block 数的选择受输入数据维度大小限制,需在特定区间(通常是某个数的倍数)内选择。因此,找到最合适的 Block 数是影响 Grid Reduce 性能的关键。
为解决前两个问题,我们使用了 Cooperative Groups 提供的原生全局 Block 同步功能。通过使用 cooperative_groups::this_grid().sync()方法,可以直接对所有 Block 进行同步,替代原子计数器。这不仅消除了原子计数器分配空间的需求,还允许每个 Block 在同步后访问最终结果,从而可以支持 Reduce 后接 Broadcast 操作。然而,Cooperative Groups 对 Block 数量有限制,要求所有 Block 同时在线,不能超过SM数与每个SM可同时运行Block数的乘积,这对 Block 分配策略提出了新的要求。针对问题3和 Cooperative Groups 的 Block 数量限制,我们提出了“Thread-Block 联合分配”的运行时配置优化策略,通过减少每个 Block 的线程数以增加可分配的 Block 数,实现更高的硬件资源利用率。以多 Batch 的 Reduce 操作 [64, 131072] => [64] 为例,假设使用V100 GPU(80个SM)。因 Batch 数(64)的限制,Block 数必须是64的整数倍。若每个Block用满1024个线程,则 Block 数不能超过SM个数(80),因此只能分配64个 Block,SM利用率为64/80=80%,未达完全利用。我们发现,通过减少每个 Block 的线程数,可以提升Block数上限。例如,将每个 Block 设置为512个线程,Block 数上限可变为SM数的2倍(160);若设置为256个线程,Block 数上限进一步增至SM数的4倍(320),由于320是64的倍数,此时恰好可以完全分配64×5=320个 Block,实现100%的SM利用率。基于此思路,我们引入了 Block 大小缩放系数F,尝试将 Block 大小缩放1~4倍,以选择SM利用率最高的 Block 数。理论上,该方法可确保任意 Batch 大小的 Reduce 操作实现80%以上的SM利用率,最大化 Grid Reduce 的性能。
在CINN神经网络编译器的调度优化进程中,内存访问安全与计算效率的平衡始终是核心挑战。为确保内存访问的绝对安全性,普遍采用int64类型进行全流程索引计算,然而这种保守策略在GPU硬件架构存在性能瓶颈——int64整数运算的延迟较int32类型显著增加,直接的制约最终生成计算内核的执行效率。因此我们希望安全且正确的将部分 int64 运算转换为 int32,以优化最终生成 kernel 的性能。是否发生越界是能否使用 int32 进行计算的核心,因此最朴素的想法是我们在进行每一次运算时进行类型及越界检查。如果操作数为 int64 类型,并且检查操作数及结果均未越界,则安全转换为 int32 运算。以i + 5为例,其中i范围为[0,32),类型转换流程图如下:上述方法存在着明显问题:每一次计算时都需要进行2次数据类型检查,且需要做3次越界检查。针对这个问题我们可以规定出访存索引部分。前期所有计算均采用 int64 保存,最后通过一次性遍历索引部分表达式进行类型转换,由于处理的都是下标索引部分(均为整数运算),因此无需类型检查。至于能否直接获取是否越界的信息,来避免多次越界检查呢?我们可以观察生成 Kernel 中的循环范围,下面通过两段伪代码展示:z = Add(x, y)加法运算 Kernel 如下,我们明显可以看出,最大访存范围即输出元素总数
for(int i = 0; i 32; ++i){
z[i] = x[i] + y[i];
}
z = Slice(x, start=5, len=8)切片运算 Kernel 如下,我们明显可以看出,最大的访存范围小于输入元素总数
for(int i = 0; i 8; ++i){
z[i] = x[i + INT32_MAX]
}
这里可以观察到 Kernel 内的访存索引范围依赖于输入输出 Tensor。因此这里我们可以直接一次性的判断输入输出 Tensor 的元素总数量(element_size)是否超过 INT32_MAX,来判断 Kernel 中访存索引是否存在越界情况。对未越界情况,表达式内所有元素均可以转化为 Int32 表示。经过我们越界检查升级优化后,我们的类型转换流程图如下:由于动态 shape 的引入,输入输出 Tensor shape 中动态维度可能使用符号表示(S0、S1),这些可变的符号导致编译优化时无法获取准确的元素总数用来判断是否发生越界,为了解决这个问题,我们可以简单的升级原有的化简逻辑:在存在动态 shape 的情况下,分别编译生成 int64、int32 两个版本的代码,并使用含有动态符号的元素总数作为分支判断的条件。在运行时选择执行相应的分支即可,最终的类型转换流程图如下:
通过在大量模型上进行验证,最终版本的 int64 至 int32 类型转换优化在部分计算 Kernel 上性能提升效果显著,模型层面整体性能平均提升超过 5%。结语
本文主要介绍了 CINN 编译器在模型性能加速方面取得的一些成果,以及我们在设计与开发落地过程中碰到的几个典型问题,和相应的应对思路与解决办法。在这一过程中,我们可谓“摸爬滚打”,积累了不少经验,也踩了不少“坑”。由于篇幅有限,难以将所有经历一一呈现,这里先分享几个我们认为对读者可能有借鉴意义的问题解决过程。总结起来,解决问题的核心思路其实都较为朴素,属于稍加思索就能想到的方向。然而,在具体实施层面,却蕴含着大量细致入微的设计考量。深度学习编译器的设计和优化是一个复杂而长期的过程,需要不断迭代和创新。通过聚焦关键算子(如Reduce)的性能调优和正确性保证,结合实际的开发经验和挑战,我们逐步构建起一个高效、可靠、灵活的编译体系。未来,随着深度学习技术的不断发展和硬件平台的持续演进,深度学习编译器将扮演更加重要的角色。我们期待在这个充满挑战和机遇的领域中不断探索和前行。倘若读者对这些内容感兴趣,欢迎随时与我们交流探讨!
直播课程:6月16日(周一)19:00,技术解析加代码实战,一线研发大佬为大家详细解析飞桨新一代框架3.0的 CINN 编译器,快速上手深度学习模型低成本性能优化“利器”!测评征集:飞桨框架3.0正式版现已全面开放,诚邀广大用户体验使用!在技术网站发布本人真实的测评报告/使用tips/实际场景应用实例等经验帖,并提交到官方(详情请见直播课程及官方社群),通过验收的高质量测评可获得最高千元激励金!立即体验:访问飞桨官网:CINN 神经网络编译器(下方链接),开启您的性能优化之旅!https://www.paddlepaddle.org.cn/documentation/docs/zh/guides/paddle_v3_features/cinn_cn.html