Py学习  »  机器学习算法

CUDA 优化!让深度学习变得更快!

3D视觉工坊 • 2 月前 • 82 次点击  

点击下方卡片,关注「3D视觉工坊」公众号
选择星标,干货第一时间送达

来源:3D视觉工坊

「3D视觉从入门到精通」知识星球(点开有惊喜) !星球内新增20多门3D视觉系统课程、入门环境配置教程、多场顶会直播、顶会论文最新解读、3D视觉算法源码、求职招聘等。想要入门3D视觉、做项目、搞科研,欢迎扫码加入

图片

本文由 @Simon V(https://github.com/simveit) 授权转载和翻译并发表到本公众号。原始地址为:https://veitner.bearblog.dev/making-rmsnorm-really-fast/

让RMSNorm变得更快

2025年4月18日

RMS Norm是一个在现代LLMs中常用的操作。给定一个向量,它的RMS Norm计算方式为,其中是权重,且。在这篇博文中,我们要计算矩阵中每一行的RMS Norm,其中,给定权重

顺序实现

检查我们的kernel的正确性需要一个基本的顺序实现作为参考。下面是我们使用的简单版本。

template <int numTokens, int hiddenDim>
voidlaunchRmsNormCpu(float *x, float *w, float eps, float *y){
float rms;
for (int token = 0; token < numTokens; token++) {
    rms = 0;
    for (int hidden = 0; hidden < hiddenDim; hidden++) {
      rms += x[token * hiddenDim + hidden] * x[token * hiddenDim + hidden];
    }
    rms = sqrt(rms / hiddenDim + eps);
    for (int hidden = 0; hidden < hiddenDim; hidden++) {
      y[token * hiddenDim + hidden] =
          x[token * hiddenDim + hidden] / rms * w[hidden];
    }
  }
}

如何并行化?

我们的并行化尝试非常简单。每个block处理一个token。如果block中的线程数小于隐藏维度的大小,每个线程就需要处理多个元素。然后我们执行一个简单的归约操作,计算RMS Norm并写入输出。如果你对归约操作不熟悉,请参考我之前关于归约的博文 。

Naive kernel

A naive solution in CUDA is as follows.

template <int hiddenDim, int threadsPerBlock>
__global__ voidrmsNormKernelNaive(float *x, float *w, float eps, float *y){
  __shared__ float squaredPerThread[threadsPerBlock];
  __shared__ float rms_;

constint tid = threadIdx.x;
constint bid = blockIdx.x;
float sum = 0.0f;

for (int i = tid; i < hiddenDim; i += threadsPerBlock) {
    float x_ = x[bid * hiddenDim + i];
    sum += x_ * x_;
  }
  squaredPerThread[tid] = sum;
  __syncthreads();

for (int activeThreads = threadsPerBlock / 2; activeThreads > 0;
       activeThreads >>= 1) {
    if (tid < activeThreads) {
      squaredPerThread[tid] += squaredPerThread[tid + activeThreads];
    }
    __syncthreads();
  }

if (tid == 0) {
    rms_ = rsqrtf(squaredPerThread[tid] / hiddenDim + eps);
  }
  __syncthreads();

for (int i = tid; i < hiddenDim; i += threadsPerBlock) {
    y[bid * hiddenDim + i] = x[bid * hiddenDim + i] * rms_ * w[i];
  }
}

template <int numTokens, int hiddenDim, int threadsPerBlock>
voidlaunchRmsNormNaive(float *x, float *w, float eps, float *y){
  rmsNormKernelNaive
      <<>>(x, w, eps, y);
}

x 跨内存访问一次,w 跨内存访问一次,y 跨内存访问一次。对于 numTokens = 1 << 18 和 hiddenDim = 1 << 12 的情况,w 的影响可以忽略不计,我们可以按如下方式计算带宽:

constsize_t size = numTokens * hiddenDim * sizeof(float);
size_t numCrossMemoryBound = 2 * size;
float latency = time / numRounds;
float bandwidth = (numCrossMemoryBound / latency) / 1e6;

上述kernel的结果如下:

Latency = 2.84878 ms
Bandwidth = 3015.3 GB/s
% of max = 91.3727 %

使用共享内存

正如我们在上面看到的,我们频繁地访问x中的元素。我们可以使用共享内存来加快内存访问。

template <int hiddenDim, int threadsPerBlock>
__global__ voidrmsNormKernelSmem(float *x, float *w, float eps, float *y){
  __shared__ float squaredPerThread[threadsPerBlock];
  __shared__ float xShared[hiddenDim];
  __shared__ float rms_;

constint tid = threadIdx.x;
constint bid = blockIdx.x;

float sum = 0.0f;

for (int i = tid; i < hiddenDim; i += threadsPerBlock) {
    int index = bid * hiddenDim + i;
    float x_ = x[index];
    xShared[i] = x_;
    sum += x_ * x_;
  }
  squaredPerThread[tid] = sum;
  __syncthreads();

for (int activeThreads = threadsPerBlock / 2; activeThreads > 0;
       activeThreads >>= 1) {
    if (tid < activeThreads) {
      squaredPerThread[tid] += squaredPerThread[tid + activeThreads];
    }
    __syncthreads();
  }

if (tid == 0) {
    rms_ = rsqrtf(squaredPerThread[tid] / hiddenDim + eps);
  }
  __syncthreads();

for (int i = tid; i < hiddenDim; i += threadsPerBlock) {
    float val = xShared[i] * rms_ * w[i];
    y[bid * hiddenDim + i] = val;
  }
}

template <int numTokens,  int hiddenDim, int threadsPerBlock>
voidlaunchRmsNormSmem(float *x, float *w, float eps, float *y){
  rmsNormKernelSmem
      <<>>(x, w, eps, y);
}

上述kernel的结果如下:

Latency = 2.82101 ms
Bandwidth = 3044.99 GB/s
% of max = 92.2723 %

使用warp

类似我们在前缀和操作中应用的技术,我们也可以这样做:

  • 在每个warp中进行归约
  • 使用一个warp归约这个数组以获得最终的归约结果。这个过程的代码如下:
#define WARP_SIZE 32

__device__ floatwarpReduce(float x){
float val = x;
for (int activeThreads = WARP_SIZE >> 1; activeThreads > 0;
       activeThreads >>= 1) {
    val += __shfl_down_sync(0xffffffff, val, activeThreads);
  }
return val;
}

template <int hiddenDim, int threadsPerBlock>
__global__ voidrmsNormKernelWarp(float *x, float *w, float eps, float *y){
  __shared__ float squaredPerThread[threadsPerBlock];
  __shared__ float xShared[hiddenDim];
  __shared__ float sumPerWarp[WARP_SIZE];
  __shared__ float rms_;

constint tid = threadIdx.x;
constint laneId = tid & 31;
constint warpId = tid >> 5;
constint warpsPerBlock = threadsPerBlock >> 5;

constint bid = blockIdx.x;
float sum = 0.0f;

for (int  i = tid; i < hiddenDim; i += threadsPerBlock) {
    float x_ = x[bid * hiddenDim + i];
    xShared[i] = x_;
    sum += x_ * x_;
  }
  squaredPerThread[tid] = sum;
  __syncthreads();

float warpSum = warpReduce(squaredPerThread[tid]);
if (laneId == 0) {
    sumPerWarp[warpId] = warpSum;
  }
  __syncthreads();

if (tid < WARP_SIZE) {
    sumPerWarp[tid] = warpReduce(tid < warpsPerBlock ? sumPerWarp[tid] : 0);
    if (tid == 0) {
      rms_ = rsqrtf(sumPerWarp[tid] / hiddenDim + eps);
    }
  }
  __syncthreads();

for (int i = tid; i < hiddenDim; i += threadsPerBlock) {
    y[bid * hiddenDim + i] = xShared[i] * rms_ * w[i];
  }
}

template <int numTokens, int hiddenDim, int threadsPerBlock>
voidlaunchRmsNormWarp(float *x, float *w, float eps, float *y){
  rmsNormKernelWarp
      <<>>(x, w, eps, y);
}

上述kernel的结果如下:

Latency = 2.82263 ms
Bandwidth = 3043.23 GB/s
% of max = 92.2192 %

最初我预计这个会更快,但事实并非如此。

向量化加载和存储

如果我们对上述kernel进行性能分析,可以看到内存加载和存储消耗了最多的指令。我们可以使用CUDA的float4数据类型来向量化加载和存储操作来优化这一点。

对于共享内存的方法,代码如下所示:

template <int hiddenDim, int threadsPerBlock>
__global__ voidrmsNormKernelSmemFloat4(float4 *x, float4 *w, float eps,
                                        float4 *y)
{
  __shared__ float squaredPerThread[threadsPerBlock];
  __shared__ float4 xShared[hiddenDim >> 2];
  __shared__ float rms_;

constint tid = threadIdx.x;
constint bid = blockIdx.x;

float sum = 0.0f;

for (int i = tid; i < hiddenDim >> 2; i += threadsPerBlock) {
    int index = bid * (hiddenDim >> 2) + i;
    float4 x_ = x[index];
    xShared[i] = x_;
    sum += (x_.x * x_.x) + (x_.y * x_.y) + (x_.z * x_.z) + (x_.w * x_.w);
  }
  squaredPerThread[tid] = sum;
  __syncthreads();

for ( int activeThreads = threadsPerBlock >> 1; activeThreads > 0;
       activeThreads >>= 1) {
    if (tid < activeThreads) {
      squaredPerThread[tid] += squaredPerThread[tid + activeThreads];
    }
    __syncthreads();
  }

if (tid == 0) {
    rms_ = rsqrtf(squaredPerThread[tid] / hiddenDim + eps);
  }
  __syncthreads();

for (int i = tid; i < hiddenDim >> 2; i += threadsPerBlock) {
    float4 w_ = w[i];
    float4 x_ = xShared[i];
    float4 val = make_float4(x_.x * rms_ * w_.x, x_.y * rms_ * w_.y,
                             x_.z * rms_ * w_.z, x_.w * rms_ * w_.w);
    y[bid * (hiddenDim >> 2) + i] = val;
  }
}

template <int numTokens, int hiddenDim, int threadsPerBlock>
voidlaunchRmsNormSmemFloat4(float *x, float *w, float eps, float *y){
  float4 *x_ = reinterpret_cast(x);
  float4 *w_ = reinterpret_cast(w);
  float4 *y_ = reinterpret_cast(y);
  rmsNormKernelSmemFloat4
      <<>>(x_, w_, eps, y_);
}

上述kernel的结果如下:

Latency = 2.80455 ms
Bandwidth = 3062.86 GB/s
% of max = 92.8139 %

类似地,我们也可以对warp kernel进行优化:

#define WARP_SIZE 32

__device__ floatwarpReduce(float x){
float val = x;
for (int activeThreads = WARP_SIZE >> 1; activeThreads > 0;
       activeThreads >>= 1) {
    val += __shfl_down_sync(0xffffffff, val, activeThreads);
  }
return val;
}

template <int hiddenDim, int threadsPerBlock>
__global__ voidrmsNormKernelWarpFloat4(float4 *x, float4 *w, float eps,
                                        float4 *y)
{
  __shared__ float squaredPerThread[threadsPerBlock];
  __shared__ float4 xShared[hiddenDim >> 2];
  __shared__ float sumPerWarp[WARP_SIZE];
  __shared__ float rms_;

constint tid = threadIdx.x;
constint laneId = tid & 31;
constint warpId = tid >> 5;
constint warpsPerBlock = threadsPerBlock >> 5;

constint bid = blockIdx.x;
float sum = 0.0f;

for (int i = tid; i < hiddenDim >> 2; i += threadsPerBlock) {
    int index = bid * (hiddenDim >> 2) + i;
    float4 x_ = x[index];
    xShared[i] = x_;
    sum += (x_.x * x_.x) + (x_.y * x_.y) + (x_.z * x_.z) + (x_.w * x_.w);
  }
  squaredPerThread[tid] = sum;
  __syncthreads();

float warpSum = warpReduce(squaredPerThread[tid]);
if (laneId == 0) {
    sumPerWarp[warpId] = warpSum;
  }
  __syncthreads();

if (tid < WARP_SIZE) {
    sumPerWarp[tid] = warpReduce(tid < warpsPerBlock ? sumPerWarp[tid] : 0);
    if (tid == 0) {
      rms_ = rsqrtf(sumPerWarp[tid] / hiddenDim + eps);
    }
  }
  __syncthreads();

for (int i = tid; i < hiddenDim >> 2; i += threadsPerBlock) {
    float4 w_ = w[i];
    float4 x_ = xShared[i];
    float4 val = make_float4(x_.x * rms_ * w_.x, x_.y * rms_ * w_.y,
                             x_.z * rms_ * w_.z, x_.w * rms_ * w_.w);
    y[bid * (hiddenDim >> 2) + i] = val;
  }
}

template <int numTokens, int hiddenDim, int threadsPerBlock>
voidlaunchRmsNormWarpFloat4(float *x, float *w, float eps, float *y){
  float4 *x_ = reinterpret_cast(x);
  float4 *w_ = reinterpret_cast(w);
  float4 *y_ = reinterpret_cast(y);

  rmsNormKernelWarpFloat4
      <<>>(x_, w_, eps, y_);
}

上述kernel的结果如下:

Latency = 2.80475 ms
Bandwidth = 3062.63 GB/s
% of max = 92.8071 %

结论

我们看到,如果我们了解Reduction的工作原理,实现高性能的RMSNorm操作kernel并不困难。如果你发现了进一步的优化机会,我很乐意听取你的意见。让我感到惊讶的一点是,使用#pragma unroll 并没有对性能产生积极影响。如果你喜欢这篇博文,我很乐意在LinkedIn(https://www.linkedin.com/in/simon-veitner-174a681b6/)上与你联系,交流关于CUDA或其他机器学习系统的想法。上述结果的所有复现代码都可以在我的Github(https://github.com/simveit/effective_rms_norm)上找到。

本文仅做学术分享,如有侵权,请联系删文。
图片
图片

3D视觉硬件

图片

3D视觉学习圈子

「3D视觉从入门到精通」知识星球(点开有惊喜) !星球内新增20多门3D视觉系统课程、入门环境配置教程、多场顶会直播、顶会论文最新解读、3D视觉算法源码、求职招聘等。想要入门3D视觉、做项目、搞科研,欢迎扫码加入

图片

3D视觉全栈学习课程:www.3dcver.com

image

3D视觉交流群成立啦

图片
点这里👇关注我,记得标星哦~

一键三连「分享」、「点赞」和「在看」

3D视觉科技前沿进展日日相见 ~ 

图片

Python社区是高质量的Python/Django开发社区
本文地址:http://www.python88.com/topic/182080
 
82 次点击