Skip to content

Commit 0daed1a

Browse files
authored
feat: support non-contiguous input/output in normalization functions (#921)
We should support normalization functions where input and output are not contiguous.
1 parent d4dc3f9 commit 0daed1a

File tree

3 files changed

+90
-44
lines changed

3 files changed

+90
-44
lines changed

csrc/norm.cu

+20-16
Original file line numberDiff line numberDiff line change
@@ -22,8 +22,8 @@ using namespace flashinfer;
2222

2323
void rmsnorm(at::Tensor& output, at::Tensor& input, at::Tensor& weight, double eps,
2424
int64_t cuda_stream) {
25-
CHECK_INPUT(input);
26-
CHECK_INPUT(weight);
25+
CHECK_LAST_DIM_CONTIGUOUS_INPUT(input);
26+
CHECK_LAST_DIM_CONTIGUOUS_INPUT(weight);
2727
auto device = input.device();
2828
CHECK_EQ(weight.device(), device);
2929
CHECK_DIM(2, input); // input: (batch_size, hidden_size)
@@ -36,9 +36,10 @@ void rmsnorm(at::Tensor& output, at::Tensor& input, at::Tensor& weight, double e
3636

3737
cudaStream_t stream = reinterpret_cast<cudaStream_t>(cuda_stream);
3838
DISPATCH_PYTORCH_DTYPE_TO_CTYPE_FP16(input.scalar_type(), c_type, [&] {
39-
cudaError_t status = norm::RMSNorm(
40-
static_cast<c_type*>(input.data_ptr()), static_cast<c_type*>(weight.data_ptr()),
41-
static_cast<c_type*>(output.data_ptr()), batch_size, hidden_size, eps, stream);
39+
cudaError_t status = norm::RMSNorm(static_cast<c_type*>(input.data_ptr()),
40+
static_cast<c_type*>(weight.data_ptr()),
41+
static_cast<c_type*>(output.data_ptr()), batch_size,
42+
hidden_size, input.stride(0), output.stride(0), eps, stream);
4243
TORCH_CHECK(status == cudaSuccess,
4344
"RMSNorm failed with error code " + std::string(cudaGetErrorString(status)));
4445
return true;
@@ -47,9 +48,9 @@ void rmsnorm(at::Tensor& output, at::Tensor& input, at::Tensor& weight, double e
4748

4849
void fused_add_rmsnorm(at::Tensor& input, at::Tensor& residual, at::Tensor& weight, double eps,
4950
int64_t cuda_stream) {
50-
CHECK_INPUT(input);
51-
CHECK_INPUT(residual);
52-
CHECK_INPUT(weight);
51+
CHECK_LAST_DIM_CONTIGUOUS_INPUT(input);
52+
CHECK_LAST_DIM_CONTIGUOUS_INPUT(residual);
53+
CHECK_LAST_DIM_CONTIGUOUS_INPUT(weight);
5354
auto device = input.device();
5455
CHECK_EQ(residual.device(), device);
5556
CHECK_EQ(weight.device(), device);
@@ -66,7 +67,8 @@ void fused_add_rmsnorm(at::Tensor& input, at::Tensor& residual, at::Tensor& weig
6667
DISPATCH_PYTORCH_DTYPE_TO_CTYPE_FP16(input.scalar_type(), c_type, [&] {
6768
cudaError_t status = norm::FusedAddRMSNorm(
6869
static_cast<c_type*>(input.data_ptr()), static_cast<c_type*>(residual.data_ptr()),
69-
static_cast<c_type*>(weight.data_ptr()), batch_size, hidden_size, eps, stream);
70+
static_cast<c_type*>(weight.data_ptr()), batch_size, hidden_size, input.stride(0),
71+
residual.stride(0), eps, stream);
7072
TORCH_CHECK(status == cudaSuccess, "FusedAddRMSNorm failed with error code " +
7173
std::string(cudaGetErrorString(status)));
7274
return true;
@@ -75,8 +77,8 @@ void fused_add_rmsnorm(at::Tensor& input, at::Tensor& residual, at::Tensor& weig
7577

7678
void gemma_rmsnorm(at::Tensor& output, at::Tensor& input, at::Tensor& weight, double eps,
7779
int64_t cuda_stream) {
78-
CHECK_INPUT(input);
79-
CHECK_INPUT(weight);
80+
CHECK_LAST_DIM_CONTIGUOUS_INPUT(input);
81+
CHECK_LAST_DIM_CONTIGUOUS_INPUT(weight);
8082
auto device = input.device();
8183
CHECK_EQ(weight.device(), device);
8284
CHECK_DIM(2, input); // input: (batch_size, hidden_size)
@@ -91,7 +93,8 @@ void gemma_rmsnorm(at::Tensor& output, at::Tensor& input, at::Tensor& weight, do
9193
DISPATCH_PYTORCH_DTYPE_TO_CTYPE_FP16(input.scalar_type(), c_type, [&] {
9294
cudaError_t status = norm::GemmaRMSNorm(
9395
static_cast<c_type*>(input.data_ptr()), static_cast<c_type*>(weight.data_ptr()),
94-
static_cast<c_type*>(output.data_ptr()), batch_size, hidden_size, eps, stream);
96+
static_cast<c_type*>(output.data_ptr()), batch_size, hidden_size, input.stride(0),
97+
output.stride(0), eps, stream);
9598
TORCH_CHECK(status == cudaSuccess,
9699
"GemmaRMSNorm failed with error code " + std::string(cudaGetErrorString(status)));
97100
return true;
@@ -100,9 +103,9 @@ void gemma_rmsnorm(at::Tensor& output, at::Tensor& input, at::Tensor& weight, do
100103

101104
void gemma_fused_add_rmsnorm(at::Tensor& input, at::Tensor& residual, at::Tensor& weight,
102105
double eps, int64_t cuda_stream) {
103-
CHECK_INPUT(input);
104-
CHECK_INPUT(residual);
105-
CHECK_INPUT(weight);
106+
CHECK_LAST_DIM_CONTIGUOUS_INPUT(input);
107+
CHECK_LAST_DIM_CONTIGUOUS_INPUT(residual);
108+
CHECK_LAST_DIM_CONTIGUOUS_INPUT(weight);
106109
auto device = input.device();
107110
CHECK_EQ(residual.device(), device);
108111
CHECK_EQ(weight.device(), device);
@@ -119,7 +122,8 @@ void gemma_fused_add_rmsnorm(at::Tensor& input, at::Tensor& residual, at::Tensor
119122
DISPATCH_PYTORCH_DTYPE_TO_CTYPE_FP16(input.scalar_type(), c_type, [&] {
120123
cudaError_t status = norm::GemmaFusedAddRMSNorm(
121124
static_cast<c_type*>(input.data_ptr()), static_cast<c_type*>(residual.data_ptr()),
122-
static_cast<c_type*>(weight.data_ptr()), batch_size, hidden_size, eps, stream);
125+
static_cast<c_type*>(weight.data_ptr()), batch_size, hidden_size, input.stride(0),
126+
residual.stride(0), eps, stream);
123127
TORCH_CHECK(status == cudaSuccess, "GemmaFusedAddRMSNorm failed with error code " +
124128
std::string(cudaGetErrorString(status)));
125129
return true;

include/flashinfer/norm.cuh

+34-20
Original file line numberDiff line numberDiff line change
@@ -29,7 +29,8 @@ namespace norm {
2929

3030
template <uint32_t VEC_SIZE, typename T>
3131
__global__ void RMSNormKernel(T* __restrict__ input, T* __restrict__ weight, T* __restrict__ output,
32-
const uint32_t d, float weight_bias, float eps) {
32+
const uint32_t d, const uint32_t stride_input,
33+
const uint32_t stride_output, float weight_bias, float eps) {
3334
const uint32_t bx = blockIdx.x;
3435
const uint32_t tx = threadIdx.x, ty = threadIdx.y;
3536
constexpr uint32_t warp_size = 32;
@@ -46,7 +47,7 @@ __global__ void RMSNormKernel(T* __restrict__ input, T* __restrict__ weight, T*
4647
vec_t<T, VEC_SIZE> input_vec;
4748
input_vec.fill(0.f);
4849
if ((i * num_threads + thread_id) * VEC_SIZE < d) {
49-
input_vec.load(input + bx * d + i * num_threads * VEC_SIZE + thread_id * VEC_SIZE);
50+
input_vec.load(input + bx * stride_input + i * num_threads * VEC_SIZE + thread_id * VEC_SIZE);
5051
}
5152
#pragma unroll
5253
for (uint32_t j = 0; j < VEC_SIZE; j++) {
@@ -82,22 +83,24 @@ __global__ void RMSNormKernel(T* __restrict__ input, T* __restrict__ weight, T*
8283
input_vec.fill(0.f);
8384
weight_vec.fill(0.f);
8485
if ((i * num_threads + thread_id) * VEC_SIZE < d) {
85-
input_vec.load(input + bx * d + i * num_threads * VEC_SIZE + thread_id * VEC_SIZE);
86+
input_vec.load(input + bx * stride_input + i * num_threads * VEC_SIZE + thread_id * VEC_SIZE);
8687
weight_vec.load(weight + i * num_threads * VEC_SIZE + thread_id * VEC_SIZE);
8788
}
8889
#pragma unroll
8990
for (uint32_t j = 0; j < VEC_SIZE; j++) {
9091
output_vec[j] = float(input_vec[j]) * rms_rcp * (weight_bias + float(weight_vec[j]));
9192
}
9293
if ((i * num_threads + thread_id) * VEC_SIZE < d) {
93-
output_vec.store(output + bx * d + i * num_threads * VEC_SIZE + thread_id * VEC_SIZE);
94+
output_vec.store(output + bx * stride_output + i * num_threads * VEC_SIZE +
95+
thread_id * VEC_SIZE);
9496
}
9597
}
9698
}
9799

98100
template <typename T>
99101
cudaError_t RMSNorm(T* input, T* weight, T* output, uint32_t batch_size, uint32_t d,
100-
float eps = 1e-5, cudaStream_t stream = 0) {
102+
uint32_t stride_input, uint32_t stride_output, float eps = 1e-5,
103+
cudaStream_t stream = 0) {
101104
const uint32_t vec_size = std::gcd(16 / sizeof(T), d);
102105

103106
const uint32_t block_size = std::min<uint32_t>(1024, d / vec_size);
@@ -106,7 +109,7 @@ cudaError_t RMSNorm(T* input, T* weight, T* output, uint32_t batch_size, uint32_
106109
dim3 nthrs(32, num_warps);
107110
const uint32_t smem_size = num_warps * sizeof(float);
108111
float weight_bias = 0.f;
109-
void* args[] = {&input, &weight, &output, &d, &weight_bias, &eps};
112+
void* args[] = {&input, &weight, &output, &d, &stride_input, &stride_output, &weight_bias, &eps};
110113

111114
DISPATCH_ALIGNED_VEC_SIZE(vec_size, VEC_SIZE, {
112115
auto kernel = RMSNormKernel<VEC_SIZE, T>;
@@ -117,8 +120,9 @@ cudaError_t RMSNorm(T* input, T* weight, T* output, uint32_t batch_size, uint32_
117120

118121
template <uint32_t VEC_SIZE, typename T>
119122
__global__ void FusedAddRMSNormKernel(T* __restrict__ input, T* __restrict__ residual,
120-
T* __restrict__ weight, const uint32_t d, float weight_bias,
121-
float eps) {
123+
T* __restrict__ weight, const uint32_t d,
124+
const uint32_t stride_input, const uint32_t stride_residual,
125+
float weight_bias, float eps) {
122126
const uint32_t bx = blockIdx.x;
123127
const uint32_t tx = threadIdx.x, ty = threadIdx.y;
124128
constexpr uint32_t warp_size = 32;
@@ -139,8 +143,9 @@ __global__ void FusedAddRMSNormKernel(T* __restrict__ input, T* __restrict__ res
139143
vec_t<float, VEC_SIZE> x_vec;
140144
x_vec.fill(0.f);
141145
if ((i * num_threads + thread_id) * VEC_SIZE < d) {
142-
input_vec.load(input + bx * d + i * num_threads * VEC_SIZE + thread_id * VEC_SIZE);
143-
residual_vec.load(residual + bx * d + i * num_threads * VEC_SIZE + thread_id * VEC_SIZE);
146+
input_vec.load(input + bx * stride_input + i * num_threads * VEC_SIZE + thread_id * VEC_SIZE);
147+
residual_vec.load(residual + bx * stride_residual + i * num_threads * VEC_SIZE +
148+
thread_id * VEC_SIZE);
144149
}
145150
#pragma unroll
146151
for (uint32_t j = 0; j < VEC_SIZE; j++) {
@@ -151,7 +156,8 @@ __global__ void FusedAddRMSNormKernel(T* __restrict__ input, T* __restrict__ res
151156
x_vec[j] = x;
152157
}
153158
if ((i * num_threads + thread_id) * VEC_SIZE < d) {
154-
residual_vec.store(residual + bx * d + i * num_threads * VEC_SIZE + thread_id * VEC_SIZE);
159+
residual_vec.store(residual + bx * stride_residual + i * num_threads * VEC_SIZE +
160+
thread_id * VEC_SIZE);
155161
x_vec.store(smem_x + i * num_threads * VEC_SIZE + thread_id * VEC_SIZE);
156162
}
157163
}
@@ -193,14 +199,16 @@ __global__ void FusedAddRMSNormKernel(T* __restrict__ input, T* __restrict__ res
193199
input_vec[j] = x_vec[j] * rms_rcp * (weight_bias + float(weight_vec[j]));
194200
}
195201
if ((i * num_threads + thread_id) * VEC_SIZE < d) {
196-
input_vec.store(input + bx * d + i * num_threads * VEC_SIZE + thread_id * VEC_SIZE);
202+
input_vec.store(input + bx * stride_input + i * num_threads * VEC_SIZE +
203+
thread_id * VEC_SIZE);
197204
}
198205
}
199206
}
200207

201208
template <typename T>
202209
cudaError_t FusedAddRMSNorm(T* input, T* residual, T* weight, uint32_t batch_size, uint32_t d,
203-
float eps = 1e-5, cudaStream_t stream = 0) {
210+
uint32_t stride_input, uint32_t stride_residual, float eps = 1e-5,
211+
cudaStream_t stream = 0) {
204212
const uint32_t vec_size = std::gcd(16 / sizeof(T), d);
205213

206214
const uint32_t block_size = std::min<uint32_t>(1024, d / vec_size);
@@ -209,11 +217,13 @@ cudaError_t FusedAddRMSNorm(T* input, T* residual, T* weight, uint32_t batch_siz
209217
dim3 nthrs(32, num_warps);
210218
const uint32_t smem_size = (ceil_div(num_warps, 4) * 4 + d) * sizeof(float);
211219
float weight_bias = 0.f;
212-
void* args[] = {&input, &residual, &weight, &d, &weight_bias, &eps};
220+
void* args[] = {&input, &residual, &weight, &d,
221+
&stride_input, &stride_residual, &weight_bias, &eps};
213222

214223
DISPATCH_ALIGNED_VEC_SIZE(vec_size, VEC_SIZE, {
215224
auto kernel = FusedAddRMSNormKernel<VEC_SIZE, T>;
216-
FLASHINFER_CUDA_CALL(cudaFuncSetAttribute(kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, smem_size));
225+
FLASHINFER_CUDA_CALL(
226+
cudaFuncSetAttribute(kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, smem_size));
217227
FLASHINFER_CUDA_CALL(cudaLaunchKernel((void*)kernel, nblks, nthrs, args, smem_size, stream));
218228
});
219229

@@ -222,7 +232,8 @@ cudaError_t FusedAddRMSNorm(T* input, T* residual, T* weight, uint32_t batch_siz
222232

223233
template <typename T>
224234
cudaError_t GemmaRMSNorm(T* input, T* weight, T* output, uint32_t batch_size, uint32_t d,
225-
float eps = 1e-5, cudaStream_t stream = 0) {
235+
uint32_t stride_input, uint32_t stride_output, float eps = 1e-5,
236+
cudaStream_t stream = 0) {
226237
const uint32_t vec_size = std::gcd(16 / sizeof(T), d);
227238

228239
const uint32_t block_size = std::min<uint32_t>(1024, d / vec_size);
@@ -231,7 +242,7 @@ cudaError_t GemmaRMSNorm(T* input, T* weight, T* output, uint32_t batch_size, ui
231242
dim3 nthrs(32, num_warps);
232243
const uint32_t smem_size = num_warps * sizeof(float);
233244
float weight_bias = 1.f;
234-
void* args[] = {&input, &weight, &output, &d, &weight_bias, &eps};
245+
void* args[] = {&input, &weight, &output, &d, &stride_input, &stride_output, &weight_bias, &eps};
235246

236247
DISPATCH_ALIGNED_VEC_SIZE(vec_size, VEC_SIZE, {
237248
auto kernel = RMSNormKernel<VEC_SIZE, T>;
@@ -242,7 +253,8 @@ cudaError_t GemmaRMSNorm(T* input, T* weight, T* output, uint32_t batch_size, ui
242253

243254
template <typename T>
244255
cudaError_t GemmaFusedAddRMSNorm(T* input, T* residual, T* weight, uint32_t batch_size, uint32_t d,
245-
float eps = 1e-5, cudaStream_t stream = 0) {
256+
uint32_t stride_input, uint32_t stride_residual, float eps = 1e-5,
257+
cudaStream_t stream = 0) {
246258
const uint32_t vec_size = std::gcd(16 / sizeof(T), d);
247259

248260
const uint32_t block_size = std::min<uint32_t>(1024, d / vec_size);
@@ -252,11 +264,13 @@ cudaError_t GemmaFusedAddRMSNorm(T* input, T* residual, T* weight, uint32_t batc
252264
// NOTE(Zihao): use ceil_div(num_warps, 4) * 4 for address alignment to 16 bytes
253265
const uint32_t smem_size = (ceil_div(num_warps, 4) * 4 + d) * sizeof(float);
254266
float weight_bias = 1.f;
255-
void* args[] = {&input, &residual, &weight, &d, &weight_bias, &eps};
267+
void* args[] = {&input, &residual, &weight, &d,
268+
&stride_input, &stride_residual, &weight_bias, &eps};
256269

257270
DISPATCH_ALIGNED_VEC_SIZE(vec_size, VEC_SIZE, {
258271
auto kernel = FusedAddRMSNormKernel<VEC_SIZE, T>;
259-
FLASHINFER_CUDA_CALL(cudaFuncSetAttribute(kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, smem_size));
272+
FLASHINFER_CUDA_CALL(
273+
cudaFuncSetAttribute(kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, smem_size));
260274
FLASHINFER_CUDA_CALL(cudaLaunchKernel((void*)kernel, nblks, nthrs, args, smem_size, stream));
261275
});
262276

tests/test_norm.py

+36-8
Original file line numberDiff line numberDiff line change
@@ -68,8 +68,14 @@ def fused_add_rms_norm(x, residual, weight, eps):
6868
@pytest.mark.parametrize("hidden_size", [111, 500, 1024, 3072, 3584, 4096, 8192, 16384])
6969
@pytest.mark.parametrize("dtype", [torch.float16])
7070
@pytest.mark.parametrize("specify_out", [True, False])
71-
def test_norm(batch_size, hidden_size, dtype, specify_out):
72-
x = torch.randn(batch_size, hidden_size).to(0).to(dtype)
71+
@pytest.mark.parametrize("contiguous", [True, False])
72+
def test_norm(batch_size, hidden_size, dtype, specify_out, contiguous):
73+
if contiguous:
74+
x = torch.randn(batch_size, hidden_size).to(0).to(dtype)
75+
else:
76+
x = torch.randn(batch_size, hidden_size * 2, device="cuda").to(dtype)
77+
x = x[:, :hidden_size]
78+
7379
w = torch.randn(hidden_size).to(0).to(dtype)
7480

7581
y_ref = llama_rms_norm(x, w)
@@ -85,10 +91,16 @@ def test_norm(batch_size, hidden_size, dtype, specify_out):
8591
@pytest.mark.parametrize("batch_size", [1, 19, 99, 989])
8692
@pytest.mark.parametrize("hidden_size", [111, 500, 1024, 3072, 3584, 4096, 8192, 16384])
8793
@pytest.mark.parametrize("dtype", [torch.float16])
88-
def test_fused_add_rmsnorm(batch_size, hidden_size, dtype):
94+
@pytest.mark.parametrize("contiguous", [True, False])
95+
def test_fused_add_rmsnorm(batch_size, hidden_size, dtype, contiguous):
8996
eps = 1e-6
9097

91-
x = torch.randn(batch_size, hidden_size, dtype=dtype, device="cuda")
98+
if contiguous:
99+
x = torch.randn(batch_size, hidden_size, dtype=dtype, device="cuda")
100+
else:
101+
x = torch.randn(batch_size, hidden_size * 2, device="cuda").to(dtype)
102+
x = x[:, :hidden_size]
103+
92104
residual = torch.randn_like(x)
93105
weight = torch.randn(hidden_size, dtype=dtype, device="cuda")
94106

@@ -108,8 +120,14 @@ def test_fused_add_rmsnorm(batch_size, hidden_size, dtype):
108120
@pytest.mark.parametrize("hidden_size", [111, 500, 1024, 3072, 3584, 4096, 8192, 16384])
109121
@pytest.mark.parametrize("dtype", [torch.float16])
110122
@pytest.mark.parametrize("specify_out", [True, False])
111-
def test_gemma_norm(batch_size, hidden_size, dtype, specify_out):
112-
x = torch.randn(batch_size, hidden_size).to(0).to(dtype)
123+
@pytest.mark.parametrize("contiguous", [True, False])
124+
def test_gemma_norm(batch_size, hidden_size, dtype, specify_out, contiguous):
125+
if contiguous:
126+
x = torch.randn(batch_size, hidden_size).to(0).to(dtype)
127+
else:
128+
x = torch.randn(batch_size, hidden_size * 2, device="cuda").to(dtype)
129+
x = x[:, :hidden_size]
130+
113131
w = torch.randn(hidden_size).to(0).to(dtype)
114132

115133
y_ref = gemma_rms_norm(x, w)
@@ -125,10 +143,16 @@ def test_gemma_norm(batch_size, hidden_size, dtype, specify_out):
125143
@pytest.mark.parametrize("batch_size", [1, 19, 99, 989])
126144
@pytest.mark.parametrize("hidden_size", [111, 500, 1024, 3072, 3584, 4096, 8192, 16384])
127145
@pytest.mark.parametrize("dtype", [torch.float16])
128-
def test_gemma_fused_add_rmsnorm(batch_size, hidden_size, dtype):
146+
@pytest.mark.parametrize("contiguous", [True, False])
147+
def test_gemma_fused_add_rmsnorm(batch_size, hidden_size, dtype, contiguous):
129148
eps = 1e-6
130149

131-
x = torch.randn(batch_size, hidden_size, dtype=dtype, device="cuda")
150+
if contiguous:
151+
x = torch.randn(batch_size, hidden_size, dtype=dtype, device="cuda")
152+
else:
153+
x = torch.randn(batch_size, hidden_size * 2, device="cuda").to(dtype)
154+
x = x[:, :hidden_size]
155+
132156
residual = torch.randn_like(x)
133157
weight = torch.randn(hidden_size, dtype=dtype, device="cuda")
134158

@@ -142,3 +166,7 @@ def test_gemma_fused_add_rmsnorm(batch_size, hidden_size, dtype):
142166

143167
torch.testing.assert_close(x_fused, x_native, rtol=1e-3, atol=1e-3)
144168
torch.testing.assert_close(residual_fused, residual_native, rtol=1e-3, atol=1e-3)
169+
170+
171+
if __name__ == "__main__":
172+
test_norm(1, 1024, torch.float16, False, True)

0 commit comments

Comments
 (0)