@@ -28,7 +28,7 @@ namespace flashinfer {
28
28
namespace norm {
29
29
30
30
template <uint32_t VEC_SIZE, typename T>
31
- __global__ void RMSNormKernel (T* __restrict__ x , T* __restrict__ w , T* __restrict__ y ,
31
+ __global__ void RMSNormKernel (T* __restrict__ input , T* __restrict__ weight , T* __restrict__ output ,
32
32
const uint32_t d, float eps) {
33
33
const uint32_t bx = blockIdx .x ;
34
34
const uint32_t tx = threadIdx .x , ty = threadIdx .y ;
@@ -43,14 +43,14 @@ __global__ void RMSNormKernel(T* __restrict__ x, T* __restrict__ w, T* __restric
43
43
float sum_sq = 0 .f ;
44
44
45
45
for (uint32_t i = 0 ; i < rounds; i++) {
46
- vec_t <T, VEC_SIZE> x_vec ;
47
- x_vec .fill (0 );
46
+ vec_t <T, VEC_SIZE> input_vec ;
47
+ input_vec .fill (0 );
48
48
if ((i * num_threads + thread_id) * VEC_SIZE < d) {
49
- x_vec .load (x + bx * d + i * num_threads * VEC_SIZE + thread_id * VEC_SIZE);
49
+ input_vec .load (input + bx * d + i * num_threads * VEC_SIZE + thread_id * VEC_SIZE);
50
50
}
51
51
#pragma unroll
52
52
for (uint32_t j = 0 ; j < VEC_SIZE; j++) {
53
- sum_sq += float (x_vec [j]) * float (x_vec [j]);
53
+ sum_sq += float (input_vec [j]) * float (input_vec [j]);
54
54
}
55
55
}
56
56
@@ -76,36 +76,36 @@ __global__ void RMSNormKernel(T* __restrict__ x, T* __restrict__ w, T* __restric
76
76
float rms_rcp = math::rsqrt (smem[0 ] / float (d) + eps);
77
77
78
78
for (uint32_t i = 0 ; i < rounds; i++) {
79
- vec_t <T, VEC_SIZE> x_vec ;
80
- vec_t <T, VEC_SIZE> w_vec ;
81
- vec_t <T, VEC_SIZE> y_vec ;
82
- x_vec .fill (0 );
83
- w_vec .fill (0 );
79
+ vec_t <T, VEC_SIZE> input_vec ;
80
+ vec_t <T, VEC_SIZE> weight_vec ;
81
+ vec_t <T, VEC_SIZE> output_vec ;
82
+ input_vec .fill (0 );
83
+ weight_vec .fill (0 );
84
84
if ((i * num_threads + thread_id) * VEC_SIZE < d) {
85
- x_vec .load (x + bx * d + i * num_threads * VEC_SIZE + thread_id * VEC_SIZE);
86
- w_vec .load (w + i * num_threads * VEC_SIZE + thread_id * VEC_SIZE);
85
+ input_vec .load (input + bx * d + i * num_threads * VEC_SIZE + thread_id * VEC_SIZE);
86
+ weight_vec .load (weight + i * num_threads * VEC_SIZE + thread_id * VEC_SIZE);
87
87
}
88
88
#pragma unroll
89
89
for (uint32_t j = 0 ; j < VEC_SIZE; j++) {
90
- y_vec [j] = float (x_vec [j]) * rms_rcp * float (w_vec [j]);
90
+ output_vec [j] = float (input_vec [j]) * rms_rcp * float (weight_vec [j]);
91
91
}
92
92
if ((i * num_threads + thread_id) * VEC_SIZE < d) {
93
- y_vec .store (y + bx * d + i * num_threads * VEC_SIZE + thread_id * VEC_SIZE);
93
+ output_vec .store (output + bx * d + i * num_threads * VEC_SIZE + thread_id * VEC_SIZE);
94
94
}
95
95
}
96
96
}
97
97
98
98
template <typename T>
99
- cudaError_t RMSNorm (T* x , T* w , T* y , uint32_t batch_size, uint32_t d, float eps = 1e-5 ,
100
- cudaStream_t stream = 0 ) {
99
+ cudaError_t RMSNorm (T* input , T* weight , T* output , uint32_t batch_size, uint32_t d,
100
+ float eps = 1e-5 , cudaStream_t stream = 0 ) {
101
101
const uint32_t vec_size = std::gcd (16 / sizeof (T), d);
102
102
103
103
const uint32_t block_size = std::min<uint32_t >(1024 , d / vec_size);
104
104
const uint32_t num_warps = ceil_div (block_size, 32 );
105
105
dim3 nblks (batch_size);
106
106
dim3 nthrs (32 , num_warps);
107
107
const uint32_t smem_size = num_warps * sizeof (float );
108
- void * args[] = {&x , &w , &y , &d, &eps};
108
+ void * args[] = {&input , &weight , &output , &d, &eps};
109
109
110
110
DISPATCH_ALIGNED_VEC_SIZE (vec_size, VEC_SIZE, {
111
111
auto kernel = RMSNormKernel<VEC_SIZE, T>;
@@ -114,6 +114,104 @@ cudaError_t RMSNorm(T* x, T* w, T* y, uint32_t batch_size, uint32_t d, float eps
114
114
return cudaSuccess;
115
115
}
116
116
117
+ template <uint32_t VEC_SIZE, typename T>
118
+ __global__ void FusedAddRMSNormKernel (T* __restrict__ input, T* __restrict__ residual,
119
+ T* __restrict__ weight, const uint32_t d, float eps) {
120
+ const uint32_t bx = blockIdx .x ;
121
+ const uint32_t tx = threadIdx .x , ty = threadIdx .y ;
122
+ constexpr uint32_t warp_size = 32 ;
123
+ const uint32_t num_warps = blockDim .y ;
124
+ const uint32_t thread_id = tx + ty * warp_size;
125
+ const uint32_t num_threads = num_warps * warp_size;
126
+ const uint32_t rounds = ceil_div (d, VEC_SIZE * num_threads);
127
+ extern __shared__ float smem[];
128
+
129
+ float sum_sq = 0 .f ;
130
+
131
+ for (uint32_t i = 0 ; i < rounds; i++) {
132
+ vec_t <T, VEC_SIZE> input_vec;
133
+ input_vec.fill (0 );
134
+ vec_t <T, VEC_SIZE> residual_vec;
135
+ residual_vec.fill (0 );
136
+ if ((i * num_threads + thread_id) * VEC_SIZE < d) {
137
+ input_vec.load (input + bx * d + i * num_threads * VEC_SIZE + thread_id * VEC_SIZE);
138
+ residual_vec.load (residual + bx * d + i * num_threads * VEC_SIZE + thread_id * VEC_SIZE);
139
+ }
140
+ #pragma unroll
141
+ for (uint32_t j = 0 ; j < VEC_SIZE; j++) {
142
+ float x = float (input_vec[j]);
143
+ x += float (residual_vec[j]);
144
+ sum_sq += x * x;
145
+ residual_vec[j] = (T)x;
146
+ }
147
+ if ((i * num_threads + thread_id) * VEC_SIZE < d) {
148
+ residual_vec.store (residual + bx * d + i * num_threads * VEC_SIZE + thread_id * VEC_SIZE);
149
+ }
150
+ }
151
+
152
+ // first, warp reduce sum
153
+ #pragma unroll
154
+ for (uint32_t offset = warp_size / 2 ; offset > 0 ; offset /= 2 ) {
155
+ sum_sq += math::shfl_xor_sync (sum_sq, offset);
156
+ }
157
+
158
+ smem[ty] = sum_sq;
159
+ __syncthreads ();
160
+ // then, cross warp reduce sum using only the first warp
161
+ if (ty == 0 ) {
162
+ sum_sq = (tx < num_warps) ? smem[tx] : 0 .f ;
163
+ #pragma unroll
164
+ for (uint32_t offset = warp_size / 2 ; offset > 0 ; offset /= 2 ) {
165
+ sum_sq += math::shfl_xor_sync (sum_sq, offset);
166
+ }
167
+ smem[0 ] = sum_sq;
168
+ }
169
+ __syncthreads ();
170
+
171
+ float rms_rcp = math::rsqrt (smem[0 ] / float (d) + eps);
172
+
173
+ for (uint32_t i = 0 ; i < rounds; i++) {
174
+ vec_t <T, VEC_SIZE> input_vec;
175
+ vec_t <T, VEC_SIZE> weight_vec;
176
+ vec_t <T, VEC_SIZE> residual_vec;
177
+ input_vec.fill (0 );
178
+ weight_vec.fill (0 );
179
+ residual_vec.fill (0 );
180
+ if ((i * num_threads + thread_id) * VEC_SIZE < d) {
181
+ input_vec.load (input + bx * d + i * num_threads * VEC_SIZE + thread_id * VEC_SIZE);
182
+ weight_vec.load (weight + i * num_threads * VEC_SIZE + thread_id * VEC_SIZE);
183
+ residual_vec.load (residual + bx * d + i * num_threads * VEC_SIZE + thread_id * VEC_SIZE);
184
+ }
185
+ #pragma unroll
186
+ for (uint32_t j = 0 ; j < VEC_SIZE; j++) {
187
+ input_vec[j] = float (residual_vec[j]) * rms_rcp * float (weight_vec[j]);
188
+ }
189
+ if ((i * num_threads + thread_id) * VEC_SIZE < d) {
190
+ input_vec.store (input + bx * d + i * num_threads * VEC_SIZE + thread_id * VEC_SIZE);
191
+ }
192
+ }
193
+ }
194
+
195
+ template <typename T>
196
+ cudaError_t FusedAddRMSNorm (T* input, T* residual, T* weight, uint32_t batch_size, uint32_t d,
197
+ float eps = 1e-5 , cudaStream_t stream = 0 ) {
198
+ const uint32_t vec_size = std::gcd (16 / sizeof (T), d);
199
+
200
+ const uint32_t block_size = std::min<uint32_t >(1024 , d / vec_size);
201
+ const uint32_t num_warps = ceil_div (block_size, 32 );
202
+ dim3 nblks (batch_size);
203
+ dim3 nthrs (32 , num_warps);
204
+ const uint32_t smem_size = num_warps * sizeof (float );
205
+ void * args[] = {&input, &residual, &weight, &d, &eps};
206
+
207
+ DISPATCH_ALIGNED_VEC_SIZE (vec_size, VEC_SIZE, {
208
+ auto kernel = FusedAddRMSNormKernel<VEC_SIZE, T>;
209
+ FLASHINFER_CUDA_CALL (cudaLaunchKernel ((void *)kernel, nblks, nthrs, args, smem_size, stream));
210
+ });
211
+
212
+ return cudaSuccess;
213
+ }
214
+
117
215
} // namespace norm
118
216
119
217
} // namespace flashinfer
0 commit comments