Skip to content

[Triton Kernel] Add varlen segment mean triton kernel #10369

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 6 commits into from
Apr 23, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions paddlenlp/ops/triton_ops/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
per_token_group_quant_fp8_api,
per_token_group_quant_fp8_api_masked,
)
from .segment_mean import segment_mean

__all__ = +["per_token_group_quant_fp8_api_masked", "per_token_group_quant_fp8_api"]
except:
Expand Down
205 changes: 205 additions & 0 deletions paddlenlp/ops/triton_ops/segment_mean.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,205 @@
# Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

# Notice: this kernel is especially implemented for the segment mean operation for a varlen qkv tensor.
# for example, the k tensor is: [total_seqlen, num_head, head_dim],
# where total_seqlen = seqlen 1 + seqlen 2 + ... + seqlen n.
# So the segment mean triton kernel will do mean operation along the **seqlen** dim.
# It will finally generate a `[bsz, num_head, head_dim]` shape-like result,
# as the result of mean value of each seqlen segment.

import paddle
import triton.language as tl
from paddle import _C_ops
from paddle.base.framework import OpProtoHolder
from paddle.base.layer_helper import LayerHelper
from paddle.framework import in_dynamic_or_pir_mode

from paddlenlp.ops.triton_ops.triton_utils import (
get_dtype_str,
paddle_use_triton,
rendering_common_template,
)


@paddle_use_triton(key=["num_heads", "head_dim"])
def segmented_mean_reduce_kernel(
input_ptr,
output_ptr,
cu_seqlen_ptr,
num_batches,
num_heads,
head_dim,
input_stride_seq,
input_stride_head,
output_stride_batch,
output_stride_head,
BLOCK_SIZE_SEQ: tl.constexpr,
BLOCK_SIZE_HEAD: tl.constexpr,
BLOCK_SIZE_DIM: tl.constexpr,
):
batch_idx = tl.program_id(0)
head_offset = tl.program_id(1) * BLOCK_SIZE_HEAD
dim_offset = tl.program_id(2) * BLOCK_SIZE_DIM

Check warning on line 54 in paddlenlp/ops/triton_ops/segment_mean.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/ops/triton_ops/segment_mean.py#L52-L54

Added lines #L52 - L54 were not covered by tests

# 获取当前 segment 的 range
seq_start = tl.load(cu_seqlen_ptr + batch_idx)
seq_end = tl.load(cu_seqlen_ptr + batch_idx + 1)
seq_len = seq_end - seq_start

Check warning on line 59 in paddlenlp/ops/triton_ops/segment_mean.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/ops/triton_ops/segment_mean.py#L57-L59

Added lines #L57 - L59 were not covered by tests

# head 和 dim 的实际索引(block中相对位置)
head_idx = head_offset + tl.arange(0, BLOCK_SIZE_HEAD)
dim_idx = dim_offset + tl.arange(0, BLOCK_SIZE_DIM)
mask_head = head_idx < num_heads
mask_dim = dim_idx < head_dim

Check warning on line 65 in paddlenlp/ops/triton_ops/segment_mean.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/ops/triton_ops/segment_mean.py#L62-L65

Added lines #L62 - L65 were not covered by tests

# 初始化累加器(float32 精度)
acc = tl.zeros((BLOCK_SIZE_HEAD, BLOCK_SIZE_DIM), dtype=tl.float32)

Check warning on line 68 in paddlenlp/ops/triton_ops/segment_mean.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/ops/triton_ops/segment_mean.py#L68

Added line #L68 was not covered by tests

for seq_offset in range(0, seq_len, BLOCK_SIZE_SEQ):
local_seq_idx = tl.arange(0, BLOCK_SIZE_SEQ)
mask_seq = local_seq_idx < (seq_len - seq_offset)
global_seq = seq_start + seq_offset + local_seq_idx

Check warning on line 73 in paddlenlp/ops/triton_ops/segment_mean.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/ops/triton_ops/segment_mean.py#L70-L73

Added lines #L70 - L73 were not covered by tests
# shape: [BLOCK_SIZE_SEQ, BLOCK_SIZE_HEAD, BLOCK_SIZE_DIM]
input_ptrs = (

Check warning on line 75 in paddlenlp/ops/triton_ops/segment_mean.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/ops/triton_ops/segment_mean.py#L75

Added line #L75 was not covered by tests
input_ptr
+ global_seq[:, None, None] * input_stride_seq
+ head_idx[None, :, None] * input_stride_head
+ dim_idx[None, None, :]
)

# 加载输入,注意输入 dtype 指明 float16 以避免不必要转换
x = tl.load(

Check warning on line 83 in paddlenlp/ops/triton_ops/segment_mean.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/ops/triton_ops/segment_mean.py#L83

Added line #L83 was not covered by tests
input_ptrs, mask=mask_seq[:, None, None] & mask_head[None, :, None] & mask_dim[None, None, :], other=0.0
).to(tl.float32)

acc += tl.sum(x, axis=0) # reduce over seq axis

Check warning on line 87 in paddlenlp/ops/triton_ops/segment_mean.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/ops/triton_ops/segment_mean.py#L87

Added line #L87 was not covered by tests

mean = acc / seq_len

Check warning on line 89 in paddlenlp/ops/triton_ops/segment_mean.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/ops/triton_ops/segment_mean.py#L89

Added line #L89 was not covered by tests

# 构造输出地址
output_ptrs = (

Check warning on line 92 in paddlenlp/ops/triton_ops/segment_mean.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/ops/triton_ops/segment_mean.py#L92

Added line #L92 was not covered by tests
output_ptr + batch_idx * output_stride_batch + head_idx[:, None] * output_stride_head + dim_idx[None, :]
)

tl.store(output_ptrs, mean.to(input_ptr.dtype.element_ty), mask=mask_head[:, None] & mask_dim[None, :])

Check warning on line 96 in paddlenlp/ops/triton_ops/segment_mean.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/ops/triton_ops/segment_mean.py#L96

Added line #L96 was not covered by tests


def segment_mean(
x: paddle.Tensor, cu_seqlen: paddle.Tensor # [total_seqlen, num_heads, head_dim] # [batch_size + 1]
):
"""
Examples:
import paddle
from paddlenlp.ops.triton_ops.segment_mean import segment_mean

cu_seqlens = [0, 1024, 2048, 4096]
total_seqlen = 4096
num_head = 24
head_dim = 128
k = paddle.randn([total_seqlen, num_head, head_dim], dtype="float16")
cu_seqlen = paddle.to_tensor(cu_seqlens, paddle.int32)
km = segment_mean(k, cu_seqlen)
"""
num_batches = cu_seqlen.shape[0] - 1

Check warning on line 115 in paddlenlp/ops/triton_ops/segment_mean.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/ops/triton_ops/segment_mean.py#L115

Added line #L115 was not covered by tests

num_heads = x.shape[1]
head_dim = x.shape[2]

Check warning on line 118 in paddlenlp/ops/triton_ops/segment_mean.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/ops/triton_ops/segment_mean.py#L117-L118

Added lines #L117 - L118 were not covered by tests

# 计算必要的strides
input_stride_seq = num_heads * head_dim
input_stride_head = head_dim

Check warning on line 122 in paddlenlp/ops/triton_ops/segment_mean.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/ops/triton_ops/segment_mean.py#L121-L122

Added lines #L121 - L122 were not covered by tests

output_stride_batch = num_heads * head_dim
output_stride_head = head_dim
prepare_attr_for_triton_kernel = """

Check warning on line 126 in paddlenlp/ops/triton_ops/segment_mean.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/ops/triton_ops/segment_mean.py#L124-L126

Added lines #L124 - L126 were not covered by tests
const int num_batches = cu_seqlen.shape()[0] - 1;
const int num_heads = x.shape()[1];
const int head_dim = x.shape()[2];
int input_stride_seq = num_heads * head_dim;
int input_stride_head = head_dim;
int output_stride_batch = num_heads * head_dim;
int output_stride_head = head_dim;
paddle::Tensor output_tensor = paddle::empty({num_batches, num_heads, head_dim}, x.dtype(), x.place());
"""
op_name = "triton_segment_mean"
op_name += get_dtype_str(x.dtype)

Check warning on line 137 in paddlenlp/ops/triton_ops/segment_mean.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/ops/triton_ops/segment_mean.py#L136-L137

Added lines #L136 - L137 were not covered by tests

# auto-tuning
segment_mean_configs = []
segment_mean_configs.append(

Check warning on line 141 in paddlenlp/ops/triton_ops/segment_mean.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/ops/triton_ops/segment_mean.py#L140-L141

Added lines #L140 - L141 were not covered by tests
{"BLOCK_SIZE_SEQ": 128, "BLOCK_SIZE_HEAD": 4, "BLOCK_SIZE_DIM": 64, "num_stages": 2, "num_warps": 4}
)
segment_mean_configs.append(

Check warning on line 144 in paddlenlp/ops/triton_ops/segment_mean.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/ops/triton_ops/segment_mean.py#L144

Added line #L144 was not covered by tests
{"BLOCK_SIZE_SEQ": 256, "BLOCK_SIZE_HEAD": 4, "BLOCK_SIZE_DIM": 64, "num_stages": 2, "num_warps": 4}
)
segment_mean_configs.append(

Check warning on line 147 in paddlenlp/ops/triton_ops/segment_mean.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/ops/triton_ops/segment_mean.py#L147

Added line #L147 was not covered by tests
{"BLOCK_SIZE_SEQ": 512, "BLOCK_SIZE_HEAD": 8, "BLOCK_SIZE_DIM": 64, "num_stages": 2, "num_warps": 8}
)
segment_mean_configs.append(

Check warning on line 150 in paddlenlp/ops/triton_ops/segment_mean.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/ops/triton_ops/segment_mean.py#L150

Added line #L150 was not covered by tests
{"BLOCK_SIZE_SEQ": 256, "BLOCK_SIZE_HEAD": 8, "BLOCK_SIZE_DIM": 128, "num_stages": 2, "num_warps": 4}
)

if op_name not in OpProtoHolder.instance().op_proto_map.keys():
Output = paddle.empty([num_batches, num_heads, head_dim], dtype=x.dtype)
prepare_ptr_for_triton_kernel = """

Check warning on line 156 in paddlenlp/ops/triton_ops/segment_mean.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/ops/triton_ops/segment_mean.py#L154-L156

Added lines #L154 - L156 were not covered by tests
// prepare tensor
CUdeviceptr input_ptrs[3] = {
get_tensor_ptr(x),
get_tensor_ptr(output_tensor),
get_tensor_ptr(cu_seqlen)
};
"""
return_tensor_names = "output_tensor"
template_used = rendering_common_template(

Check warning on line 165 in paddlenlp/ops/triton_ops/segment_mean.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/ops/triton_ops/segment_mean.py#L164-L165

Added lines #L164 - L165 were not covered by tests
segment_mean,
prepare_attr_for_triton_kernel,
prepare_ptr_for_triton_kernel,
return_tensor_names,
)

# 确定kernel配置
grid = (

Check warning on line 173 in paddlenlp/ops/triton_ops/segment_mean.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/ops/triton_ops/segment_mean.py#L173

Added line #L173 was not covered by tests
"num_batches",
"(num_heads + BLOCK_SIZE_HEAD - 1) / BLOCK_SIZE_HEAD",
"(head_dim + BLOCK_SIZE_DIM - 1) / BLOCK_SIZE_DIM",
)

# 调用kernel
segmented_mean_reduce_kernel[(op_name, template_used, grid, segment_mean_configs)](

Check warning on line 180 in paddlenlp/ops/triton_ops/segment_mean.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/ops/triton_ops/segment_mean.py#L180

Added line #L180 was not covered by tests
input_ptr=x,
output_ptr=Output,
cu_seqlen_ptr=cu_seqlen,
num_batches=-1,
num_heads=num_heads,
head_dim=head_dim,
input_stride_seq=input_stride_seq,
input_stride_head=input_stride_head,
output_stride_batch=output_stride_batch,
output_stride_head=output_stride_head,
)

if in_dynamic_or_pir_mode():
outs = _C_ops._run_custom_op(op_name, x, cu_seqlen)
return outs[0]

Check warning on line 195 in paddlenlp/ops/triton_ops/segment_mean.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/ops/triton_ops/segment_mean.py#L193-L195

Added lines #L193 - L195 were not covered by tests
else:
helper = LayerHelper(op_name, **locals())
inputs = {

Check warning on line 198 in paddlenlp/ops/triton_ops/segment_mean.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/ops/triton_ops/segment_mean.py#L197-L198

Added lines #L197 - L198 were not covered by tests
"x": x,
"cu_seqlen_tensor": cu_seqlen,
}
output = helper.create_variable_for_type_inference(dtype=x.dtype)

Check warning on line 202 in paddlenlp/ops/triton_ops/segment_mean.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/ops/triton_ops/segment_mean.py#L202

Added line #L202 was not covered by tests

helper.append_op(type=op_name, inputs=inputs, outputs={"output_tensor": output})
return output

Check warning on line 205 in paddlenlp/ops/triton_ops/segment_mean.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/ops/triton_ops/segment_mean.py#L204-L205

Added lines #L204 - L205 were not covered by tests