Skip to content

Commit 0547781

Browse files
authored
feat: apply sm_scale at logits instead of q in FA2 template (#801)
This PR is the second step (3 in total) in addressing the bf16 kernel correctness issue for deepseek model.
1 parent 32388d0 commit 0547781

15 files changed

+259
-372
lines changed

csrc/batch_decode_customize_config.jinja

+1
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33
#include <flashinfer/math.cuh>
44
#include <flashinfer/layout.cuh>
55
#include <flashinfer/pos_enc.cuh>
6+
#include <flashinfer/attention/variant_helper.cuh>
67

78
#define ADDITIONAL_FUNC_PARAMS {{ additional_func_params }}
89
#define ADDITIONAL_PARAMS_SETTER {{ additional_params_setter }}

csrc/batch_prefill_customize_config.jinja

+1
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55
#include <flashinfer/utils.cuh>
66
#include <flashinfer/pos_enc.cuh>
77
#include <flashinfer/fastdiv.cuh>
8+
#include <flashinfer/attention/variant_helper.cuh>
89

910
#define ADDITIONAL_FUNC_PARAMS {{ additional_func_params }}
1011
#define ADDITIONAL_PARAMS_SETTER {{ additional_params_setter }}

csrc/batch_prefill_sm90_customize_config.jinja

+1
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
#pragma once
22
#include <flashinfer/attention/hopper/attention_updater.cuh>
3+
#include <flashinfer/attention/hopper/variant_helper.cuh>
34
#include <flashinfer/math.cuh>
45
#include <flashinfer/layout.cuh>
56
#include <flashinfer/cutlass_utils.cuh>

csrc/single_decode_customize_config.jinja

+1
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22
#include <flashinfer/math.cuh>
33
#include <flashinfer/layout.cuh>
44
#include <flashinfer/pos_enc.cuh>
5+
#include <flashinfer/attention/variant_helper.cuh>
56

67
#define ADDITIONAL_FUNC_PARAMS {{ additional_func_params }}
78
#define ADDITIONAL_PARAMS_SETTER {{ additional_params_setter }}

csrc/single_prefill_customize_config.jinja

+1
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44
#include <flashinfer/utils.cuh>
55
#include <flashinfer/pos_enc.cuh>
66
#include <flashinfer/fastdiv.cuh>
7+
#include <flashinfer/attention/variant_helper.cuh>
78

89
#define ADDITIONAL_FUNC_PARAMS {{ additional_func_params }}
910
#define ADDITIONAL_PARAMS_SETTER {{ additional_params_setter }}

csrc/single_prefill_sm90_customize_config.jinja

+1
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
#pragma once
22
#include <flashinfer/attention/hopper/attention_updater.cuh>
3+
#include <flashinfer/attention/hopper/variant_helper.cuh>
34
#include <flashinfer/math.cuh>
45
#include <flashinfer/layout.cuh>
56
#include <flashinfer/cutlass_utils.cuh>

flashinfer/utils.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -171,7 +171,7 @@ def _get_cache_alibi_slopes_buf(
171171
key = (f"alibi_slopes_{num_qo_heads}", device)
172172
buf = _cache_buf.get(key)
173173
if buf is None:
174-
buf = (get_alibi_slopes(num_qo_heads) * log2e).to(device)
174+
buf = get_alibi_slopes(num_qo_heads).to(device)
175175
_cache_buf[key] = buf
176176
return buf
177177

include/flashinfer/attention/decode.cuh

+4-10
Original file line numberDiff line numberDiff line change
@@ -91,6 +91,10 @@ __device__ __forceinline__ void compute_qk(const Params& params, AttentionVarian
9191
const uint32_t pos = kv_idx_base + tz * tile_size + j;
9292
s[j] = variant.LogitsTransform(params, s[j], batch_idx, /*qo_idx=*/0, /*kv_idx=*/pos,
9393
qo_head_idx, kv_head_idx);
94+
if constexpr (variant.use_softmax) {
95+
s[j] *= variant.sm_scale_log2;
96+
}
97+
9498
bool mask = variant.LogitsMask(params, batch_idx, /*qo_idx=*/0, /*kv_idx=*/pos, qo_head_idx,
9599
kv_head_idx);
96100
s[j] = (iter_base + tz * tile_size + j < iter_bound && mask) ? s[j] : -math::inf;
@@ -263,11 +267,6 @@ __global__ void SingleDecodeWithKVCacheKernel(const __grid_constant__ Params par
263267
// do not apply rotary embedding to q matrix
264268
q_vec.cast_load(q + qo_head_idx * q_stride_h + tx * vec_size);
265269
}
266-
// multiple q_vec by sm_scale
267-
#pragma unroll
268-
for (uint32_t i = 0; i < vec_size; ++i) {
269-
q_vec[i] = variant.QueryTransform(params, q_vec[i]);
270-
}
271270
block.sync();
272271

273272
uint32_t chunk_start = kv_chunk_idx * kv_chunk_size;
@@ -456,11 +455,6 @@ __global__ void BatchDecodeWithPagedKVCacheKernel(const __grid_constant__ Params
456455
// do not apply rotary embedding to q matrix
457456
q_vec.cast_load(q + batch_idx * q_stride_n + qo_head_idx * q_stride_h + tx * vec_size);
458457
}
459-
#pragma unroll
460-
for (uint32_t i = 0; i < vec_size; ++i) {
461-
q_vec[i] = variant.QueryTransform(params, q_vec[i]);
462-
}
463-
block.sync();
464458

465459
// preload k/v tiles
466460
uint32_t stage_idx = 0;
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,64 @@
1+
/*
2+
* Copyright (c) 2024 by FlashInfer team.
3+
*
4+
* Licensed under the Apache License, Version 2.0 (the "License");
5+
* you may not use this file except in compliance with the License.
6+
* You may obtain a copy of the License at
7+
*
8+
* http://www.apache.org/licenses/LICENSE-2.0
9+
*
10+
* Unless required by applicable law or agreed to in writing, software
11+
* distributed under the License is distributed on an "AS IS" BASIS,
12+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
* See the License for the specific language governing permissions and
14+
* limitations under the License.
15+
*/
16+
#ifndef FLASHINFER_ATTENTION_HOPPER_VARIANT_HELPER_H
17+
#define FLASHINFER_ATTENTION_HOPPER_VARIANT_HELPER_H
18+
19+
#include <cuda_runtime.h>
20+
21+
#include <cstdint>
22+
23+
namespace flashinfer {
24+
25+
#define REGISTER_QUERY_TRANSFORM(params, q, ...) \
26+
template <typename MainloopParams, typename T> \
27+
__device__ __forceinline__ T QueryTransform(const MainloopParams& params, void* q_smem) { \
28+
__VA_ARGS__ \
29+
}
30+
31+
#define REGISTER_KEY_TRANSFORM(params, k, ...) \
32+
template <typename MainloopParams, typename T> \
33+
__device__ __forceinline__ T KeyTransform(const MainloopParams& params, void* k_smem) { \
34+
__VA_ARGS__ \
35+
}
36+
37+
#define REGISTER_LOGITS_TRANSFORM(params, logits, batch_idx, qo_idx, kv_idx, qo_head_idx, \
38+
kv_head_idx, ...) \
39+
template <typename MainloopParams, typename T> \
40+
__device__ __forceinline__ T LogitsTransform( \
41+
const MainloopParams& params, T logits, uint32_t batch_idx, uint32_t qo_idx, \
42+
uint32_t kv_idx, uint32_t qo_head_idx, uint32_t kv_head_idx) { \
43+
__VA_ARGS__ \
44+
}
45+
46+
#define REGISTER_LOGITS_MASK(params, batch_idx, qo_idx, kv_idx, qo_head_idx, kv_head_idx, ...) \
47+
template <typename MainloopParams> \
48+
__device__ __forceinline__ bool LogitsMask(const MainloopParams& params, uint32_t batch_idx, \
49+
uint32_t qo_idx, uint32_t kv_idx, \
50+
uint32_t qo_head_idx, uint32_t kv_head_idx) { \
51+
__VA_ARGS__ \
52+
}
53+
54+
struct AttentionVariantBase {
55+
REGISTER_LOGITS_TRANSFORM(params, logits, batch_idx, qo_idx, kv_idx, qo_head_idx, kv_head_idx,
56+
{ return logits; })
57+
58+
REGISTER_LOGITS_MASK(params, batch_idx, qo_idx, kv_idx, qo_head_idx, kv_head_idx,
59+
{ return true; })
60+
};
61+
62+
} // namespace flashinfer
63+
64+
#endif // FLASHINFER_ATTENTION_HOPPER_VARIANT_HELPER_H

include/flashinfer/attention/hopper/variants.cuh

+6-13
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@
2020

2121
#include "../../math.cuh"
2222
#include "attention_updater.cuh"
23+
#include "variant_helper.cuh"
2324

2425
namespace flashinfer {
2526

@@ -36,12 +37,8 @@ struct StandardAttention {
3637
return OnlineSoftmax<NUM_ROWS_PER_THREAD, /*WITH_SCALE=*/true>(sm_scale_log2);
3738
}
3839

39-
template <typename MainloopParams, typename T>
40-
__device__ __forceinline__ T LogitsTransform(const MainloopParams& params, T logits,
41-
uint32_t batch_idx, uint32_t qo_idx, uint32_t kv_idx,
42-
uint32_t qo_head_idx, uint32_t kv_head_idx) {
43-
return logits;
44-
}
40+
REGISTER_LOGITS_TRANSFORM(params, logits, batch_idx, qo_idx, kv_idx, qo_head_idx, kv_head_idx,
41+
{ return logits; })
4542
};
4643

4744
struct LogitsSoftCap {
@@ -57,15 +54,11 @@ struct LogitsSoftCap {
5754

5855
template <int NUM_ROWS_PER_THREAD>
5956
__device__ auto GetAttentionUpdater() {
60-
return OnlineSoftmax<NUM_ROWS_PER_THREAD, /*WITH_SCALE=*/false>(0.);
57+
return OnlineSoftmax<NUM_ROWS_PER_THREAD, /*WITH_SCALE=*/true>(post_tanh_scale);
6158
}
6259

63-
template <typename MainloopParams, typename T>
64-
__device__ __forceinline__ T LogitsTransform(const MainloopParams& params, T logits,
65-
uint32_t batch_idx, uint32_t qo_idx, uint32_t kv_idx,
66-
uint32_t qo_head_idx, uint32_t kv_head_idx) {
67-
return math::tanh(logits * pre_tanh_scale) * post_tanh_scale;
68-
}
60+
REGISTER_LOGITS_TRANSFORM(params, logits, batch_idx, qo_idx, kv_idx, qo_head_idx, kv_head_idx,
61+
{ return math::tanh(logits * pre_tanh_scale); })
6962
};
7063

7164
template <bool use_logits_soft_cap>

0 commit comments

Comments
 (0)