@@ -29,7 +29,7 @@ namespace norm {
29
29
30
30
template <uint32_t VEC_SIZE, typename T>
31
31
__global__ void RMSNormKernel (T* __restrict__ input, T* __restrict__ weight, T* __restrict__ output,
32
- const uint32_t d, float eps) {
32
+ const uint32_t d, float weight_bias, float eps) {
33
33
const uint32_t bx = blockIdx .x ;
34
34
const uint32_t tx = threadIdx .x , ty = threadIdx .y ;
35
35
constexpr uint32_t warp_size = 32 ;
@@ -87,7 +87,7 @@ __global__ void RMSNormKernel(T* __restrict__ input, T* __restrict__ weight, T*
87
87
}
88
88
#pragma unroll
89
89
for (uint32_t j = 0 ; j < VEC_SIZE; j++) {
90
- output_vec[j] = float (input_vec[j]) * rms_rcp * float (weight_vec[j]);
90
+ output_vec[j] = float (input_vec[j]) * rms_rcp * (weight_bias + float (weight_vec[j]) );
91
91
}
92
92
if ((i * num_threads + thread_id) * VEC_SIZE < d) {
93
93
output_vec.store (output + bx * d + i * num_threads * VEC_SIZE + thread_id * VEC_SIZE);
@@ -105,7 +105,8 @@ cudaError_t RMSNorm(T* input, T* weight, T* output, uint32_t batch_size, uint32_
105
105
dim3 nblks (batch_size);
106
106
dim3 nthrs (32 , num_warps);
107
107
const uint32_t smem_size = num_warps * sizeof (float );
108
- void * args[] = {&input, &weight, &output, &d, &eps};
108
+ float weight_bias = 0 .f ;
109
+ void * args[] = {&input, &weight, &output, &d, &weight_bias, &eps};
109
110
110
111
DISPATCH_ALIGNED_VEC_SIZE (vec_size, VEC_SIZE, {
111
112
auto kernel = RMSNormKernel<VEC_SIZE, T>;
@@ -116,7 +117,8 @@ cudaError_t RMSNorm(T* input, T* weight, T* output, uint32_t batch_size, uint32_
116
117
117
118
template <uint32_t VEC_SIZE, typename T>
118
119
__global__ void FusedAddRMSNormKernel (T* __restrict__ input, T* __restrict__ residual,
119
- T* __restrict__ weight, const uint32_t d, float eps) {
120
+ T* __restrict__ weight, const uint32_t d, float weight_bias,
121
+ float eps) {
120
122
const uint32_t bx = blockIdx .x ;
121
123
const uint32_t tx = threadIdx .x , ty = threadIdx .y ;
122
124
constexpr uint32_t warp_size = 32 ;
@@ -187,7 +189,7 @@ __global__ void FusedAddRMSNormKernel(T* __restrict__ input, T* __restrict__ res
187
189
}
188
190
#pragma unroll
189
191
for (uint32_t j = 0 ; j < VEC_SIZE; j++) {
190
- input_vec[j] = x_vec[j] * rms_rcp * float (weight_vec[j]);
192
+ input_vec[j] = x_vec[j] * rms_rcp * (weight_bias + float (weight_vec[j]) );
191
193
}
192
194
if ((i * num_threads + thread_id) * VEC_SIZE < d) {
193
195
input_vec.store (input + bx * d + i * num_threads * VEC_SIZE + thread_id * VEC_SIZE);
@@ -205,7 +207,8 @@ cudaError_t FusedAddRMSNorm(T* input, T* residual, T* weight, uint32_t batch_siz
205
207
dim3 nblks (batch_size);
206
208
dim3 nthrs (32 , num_warps);
207
209
const uint32_t smem_size = (num_warps + d) * sizeof (float );
208
- void * args[] = {&input, &residual, &weight, &d, &eps};
210
+ float weight_bias = 0 .f ;
211
+ void * args[] = {&input, &residual, &weight, &d, &weight_bias, &eps};
209
212
210
213
DISPATCH_ALIGNED_VEC_SIZE (vec_size, VEC_SIZE, {
211
214
auto kernel = FusedAddRMSNormKernel<VEC_SIZE, T>;
@@ -215,73 +218,6 @@ cudaError_t FusedAddRMSNorm(T* input, T* residual, T* weight, uint32_t batch_siz
215
218
return cudaSuccess;
216
219
}
217
220
218
- template <uint32_t VEC_SIZE, typename T>
219
- __global__ void GemmaRMSNormKernel (T* __restrict__ input, T* __restrict__ weight,
220
- T* __restrict__ output, const uint32_t d, float eps) {
221
- const uint32_t bx = blockIdx .x ;
222
- const uint32_t tx = threadIdx .x , ty = threadIdx .y ;
223
- constexpr uint32_t warp_size = 32 ;
224
- const uint32_t num_warps = blockDim .y ;
225
- const uint32_t thread_id = tx + ty * warp_size;
226
- const uint32_t num_threads = num_warps * warp_size;
227
- const uint32_t rounds = ceil_div (d, VEC_SIZE * num_threads);
228
- extern __shared__ float smem[];
229
-
230
- float sum_sq = 0 .f ;
231
-
232
- for (uint32_t i = 0 ; i < rounds; i++) {
233
- vec_t <T, VEC_SIZE> input_vec;
234
- input_vec.fill (0 .f );
235
- if ((i * num_threads + thread_id) * VEC_SIZE < d) {
236
- input_vec.load (input + bx * d + i * num_threads * VEC_SIZE + thread_id * VEC_SIZE);
237
- }
238
- #pragma unroll
239
- for (uint32_t j = 0 ; j < VEC_SIZE; j++) {
240
- sum_sq += float (input_vec[j]) * float (input_vec[j]);
241
- }
242
- }
243
-
244
- // first, warp reduce sum
245
- #pragma unroll
246
- for (uint32_t offset = warp_size / 2 ; offset > 0 ; offset /= 2 ) {
247
- sum_sq += math::shfl_xor_sync (sum_sq, offset);
248
- }
249
-
250
- smem[ty] = sum_sq;
251
- __syncthreads ();
252
- // then, cross warp reduce sum using only the first warp
253
- if (ty == 0 ) {
254
- sum_sq = (tx < num_warps) ? smem[tx] : 0 .f ;
255
- #pragma unroll
256
- for (uint32_t offset = warp_size / 2 ; offset > 0 ; offset /= 2 ) {
257
- sum_sq += math::shfl_xor_sync (sum_sq, offset);
258
- }
259
- smem[0 ] = sum_sq;
260
- }
261
- __syncthreads ();
262
-
263
- float rms_rcp = math::rsqrt (smem[0 ] / (float (d) + eps));
264
-
265
- for (uint32_t i = 0 ; i < rounds; i++) {
266
- vec_t <T, VEC_SIZE> input_vec;
267
- vec_t <T, VEC_SIZE> weight_vec;
268
- vec_t <T, VEC_SIZE> output_vec;
269
- input_vec.fill (0 .f );
270
- weight_vec.fill (0 .f );
271
- if ((i * num_threads + thread_id) * VEC_SIZE < d) {
272
- input_vec.load (input + bx * d + i * num_threads * VEC_SIZE + thread_id * VEC_SIZE);
273
- weight_vec.load (weight + i * num_threads * VEC_SIZE + thread_id * VEC_SIZE);
274
- }
275
- #pragma unroll
276
- for (uint32_t j = 0 ; j < VEC_SIZE; j++) {
277
- output_vec[j] = float (input_vec[j]) * rms_rcp * (1 .0f + float (weight_vec[j]));
278
- }
279
- if ((i * num_threads + thread_id) * VEC_SIZE < d) {
280
- output_vec.store (output + bx * d + i * num_threads * VEC_SIZE + thread_id * VEC_SIZE);
281
- }
282
- }
283
- }
284
-
285
221
template <typename T>
286
222
cudaError_t GemmaRMSNorm (T* input, T* weight, T* output, uint32_t batch_size, uint32_t d,
287
223
float eps = 1e-5 , cudaStream_t stream = 0 ) {
@@ -292,92 +228,16 @@ cudaError_t GemmaRMSNorm(T* input, T* weight, T* output, uint32_t batch_size, ui
292
228
dim3 nblks (batch_size);
293
229
dim3 nthrs (32 , num_warps);
294
230
const uint32_t smem_size = num_warps * sizeof (float );
295
- void * args[] = {&input, &weight, &output, &d, &eps};
231
+ float weight_bias = 1 .f ;
232
+ void * args[] = {&input, &weight, &output, &d, &weight_bias, &eps};
296
233
297
234
DISPATCH_ALIGNED_VEC_SIZE (vec_size, VEC_SIZE, {
298
- auto kernel = GemmaRMSNormKernel <VEC_SIZE, T>;
235
+ auto kernel = RMSNormKernel <VEC_SIZE, T>;
299
236
FLASHINFER_CUDA_CALL (cudaLaunchKernel ((void *)kernel, nblks, nthrs, args, smem_size, stream));
300
237
});
301
238
return cudaSuccess;
302
239
}
303
240
304
- template <uint32_t VEC_SIZE, typename T>
305
- __global__ void GemmaFusedAddRMSNormKernel (T* __restrict__ input, T* __restrict__ residual,
306
- T* __restrict__ weight, const uint32_t d, float eps) {
307
- const uint32_t bx = blockIdx .x ;
308
- const uint32_t tx = threadIdx .x , ty = threadIdx .y ;
309
- constexpr uint32_t warp_size = 32 ;
310
- const uint32_t num_warps = blockDim .y ;
311
- const uint32_t thread_id = tx + ty * warp_size;
312
- const uint32_t num_threads = num_warps * warp_size;
313
- const uint32_t rounds = ceil_div (d, VEC_SIZE * num_threads);
314
- extern __shared__ float smem[];
315
-
316
- float sum_sq = 0 .f ;
317
-
318
- for (uint32_t i = 0 ; i < rounds; i++) {
319
- vec_t <T, VEC_SIZE> input_vec;
320
- input_vec.fill (0 .f );
321
- vec_t <T, VEC_SIZE> residual_vec;
322
- residual_vec.fill (0 .f );
323
- if ((i * num_threads + thread_id) * VEC_SIZE < d) {
324
- input_vec.load (input + bx * d + i * num_threads * VEC_SIZE + thread_id * VEC_SIZE);
325
- residual_vec.load (residual + bx * d + i * num_threads * VEC_SIZE + thread_id * VEC_SIZE);
326
- }
327
- #pragma unroll
328
- for (uint32_t j = 0 ; j < VEC_SIZE; j++) {
329
- float x = float (input_vec[j]);
330
- x += float (residual_vec[j]);
331
- sum_sq += x * x;
332
- residual_vec[j] = (T)x;
333
- }
334
- if ((i * num_threads + thread_id) * VEC_SIZE < d) {
335
- residual_vec.store (residual + bx * d + i * num_threads * VEC_SIZE + thread_id * VEC_SIZE);
336
- }
337
- }
338
-
339
- // first, warp reduce sum
340
- #pragma unroll
341
- for (uint32_t offset = warp_size / 2 ; offset > 0 ; offset /= 2 ) {
342
- sum_sq += math::shfl_xor_sync (sum_sq, offset);
343
- }
344
-
345
- smem[ty] = sum_sq;
346
- __syncthreads ();
347
- // then, cross warp reduce sum using only the first warp
348
- if (ty == 0 ) {
349
- sum_sq = (tx < num_warps) ? smem[tx] : 0 .f ;
350
- #pragma unroll
351
- for (uint32_t offset = warp_size / 2 ; offset > 0 ; offset /= 2 ) {
352
- sum_sq += math::shfl_xor_sync (sum_sq, offset);
353
- }
354
- smem[0 ] = sum_sq;
355
- }
356
- __syncthreads ();
357
-
358
- float rms_rcp = math::rsqrt (smem[0 ] / (float (d) + eps));
359
-
360
- for (uint32_t i = 0 ; i < rounds; i++) {
361
- vec_t <T, VEC_SIZE> input_vec;
362
- vec_t <T, VEC_SIZE> weight_vec;
363
- vec_t <T, VEC_SIZE> residual_vec;
364
- input_vec.fill (0 .f );
365
- weight_vec.fill (0 .f );
366
- residual_vec.fill (0 .f );
367
- if ((i * num_threads + thread_id) * VEC_SIZE < d) {
368
- weight_vec.load (weight + i * num_threads * VEC_SIZE + thread_id * VEC_SIZE);
369
- residual_vec.load (residual + bx * d + i * num_threads * VEC_SIZE + thread_id * VEC_SIZE);
370
- }
371
- #pragma unroll
372
- for (uint32_t j = 0 ; j < VEC_SIZE; j++) {
373
- input_vec[j] = float (residual_vec[j]) * rms_rcp * (1 .0f + float (weight_vec[j]));
374
- }
375
- if ((i * num_threads + thread_id) * VEC_SIZE < d) {
376
- input_vec.store (input + bx * d + i * num_threads * VEC_SIZE + thread_id * VEC_SIZE);
377
- }
378
- }
379
- }
380
-
381
241
template <typename T>
382
242
cudaError_t GemmaFusedAddRMSNorm (T* input, T* residual, T* weight, uint32_t batch_size, uint32_t d,
383
243
float eps = 1e-5 , cudaStream_t stream = 0 ) {
@@ -387,11 +247,12 @@ cudaError_t GemmaFusedAddRMSNorm(T* input, T* residual, T* weight, uint32_t batc
387
247
const uint32_t num_warps = ceil_div (block_size, 32 );
388
248
dim3 nblks (batch_size);
389
249
dim3 nthrs (32 , num_warps);
390
- const uint32_t smem_size = num_warps * sizeof (float );
391
- void * args[] = {&input, &residual, &weight, &d, &eps};
250
+ const uint32_t smem_size = (num_warps + d) * sizeof (float );
251
+ float weight_bias = 1 .f ;
252
+ void * args[] = {&input, &residual, &weight, &d, &weight_bias, &eps};
392
253
393
254
DISPATCH_ALIGNED_VEC_SIZE (vec_size, VEC_SIZE, {
394
- auto kernel = GemmaFusedAddRMSNormKernel <VEC_SIZE, T>;
255
+ auto kernel = FusedAddRMSNormKernel <VEC_SIZE, T>;
395
256
FLASHINFER_CUDA_CALL (cudaLaunchKernel ((void *)kernel, nblks, nthrs, args, smem_size, stream));
396
257
});
397
258
0 commit comments