Skip to content

Commit 893172c

Browse files
authored
triton: Triton rms_norm kernels (#983)
1 parent 77ccda8 commit 893172c

File tree

2 files changed

+168
-0
lines changed

2 files changed

+168
-0
lines changed

flashinfer/triton/kernels/norm.py

+77
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,77 @@
1+
import triton # type: ignore[import]
2+
import triton.language as tl # type: ignore[import]
3+
4+
from flashinfer.triton.kernels.quant import scale_and_clamp
5+
6+
7+
@triton.jit
8+
def rms_norm_kernel(
9+
n,
10+
b,
11+
x_ptr,
12+
x_stride,
13+
x_scale_ptr,
14+
r_ptr,
15+
r_stride,
16+
w_ptr,
17+
o_ptr,
18+
o_stride,
19+
o_scale_ptr,
20+
EPS: tl.constexpr,
21+
BLOCK_SIZE: tl.constexpr,
22+
HAS_IN_SCALE: tl.constexpr,
23+
HAS_OUT_SCALE: tl.constexpr,
24+
HAS_OUTPUT: tl.constexpr,
25+
HAS_RESIDUAL: tl.constexpr,
26+
) -> None:
27+
i = tl.program_id(axis=0).to(tl.int64)
28+
29+
# If r_ptr is present, the input to norm is x + r.
30+
x_row = x_ptr + i * x_stride
31+
o_row = o_ptr + i * o_stride if HAS_OUTPUT else x_row
32+
r_row = r_ptr + i * r_stride if HAS_RESIDUAL else None
33+
34+
x_scale = tl.load(x_scale_ptr) if HAS_IN_SCALE else None
35+
o_scale = tl.load(o_scale_ptr) if HAS_OUT_SCALE else None
36+
37+
# Find the root mean square for the given row.
38+
square_sum = tl.zeros([BLOCK_SIZE], dtype=tl.float32)
39+
for off in range(0, n, BLOCK_SIZE):
40+
offsets = off + tl.arange(0, BLOCK_SIZE)
41+
mask = offsets < n
42+
43+
x = tl.load(x_row + offsets, mask=mask, other=0.0).to(tl.float32)
44+
if HAS_IN_SCALE:
45+
x *= x_scale
46+
47+
if HAS_RESIDUAL:
48+
r = tl.load(r_row + offsets, mask=mask, other=0.0).to(tl.float32)
49+
x += r
50+
tl.store(r_row + offsets, x, mask=mask)
51+
52+
square_sum += x * x
53+
54+
# Compute the norm.
55+
rms = tl.rsqrt(tl.sum(square_sum) / n + EPS)
56+
57+
# x[i] = r[i] + x[i] / rms * weight[i]
58+
output_dtype = o_row.dtype.element_ty
59+
for off in range(0, n, BLOCK_SIZE):
60+
offsets = off + tl.arange(0, BLOCK_SIZE)
61+
mask = offsets < n
62+
63+
if HAS_RESIDUAL:
64+
x = tl.load(r_row + offsets, mask=mask).to(tl.float32)
65+
else:
66+
x = tl.load(x_row + offsets, mask=mask).to(tl.float32)
67+
if HAS_IN_SCALE:
68+
x *= x_scale
69+
70+
w = tl.load(w_ptr + offsets, mask=mask).to(tl.float32)
71+
72+
# Multiply x with RMS on float32, but cast to the narrower type before
73+
# multiplying with the weights to replicate the HF behaviour precisely.
74+
result = w * (x * rms)
75+
if HAS_OUT_SCALE:
76+
result = scale_and_clamp(result, o_scale, output_dtype)
77+
tl.store(o_row + offsets, result, mask=mask)

flashinfer/triton/norm.py

+91
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,91 @@
1+
from collections.abc import Mapping
2+
from typing import Optional
3+
4+
import torch
5+
import triton # type: ignore[import]
6+
7+
from flashinfer.triton.kernels.norm import rms_norm_kernel
8+
9+
10+
def rms_norm(
11+
x: torch.Tensor,
12+
weight: torch.Tensor,
13+
out: torch.Tensor,
14+
eps: float,
15+
in_scale: Optional[torch.Tensor] = None,
16+
out_scale: Optional[torch.Tensor] = None,
17+
) -> None:
18+
"""RMS norm.
19+
20+
Computes `out[i,j] = x[i,j] * weight[j] / sqrt(eps + sum(x[i]^2) / n)`.
21+
"""
22+
23+
b, n = x.shape
24+
25+
block_size = triton.next_power_of_2(n)
26+
num_warps = max(8, min(32, block_size // 256))
27+
28+
rms_norm_kernel[(b,)](
29+
n=n,
30+
b=b,
31+
x_ptr=x,
32+
x_stride=x.stride(0),
33+
x_scale_ptr=in_scale,
34+
r_ptr=None,
35+
r_stride=0,
36+
w_ptr=weight,
37+
o_ptr=out,
38+
o_stride=out.stride(0),
39+
o_scale_ptr=out_scale,
40+
EPS=eps,
41+
BLOCK_SIZE=block_size,
42+
HAS_IN_SCALE=in_scale is not None,
43+
HAS_OUT_SCALE=out_scale is not None,
44+
HAS_OUTPUT=True,
45+
HAS_RESIDUAL=False,
46+
num_warps=num_warps,
47+
)
48+
49+
50+
def rms_norm_add_residual(
51+
x: torch.Tensor,
52+
residual: torch.Tensor,
53+
weight: torch.Tensor,
54+
eps: float,
55+
x_out: Optional[torch.Tensor] = None,
56+
x_in_scale: Optional[torch.Tensor] = None,
57+
x_out_scale: Optional[torch.Tensor] = None,
58+
) -> None:
59+
"""In-place RMS norm with fused residual addition.
60+
61+
Computes `r = r + x`, followed by `x = rmsnorm(r)`.
62+
"""
63+
64+
b, n = x.shape
65+
66+
assert x.shape == residual.shape
67+
assert x.stride(0) == residual.stride(0)
68+
69+
block_size = triton.next_power_of_2(n)
70+
num_warps = min(32, triton.cdiv(block_size, 32))
71+
72+
rms_norm_kernel[(b,)](
73+
n=n,
74+
b=b,
75+
x_ptr=x,
76+
x_stride=x.stride(0),
77+
x_scale_ptr=x_in_scale,
78+
r_ptr=residual,
79+
r_stride=residual.stride(0),
80+
w_ptr=weight,
81+
o_ptr=x_out,
82+
o_stride=x_out.stride(0) if x_out is not None else 0,
83+
o_scale_ptr=x_out_scale,
84+
EPS=eps,
85+
BLOCK_SIZE=block_size,
86+
HAS_IN_SCALE=x_in_scale is not None,
87+
HAS_OUT_SCALE=x_out_scale is not None,
88+
HAS_OUTPUT=x_out is not None,
89+
HAS_RESIDUAL=True,
90+
num_warps=num_warps,
91+
)

0 commit comments

Comments
 (0)