|
| 1 | +# Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved. |
| 2 | +# |
| 3 | +# Licensed under the Apache License, Version 2.0 (the "License"); |
| 4 | +# you may not use this file except in compliance with the License. |
| 5 | +# You may obtain a copy of the License at |
| 6 | +# |
| 7 | +# http://www.apache.org/licenses/LICENSE-2.0 |
| 8 | +# |
| 9 | +# Unless required by applicable law or agreed to in writing, software |
| 10 | +# distributed under the License is distributed on an "AS IS" BASIS, |
| 11 | +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
| 12 | +# See the License for the specific language governing permissions and |
| 13 | +# limitations under the License. |
| 14 | + |
| 15 | +# Notice: this kernel is especially implemented for the segment mean operation for a varlen qkv tensor. |
| 16 | +# for example, the k tensor is: [total_seqlen, num_head, head_dim], |
| 17 | +# where total_seqlen = seqlen 1 + seqlen 2 + ... + seqlen n. |
| 18 | +# So the segment mean triton kernel will do mean operation along the **seqlen** dim. |
| 19 | +# It will finally generate a `[bsz, num_head, head_dim]` shape-like result, |
| 20 | +# as the result of mean value of each seqlen segment. |
| 21 | + |
| 22 | +import paddle |
| 23 | +import triton.language as tl |
| 24 | +from paddle import _C_ops |
| 25 | +from paddle.base.framework import OpProtoHolder |
| 26 | +from paddle.base.layer_helper import LayerHelper |
| 27 | +from paddle.framework import in_dynamic_or_pir_mode |
| 28 | + |
| 29 | +from paddlenlp.ops.triton_ops.triton_utils import ( |
| 30 | + get_dtype_str, |
| 31 | + paddle_use_triton, |
| 32 | + rendering_common_template, |
| 33 | +) |
| 34 | + |
| 35 | + |
| 36 | +@paddle_use_triton(key=["num_heads", "head_dim"]) |
| 37 | +def segmented_mean_reduce_kernel( |
| 38 | + input_ptr, |
| 39 | + output_ptr, |
| 40 | + cu_seqlen_ptr, |
| 41 | + num_batches, |
| 42 | + num_heads, |
| 43 | + head_dim, |
| 44 | + input_stride_seq, |
| 45 | + input_stride_head, |
| 46 | + output_stride_batch, |
| 47 | + output_stride_head, |
| 48 | + BLOCK_SIZE_SEQ: tl.constexpr, |
| 49 | + BLOCK_SIZE_HEAD: tl.constexpr, |
| 50 | + BLOCK_SIZE_DIM: tl.constexpr, |
| 51 | +): |
| 52 | + batch_idx = tl.program_id(0) |
| 53 | + head_offset = tl.program_id(1) * BLOCK_SIZE_HEAD |
| 54 | + dim_offset = tl.program_id(2) * BLOCK_SIZE_DIM |
| 55 | + |
| 56 | + # 获取当前 segment 的 range |
| 57 | + seq_start = tl.load(cu_seqlen_ptr + batch_idx) |
| 58 | + seq_end = tl.load(cu_seqlen_ptr + batch_idx + 1) |
| 59 | + seq_len = seq_end - seq_start |
| 60 | + |
| 61 | + # head 和 dim 的实际索引(block中相对位置) |
| 62 | + head_idx = head_offset + tl.arange(0, BLOCK_SIZE_HEAD) |
| 63 | + dim_idx = dim_offset + tl.arange(0, BLOCK_SIZE_DIM) |
| 64 | + mask_head = head_idx < num_heads |
| 65 | + mask_dim = dim_idx < head_dim |
| 66 | + |
| 67 | + # 初始化累加器(float32 精度) |
| 68 | + acc = tl.zeros((BLOCK_SIZE_HEAD, BLOCK_SIZE_DIM), dtype=tl.float32) |
| 69 | + |
| 70 | + for seq_offset in range(0, seq_len, BLOCK_SIZE_SEQ): |
| 71 | + local_seq_idx = tl.arange(0, BLOCK_SIZE_SEQ) |
| 72 | + mask_seq = local_seq_idx < (seq_len - seq_offset) |
| 73 | + global_seq = seq_start + seq_offset + local_seq_idx |
| 74 | + # shape: [BLOCK_SIZE_SEQ, BLOCK_SIZE_HEAD, BLOCK_SIZE_DIM] |
| 75 | + input_ptrs = ( |
| 76 | + input_ptr |
| 77 | + + global_seq[:, None, None] * input_stride_seq |
| 78 | + + head_idx[None, :, None] * input_stride_head |
| 79 | + + dim_idx[None, None, :] |
| 80 | + ) |
| 81 | + |
| 82 | + # 加载输入,注意输入 dtype 指明 float16 以避免不必要转换 |
| 83 | + x = tl.load( |
| 84 | + input_ptrs, mask=mask_seq[:, None, None] & mask_head[None, :, None] & mask_dim[None, None, :], other=0.0 |
| 85 | + ).to(tl.float32) |
| 86 | + |
| 87 | + acc += tl.sum(x, axis=0) # reduce over seq axis |
| 88 | + |
| 89 | + mean = acc / seq_len |
| 90 | + |
| 91 | + # 构造输出地址 |
| 92 | + output_ptrs = ( |
| 93 | + output_ptr + batch_idx * output_stride_batch + head_idx[:, None] * output_stride_head + dim_idx[None, :] |
| 94 | + ) |
| 95 | + |
| 96 | + tl.store(output_ptrs, mean.to(input_ptr.dtype.element_ty), mask=mask_head[:, None] & mask_dim[None, :]) |
| 97 | + |
| 98 | + |
| 99 | +def segment_mean( |
| 100 | + x: paddle.Tensor, cu_seqlen: paddle.Tensor # [total_seqlen, num_heads, head_dim] # [batch_size + 1] |
| 101 | +): |
| 102 | + """ |
| 103 | + Examples: |
| 104 | + import paddle |
| 105 | + from paddlenlp.ops.triton_ops.segment_mean import segment_mean |
| 106 | +
|
| 107 | + cu_seqlens = [0, 1024, 2048, 4096] |
| 108 | + total_seqlen = 4096 |
| 109 | + num_head = 24 |
| 110 | + head_dim = 128 |
| 111 | + k = paddle.randn([total_seqlen, num_head, head_dim], dtype="float16") |
| 112 | + cu_seqlen = paddle.to_tensor(cu_seqlens, paddle.int32) |
| 113 | + km = segment_mean(k, cu_seqlen) |
| 114 | + """ |
| 115 | + num_batches = cu_seqlen.shape[0] - 1 |
| 116 | + |
| 117 | + num_heads = x.shape[1] |
| 118 | + head_dim = x.shape[2] |
| 119 | + |
| 120 | + # 计算必要的strides |
| 121 | + input_stride_seq = num_heads * head_dim |
| 122 | + input_stride_head = head_dim |
| 123 | + |
| 124 | + output_stride_batch = num_heads * head_dim |
| 125 | + output_stride_head = head_dim |
| 126 | + prepare_attr_for_triton_kernel = """ |
| 127 | + const int num_batches = cu_seqlen.shape()[0] - 1; |
| 128 | + const int num_heads = x.shape()[1]; |
| 129 | + const int head_dim = x.shape()[2]; |
| 130 | + int input_stride_seq = num_heads * head_dim; |
| 131 | + int input_stride_head = head_dim; |
| 132 | + int output_stride_batch = num_heads * head_dim; |
| 133 | + int output_stride_head = head_dim; |
| 134 | + paddle::Tensor output_tensor = paddle::empty({num_batches, num_heads, head_dim}, x.dtype(), x.place()); |
| 135 | +""" |
| 136 | + op_name = "triton_segment_mean" |
| 137 | + op_name += get_dtype_str(x.dtype) |
| 138 | + |
| 139 | + # auto-tuning |
| 140 | + segment_mean_configs = [] |
| 141 | + segment_mean_configs.append( |
| 142 | + {"BLOCK_SIZE_SEQ": 128, "BLOCK_SIZE_HEAD": 4, "BLOCK_SIZE_DIM": 64, "num_stages": 2, "num_warps": 4} |
| 143 | + ) |
| 144 | + segment_mean_configs.append( |
| 145 | + {"BLOCK_SIZE_SEQ": 256, "BLOCK_SIZE_HEAD": 4, "BLOCK_SIZE_DIM": 64, "num_stages": 2, "num_warps": 4} |
| 146 | + ) |
| 147 | + segment_mean_configs.append( |
| 148 | + {"BLOCK_SIZE_SEQ": 512, "BLOCK_SIZE_HEAD": 8, "BLOCK_SIZE_DIM": 64, "num_stages": 2, "num_warps": 8} |
| 149 | + ) |
| 150 | + segment_mean_configs.append( |
| 151 | + {"BLOCK_SIZE_SEQ": 256, "BLOCK_SIZE_HEAD": 8, "BLOCK_SIZE_DIM": 128, "num_stages": 2, "num_warps": 4} |
| 152 | + ) |
| 153 | + |
| 154 | + if op_name not in OpProtoHolder.instance().op_proto_map.keys(): |
| 155 | + Output = paddle.empty([num_batches, num_heads, head_dim], dtype=x.dtype) |
| 156 | + prepare_ptr_for_triton_kernel = """ |
| 157 | + // prepare tensor |
| 158 | + CUdeviceptr input_ptrs[3] = { |
| 159 | + get_tensor_ptr(x), |
| 160 | + get_tensor_ptr(output_tensor), |
| 161 | + get_tensor_ptr(cu_seqlen) |
| 162 | + }; |
| 163 | +""" |
| 164 | + return_tensor_names = "output_tensor" |
| 165 | + template_used = rendering_common_template( |
| 166 | + segment_mean, |
| 167 | + prepare_attr_for_triton_kernel, |
| 168 | + prepare_ptr_for_triton_kernel, |
| 169 | + return_tensor_names, |
| 170 | + ) |
| 171 | + |
| 172 | + # 确定kernel配置 |
| 173 | + grid = ( |
| 174 | + "num_batches", |
| 175 | + "(num_heads + BLOCK_SIZE_HEAD - 1) / BLOCK_SIZE_HEAD", |
| 176 | + "(head_dim + BLOCK_SIZE_DIM - 1) / BLOCK_SIZE_DIM", |
| 177 | + ) |
| 178 | + |
| 179 | + # 调用kernel |
| 180 | + segmented_mean_reduce_kernel[(op_name, template_used, grid, segment_mean_configs)]( |
| 181 | + input_ptr=x, |
| 182 | + output_ptr=Output, |
| 183 | + cu_seqlen_ptr=cu_seqlen, |
| 184 | + num_batches=-1, |
| 185 | + num_heads=num_heads, |
| 186 | + head_dim=head_dim, |
| 187 | + input_stride_seq=input_stride_seq, |
| 188 | + input_stride_head=input_stride_head, |
| 189 | + output_stride_batch=output_stride_batch, |
| 190 | + output_stride_head=output_stride_head, |
| 191 | + ) |
| 192 | + |
| 193 | + if in_dynamic_or_pir_mode(): |
| 194 | + outs = _C_ops._run_custom_op(op_name, x, cu_seqlen) |
| 195 | + return outs[0] |
| 196 | + else: |
| 197 | + helper = LayerHelper(op_name, **locals()) |
| 198 | + inputs = { |
| 199 | + "x": x, |
| 200 | + "cu_seqlen_tensor": cu_seqlen, |
| 201 | + } |
| 202 | + output = helper.create_variable_for_type_inference(dtype=x.dtype) |
| 203 | + |
| 204 | + helper.append_op(type=op_name, inputs=inputs, outputs={"output_tensor": output}) |
| 205 | + return output |
0 commit comments