社区所有版块导航
Python
python开源   Django   Python   DjangoApp   pycharm  
DATA
docker   Elasticsearch  
aigc
aigc   chatgpt  
WEB开发
linux   MongoDB   Redis   DATABASE   NGINX   其他Web框架   web工具   zookeeper   tornado   NoSql   Bootstrap   js   peewee   Git   bottle   IE   MQ   Jquery  
机器学习
机器学习算法  
Python88.com
反馈   公告   社区推广  
产品
短视频  
印度
印度  
Py学习  »  机器学习算法

CUDA 优化!让深度学习变得更快!

3D视觉工坊 • 4 月前 • 128 次点击  

点击下方卡片,关注「3D视觉工坊」公众号
选择星标,干货第一时间送达

来源:3D视觉工坊

「3D视觉从入门到精通」知识星球(点开有惊喜) !星球内新增20多门3D视觉系统课程、入门环境配置教程、多场顶会直播、顶会论文最新解读、3D视觉算法源码、求职招聘等。想要入门3D视觉、做项目、搞科研,欢迎扫码加入

图片

本文由 @Simon V(https://github.com/simveit) 授权转载和翻译并发表到本公众号。原始地址为:https://veitner.bearblog.dev/making-rmsnorm-really-fast/

让RMSNorm变得更快

2025年4月18日

RMS Norm是一个在现代LLMs中常用的操作。给定一个向量,它的RMS Norm计算方式为,其中是权重,且。在这篇博文中,我们要计算矩阵中每一行的RMS Norm,其中,给定权重

顺序实现

检查我们的kernel的正确性需要一个基本的顺序实现作为参考。下面是我们使用的简单版本。

template <int numTokens, int hiddenDim>
voidlaunchRmsNormCpu(float *x, float *w, float eps, float *y){
float rms;
for (int token = 0; token < numTokens; token++) {
    rms = 0;
    for (int hidden = 0; hidden < hiddenDim; hidden++) {
      rms += x[token * hiddenDim + hidden] * x[token * hiddenDim + hidden];
    }
    rms = sqrt(rms / hiddenDim + eps);
    for (int hidden = 0; hidden < hiddenDim; hidden++) {
      y[token * hiddenDim + hidden] =
          x[token * hiddenDim + hidden] / rms * w[hidden];
    }
  }
}

如何并行化?

我们的并行化尝试非常简单。每个block处理一个token。如果block中的线程数小于隐藏维度的大小,每个线程就需要处理多个元素。然后我们执行一个简单的归约操作,计算RMS Norm并写入输出。如果你对归约操作不熟悉,请参考我之前关于归约的博文 。

Naive kernel

A naive solution in CUDA is as follows.

template <int hiddenDim, int threadsPerBlock>
__global__ voidrmsNormKernelNaive(float *x, float *w, float eps, float *y){
  __shared__ float squaredPerThread[threadsPerBlock];
  __shared__ float rms_;

constint tid = threadIdx.x;
constint bid = blockIdx.x;
float sum = 0.0f;

for (int i = tid; i < hiddenDim; i += threadsPerBlock) {
    float x_ = x[bid * hiddenDim + i];
    sum += x_ * x_;
  }
  squaredPerThread[tid] = sum;
  __syncthreads();

for (int activeThreads = threadsPerBlock / 2; activeThreads > 0;
       activeThreads >>= 1) {
    if (tid < activeThreads) {
      squaredPerThread[tid] += squaredPerThread[tid + activeThreads];
    }
    __syncthreads();
  }

if (tid == 0) {
    rms_ = rsqrtf(squaredPerThread[tid] / hiddenDim + eps);
  }
  __syncthreads();

for (int i = tid; i < hiddenDim; i += threadsPerBlock) {
    y[bid * hiddenDim + i] = x[bid * hiddenDim + i] * rms_ * w[i];
  }
}

template <int numTokens, int hiddenDim, int threadsPerBlock>
voidlaunchRmsNormNaive(float *x, float *w, float eps, float *y){
  rmsNormKernelNaive
      <<>>(x, w, eps, y);
}

x 跨内存访问一次,w 跨内存访问一次,y 跨内存访问一次。对于 numTokens = 1 << 18 和 hiddenDim = 1 << 12 的情况,w 的影响可以忽略不计,我们可以按如下方式计算带宽:

constsize_t size = numTokens * hiddenDim * sizeof(float);
size_t numCrossMemoryBound = 2 * size;
float latency = time / numRounds;
float bandwidth = (numCrossMemoryBound / latency) / 1e6;

上述kernel的结果如下:

Latency = 2.84878 ms
Bandwidth = 3015.3 GB/s
% of max = 91.3727 %

使用共享内存

正如我们在上面看到的,我们频繁地访问x中的元素。我们可以使用共享内存来加快内存访问。

template <int hiddenDim, int threadsPerBlock>
__global__ voidrmsNormKernelSmem(float *x, float *w, float eps, float *y){
  __shared__ float squaredPerThread[threadsPerBlock];
  __shared__ float xShared[hiddenDim];
  __shared__ float rms_;

constint tid = threadIdx.x;
constint bid = blockIdx.x;

float sum = 0.0f;

for (int i = tid; i < hiddenDim; i += threadsPerBlock) {
    int index = bid * hiddenDim + i;
    float x_ = x[index];
    xShared[i] = x_;
    sum += x_ * x_;
  }
  squaredPerThread[tid] = sum;
  __syncthreads();

for (int activeThreads = threadsPerBlock / 2; activeThreads > 0;
       activeThreads >>= 1) {
    if (tid < activeThreads) {
      squaredPerThread[tid] += squaredPerThread[tid + activeThreads];
    }
    __syncthreads();
  }

if (tid == 0) {
    rms_ = rsqrtf(squaredPerThread[tid] / hiddenDim + eps);
  }
  __syncthreads();

for (int i = tid; i < hiddenDim; i += threadsPerBlock) {
    float val = xShared[i] * rms_ * w[i];
    y[bid * hiddenDim + i] = val;
  }
}

template <int numTokens,  int hiddenDim, int threadsPerBlock>
voidlaunchRmsNormSmem(float *x, float *w, float eps, float *y){
  rmsNormKernelSmem
      <<>>(x, w, eps, y);
}

上述kernel的结果如下:

Latency = 2.82101 ms
Bandwidth = 3044.99 GB/s
% of max = 92.2723 %

使用warp

类似我们在前缀和操作中应用的技术,我们也可以这样做:

  • 在每个warp中进行归约
  • 使用一个warp归约这个数组以获得最终的归约结果。这个过程的代码如下:
#define WARP_SIZE 32

__device__ floatwarpReduce(float x){
float val = x;
for (int activeThreads = WARP_SIZE >> 1; activeThreads > 0;
       activeThreads >>= 1) {
    val += __shfl_down_sync(0xffffffff, val, activeThreads);
  }
return val;
}

template <int hiddenDim, int threadsPerBlock>
__global__ voidrmsNormKernelWarp(float *x, float *w, float eps, float *y){
  __shared__ float squaredPerThread[threadsPerBlock];
  __shared__ float xShared[hiddenDim];
  __shared__ float sumPerWarp[WARP_SIZE];
  __shared__ float rms_;

constint tid = threadIdx.x;
constint laneId = tid & 31;
constint warpId = tid >> 5;
constint warpsPerBlock = threadsPerBlock >> 5;

constint bid = blockIdx.x;
float sum = 0.0f;

for (int  i = tid; i < hiddenDim; i += threadsPerBlock) {
    float x_ = x[bid * hiddenDim + i];
    xShared[i] = x_;
    sum += x_ * x_;
  }
  squaredPerThread[tid] = sum;
  __syncthreads();

float warpSum = warpReduce(squaredPerThread[tid]);
if (laneId == 0) {
    sumPerWarp[warpId] = warpSum;
  }
  __syncthreads();

if (tid < WARP_SIZE) {
    sumPerWarp[tid] = warpReduce(tid < warpsPerBlock ? sumPerWarp[tid] : 0);
    if (tid == 0) {
      rms_ = rsqrtf(sumPerWarp[tid] / hiddenDim + eps);
    }
  }
  __syncthreads();

for (int i = tid; i < hiddenDim; i += threadsPerBlock) {
    y[bid * hiddenDim + i] = xShared[i] * rms_ * w[i];
  }
}

template <int numTokens, int hiddenDim, int threadsPerBlock>
voidlaunchRmsNormWarp(float *x, float *w, float eps, float *y){
  rmsNormKernelWarp
      <<>>(x, w, eps, y);
}

上述kernel的结果如下:

Latency = 2.82263 ms
Bandwidth = 3043.23 GB/s
% of max = 92.2192 %

最初我预计这个会更快,但事实并非如此。

向量化加载和存储

如果我们对上述kernel进行性能分析,可以看到内存加载和存储消耗了最多的指令。我们可以使用CUDA的float4数据类型来向量化加载和存储操作来优化这一点。

对于共享内存的方法,代码如下所示:

template <int hiddenDim, int threadsPerBlock>
__global__ voidrmsNormKernelSmemFloat4(float4 *x, float4 *w, float eps,
                                        float4 *y)
{
  __shared__ float squaredPerThread[threadsPerBlock];
  __shared__ float4 xShared[hiddenDim >> 2];
  __shared__ float rms_;

constint tid = threadIdx.x;
constint bid = blockIdx.x;

float sum = 0.0f;

for (int i = tid; i < hiddenDim >> 2; i += threadsPerBlock) {
    int index = bid * (hiddenDim >> 2) + i;
    float4 x_ = x[index];
    xShared[i] = x_;
    sum += (x_.x * x_.x) + (x_.y * x_.y) + (x_.z * x_.z) + (x_.w * x_.w);
  }
  squaredPerThread[tid] = sum;
  __syncthreads();

for ( int activeThreads = threadsPerBlock >> 1; activeThreads > 0;
       activeThreads >>= 1) {
    if (tid < activeThreads) {
      squaredPerThread[tid] += squaredPerThread[tid + activeThreads];
    }
    __syncthreads();
  }

if (tid == 0) {
    rms_ = rsqrtf(squaredPerThread[tid] / hiddenDim + eps);
  }
  __syncthreads();

for (int i = tid; i < hiddenDim >> 2; i += threadsPerBlock) {
    float4 w_ = w[i];
    float4 x_ = xShared[i];
    float4 val = make_float4(x_.x * rms_ * w_.x, x_.y * rms_ * w_.y,
                             x_.z * rms_ * w_.z, x_.w * rms_ * w_.w);
    y[bid * (hiddenDim >> 2) + i] = val;
  }
}

template <int numTokens, int hiddenDim, int threadsPerBlock>
voidlaunchRmsNormSmemFloat4(float *x, float *w, float eps, float *y){
  float4 *x_ = reinterpret_cast(x);
  float4 *w_ = reinterpret_cast(w);
  float4 *y_ = reinterpret_cast(y);
  rmsNormKernelSmemFloat4
      <<>>(x_, w_, eps, y_);
}

上述kernel的结果如下:

Latency = 2.80455 ms
Bandwidth = 3062.86 GB/s
% of max = 92.8139 %

类似地,我们也可以对warp kernel进行优化:

#define WARP_SIZE 32

__device__ floatwarpReduce(float x){
float val = x;
for (int activeThreads = WARP_SIZE >> 1; activeThreads > 0;
       activeThreads >>= 1) {
    val += __shfl_down_sync(0xffffffff, val, activeThreads);
  }
return val;
}

template <int hiddenDim, int threadsPerBlock>
__global__ voidrmsNormKernelWarpFloat4(float4 *x, float4 *w, float eps,
                                        float4 *y)
{
  __shared__ float squaredPerThread[threadsPerBlock];
  __shared__ float4 xShared[hiddenDim >> 2];
  __shared__ float sumPerWarp[WARP_SIZE];
  __shared__ float rms_;

constint tid = threadIdx.x;
constint laneId = tid & 31;
constint warpId = tid >> 5;
constint warpsPerBlock = threadsPerBlock >> 5;

constint bid = blockIdx.x;
float sum = 0.0f;

for (int i = tid; i < hiddenDim >> 2; i += threadsPerBlock) {
    int index = bid * (hiddenDim >> 2) + i;
    float4 x_ = x[index];
    xShared[i] = x_;
    sum += (x_.x * x_.x) + (x_.y * x_.y) + (x_.z * x_.z) + (x_.w * x_.w);
  }
  squaredPerThread[tid] = sum;
  __syncthreads();

float warpSum = warpReduce(squaredPerThread[tid]);
if (laneId == 0) {
    sumPerWarp[warpId] = warpSum;
  }
  __syncthreads();

if (tid < WARP_SIZE) {
    sumPerWarp[tid] = warpReduce(tid < warpsPerBlock ? sumPerWarp[tid] : 0);
    if (tid == 0) {
      rms_ = rsqrtf(sumPerWarp[tid] / hiddenDim + eps);
    }
  }
  __syncthreads();

for (int i = tid; i < hiddenDim >> 2; i += threadsPerBlock) {
    float4 w_ = w[i];
    float4 x_ = xShared[i];
    float4 val = make_float4(x_.x * rms_ * w_.x, x_.y * rms_ * w_.y,
                             x_.z * rms_ * w_.z, x_.w * rms_ * w_.w);
    y[bid * (hiddenDim >> 2) + i] = val;
  }
}

template <int numTokens, int hiddenDim, int threadsPerBlock>
voidlaunchRmsNormWarpFloat4(float *x, float *w, float eps, float *y){
  float4 *x_ = reinterpret_cast(x);
  float4 *w_ = reinterpret_cast(w);
  float4 *y_ = reinterpret_cast(y);

  rmsNormKernelWarpFloat4
      <<>>(x_, w_, eps, y_);
}

上述kernel的结果如下:

Latency = 2.80475 ms
Bandwidth = 3062.63 GB/s
% of max = 92.8071 %

结论

我们看到,如果我们了解Reduction的工作原理,实现高性能的RMSNorm操作kernel并不困难。如果你发现了进一步的优化机会,我很乐意听取你的意见。让我感到惊讶的一点是,使用#pragma unroll 并没有对性能产生积极影响。如果你喜欢这篇博文,我很乐意在LinkedIn(https://www.linkedin.com/in/simon-veitner-174a681b6/)上与你联系,交流关于CUDA或其他机器学习系统的想法。上述结果的所有复现代码都可以在我的Github(https://github.com/simveit/effective_rms_norm)上找到。

本文仅做学术分享,如有侵权,请联系删文。
图片
图片

3D视觉硬件

图片

3D视觉学习圈子

「3D视觉从入门到精通」知识星球(点开有惊喜) !星球内新增20多门3D视觉系统课程、入门环境配置教程、多场顶会直播、顶会论文最新解读、3D视觉算法源码、求职招聘等。想要入门3D视觉、做项目、搞科研,欢迎扫码加入

图片

3D视觉全栈学习课程:www.3dcver.com

image

3D视觉交流群成立啦

图片
点这里👇关注我,记得标星哦~

一键三连「分享」、「点赞」和「在看」

3D视觉科技前沿进展日日相见 ~ 

图片

Python社区是高质量的Python/Django开发社区
本文地址:http://www.python88.com/topic/182080
 
128 次点击