Skip to content

Commit 323ffec

Browse files
authored
[XPU] support absorb mla, separate the model networking from GPU (#10353)
* 独立的,不和 GPU 混用的组网,现在还缺算子迁移过来 * update csrc * fix * fix * 移除多余代码 * support infer param * fix save output * add warm script * 组网剥离 * fix style * update setup and remove useless files * fused block attn * hack 实现 int_wo MOE * 多线程 mmlu 脚本 * mmlu 8K 输出 * add inference time for benchmark * fix warm max_dec_len * Revert "mmlu 8K 输出" This reverts commit 0c89ea0. * Revert "多线程 mmlu 脚本" This reverts commit fe6bfc5. * add high speed moe * change rope cos sin to cpu * fix save output
1 parent 2211480 commit 323ffec

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

44 files changed

+5774
-1053
lines changed

.gitignore

+1-1
Original file line numberDiff line numberDiff line change
@@ -140,4 +140,4 @@ autogen/
140140
#fp8
141141
ops/csrc/fp8/deep_gemm/include/cutlass
142142
ops/csrc/fp8/deep_gemm/include/cute
143-
143+
.ccls-cache
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,219 @@
1+
// Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved.
2+
//
3+
// Licensed under the Apache License, Version 2.0 (the "License");
4+
// you may not use this file except in compliance with the License.
5+
// You may obtain a copy of the License at
6+
//
7+
// http://www.apache.org/licenses/LICENSE-2.0
8+
//
9+
// Unless required by applicable law or agreed to in writing, software
10+
// distributed under the License is distributed on an "AS IS" BASIS,
11+
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
// See the License for the specific language governing permissions and
13+
// limitations under the License.
14+
15+
#include <paddle/phi/backends/xpu/xpu_context.h>
16+
17+
#include "paddle/extension.h"
18+
#include "paddle/phi/core/enforce.h"
19+
#include "xpu/plugin.h"
20+
#include <core/ctx_manager.h>
21+
#include <core/xft_check.h>
22+
#include <core/xft_event.h>
23+
#include <core/xft_params.h>
24+
#include <xft/xdnn_plugin.h>
25+
#include <xft/operation/page_attn.h>
26+
#include <xft/operation/fmha.h>
27+
#include <flash_api.h> // link xfa
28+
#include "ops.h"
29+
30+
namespace xftkernel = baidu::xpu::xftkernel;
31+
32+
template <typename T>
33+
struct kl3_pa_TL_trait {
34+
using TL = T;
35+
};
36+
template <>
37+
struct kl3_pa_TL_trait<bfloat16> {
38+
using TL = float;
39+
};
40+
std::vector<paddle::Tensor> MlaDeAttn(
41+
const paddle::Tensor& q,
42+
const paddle::Tensor& kv_cache,
43+
const paddle::Tensor& decoder_context_len,
44+
const paddle::Tensor& decoder_batch_map,
45+
const paddle::Tensor& decoder_context_len_cpu,
46+
const paddle::Tensor& decoder_batch_map_cpu,
47+
const paddle::Tensor& dec_batch_tensor,
48+
const paddle::Tensor& padding_offsets,
49+
const paddle::Tensor& cum_offsets,
50+
const paddle::Tensor& block_tables,
51+
const float softmax_scale,
52+
const int block_size,
53+
const int num_head,
54+
const int kv_lora_rank,
55+
const int rope_head_dim,
56+
const int dim_qk,
57+
const int dim_v) {
58+
baidu::xpu::api::plugin::print_times("[TIME BEGIN] MlaDeAttn" );
59+
phi::XPUPlace place(phi::backends::xpu::GetXPUCurrentDeviceId());
60+
auto dev_ctx = paddle::experimental::DeviceContextPool::Instance().Get(place);
61+
auto xpu_ctx = static_cast<const phi::XPUContext*>(dev_ctx);
62+
xpu::ctx_guard RAII_GUARD(xpu_ctx->x_context());
63+
64+
using QType = typename XPUTypeTrait<bfloat16>::Type;
65+
using CacheType = typename XPUTypeTrait<bfloat16>::Type;
66+
typedef paddle::bfloat16 qdata_t, cache_t;
67+
const auto& input_dims = q.dims();
68+
const int bsz = cum_offsets.dims()[0];
69+
const int token_num = input_dims[0];
70+
const int block_batch = block_tables.dims()[0]; // TODO参数含义 block_batch_ PageParam page_param_
71+
const int max_block_per_seq = block_tables.dims()[1];
72+
const int max_seq_len = block_size * max_block_per_seq;
73+
int dec_batch = dec_batch_tensor.data<int32_t>()[0];
74+
// 初始化输入:q k v
75+
auto q_xft = baidu::xpu::xft::xftTensor<QType, 3>(
76+
reinterpret_cast<QType*>(const_cast<paddle::bfloat16*>(q.data<qdata_t>())),
77+
std::array<int64_t, 3>{q.shape()[0],
78+
q.shape()[1],
79+
q.shape()[2]});
80+
// 初始化输入:k cache
81+
auto kv_cache_xft = baidu::xpu::xft::xftTensor<CacheType, 4>(
82+
reinterpret_cast<CacheType*>(const_cast<paddle::bfloat16*>(kv_cache.data<cache_t>())),
83+
std::array<int64_t, 4>{kv_cache.shape()[0],
84+
kv_cache.shape()[1],
85+
kv_cache.shape()[2],
86+
kv_cache.shape()[3]});
87+
// 初始化输入:block table
88+
auto block_tables_xft = baidu::xpu::xft::xftTensor<int, 2>(
89+
reinterpret_cast<int*>(const_cast<int*>(block_tables.data<int>())),
90+
std::array<int64_t, 2>{block_tables.shape()[0],
91+
block_tables.shape()[1]});
92+
// 初始化输出tensor
93+
auto fmha_out = paddle::full({q.shape()[0], num_head * kv_lora_rank}, -2, q.type(), q.place());
94+
auto fmha_out_xft = baidu::xpu::xft::xftTensor<QType, 2>(
95+
reinterpret_cast<QType*>(const_cast<paddle::bfloat16*>(fmha_out.data<qdata_t>())),
96+
std::array<int64_t, 2>{fmha_out.shape()[0],
97+
fmha_out.shape()[1]});
98+
99+
// decoder
100+
if(dec_batch > 0){
101+
// context_len
102+
baidu::xpu::api::VectorParam<int32_t> context_len_vp{const_cast<int32_t*>(decoder_context_len_cpu.data<int32_t>()), dec_batch, const_cast<int32_t*>(decoder_context_len.data<int32_t>())};
103+
// real batch
104+
baidu::xpu::api::VectorParam<int32_t> valid_batch_vp{const_cast<int32_t*>(decoder_batch_map_cpu.data<int32_t>()), dec_batch, const_cast<int32_t*>(decoder_batch_map.data<int32_t>())};
105+
106+
// multi_latent_attention
107+
using TQ = bfloat16;
108+
using TKVCACHE = bfloat16;
109+
using TO = TQ;
110+
using TGEMM = float;
111+
using TEW = float;
112+
using TID = int;
113+
constexpr int quant_mode = 0;
114+
// xpu_ctx->x_context().set_debug_level(0xa1);
115+
int ret = baidu::xpu::xfa::multi_latent_attention<
116+
TQ,
117+
TKVCACHE,
118+
TO,
119+
TGEMM,
120+
TEW,
121+
TID,
122+
quant_mode>(
123+
xpu_ctx->x_context(),
124+
fmha_out_xft.data(),
125+
q_xft.data(),
126+
kv_cache_xft.data(),
127+
block_tables_xft.data(),
128+
context_len_vp,
129+
valid_batch_vp,
130+
block_batch,
131+
max_seq_len,
132+
num_head,
133+
kv_lora_rank,
134+
rope_head_dim,
135+
nullptr, // attn_mask
136+
softmax_scale, // 0.13523377478122711f, // scale
137+
block_size,
138+
max_block_per_seq,
139+
-1,
140+
nullptr,
141+
nullptr,
142+
nullptr);
143+
}
144+
baidu::xpu::api::plugin::print_times("[TIME END] MlaDeAttn");
145+
return {fmha_out};
146+
}
147+
148+
std::vector<std::vector<int64_t>> MlaDeAttnInferShape(
149+
const std::vector<int64_t>& q_shape,
150+
const std::vector<int64_t>& kv_cache_shape,
151+
const std::vector<int64_t>& decoder_context_len_shape,
152+
const std::vector<int64_t>& decoder_batch_map_shape,
153+
const std::vector<int64_t>& decoder_context_len_cpu_shape,
154+
const std::vector<int64_t>& decoder_batch_map_cpu_shape,
155+
const std::vector<int64_t>& dec_batch_tensor_shape,
156+
const std::vector<int64_t>& padding_offsets_shape,
157+
const std::vector<int64_t>& cum_offsets_shape,
158+
const std::vector<int64_t>& block_tables_shape,
159+
const float softmax_scale,
160+
const int block_size,
161+
const int num_head,
162+
const int kv_lora_rank,
163+
const int rope_head_dim,
164+
const int dim_qk,
165+
const int dim_v) {
166+
return {{q_shape[0], num_head * kv_lora_rank}};
167+
}
168+
169+
std::vector<paddle::DataType> MlaDeAttnInferDtype(
170+
const paddle::DataType& q_dtype,
171+
const paddle::DataType& kv_cache_dtype,
172+
const paddle::DataType& decoder_context_len_dtype,
173+
const paddle::DataType& decoder_batch_map_dtype,
174+
const paddle::DataType& decoder_context_len_cpu_dtype,
175+
const paddle::DataType& decoder_batch_map_cpu_dtype,
176+
const paddle::DataType& dec_batch_tensor_dtype,
177+
const paddle::DataType& padding_offsets_dtype,
178+
const paddle::DataType& cum_offsets_dtype,
179+
const paddle::DataType& block_tables_dtype,
180+
const float softmax_scale,
181+
const int block_size,
182+
const int num_head,
183+
const int kv_lora_rank,
184+
const int rope_head_dim,
185+
const int dim_qk,
186+
const int dim_v) {
187+
if (q_dtype == paddle::DataType::FLOAT16) {
188+
return {paddle::DataType::FLOAT16};
189+
} else if(q_dtype == paddle::DataType::BFLOAT16){
190+
return {paddle::DataType::BFLOAT16};
191+
}
192+
else {
193+
PD_THROW("Only supported attr of compute_dtype in ['fp16','bfp16'].");
194+
}
195+
}
196+
197+
PD_BUILD_OP(absorb_mla_block_mha_decoder_xpu)
198+
.Inputs({"q",
199+
"kv_cache",
200+
"decoder_context_len",
201+
"decoder_batch_map",
202+
"decoder_context_len_cpu",
203+
"decoder_batch_map_cpu",
204+
"dec_batch_tensor",
205+
"padding_offsets",
206+
"cum_offsets",
207+
"block_tables"})
208+
.Outputs({"fmha_out"})
209+
.Attrs({"softmax_scale: float",
210+
"block_size: int",
211+
"num_head: int",
212+
"kv_lora_rank: int",
213+
"rope_head_dim: int",
214+
"dim_qk: int",
215+
"dim_v: int"})
216+
.SetKernelFn(PD_KERNEL(MlaDeAttn))
217+
.SetInferShapeFn(PD_INFER_SHAPE(MlaDeAttnInferShape))
218+
.SetInferDtypeFn(PD_INFER_DTYPE(MlaDeAttnInferDtype));
219+

0 commit comments

Comments
 (0)