From 27c9625ba8faa17d7c163560eb9c1371ab88ec2d Mon Sep 17 00:00:00 2001 From: l1cacheDell Date: Wed, 9 Apr 2025 17:22:22 +0800 Subject: [PATCH 1/4] upload triton kernel --- paddlenlp/ops/triton_ops/segment_mean.py | 205 +++++++++++++++++++++++ 1 file changed, 205 insertions(+) create mode 100644 paddlenlp/ops/triton_ops/segment_mean.py diff --git a/paddlenlp/ops/triton_ops/segment_mean.py b/paddlenlp/ops/triton_ops/segment_mean.py new file mode 100644 index 000000000000..6c52d2846467 --- /dev/null +++ b/paddlenlp/ops/triton_ops/segment_mean.py @@ -0,0 +1,205 @@ +# Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# Notice: this kernel is especially implemented for the segment mean operation for a varlen qkv tensor. +# for example, the k tensor is: [total_seqlen, num_head, head_dim], +# where total_seqlen = seqlen 1 + seqlen 2 + ... + seqlen n. +# So the segment mean triton kernel will do mean operation along the **seqlen** dim. +# It will finally generate a `[bsz, num_head, head_dim]` shape-like result, +# as the result of mean value of each seqlen segment. + +import paddle +import triton.language as tl +from paddle import _C_ops +from paddle.base.framework import OpProtoHolder +from paddle.base.layer_helper import LayerHelper +from paddle.framework import in_dynamic_or_pir_mode + +from paddlenlp.ops.triton_ops.triton_utils import ( + get_dtype_str, + paddle_use_triton, + rendering_common_template, +) + + +@paddle_use_triton(key=["num_heads", "head_dim"]) +def segmented_mean_reduce_kernel( + input_ptr, + output_ptr, + cu_seqlen_ptr, + num_batches, + num_heads, + head_dim, + input_stride_seq, + input_stride_head, + output_stride_batch, + output_stride_head, + BLOCK_SIZE_SEQ: tl.constexpr, + BLOCK_SIZE_HEAD: tl.constexpr, + BLOCK_SIZE_DIM: tl.constexpr, +): + batch_idx = tl.program_id(0) + head_offset = tl.program_id(1) * BLOCK_SIZE_HEAD + dim_offset = tl.program_id(2) * BLOCK_SIZE_DIM + + # 获取当前 segment 的 range + seq_start = tl.load(cu_seqlen_ptr + batch_idx) + seq_end = tl.load(cu_seqlen_ptr + batch_idx + 1) + seq_len = seq_end - seq_start + + # head 和 dim 的实际索引(block中相对位置) + head_idx = head_offset + tl.arange(0, BLOCK_SIZE_HEAD) + dim_idx = dim_offset + tl.arange(0, BLOCK_SIZE_DIM) + mask_head = head_idx < num_heads + mask_dim = dim_idx < head_dim + + # 初始化累加器(float32 精度) + acc = tl.zeros((BLOCK_SIZE_HEAD, BLOCK_SIZE_DIM), dtype=tl.float32) + + for seq_offset in range(0, seq_len, BLOCK_SIZE_SEQ): + local_seq_idx = tl.arange(0, BLOCK_SIZE_SEQ) + mask_seq = local_seq_idx < (seq_len - seq_offset) + global_seq = seq_start + seq_offset + local_seq_idx + # shape: [BLOCK_SIZE_SEQ, BLOCK_SIZE_HEAD, BLOCK_SIZE_DIM] + input_ptrs = ( + input_ptr + + global_seq[:, None, None] * input_stride_seq + + head_idx[None, :, None] * input_stride_head + + dim_idx[None, None, :] + ) + + # 加载输入,注意输入 dtype 指明 float16 以避免不必要转换 + x = tl.load( + input_ptrs, mask=mask_seq[:, None, None] & mask_head[None, :, None] & mask_dim[None, None, :], other=0.0 + ).to(tl.float32) + + acc += tl.sum(x, axis=0) # reduce over seq axis + + mean = acc / tl.maximum(seq_len, 1) + + # 构造输出地址 + output_ptrs = ( + output_ptr + batch_idx * output_stride_batch + head_idx[:, None] * output_stride_head + dim_idx[None, :] + ) + + tl.store(output_ptrs, mean.to(input_ptr.dtype.element_ty), mask=mask_head[:, None] & mask_dim[None, :]) + + +def segment_mean( + x: paddle.Tensor, cu_seqlen: paddle.Tensor # [total_seqlen, num_heads, head_dim] # [batch_size + 1] +): + """ + Examples: + import paddle + import paddlemix + cu_seqlens = [0, 1024, 2048, 4096] + total_seqlen = 4096 + num_head = 24 + head_dim = 128 + k = paddle.randn([total_seqlen, num_head, head_dim], dtype="float16") + cu_seqlen = paddle.to_tensor(cu_seqlens, paddle.int32) + km = paddlemix.triton_ops.segment_mean(k, cu_seqlen) + """ + num_batches = cu_seqlen.shape[0] - 1 + + num_heads = x.shape[1] + head_dim = x.shape[2] + + # 计算每个batch的最大序列长度 + max_seqlen = (cu_seqlen[1:] - cu_seqlen[:-1]).max().item() + input_stride_seq = num_heads * head_dim + input_stride_head = head_dim + + output_stride_batch = num_heads * head_dim + output_stride_head = head_dim + prepare_attr_for_triton_kernel = """ + const int num_batches = cu_seqlen.shape()[0] - 1; + const int num_heads = x.shape()[1]; + const int head_dim = x.shape()[2]; + int input_stride_seq = num_heads * head_dim; + int input_stride_head = head_dim; + int output_stride_batch = num_heads * head_dim; + int output_stride_head = head_dim; + paddle::Tensor output_tensor = paddle::empty({num_batches, num_heads, head_dim}, x.dtype(), x.place()); +""" + op_name = "triton_segment_mean" + op_name += get_dtype_str(x.dtype) + + # auto-tuning + segment_mean_configs = [] + segment_mean_configs.append( + {"BLOCK_SIZE_SEQ": 128, "BLOCK_SIZE_HEAD": 4, "BLOCK_SIZE_DIM": 64, "num_stages": 2, "num_warps": 4} + ) + segment_mean_configs.append( + {"BLOCK_SIZE_SEQ": 256, "BLOCK_SIZE_HEAD": 4, "BLOCK_SIZE_DIM": 64, "num_stages": 2, "num_warps": 4} + ) + segment_mean_configs.append( + {"BLOCK_SIZE_SEQ": 512, "BLOCK_SIZE_HEAD": 8, "BLOCK_SIZE_DIM": 64, "num_stages": 2, "num_warps": 8} + ) + segment_mean_configs.append( + {"BLOCK_SIZE_SEQ": 256, "BLOCK_SIZE_HEAD": 8, "BLOCK_SIZE_DIM": 128, "num_stages": 2, "num_warps": 4} + ) + + if op_name not in OpProtoHolder.instance().op_proto_map.keys(): + Output = paddle.empty([num_batches, num_heads, head_dim], dtype=x.dtype) + prepare_ptr_for_triton_kernel = """ + // prepare tensor + CUdeviceptr input_ptrs[3] = { + get_tensor_ptr(x), + get_tensor_ptr(output_tensor), + get_tensor_ptr(cu_seqlen) + }; +""" + return_tensor_names = "output_tensor" + template_used = rendering_common_template( + segment_mean, + prepare_attr_for_triton_kernel, + prepare_ptr_for_triton_kernel, + return_tensor_names, + ) + + # 确定kernel配置 + grid = ( + "num_batches", + "(num_heads + BLOCK_SIZE_HEAD - 1) / BLOCK_SIZE_HEAD", + "(head_dim + BLOCK_SIZE_DIM - 1) / BLOCK_SIZE_DIM", + ) + + # 调用kernel + segmented_mean_reduce_kernel[(op_name, template_used, grid, segment_mean_configs)]( + input_ptr=x, + output_ptr=Output, + cu_seqlen_ptr=cu_seqlen, + num_batches=-1, + num_heads=num_heads, + head_dim=head_dim, + input_stride_seq=input_stride_seq, + input_stride_head=input_stride_head, + output_stride_batch=output_stride_batch, + output_stride_head=output_stride_head, + ) + + if in_dynamic_or_pir_mode(): + outs = _C_ops._run_custom_op(op_name, x, cu_seqlen) + return outs[0] + else: + helper = LayerHelper(op_name, **locals()) + inputs = { + "x": x, + "cu_seqlen_tensor": cu_seqlen, + } + output = helper.create_variable_for_type_inference(dtype=x.dtype) + + helper.append_op(type=op_name, inputs=inputs, outputs={"output_tensor": output}) + return output From d49d8e210415f0fc4ff00f243de021ed83f6276b Mon Sep 17 00:00:00 2001 From: l1cacheDell Date: Mon, 14 Apr 2025 09:48:41 +0800 Subject: [PATCH 2/4] update kernel from review --- paddlenlp/ops/triton_ops/segment_mean.py | 5 ++--- 1 file changed, 2 insertions(+), 3 deletions(-) diff --git a/paddlenlp/ops/triton_ops/segment_mean.py b/paddlenlp/ops/triton_ops/segment_mean.py index 6c52d2846467..eb0885b75ccc 100644 --- a/paddlenlp/ops/triton_ops/segment_mean.py +++ b/paddlenlp/ops/triton_ops/segment_mean.py @@ -86,7 +86,7 @@ def segmented_mean_reduce_kernel( acc += tl.sum(x, axis=0) # reduce over seq axis - mean = acc / tl.maximum(seq_len, 1) + mean = acc / seq_len # 构造输出地址 output_ptrs = ( @@ -116,8 +116,7 @@ def segment_mean( num_heads = x.shape[1] head_dim = x.shape[2] - # 计算每个batch的最大序列长度 - max_seqlen = (cu_seqlen[1:] - cu_seqlen[:-1]).max().item() + # 计算必要的strides input_stride_seq = num_heads * head_dim input_stride_head = head_dim From 08291716fb60d4b09a6ad7098f5d52d8d67ccc41 Mon Sep 17 00:00:00 2001 From: l1cacheDell Date: Mon, 14 Apr 2025 11:40:20 +0800 Subject: [PATCH 3/4] fix docstring --- paddlenlp/ops/triton_ops/segment_mean.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/paddlenlp/ops/triton_ops/segment_mean.py b/paddlenlp/ops/triton_ops/segment_mean.py index eb0885b75ccc..d6ff0d07e10b 100644 --- a/paddlenlp/ops/triton_ops/segment_mean.py +++ b/paddlenlp/ops/triton_ops/segment_mean.py @@ -102,14 +102,15 @@ def segment_mean( """ Examples: import paddle - import paddlemix + from paddlenlp.ops.triton_ops.segment_mean import segment_mean + cu_seqlens = [0, 1024, 2048, 4096] total_seqlen = 4096 num_head = 24 head_dim = 128 k = paddle.randn([total_seqlen, num_head, head_dim], dtype="float16") cu_seqlen = paddle.to_tensor(cu_seqlens, paddle.int32) - km = paddlemix.triton_ops.segment_mean(k, cu_seqlen) + km = segment_mean(k, cu_seqlen) """ num_batches = cu_seqlen.shape[0] - 1 From b630601eddb63068357e897f133356cafd60396e Mon Sep 17 00:00:00 2001 From: l1cacheDell Date: Mon, 21 Apr 2025 17:38:40 +0800 Subject: [PATCH 4/4] update __init__ --- paddlenlp/ops/triton_ops/__init__.py | 1 + 1 file changed, 1 insertion(+) diff --git a/paddlenlp/ops/triton_ops/__init__.py b/paddlenlp/ops/triton_ops/__init__.py index 13c89a0f19c5..5b5d4c07233a 100644 --- a/paddlenlp/ops/triton_ops/__init__.py +++ b/paddlenlp/ops/triton_ops/__init__.py @@ -17,6 +17,7 @@ per_token_group_quant_fp8_api, per_token_group_quant_fp8_api_masked, ) + from .segment_mean import segment_mean __all__ = +["per_token_group_quant_fp8_api_masked", "per_token_group_quant_fp8_api"] except: