Skip to content

Commit b781513

Browse files
authored
feat: support fused add rmsnorm (#419)
ref sgl-project/sglang#907 cc @yzh119
1 parent 1c9ffb3 commit b781513

File tree

7 files changed

+246
-79
lines changed

7 files changed

+246
-79
lines changed

include/flashinfer/norm.cuh

+115-17
Original file line numberDiff line numberDiff line change
@@ -28,7 +28,7 @@ namespace flashinfer {
2828
namespace norm {
2929

3030
template <uint32_t VEC_SIZE, typename T>
31-
__global__ void RMSNormKernel(T* __restrict__ x, T* __restrict__ w, T* __restrict__ y,
31+
__global__ void RMSNormKernel(T* __restrict__ input, T* __restrict__ weight, T* __restrict__ output,
3232
const uint32_t d, float eps) {
3333
const uint32_t bx = blockIdx.x;
3434
const uint32_t tx = threadIdx.x, ty = threadIdx.y;
@@ -43,14 +43,14 @@ __global__ void RMSNormKernel(T* __restrict__ x, T* __restrict__ w, T* __restric
4343
float sum_sq = 0.f;
4444

4545
for (uint32_t i = 0; i < rounds; i++) {
46-
vec_t<T, VEC_SIZE> x_vec;
47-
x_vec.fill(0);
46+
vec_t<T, VEC_SIZE> input_vec;
47+
input_vec.fill(0);
4848
if ((i * num_threads + thread_id) * VEC_SIZE < d) {
49-
x_vec.load(x + bx * d + i * num_threads * VEC_SIZE + thread_id * VEC_SIZE);
49+
input_vec.load(input + bx * d + i * num_threads * VEC_SIZE + thread_id * VEC_SIZE);
5050
}
5151
#pragma unroll
5252
for (uint32_t j = 0; j < VEC_SIZE; j++) {
53-
sum_sq += float(x_vec[j]) * float(x_vec[j]);
53+
sum_sq += float(input_vec[j]) * float(input_vec[j]);
5454
}
5555
}
5656

@@ -76,36 +76,36 @@ __global__ void RMSNormKernel(T* __restrict__ x, T* __restrict__ w, T* __restric
7676
float rms_rcp = math::rsqrt(smem[0] / float(d) + eps);
7777

7878
for (uint32_t i = 0; i < rounds; i++) {
79-
vec_t<T, VEC_SIZE> x_vec;
80-
vec_t<T, VEC_SIZE> w_vec;
81-
vec_t<T, VEC_SIZE> y_vec;
82-
x_vec.fill(0);
83-
w_vec.fill(0);
79+
vec_t<T, VEC_SIZE> input_vec;
80+
vec_t<T, VEC_SIZE> weight_vec;
81+
vec_t<T, VEC_SIZE> output_vec;
82+
input_vec.fill(0);
83+
weight_vec.fill(0);
8484
if ((i * num_threads + thread_id) * VEC_SIZE < d) {
85-
x_vec.load(x + bx * d + i * num_threads * VEC_SIZE + thread_id * VEC_SIZE);
86-
w_vec.load(w + i * num_threads * VEC_SIZE + thread_id * VEC_SIZE);
85+
input_vec.load(input + bx * d + i * num_threads * VEC_SIZE + thread_id * VEC_SIZE);
86+
weight_vec.load(weight + i * num_threads * VEC_SIZE + thread_id * VEC_SIZE);
8787
}
8888
#pragma unroll
8989
for (uint32_t j = 0; j < VEC_SIZE; j++) {
90-
y_vec[j] = float(x_vec[j]) * rms_rcp * float(w_vec[j]);
90+
output_vec[j] = float(input_vec[j]) * rms_rcp * float(weight_vec[j]);
9191
}
9292
if ((i * num_threads + thread_id) * VEC_SIZE < d) {
93-
y_vec.store(y + bx * d + i * num_threads * VEC_SIZE + thread_id * VEC_SIZE);
93+
output_vec.store(output + bx * d + i * num_threads * VEC_SIZE + thread_id * VEC_SIZE);
9494
}
9595
}
9696
}
9797

9898
template <typename T>
99-
cudaError_t RMSNorm(T* x, T* w, T* y, uint32_t batch_size, uint32_t d, float eps = 1e-5,
100-
cudaStream_t stream = 0) {
99+
cudaError_t RMSNorm(T* input, T* weight, T* output, uint32_t batch_size, uint32_t d,
100+
float eps = 1e-5, cudaStream_t stream = 0) {
101101
const uint32_t vec_size = std::gcd(16 / sizeof(T), d);
102102

103103
const uint32_t block_size = std::min<uint32_t>(1024, d / vec_size);
104104
const uint32_t num_warps = ceil_div(block_size, 32);
105105
dim3 nblks(batch_size);
106106
dim3 nthrs(32, num_warps);
107107
const uint32_t smem_size = num_warps * sizeof(float);
108-
void* args[] = {&x, &w, &y, &d, &eps};
108+
void* args[] = {&input, &weight, &output, &d, &eps};
109109

110110
DISPATCH_ALIGNED_VEC_SIZE(vec_size, VEC_SIZE, {
111111
auto kernel = RMSNormKernel<VEC_SIZE, T>;
@@ -114,6 +114,104 @@ cudaError_t RMSNorm(T* x, T* w, T* y, uint32_t batch_size, uint32_t d, float eps
114114
return cudaSuccess;
115115
}
116116

117+
template <uint32_t VEC_SIZE, typename T>
118+
__global__ void FusedAddRMSNormKernel(T* __restrict__ input, T* __restrict__ residual,
119+
T* __restrict__ weight, const uint32_t d, float eps) {
120+
const uint32_t bx = blockIdx.x;
121+
const uint32_t tx = threadIdx.x, ty = threadIdx.y;
122+
constexpr uint32_t warp_size = 32;
123+
const uint32_t num_warps = blockDim.y;
124+
const uint32_t thread_id = tx + ty * warp_size;
125+
const uint32_t num_threads = num_warps * warp_size;
126+
const uint32_t rounds = ceil_div(d, VEC_SIZE * num_threads);
127+
extern __shared__ float smem[];
128+
129+
float sum_sq = 0.f;
130+
131+
for (uint32_t i = 0; i < rounds; i++) {
132+
vec_t<T, VEC_SIZE> input_vec;
133+
input_vec.fill(0);
134+
vec_t<T, VEC_SIZE> residual_vec;
135+
residual_vec.fill(0);
136+
if ((i * num_threads + thread_id) * VEC_SIZE < d) {
137+
input_vec.load(input + bx * d + i * num_threads * VEC_SIZE + thread_id * VEC_SIZE);
138+
residual_vec.load(residual + bx * d + i * num_threads * VEC_SIZE + thread_id * VEC_SIZE);
139+
}
140+
#pragma unroll
141+
for (uint32_t j = 0; j < VEC_SIZE; j++) {
142+
float x = float(input_vec[j]);
143+
x += float(residual_vec[j]);
144+
sum_sq += x * x;
145+
residual_vec[j] = (T)x;
146+
}
147+
if ((i * num_threads + thread_id) * VEC_SIZE < d) {
148+
residual_vec.store(residual + bx * d + i * num_threads * VEC_SIZE + thread_id * VEC_SIZE);
149+
}
150+
}
151+
152+
// first, warp reduce sum
153+
#pragma unroll
154+
for (uint32_t offset = warp_size / 2; offset > 0; offset /= 2) {
155+
sum_sq += math::shfl_xor_sync(sum_sq, offset);
156+
}
157+
158+
smem[ty] = sum_sq;
159+
__syncthreads();
160+
// then, cross warp reduce sum using only the first warp
161+
if (ty == 0) {
162+
sum_sq = (tx < num_warps) ? smem[tx] : 0.f;
163+
#pragma unroll
164+
for (uint32_t offset = warp_size / 2; offset > 0; offset /= 2) {
165+
sum_sq += math::shfl_xor_sync(sum_sq, offset);
166+
}
167+
smem[0] = sum_sq;
168+
}
169+
__syncthreads();
170+
171+
float rms_rcp = math::rsqrt(smem[0] / float(d) + eps);
172+
173+
for (uint32_t i = 0; i < rounds; i++) {
174+
vec_t<T, VEC_SIZE> input_vec;
175+
vec_t<T, VEC_SIZE> weight_vec;
176+
vec_t<T, VEC_SIZE> residual_vec;
177+
input_vec.fill(0);
178+
weight_vec.fill(0);
179+
residual_vec.fill(0);
180+
if ((i * num_threads + thread_id) * VEC_SIZE < d) {
181+
input_vec.load(input + bx * d + i * num_threads * VEC_SIZE + thread_id * VEC_SIZE);
182+
weight_vec.load(weight + i * num_threads * VEC_SIZE + thread_id * VEC_SIZE);
183+
residual_vec.load(residual + bx * d + i * num_threads * VEC_SIZE + thread_id * VEC_SIZE);
184+
}
185+
#pragma unroll
186+
for (uint32_t j = 0; j < VEC_SIZE; j++) {
187+
input_vec[j] = float(residual_vec[j]) * rms_rcp * float(weight_vec[j]);
188+
}
189+
if ((i * num_threads + thread_id) * VEC_SIZE < d) {
190+
input_vec.store(input + bx * d + i * num_threads * VEC_SIZE + thread_id * VEC_SIZE);
191+
}
192+
}
193+
}
194+
195+
template <typename T>
196+
cudaError_t FusedAddRMSNorm(T* input, T* residual, T* weight, uint32_t batch_size, uint32_t d,
197+
float eps = 1e-5, cudaStream_t stream = 0) {
198+
const uint32_t vec_size = std::gcd(16 / sizeof(T), d);
199+
200+
const uint32_t block_size = std::min<uint32_t>(1024, d / vec_size);
201+
const uint32_t num_warps = ceil_div(block_size, 32);
202+
dim3 nblks(batch_size);
203+
dim3 nthrs(32, num_warps);
204+
const uint32_t smem_size = num_warps * sizeof(float);
205+
void* args[] = {&input, &residual, &weight, &d, &eps};
206+
207+
DISPATCH_ALIGNED_VEC_SIZE(vec_size, VEC_SIZE, {
208+
auto kernel = FusedAddRMSNormKernel<VEC_SIZE, T>;
209+
FLASHINFER_CUDA_CALL(cudaLaunchKernel((void*)kernel, nblks, nthrs, args, smem_size, stream));
210+
});
211+
212+
return cudaSuccess;
213+
}
214+
117215
} // namespace norm
118216

119217
} // namespace flashinfer

python/csrc/flashinfer_ops.cu

+1
Original file line numberDiff line numberDiff line change
@@ -42,6 +42,7 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
4242
m.def("chain_speculative_sampling", &chain_speculative_sampling,
4343
"Speculative sampling from sequence of probabilities");
4444
m.def("rmsnorm", &rmsnorm, "Root mean square normalization");
45+
m.def("fused_add_rmsnorm", &fused_add_rmsnorm, "Fused add root mean square normalization");
4546
m.def("apply_rope_inplace", &apply_rope_inplace, "Apply RoPE in-place");
4647
m.def("apply_llama31_rope_inplace", &apply_llama31_rope_inplace,
4748
"Apply Llama 3.1 style RoPE in-place");

python/csrc/flashinfer_ops.h

+4-1
Original file line numberDiff line numberDiff line change
@@ -78,7 +78,10 @@ torch::Tensor chain_speculative_sampling(torch::Tensor draft_probs, torch::Tenso
7878
torch::Tensor uniform_samples, torch::Tensor target_probs,
7979
bool deterministic);
8080

81-
torch::Tensor rmsnorm(torch::Tensor x, torch::Tensor w, double eps);
81+
torch::Tensor rmsnorm(torch::Tensor input, torch::Tensor weight, double eps);
82+
83+
void fused_add_rmsnorm(torch::Tensor input, torch::Tensor residual, torch::Tensor weight,
84+
double eps);
8285

8386
void apply_rope_inplace(torch::Tensor q, torch::Tensor k, torch::Tensor indptr,
8487
torch::Tensor offsets, bool interleave, float rope_scale, float rope_theta);

python/csrc/norm.cu

+46-16
Original file line numberDiff line numberDiff line change
@@ -20,26 +20,56 @@
2020

2121
using namespace flashinfer;
2222

23-
torch::Tensor rmsnorm(torch::Tensor x, torch::Tensor w, double eps) {
24-
CHECK_INPUT(x);
25-
CHECK_INPUT(w);
26-
auto device = x.device();
27-
CHECK_EQ(w.device(), device);
28-
CHECK_DIM(2, x); // x: (batch_size, hidden_size)
29-
CHECK_DIM(1, w); // w: (hidden_size)
30-
CHECK_EQ(x.size(1), w.size(0));
31-
unsigned int batch_size = x.size(0);
32-
unsigned int hidden_size = x.size(1);
23+
torch::Tensor rmsnorm(torch::Tensor input, torch::Tensor weight, double eps) {
24+
CHECK_INPUT(input);
25+
CHECK_INPUT(weight);
26+
auto device = input.device();
27+
CHECK_EQ(weight.device(), device);
28+
CHECK_DIM(2, input); // input: (batch_size, hidden_size)
29+
CHECK_DIM(1, weight); // weight: (hidden_size)
30+
CHECK_EQ(input.size(1), weight.size(0));
31+
unsigned int batch_size = input.size(0);
32+
unsigned int hidden_size = input.size(1);
3333

3434
cudaStream_t torch_current_stream = c10::cuda::getCurrentCUDAStream(device.index());
35-
auto y = torch::empty_like(x);
36-
DISPATCH_PYTORCH_DTYPE_TO_CTYPE_FP16(x.scalar_type(), c_type, [&] {
37-
cudaError_t status = norm::RMSNorm(
38-
static_cast<c_type*>(x.data_ptr()), static_cast<c_type*>(w.data_ptr()),
39-
static_cast<c_type*>(y.data_ptr()), batch_size, hidden_size, eps, torch_current_stream);
35+
auto output = torch::empty_like(input);
36+
DISPATCH_PYTORCH_DTYPE_TO_CTYPE_FP16(input.scalar_type(), c_type, [&] {
37+
cudaError_t status = norm::RMSNorm(static_cast<c_type*>(input.data_ptr()),
38+
static_cast<c_type*>(weight.data_ptr()),
39+
static_cast<c_type*>(output.data_ptr()), batch_size,
40+
hidden_size, eps, torch_current_stream);
4041
TORCH_CHECK(status == cudaSuccess,
4142
"RMSNorm failed with error code " + std::string(cudaGetErrorString(status)));
4243
return true;
4344
});
44-
return y;
45+
return output;
46+
}
47+
48+
void fused_add_rmsnorm(torch::Tensor input, torch::Tensor residual, torch::Tensor weight,
49+
double eps) {
50+
CHECK_INPUT(input);
51+
CHECK_INPUT(residual);
52+
CHECK_INPUT(weight);
53+
auto device = input.device();
54+
CHECK_EQ(residual.device(), device);
55+
CHECK_EQ(weight.device(), device);
56+
CHECK_DIM(2, input); // input: (batch_size, hidden_size)
57+
CHECK_DIM(2, residual); // residual: (batch_size, hidden_size)
58+
CHECK_DIM(1, weight); // weight: (hidden_size)
59+
CHECK_EQ(input.size(0), residual.size(0));
60+
CHECK_EQ(input.size(1), residual.size(1));
61+
CHECK_EQ(input.size(1), weight.size(0));
62+
unsigned int batch_size = input.size(0);
63+
unsigned int hidden_size = input.size(1);
64+
65+
cudaStream_t torch_current_stream = c10::cuda::getCurrentCUDAStream(device.index());
66+
DISPATCH_PYTORCH_DTYPE_TO_CTYPE_FP16(input.scalar_type(), c_type, [&] {
67+
cudaError_t status = norm::FusedAddRMSNorm(static_cast<c_type*>(input.data_ptr()),
68+
static_cast<c_type*>(residual.data_ptr()),
69+
static_cast<c_type*>(weight.data_ptr()), batch_size,
70+
hidden_size, eps, torch_current_stream);
71+
TORCH_CHECK(status == cudaSuccess, "FusedAddRMSNorm failed with error code " +
72+
std::string(cudaGetErrorString(status)));
73+
return true;
74+
});
4575
}

python/flashinfer/__init__.py

+19-36
Original file line numberDiff line numberDiff line change
@@ -14,44 +14,27 @@
1414
limitations under the License.
1515
"""
1616

17-
from .decode import (
18-
single_decode_with_kv_cache,
19-
BatchDecodeWithPagedKVCacheWrapper,
20-
CUDAGraphBatchDecodeWithPagedKVCacheWrapper,
21-
)
22-
from .prefill import (
23-
single_prefill_with_kv_cache,
24-
single_prefill_with_kv_cache_return_lse,
25-
BatchPrefillWithRaggedKVCacheWrapper,
26-
BatchPrefillWithPagedKVCacheWrapper,
27-
)
28-
from .sparse import BlockSparseAttentionWrapper
29-
from .cascade import (
30-
merge_state,
31-
merge_state_in_place,
32-
merge_states,
33-
BatchDecodeWithSharedPrefixPagedKVCacheWrapper,
34-
BatchPrefillWithSharedPrefixPagedKVCacheWrapper,
35-
)
36-
from .page import append_paged_kv_cache
37-
from .sampling import (
38-
sampling_from_probs,
39-
top_p_sampling_from_probs,
40-
top_k_sampling_from_probs,
41-
top_k_top_p_sampling_from_probs,
42-
top_p_renorm_prob,
43-
top_k_renorm_prob,
44-
chain_speculative_sampling,
45-
)
46-
from .norm import rmsnorm
47-
from .rope import (
48-
apply_rope_inplace,
49-
apply_llama31_rope_inplace,
50-
apply_rope,
51-
apply_llama31_rope,
52-
)
17+
from .cascade import (BatchDecodeWithSharedPrefixPagedKVCacheWrapper,
18+
BatchPrefillWithSharedPrefixPagedKVCacheWrapper,
19+
merge_state, merge_state_in_place, merge_states)
20+
from .decode import (BatchDecodeWithPagedKVCacheWrapper,
21+
CUDAGraphBatchDecodeWithPagedKVCacheWrapper,
22+
single_decode_with_kv_cache)
5323
from .group_gemm import SegmentGEMMWrapper
24+
from .norm import fused_add_rmsnorm, rmsnorm
25+
from .page import append_paged_kv_cache
26+
from .prefill import (BatchPrefillWithPagedKVCacheWrapper,
27+
BatchPrefillWithRaggedKVCacheWrapper,
28+
single_prefill_with_kv_cache,
29+
single_prefill_with_kv_cache_return_lse)
5430
from .quantization import packbits, segment_packbits
31+
from .rope import (apply_llama31_rope, apply_llama31_rope_inplace, apply_rope,
32+
apply_rope_inplace)
33+
from .sampling import (chain_speculative_sampling, sampling_from_probs,
34+
top_k_renorm_prob, top_k_sampling_from_probs,
35+
top_k_top_p_sampling_from_probs, top_p_renorm_prob,
36+
top_p_sampling_from_probs)
37+
from .sparse import BlockSparseAttentionWrapper
5538

5639
try:
5740
from ._build_meta import __version__

python/flashinfer/norm.py

+27-6
Original file line numberDiff line numberDiff line change
@@ -20,8 +20,8 @@
2020
try:
2121
from . import _kernels
2222
except ImportError as e:
23-
import os
2423
import logging
24+
import os
2525

2626
if os.environ.get("BUILD_DOC", "0") == "1":
2727
_kernels = None
@@ -30,21 +30,42 @@
3030
raise e
3131

3232

33-
def rmsnorm(x: torch.Tensor, w: torch.Tensor, eps: float = 1e-6) -> torch.Tensor:
33+
def rmsnorm(
34+
input: torch.Tensor, weight: torch.Tensor, eps: float = 1e-6
35+
) -> torch.Tensor:
3436
r"""Root mean square normalization.
3537
3638
Parameters
3739
----------
38-
x: torch.Tensor
40+
input: torch.Tensor
3941
Input tensor, shape (batch_size, hidden_size).
40-
w: torch.Tensor
42+
weight: torch.Tensor
4143
Weight tensor, shape (hidden_size,).
4244
eps: float
4345
Epsilon for numerical stability.
4446
4547
Returns
4648
-------
47-
y: torch.Tensor
49+
output: torch.Tensor
4850
Normalized tensor, shape (batch_size, hidden_size).
4951
"""
50-
return _kernels.rmsnorm(x, w, eps)
52+
return _kernels.rmsnorm(input, weight, eps)
53+
54+
55+
def fused_add_rmsnorm(
56+
input: torch.Tensor, residual: torch.Tensor, weight: torch.Tensor, eps: float = 1e-6
57+
):
58+
r"""Fused add root mean square normalization.
59+
60+
Parameters
61+
----------
62+
input: torch.Tensor
63+
Input tensor, shape (batch_size, hidden_size).
64+
residual: torch.Tensor
65+
Residual tensor, shape (batch_size, hidden_size).
66+
weight: torch.Tensor
67+
Weight tensor, shape (hidden_size,).
68+
eps: float
69+
Epsilon for numerical stability.
70+
"""
71+
_kernels.fused_add_rmsnorm(input, residual, weight, eps)

0 commit comments

Comments
 (0)