Skip to content

Commit 345ff57

Browse files
authored
[Triton Kernel] Add varlen segment mean triton kernel (#10369)
* upload triton kernel * update kernel from review * fix docstring * update __init__
1 parent 53273ad commit 345ff57

File tree

2 files changed

+206
-0
lines changed

2 files changed

+206
-0
lines changed

paddlenlp/ops/triton_ops/__init__.py

+1
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@
1717
per_token_group_quant_fp8_api,
1818
per_token_group_quant_fp8_api_masked,
1919
)
20+
from .segment_mean import segment_mean
2021

2122
__all__ = +["per_token_group_quant_fp8_api_masked", "per_token_group_quant_fp8_api"]
2223
except:
+205
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,205 @@
1+
# Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
# Notice: this kernel is especially implemented for the segment mean operation for a varlen qkv tensor.
16+
# for example, the k tensor is: [total_seqlen, num_head, head_dim],
17+
# where total_seqlen = seqlen 1 + seqlen 2 + ... + seqlen n.
18+
# So the segment mean triton kernel will do mean operation along the **seqlen** dim.
19+
# It will finally generate a `[bsz, num_head, head_dim]` shape-like result,
20+
# as the result of mean value of each seqlen segment.
21+
22+
import paddle
23+
import triton.language as tl
24+
from paddle import _C_ops
25+
from paddle.base.framework import OpProtoHolder
26+
from paddle.base.layer_helper import LayerHelper
27+
from paddle.framework import in_dynamic_or_pir_mode
28+
29+
from paddlenlp.ops.triton_ops.triton_utils import (
30+
get_dtype_str,
31+
paddle_use_triton,
32+
rendering_common_template,
33+
)
34+
35+
36+
@paddle_use_triton(key=["num_heads", "head_dim"])
37+
def segmented_mean_reduce_kernel(
38+
input_ptr,
39+
output_ptr,
40+
cu_seqlen_ptr,
41+
num_batches,
42+
num_heads,
43+
head_dim,
44+
input_stride_seq,
45+
input_stride_head,
46+
output_stride_batch,
47+
output_stride_head,
48+
BLOCK_SIZE_SEQ: tl.constexpr,
49+
BLOCK_SIZE_HEAD: tl.constexpr,
50+
BLOCK_SIZE_DIM: tl.constexpr,
51+
):
52+
batch_idx = tl.program_id(0)
53+
head_offset = tl.program_id(1) * BLOCK_SIZE_HEAD
54+
dim_offset = tl.program_id(2) * BLOCK_SIZE_DIM
55+
56+
# 获取当前 segment 的 range
57+
seq_start = tl.load(cu_seqlen_ptr + batch_idx)
58+
seq_end = tl.load(cu_seqlen_ptr + batch_idx + 1)
59+
seq_len = seq_end - seq_start
60+
61+
# head 和 dim 的实际索引(block中相对位置)
62+
head_idx = head_offset + tl.arange(0, BLOCK_SIZE_HEAD)
63+
dim_idx = dim_offset + tl.arange(0, BLOCK_SIZE_DIM)
64+
mask_head = head_idx < num_heads
65+
mask_dim = dim_idx < head_dim
66+
67+
# 初始化累加器(float32 精度)
68+
acc = tl.zeros((BLOCK_SIZE_HEAD, BLOCK_SIZE_DIM), dtype=tl.float32)
69+
70+
for seq_offset in range(0, seq_len, BLOCK_SIZE_SEQ):
71+
local_seq_idx = tl.arange(0, BLOCK_SIZE_SEQ)
72+
mask_seq = local_seq_idx < (seq_len - seq_offset)
73+
global_seq = seq_start + seq_offset + local_seq_idx
74+
# shape: [BLOCK_SIZE_SEQ, BLOCK_SIZE_HEAD, BLOCK_SIZE_DIM]
75+
input_ptrs = (
76+
input_ptr
77+
+ global_seq[:, None, None] * input_stride_seq
78+
+ head_idx[None, :, None] * input_stride_head
79+
+ dim_idx[None, None, :]
80+
)
81+
82+
# 加载输入,注意输入 dtype 指明 float16 以避免不必要转换
83+
x = tl.load(
84+
input_ptrs, mask=mask_seq[:, None, None] & mask_head[None, :, None] & mask_dim[None, None, :], other=0.0
85+
).to(tl.float32)
86+
87+
acc += tl.sum(x, axis=0) # reduce over seq axis
88+
89+
mean = acc / seq_len
90+
91+
# 构造输出地址
92+
output_ptrs = (
93+
output_ptr + batch_idx * output_stride_batch + head_idx[:, None] * output_stride_head + dim_idx[None, :]
94+
)
95+
96+
tl.store(output_ptrs, mean.to(input_ptr.dtype.element_ty), mask=mask_head[:, None] & mask_dim[None, :])
97+
98+
99+
def segment_mean(
100+
x: paddle.Tensor, cu_seqlen: paddle.Tensor # [total_seqlen, num_heads, head_dim] # [batch_size + 1]
101+
):
102+
"""
103+
Examples:
104+
import paddle
105+
from paddlenlp.ops.triton_ops.segment_mean import segment_mean
106+
107+
cu_seqlens = [0, 1024, 2048, 4096]
108+
total_seqlen = 4096
109+
num_head = 24
110+
head_dim = 128
111+
k = paddle.randn([total_seqlen, num_head, head_dim], dtype="float16")
112+
cu_seqlen = paddle.to_tensor(cu_seqlens, paddle.int32)
113+
km = segment_mean(k, cu_seqlen)
114+
"""
115+
num_batches = cu_seqlen.shape[0] - 1
116+
117+
num_heads = x.shape[1]
118+
head_dim = x.shape[2]
119+
120+
# 计算必要的strides
121+
input_stride_seq = num_heads * head_dim
122+
input_stride_head = head_dim
123+
124+
output_stride_batch = num_heads * head_dim
125+
output_stride_head = head_dim
126+
prepare_attr_for_triton_kernel = """
127+
const int num_batches = cu_seqlen.shape()[0] - 1;
128+
const int num_heads = x.shape()[1];
129+
const int head_dim = x.shape()[2];
130+
int input_stride_seq = num_heads * head_dim;
131+
int input_stride_head = head_dim;
132+
int output_stride_batch = num_heads * head_dim;
133+
int output_stride_head = head_dim;
134+
paddle::Tensor output_tensor = paddle::empty({num_batches, num_heads, head_dim}, x.dtype(), x.place());
135+
"""
136+
op_name = "triton_segment_mean"
137+
op_name += get_dtype_str(x.dtype)
138+
139+
# auto-tuning
140+
segment_mean_configs = []
141+
segment_mean_configs.append(
142+
{"BLOCK_SIZE_SEQ": 128, "BLOCK_SIZE_HEAD": 4, "BLOCK_SIZE_DIM": 64, "num_stages": 2, "num_warps": 4}
143+
)
144+
segment_mean_configs.append(
145+
{"BLOCK_SIZE_SEQ": 256, "BLOCK_SIZE_HEAD": 4, "BLOCK_SIZE_DIM": 64, "num_stages": 2, "num_warps": 4}
146+
)
147+
segment_mean_configs.append(
148+
{"BLOCK_SIZE_SEQ": 512, "BLOCK_SIZE_HEAD": 8, "BLOCK_SIZE_DIM": 64, "num_stages": 2, "num_warps": 8}
149+
)
150+
segment_mean_configs.append(
151+
{"BLOCK_SIZE_SEQ": 256, "BLOCK_SIZE_HEAD": 8, "BLOCK_SIZE_DIM": 128, "num_stages": 2, "num_warps": 4}
152+
)
153+
154+
if op_name not in OpProtoHolder.instance().op_proto_map.keys():
155+
Output = paddle.empty([num_batches, num_heads, head_dim], dtype=x.dtype)
156+
prepare_ptr_for_triton_kernel = """
157+
// prepare tensor
158+
CUdeviceptr input_ptrs[3] = {
159+
get_tensor_ptr(x),
160+
get_tensor_ptr(output_tensor),
161+
get_tensor_ptr(cu_seqlen)
162+
};
163+
"""
164+
return_tensor_names = "output_tensor"
165+
template_used = rendering_common_template(
166+
segment_mean,
167+
prepare_attr_for_triton_kernel,
168+
prepare_ptr_for_triton_kernel,
169+
return_tensor_names,
170+
)
171+
172+
# 确定kernel配置
173+
grid = (
174+
"num_batches",
175+
"(num_heads + BLOCK_SIZE_HEAD - 1) / BLOCK_SIZE_HEAD",
176+
"(head_dim + BLOCK_SIZE_DIM - 1) / BLOCK_SIZE_DIM",
177+
)
178+
179+
# 调用kernel
180+
segmented_mean_reduce_kernel[(op_name, template_used, grid, segment_mean_configs)](
181+
input_ptr=x,
182+
output_ptr=Output,
183+
cu_seqlen_ptr=cu_seqlen,
184+
num_batches=-1,
185+
num_heads=num_heads,
186+
head_dim=head_dim,
187+
input_stride_seq=input_stride_seq,
188+
input_stride_head=input_stride_head,
189+
output_stride_batch=output_stride_batch,
190+
output_stride_head=output_stride_head,
191+
)
192+
193+
if in_dynamic_or_pir_mode():
194+
outs = _C_ops._run_custom_op(op_name, x, cu_seqlen)
195+
return outs[0]
196+
else:
197+
helper = LayerHelper(op_name, **locals())
198+
inputs = {
199+
"x": x,
200+
"cu_seqlen_tensor": cu_seqlen,
201+
}
202+
output = helper.create_variable_for_type_inference(dtype=x.dtype)
203+
204+
helper.append_op(type=op_name, inputs=inputs, outputs={"output_tensor": output})
205+
return output

0 commit comments

Comments
 (0)