Skip to content

Commit 5bf36ce

Browse files
authored
misc: remove duplicate norm cuda kernels (#631)
gemma-style rmsnorm kernels (introduced in #477 ) are similar to original rmsnorm kernel, and we should use the same kernel for them. This PR cleans up duplicate code and unifies the kernels for gemma-style and original rmsnorm kernels. The precision improvements (#587, #592) are kept in this PR.
1 parent 6ec4db3 commit 5bf36ce

File tree

2 files changed

+26
-163
lines changed

2 files changed

+26
-163
lines changed

include/flashinfer/norm.cuh

+16-155
Original file line numberDiff line numberDiff line change
@@ -29,7 +29,7 @@ namespace norm {
2929

3030
template <uint32_t VEC_SIZE, typename T>
3131
__global__ void RMSNormKernel(T* __restrict__ input, T* __restrict__ weight, T* __restrict__ output,
32-
const uint32_t d, float eps) {
32+
const uint32_t d, float weight_bias, float eps) {
3333
const uint32_t bx = blockIdx.x;
3434
const uint32_t tx = threadIdx.x, ty = threadIdx.y;
3535
constexpr uint32_t warp_size = 32;
@@ -87,7 +87,7 @@ __global__ void RMSNormKernel(T* __restrict__ input, T* __restrict__ weight, T*
8787
}
8888
#pragma unroll
8989
for (uint32_t j = 0; j < VEC_SIZE; j++) {
90-
output_vec[j] = float(input_vec[j]) * rms_rcp * float(weight_vec[j]);
90+
output_vec[j] = float(input_vec[j]) * rms_rcp * (weight_bias + float(weight_vec[j]));
9191
}
9292
if ((i * num_threads + thread_id) * VEC_SIZE < d) {
9393
output_vec.store(output + bx * d + i * num_threads * VEC_SIZE + thread_id * VEC_SIZE);
@@ -105,7 +105,8 @@ cudaError_t RMSNorm(T* input, T* weight, T* output, uint32_t batch_size, uint32_
105105
dim3 nblks(batch_size);
106106
dim3 nthrs(32, num_warps);
107107
const uint32_t smem_size = num_warps * sizeof(float);
108-
void* args[] = {&input, &weight, &output, &d, &eps};
108+
float weight_bias = 0.f;
109+
void* args[] = {&input, &weight, &output, &d, &weight_bias, &eps};
109110

110111
DISPATCH_ALIGNED_VEC_SIZE(vec_size, VEC_SIZE, {
111112
auto kernel = RMSNormKernel<VEC_SIZE, T>;
@@ -116,7 +117,8 @@ cudaError_t RMSNorm(T* input, T* weight, T* output, uint32_t batch_size, uint32_
116117

117118
template <uint32_t VEC_SIZE, typename T>
118119
__global__ void FusedAddRMSNormKernel(T* __restrict__ input, T* __restrict__ residual,
119-
T* __restrict__ weight, const uint32_t d, float eps) {
120+
T* __restrict__ weight, const uint32_t d, float weight_bias,
121+
float eps) {
120122
const uint32_t bx = blockIdx.x;
121123
const uint32_t tx = threadIdx.x, ty = threadIdx.y;
122124
constexpr uint32_t warp_size = 32;
@@ -187,7 +189,7 @@ __global__ void FusedAddRMSNormKernel(T* __restrict__ input, T* __restrict__ res
187189
}
188190
#pragma unroll
189191
for (uint32_t j = 0; j < VEC_SIZE; j++) {
190-
input_vec[j] = x_vec[j] * rms_rcp * float(weight_vec[j]);
192+
input_vec[j] = x_vec[j] * rms_rcp * (weight_bias + float(weight_vec[j]));
191193
}
192194
if ((i * num_threads + thread_id) * VEC_SIZE < d) {
193195
input_vec.store(input + bx * d + i * num_threads * VEC_SIZE + thread_id * VEC_SIZE);
@@ -205,7 +207,8 @@ cudaError_t FusedAddRMSNorm(T* input, T* residual, T* weight, uint32_t batch_siz
205207
dim3 nblks(batch_size);
206208
dim3 nthrs(32, num_warps);
207209
const uint32_t smem_size = (num_warps + d) * sizeof(float);
208-
void* args[] = {&input, &residual, &weight, &d, &eps};
210+
float weight_bias = 0.f;
211+
void* args[] = {&input, &residual, &weight, &d, &weight_bias, &eps};
209212

210213
DISPATCH_ALIGNED_VEC_SIZE(vec_size, VEC_SIZE, {
211214
auto kernel = FusedAddRMSNormKernel<VEC_SIZE, T>;
@@ -215,73 +218,6 @@ cudaError_t FusedAddRMSNorm(T* input, T* residual, T* weight, uint32_t batch_siz
215218
return cudaSuccess;
216219
}
217220

218-
template <uint32_t VEC_SIZE, typename T>
219-
__global__ void GemmaRMSNormKernel(T* __restrict__ input, T* __restrict__ weight,
220-
T* __restrict__ output, const uint32_t d, float eps) {
221-
const uint32_t bx = blockIdx.x;
222-
const uint32_t tx = threadIdx.x, ty = threadIdx.y;
223-
constexpr uint32_t warp_size = 32;
224-
const uint32_t num_warps = blockDim.y;
225-
const uint32_t thread_id = tx + ty * warp_size;
226-
const uint32_t num_threads = num_warps * warp_size;
227-
const uint32_t rounds = ceil_div(d, VEC_SIZE * num_threads);
228-
extern __shared__ float smem[];
229-
230-
float sum_sq = 0.f;
231-
232-
for (uint32_t i = 0; i < rounds; i++) {
233-
vec_t<T, VEC_SIZE> input_vec;
234-
input_vec.fill(0.f);
235-
if ((i * num_threads + thread_id) * VEC_SIZE < d) {
236-
input_vec.load(input + bx * d + i * num_threads * VEC_SIZE + thread_id * VEC_SIZE);
237-
}
238-
#pragma unroll
239-
for (uint32_t j = 0; j < VEC_SIZE; j++) {
240-
sum_sq += float(input_vec[j]) * float(input_vec[j]);
241-
}
242-
}
243-
244-
// first, warp reduce sum
245-
#pragma unroll
246-
for (uint32_t offset = warp_size / 2; offset > 0; offset /= 2) {
247-
sum_sq += math::shfl_xor_sync(sum_sq, offset);
248-
}
249-
250-
smem[ty] = sum_sq;
251-
__syncthreads();
252-
// then, cross warp reduce sum using only the first warp
253-
if (ty == 0) {
254-
sum_sq = (tx < num_warps) ? smem[tx] : 0.f;
255-
#pragma unroll
256-
for (uint32_t offset = warp_size / 2; offset > 0; offset /= 2) {
257-
sum_sq += math::shfl_xor_sync(sum_sq, offset);
258-
}
259-
smem[0] = sum_sq;
260-
}
261-
__syncthreads();
262-
263-
float rms_rcp = math::rsqrt(smem[0] / (float(d) + eps));
264-
265-
for (uint32_t i = 0; i < rounds; i++) {
266-
vec_t<T, VEC_SIZE> input_vec;
267-
vec_t<T, VEC_SIZE> weight_vec;
268-
vec_t<T, VEC_SIZE> output_vec;
269-
input_vec.fill(0.f);
270-
weight_vec.fill(0.f);
271-
if ((i * num_threads + thread_id) * VEC_SIZE < d) {
272-
input_vec.load(input + bx * d + i * num_threads * VEC_SIZE + thread_id * VEC_SIZE);
273-
weight_vec.load(weight + i * num_threads * VEC_SIZE + thread_id * VEC_SIZE);
274-
}
275-
#pragma unroll
276-
for (uint32_t j = 0; j < VEC_SIZE; j++) {
277-
output_vec[j] = float(input_vec[j]) * rms_rcp * (1.0f + float(weight_vec[j]));
278-
}
279-
if ((i * num_threads + thread_id) * VEC_SIZE < d) {
280-
output_vec.store(output + bx * d + i * num_threads * VEC_SIZE + thread_id * VEC_SIZE);
281-
}
282-
}
283-
}
284-
285221
template <typename T>
286222
cudaError_t GemmaRMSNorm(T* input, T* weight, T* output, uint32_t batch_size, uint32_t d,
287223
float eps = 1e-5, cudaStream_t stream = 0) {
@@ -292,92 +228,16 @@ cudaError_t GemmaRMSNorm(T* input, T* weight, T* output, uint32_t batch_size, ui
292228
dim3 nblks(batch_size);
293229
dim3 nthrs(32, num_warps);
294230
const uint32_t smem_size = num_warps * sizeof(float);
295-
void* args[] = {&input, &weight, &output, &d, &eps};
231+
float weight_bias = 1.f;
232+
void* args[] = {&input, &weight, &output, &d, &weight_bias, &eps};
296233

297234
DISPATCH_ALIGNED_VEC_SIZE(vec_size, VEC_SIZE, {
298-
auto kernel = GemmaRMSNormKernel<VEC_SIZE, T>;
235+
auto kernel = RMSNormKernel<VEC_SIZE, T>;
299236
FLASHINFER_CUDA_CALL(cudaLaunchKernel((void*)kernel, nblks, nthrs, args, smem_size, stream));
300237
});
301238
return cudaSuccess;
302239
}
303240

304-
template <uint32_t VEC_SIZE, typename T>
305-
__global__ void GemmaFusedAddRMSNormKernel(T* __restrict__ input, T* __restrict__ residual,
306-
T* __restrict__ weight, const uint32_t d, float eps) {
307-
const uint32_t bx = blockIdx.x;
308-
const uint32_t tx = threadIdx.x, ty = threadIdx.y;
309-
constexpr uint32_t warp_size = 32;
310-
const uint32_t num_warps = blockDim.y;
311-
const uint32_t thread_id = tx + ty * warp_size;
312-
const uint32_t num_threads = num_warps * warp_size;
313-
const uint32_t rounds = ceil_div(d, VEC_SIZE * num_threads);
314-
extern __shared__ float smem[];
315-
316-
float sum_sq = 0.f;
317-
318-
for (uint32_t i = 0; i < rounds; i++) {
319-
vec_t<T, VEC_SIZE> input_vec;
320-
input_vec.fill(0.f);
321-
vec_t<T, VEC_SIZE> residual_vec;
322-
residual_vec.fill(0.f);
323-
if ((i * num_threads + thread_id) * VEC_SIZE < d) {
324-
input_vec.load(input + bx * d + i * num_threads * VEC_SIZE + thread_id * VEC_SIZE);
325-
residual_vec.load(residual + bx * d + i * num_threads * VEC_SIZE + thread_id * VEC_SIZE);
326-
}
327-
#pragma unroll
328-
for (uint32_t j = 0; j < VEC_SIZE; j++) {
329-
float x = float(input_vec[j]);
330-
x += float(residual_vec[j]);
331-
sum_sq += x * x;
332-
residual_vec[j] = (T)x;
333-
}
334-
if ((i * num_threads + thread_id) * VEC_SIZE < d) {
335-
residual_vec.store(residual + bx * d + i * num_threads * VEC_SIZE + thread_id * VEC_SIZE);
336-
}
337-
}
338-
339-
// first, warp reduce sum
340-
#pragma unroll
341-
for (uint32_t offset = warp_size / 2; offset > 0; offset /= 2) {
342-
sum_sq += math::shfl_xor_sync(sum_sq, offset);
343-
}
344-
345-
smem[ty] = sum_sq;
346-
__syncthreads();
347-
// then, cross warp reduce sum using only the first warp
348-
if (ty == 0) {
349-
sum_sq = (tx < num_warps) ? smem[tx] : 0.f;
350-
#pragma unroll
351-
for (uint32_t offset = warp_size / 2; offset > 0; offset /= 2) {
352-
sum_sq += math::shfl_xor_sync(sum_sq, offset);
353-
}
354-
smem[0] = sum_sq;
355-
}
356-
__syncthreads();
357-
358-
float rms_rcp = math::rsqrt(smem[0] / (float(d) + eps));
359-
360-
for (uint32_t i = 0; i < rounds; i++) {
361-
vec_t<T, VEC_SIZE> input_vec;
362-
vec_t<T, VEC_SIZE> weight_vec;
363-
vec_t<T, VEC_SIZE> residual_vec;
364-
input_vec.fill(0.f);
365-
weight_vec.fill(0.f);
366-
residual_vec.fill(0.f);
367-
if ((i * num_threads + thread_id) * VEC_SIZE < d) {
368-
weight_vec.load(weight + i * num_threads * VEC_SIZE + thread_id * VEC_SIZE);
369-
residual_vec.load(residual + bx * d + i * num_threads * VEC_SIZE + thread_id * VEC_SIZE);
370-
}
371-
#pragma unroll
372-
for (uint32_t j = 0; j < VEC_SIZE; j++) {
373-
input_vec[j] = float(residual_vec[j]) * rms_rcp * (1.0f + float(weight_vec[j]));
374-
}
375-
if ((i * num_threads + thread_id) * VEC_SIZE < d) {
376-
input_vec.store(input + bx * d + i * num_threads * VEC_SIZE + thread_id * VEC_SIZE);
377-
}
378-
}
379-
}
380-
381241
template <typename T>
382242
cudaError_t GemmaFusedAddRMSNorm(T* input, T* residual, T* weight, uint32_t batch_size, uint32_t d,
383243
float eps = 1e-5, cudaStream_t stream = 0) {
@@ -387,11 +247,12 @@ cudaError_t GemmaFusedAddRMSNorm(T* input, T* residual, T* weight, uint32_t batc
387247
const uint32_t num_warps = ceil_div(block_size, 32);
388248
dim3 nblks(batch_size);
389249
dim3 nthrs(32, num_warps);
390-
const uint32_t smem_size = num_warps * sizeof(float);
391-
void* args[] = {&input, &residual, &weight, &d, &eps};
250+
const uint32_t smem_size = (num_warps + d) * sizeof(float);
251+
float weight_bias = 1.f;
252+
void* args[] = {&input, &residual, &weight, &d, &weight_bias, &eps};
392253

393254
DISPATCH_ALIGNED_VEC_SIZE(vec_size, VEC_SIZE, {
394-
auto kernel = GemmaFusedAddRMSNormKernel<VEC_SIZE, T>;
255+
auto kernel = FusedAddRMSNormKernel<VEC_SIZE, T>;
395256
FLASHINFER_CUDA_CALL(cudaLaunchKernel((void*)kernel, nblks, nthrs, args, smem_size, stream));
396257
});
397258

tests/test_norm.py

+10-8
Original file line numberDiff line numberDiff line change
@@ -21,19 +21,21 @@
2121

2222

2323
def llama_rms_norm(x, w, eps=1e-6):
24-
def _norm(x):
25-
return x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + eps)
26-
27-
output = _norm(x.float()).type_as(x)
28-
return output * w
24+
orig_dtype = x.dtype
25+
x = x.float()
26+
variance = x.pow(2).mean(dim=-1, keepdim=True)
27+
x = x * torch.rsqrt(variance + eps)
28+
x = x * w.float()
29+
x = x.to(orig_dtype)
30+
return x
2931

3032

3133
def gemma_rms_norm(x, w, eps=1e-6):
3234
orig_dtype = x.dtype
3335
x = x.float()
3436
variance = x.pow(2).mean(dim=-1, keepdim=True)
3537
x = x * torch.rsqrt(variance + eps)
36-
x = x * (1.0 + w)
38+
x = x * (1.0 + w.float())
3739
x = x.to(orig_dtype)
3840
return x
3941

@@ -45,7 +47,7 @@ def gemma_fused_add_rms_norm(x, residual, w, eps=1e-6):
4547
x = x.float()
4648
variance = x.pow(2).mean(dim=-1, keepdim=True)
4749
x = x * torch.rsqrt(variance + eps)
48-
x = x * (1.0 + w)
50+
x = x * (1.0 + w.float())
4951
x = x.to(orig_dtype)
5052
return x, residual
5153

@@ -58,7 +60,7 @@ def fused_add_rms_norm(x, residual, weight, eps):
5860

5961
variance = x.pow(2).mean(dim=-1, keepdim=True)
6062
x = x * torch.rsqrt(variance + eps)
61-
x = x.to(orig_dtype) * weight
63+
x = (x * weight.float()).to(orig_dtype)
6264
return x, residual
6365

6466

0 commit comments

Comments
 (0)