@@ -29,7 +29,8 @@ namespace norm {
29
29
30
30
template <uint32_t VEC_SIZE, typename T>
31
31
__global__ void RMSNormKernel (T* __restrict__ input, T* __restrict__ weight, T* __restrict__ output,
32
- const uint32_t d, float weight_bias, float eps) {
32
+ const uint32_t d, const uint32_t stride_input,
33
+ const uint32_t stride_output, float weight_bias, float eps) {
33
34
const uint32_t bx = blockIdx .x ;
34
35
const uint32_t tx = threadIdx .x , ty = threadIdx .y ;
35
36
constexpr uint32_t warp_size = 32 ;
@@ -46,7 +47,7 @@ __global__ void RMSNormKernel(T* __restrict__ input, T* __restrict__ weight, T*
46
47
vec_t <T, VEC_SIZE> input_vec;
47
48
input_vec.fill (0 .f );
48
49
if ((i * num_threads + thread_id) * VEC_SIZE < d) {
49
- input_vec.load (input + bx * d + i * num_threads * VEC_SIZE + thread_id * VEC_SIZE);
50
+ input_vec.load (input + bx * stride_input + i * num_threads * VEC_SIZE + thread_id * VEC_SIZE);
50
51
}
51
52
#pragma unroll
52
53
for (uint32_t j = 0 ; j < VEC_SIZE; j++) {
@@ -82,22 +83,24 @@ __global__ void RMSNormKernel(T* __restrict__ input, T* __restrict__ weight, T*
82
83
input_vec.fill (0 .f );
83
84
weight_vec.fill (0 .f );
84
85
if ((i * num_threads + thread_id) * VEC_SIZE < d) {
85
- input_vec.load (input + bx * d + i * num_threads * VEC_SIZE + thread_id * VEC_SIZE);
86
+ input_vec.load (input + bx * stride_input + i * num_threads * VEC_SIZE + thread_id * VEC_SIZE);
86
87
weight_vec.load (weight + i * num_threads * VEC_SIZE + thread_id * VEC_SIZE);
87
88
}
88
89
#pragma unroll
89
90
for (uint32_t j = 0 ; j < VEC_SIZE; j++) {
90
91
output_vec[j] = float (input_vec[j]) * rms_rcp * (weight_bias + float (weight_vec[j]));
91
92
}
92
93
if ((i * num_threads + thread_id) * VEC_SIZE < d) {
93
- output_vec.store (output + bx * d + i * num_threads * VEC_SIZE + thread_id * VEC_SIZE);
94
+ output_vec.store (output + bx * stride_output + i * num_threads * VEC_SIZE +
95
+ thread_id * VEC_SIZE);
94
96
}
95
97
}
96
98
}
97
99
98
100
template <typename T>
99
101
cudaError_t RMSNorm (T* input, T* weight, T* output, uint32_t batch_size, uint32_t d,
100
- float eps = 1e-5 , cudaStream_t stream = 0 ) {
102
+ uint32_t stride_input, uint32_t stride_output, float eps = 1e-5 ,
103
+ cudaStream_t stream = 0 ) {
101
104
const uint32_t vec_size = std::gcd (16 / sizeof (T), d);
102
105
103
106
const uint32_t block_size = std::min<uint32_t >(1024 , d / vec_size);
@@ -106,7 +109,7 @@ cudaError_t RMSNorm(T* input, T* weight, T* output, uint32_t batch_size, uint32_
106
109
dim3 nthrs (32 , num_warps);
107
110
const uint32_t smem_size = num_warps * sizeof (float );
108
111
float weight_bias = 0 .f ;
109
- void * args[] = {&input, &weight, &output, &d, &weight_bias, &eps};
112
+ void * args[] = {&input, &weight, &output, &d, &stride_input, &stride_output, & weight_bias, &eps};
110
113
111
114
DISPATCH_ALIGNED_VEC_SIZE (vec_size, VEC_SIZE, {
112
115
auto kernel = RMSNormKernel<VEC_SIZE, T>;
@@ -117,8 +120,9 @@ cudaError_t RMSNorm(T* input, T* weight, T* output, uint32_t batch_size, uint32_
117
120
118
121
template <uint32_t VEC_SIZE, typename T>
119
122
__global__ void FusedAddRMSNormKernel (T* __restrict__ input, T* __restrict__ residual,
120
- T* __restrict__ weight, const uint32_t d, float weight_bias,
121
- float eps) {
123
+ T* __restrict__ weight, const uint32_t d,
124
+ const uint32_t stride_input, const uint32_t stride_residual,
125
+ float weight_bias, float eps) {
122
126
const uint32_t bx = blockIdx .x ;
123
127
const uint32_t tx = threadIdx .x , ty = threadIdx .y ;
124
128
constexpr uint32_t warp_size = 32 ;
@@ -139,8 +143,9 @@ __global__ void FusedAddRMSNormKernel(T* __restrict__ input, T* __restrict__ res
139
143
vec_t <float , VEC_SIZE> x_vec;
140
144
x_vec.fill (0 .f );
141
145
if ((i * num_threads + thread_id) * VEC_SIZE < d) {
142
- input_vec.load (input + bx * d + i * num_threads * VEC_SIZE + thread_id * VEC_SIZE);
143
- residual_vec.load (residual + bx * d + i * num_threads * VEC_SIZE + thread_id * VEC_SIZE);
146
+ input_vec.load (input + bx * stride_input + i * num_threads * VEC_SIZE + thread_id * VEC_SIZE);
147
+ residual_vec.load (residual + bx * stride_residual + i * num_threads * VEC_SIZE +
148
+ thread_id * VEC_SIZE);
144
149
}
145
150
#pragma unroll
146
151
for (uint32_t j = 0 ; j < VEC_SIZE; j++) {
@@ -151,7 +156,8 @@ __global__ void FusedAddRMSNormKernel(T* __restrict__ input, T* __restrict__ res
151
156
x_vec[j] = x;
152
157
}
153
158
if ((i * num_threads + thread_id) * VEC_SIZE < d) {
154
- residual_vec.store (residual + bx * d + i * num_threads * VEC_SIZE + thread_id * VEC_SIZE);
159
+ residual_vec.store (residual + bx * stride_residual + i * num_threads * VEC_SIZE +
160
+ thread_id * VEC_SIZE);
155
161
x_vec.store (smem_x + i * num_threads * VEC_SIZE + thread_id * VEC_SIZE);
156
162
}
157
163
}
@@ -193,14 +199,16 @@ __global__ void FusedAddRMSNormKernel(T* __restrict__ input, T* __restrict__ res
193
199
input_vec[j] = x_vec[j] * rms_rcp * (weight_bias + float (weight_vec[j]));
194
200
}
195
201
if ((i * num_threads + thread_id) * VEC_SIZE < d) {
196
- input_vec.store (input + bx * d + i * num_threads * VEC_SIZE + thread_id * VEC_SIZE);
202
+ input_vec.store (input + bx * stride_input + i * num_threads * VEC_SIZE +
203
+ thread_id * VEC_SIZE);
197
204
}
198
205
}
199
206
}
200
207
201
208
template <typename T>
202
209
cudaError_t FusedAddRMSNorm (T* input, T* residual, T* weight, uint32_t batch_size, uint32_t d,
203
- float eps = 1e-5 , cudaStream_t stream = 0 ) {
210
+ uint32_t stride_input, uint32_t stride_residual, float eps = 1e-5 ,
211
+ cudaStream_t stream = 0 ) {
204
212
const uint32_t vec_size = std::gcd (16 / sizeof (T), d);
205
213
206
214
const uint32_t block_size = std::min<uint32_t >(1024 , d / vec_size);
@@ -209,11 +217,13 @@ cudaError_t FusedAddRMSNorm(T* input, T* residual, T* weight, uint32_t batch_siz
209
217
dim3 nthrs (32 , num_warps);
210
218
const uint32_t smem_size = (ceil_div (num_warps, 4 ) * 4 + d) * sizeof (float );
211
219
float weight_bias = 0 .f ;
212
- void * args[] = {&input, &residual, &weight, &d, &weight_bias, &eps};
220
+ void * args[] = {&input, &residual, &weight, &d,
221
+ &stride_input, &stride_residual, &weight_bias, &eps};
213
222
214
223
DISPATCH_ALIGNED_VEC_SIZE (vec_size, VEC_SIZE, {
215
224
auto kernel = FusedAddRMSNormKernel<VEC_SIZE, T>;
216
- FLASHINFER_CUDA_CALL (cudaFuncSetAttribute (kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, smem_size));
225
+ FLASHINFER_CUDA_CALL (
226
+ cudaFuncSetAttribute (kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, smem_size));
217
227
FLASHINFER_CUDA_CALL (cudaLaunchKernel ((void *)kernel, nblks, nthrs, args, smem_size, stream));
218
228
});
219
229
@@ -222,7 +232,8 @@ cudaError_t FusedAddRMSNorm(T* input, T* residual, T* weight, uint32_t batch_siz
222
232
223
233
template <typename T>
224
234
cudaError_t GemmaRMSNorm (T* input, T* weight, T* output, uint32_t batch_size, uint32_t d,
225
- float eps = 1e-5 , cudaStream_t stream = 0 ) {
235
+ uint32_t stride_input, uint32_t stride_output, float eps = 1e-5 ,
236
+ cudaStream_t stream = 0 ) {
226
237
const uint32_t vec_size = std::gcd (16 / sizeof (T), d);
227
238
228
239
const uint32_t block_size = std::min<uint32_t >(1024 , d / vec_size);
@@ -231,7 +242,7 @@ cudaError_t GemmaRMSNorm(T* input, T* weight, T* output, uint32_t batch_size, ui
231
242
dim3 nthrs (32 , num_warps);
232
243
const uint32_t smem_size = num_warps * sizeof (float );
233
244
float weight_bias = 1 .f ;
234
- void * args[] = {&input, &weight, &output, &d, &weight_bias, &eps};
245
+ void * args[] = {&input, &weight, &output, &d, &stride_input, &stride_output, & weight_bias, &eps};
235
246
236
247
DISPATCH_ALIGNED_VEC_SIZE (vec_size, VEC_SIZE, {
237
248
auto kernel = RMSNormKernel<VEC_SIZE, T>;
@@ -242,7 +253,8 @@ cudaError_t GemmaRMSNorm(T* input, T* weight, T* output, uint32_t batch_size, ui
242
253
243
254
template <typename T>
244
255
cudaError_t GemmaFusedAddRMSNorm (T* input, T* residual, T* weight, uint32_t batch_size, uint32_t d,
245
- float eps = 1e-5 , cudaStream_t stream = 0 ) {
256
+ uint32_t stride_input, uint32_t stride_residual, float eps = 1e-5 ,
257
+ cudaStream_t stream = 0 ) {
246
258
const uint32_t vec_size = std::gcd (16 / sizeof (T), d);
247
259
248
260
const uint32_t block_size = std::min<uint32_t >(1024 , d / vec_size);
@@ -252,11 +264,13 @@ cudaError_t GemmaFusedAddRMSNorm(T* input, T* residual, T* weight, uint32_t batc
252
264
// NOTE(Zihao): use ceil_div(num_warps, 4) * 4 for address alignment to 16 bytes
253
265
const uint32_t smem_size = (ceil_div (num_warps, 4 ) * 4 + d) * sizeof (float );
254
266
float weight_bias = 1 .f ;
255
- void * args[] = {&input, &residual, &weight, &d, &weight_bias, &eps};
267
+ void * args[] = {&input, &residual, &weight, &d,
268
+ &stride_input, &stride_residual, &weight_bias, &eps};
256
269
257
270
DISPATCH_ALIGNED_VEC_SIZE (vec_size, VEC_SIZE, {
258
271
auto kernel = FusedAddRMSNormKernel<VEC_SIZE, T>;
259
- FLASHINFER_CUDA_CALL (cudaFuncSetAttribute (kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, smem_size));
272
+ FLASHINFER_CUDA_CALL (
273
+ cudaFuncSetAttribute (kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, smem_size));
260
274
FLASHINFER_CUDA_CALL (cudaLaunchKernel ((void *)kernel, nblks, nthrs, args, smem_size, stream));
261
275
});
262
276
0 commit comments