Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: apply sm_scale at logits instead of q in FA2 template #801

Merged
merged 8 commits into from
Feb 9, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions csrc/batch_decode_customize_config.jinja
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
#include <flashinfer/math.cuh>
#include <flashinfer/layout.cuh>
#include <flashinfer/pos_enc.cuh>
#include <flashinfer/attention/variant_helper.cuh>

#define ADDITIONAL_FUNC_PARAMS {{ additional_func_params }}
#define ADDITIONAL_PARAMS_SETTER {{ additional_params_setter }}
Expand Down
1 change: 1 addition & 0 deletions csrc/batch_prefill_customize_config.jinja
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
#include <flashinfer/utils.cuh>
#include <flashinfer/pos_enc.cuh>
#include <flashinfer/fastdiv.cuh>
#include <flashinfer/attention/variant_helper.cuh>

#define ADDITIONAL_FUNC_PARAMS {{ additional_func_params }}
#define ADDITIONAL_PARAMS_SETTER {{ additional_params_setter }}
Expand Down
1 change: 1 addition & 0 deletions csrc/batch_prefill_sm90_customize_config.jinja
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
#pragma once
#include <flashinfer/attention/hopper/attention_updater.cuh>
#include <flashinfer/attention/hopper/variant_helper.cuh>
#include <flashinfer/math.cuh>
#include <flashinfer/layout.cuh>
#include <flashinfer/cutlass_utils.cuh>
Expand Down
1 change: 1 addition & 0 deletions csrc/single_decode_customize_config.jinja
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
#include <flashinfer/math.cuh>
#include <flashinfer/layout.cuh>
#include <flashinfer/pos_enc.cuh>
#include <flashinfer/attention/variant_helper.cuh>

#define ADDITIONAL_FUNC_PARAMS {{ additional_func_params }}
#define ADDITIONAL_PARAMS_SETTER {{ additional_params_setter }}
Expand Down
1 change: 1 addition & 0 deletions csrc/single_prefill_customize_config.jinja
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
#include <flashinfer/utils.cuh>
#include <flashinfer/pos_enc.cuh>
#include <flashinfer/fastdiv.cuh>
#include <flashinfer/attention/variant_helper.cuh>

#define ADDITIONAL_FUNC_PARAMS {{ additional_func_params }}
#define ADDITIONAL_PARAMS_SETTER {{ additional_params_setter }}
Expand Down
1 change: 1 addition & 0 deletions csrc/single_prefill_sm90_customize_config.jinja
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
#pragma once
#include <flashinfer/attention/hopper/attention_updater.cuh>
#include <flashinfer/attention/hopper/variant_helper.cuh>
#include <flashinfer/math.cuh>
#include <flashinfer/layout.cuh>
#include <flashinfer/cutlass_utils.cuh>
Expand Down
2 changes: 1 addition & 1 deletion flashinfer/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -171,7 +171,7 @@ def _get_cache_alibi_slopes_buf(
key = (f"alibi_slopes_{num_qo_heads}", device)
buf = _cache_buf.get(key)
if buf is None:
buf = (get_alibi_slopes(num_qo_heads) * log2e).to(device)
buf = get_alibi_slopes(num_qo_heads).to(device)
_cache_buf[key] = buf
return buf

Expand Down
14 changes: 4 additions & 10 deletions include/flashinfer/attention/decode.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -91,6 +91,10 @@ __device__ __forceinline__ void compute_qk(const Params& params, AttentionVarian
const uint32_t pos = kv_idx_base + tz * tile_size + j;
s[j] = variant.LogitsTransform(params, s[j], batch_idx, /*qo_idx=*/0, /*kv_idx=*/pos,
qo_head_idx, kv_head_idx);
if constexpr (variant.use_softmax) {
s[j] *= variant.sm_scale_log2;
}

bool mask = variant.LogitsMask(params, batch_idx, /*qo_idx=*/0, /*kv_idx=*/pos, qo_head_idx,
kv_head_idx);
s[j] = (iter_base + tz * tile_size + j < iter_bound && mask) ? s[j] : -math::inf;
Expand Down Expand Up @@ -263,11 +267,6 @@ __global__ void SingleDecodeWithKVCacheKernel(const __grid_constant__ Params par
// do not apply rotary embedding to q matrix
q_vec.cast_load(q + qo_head_idx * q_stride_h + tx * vec_size);
}
// multiple q_vec by sm_scale
#pragma unroll
for (uint32_t i = 0; i < vec_size; ++i) {
q_vec[i] = variant.QueryTransform(params, q_vec[i]);
}
block.sync();

uint32_t chunk_start = kv_chunk_idx * kv_chunk_size;
Expand Down Expand Up @@ -456,11 +455,6 @@ __global__ void BatchDecodeWithPagedKVCacheKernel(const __grid_constant__ Params
// do not apply rotary embedding to q matrix
q_vec.cast_load(q + batch_idx * q_stride_n + qo_head_idx * q_stride_h + tx * vec_size);
}
#pragma unroll
for (uint32_t i = 0; i < vec_size; ++i) {
q_vec[i] = variant.QueryTransform(params, q_vec[i]);
}
block.sync();

// preload k/v tiles
uint32_t stage_idx = 0;
Expand Down
64 changes: 64 additions & 0 deletions include/flashinfer/attention/hopper/variant_helper.cuh
Original file line number Diff line number Diff line change
@@ -0,0 +1,64 @@
/*
* Copyright (c) 2024 by FlashInfer team.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef FLASHINFER_ATTENTION_HOPPER_VARIANT_HELPER_H
#define FLASHINFER_ATTENTION_HOPPER_VARIANT_HELPER_H

#include <cuda_runtime.h>

#include <cstdint>

namespace flashinfer {

#define REGISTER_QUERY_TRANSFORM(params, q, ...) \
template <typename MainloopParams, typename T> \
__device__ __forceinline__ T QueryTransform(const MainloopParams& params, void* q_smem) { \
__VA_ARGS__ \
}

#define REGISTER_KEY_TRANSFORM(params, k, ...) \
template <typename MainloopParams, typename T> \
__device__ __forceinline__ T KeyTransform(const MainloopParams& params, void* k_smem) { \
__VA_ARGS__ \
}

#define REGISTER_LOGITS_TRANSFORM(params, logits, batch_idx, qo_idx, kv_idx, qo_head_idx, \
kv_head_idx, ...) \
template <typename MainloopParams, typename T> \
__device__ __forceinline__ T LogitsTransform( \
const MainloopParams& params, T logits, uint32_t batch_idx, uint32_t qo_idx, \
uint32_t kv_idx, uint32_t qo_head_idx, uint32_t kv_head_idx) { \
__VA_ARGS__ \
}

#define REGISTER_LOGITS_MASK(params, batch_idx, qo_idx, kv_idx, qo_head_idx, kv_head_idx, ...) \
template <typename MainloopParams> \
__device__ __forceinline__ bool LogitsMask(const MainloopParams& params, uint32_t batch_idx, \
uint32_t qo_idx, uint32_t kv_idx, \
uint32_t qo_head_idx, uint32_t kv_head_idx) { \
__VA_ARGS__ \
}

struct AttentionVariantBase {
REGISTER_LOGITS_TRANSFORM(params, logits, batch_idx, qo_idx, kv_idx, qo_head_idx, kv_head_idx,
{ return logits; })

REGISTER_LOGITS_MASK(params, batch_idx, qo_idx, kv_idx, qo_head_idx, kv_head_idx,
{ return true; })
};

} // namespace flashinfer

#endif // FLASHINFER_ATTENTION_HOPPER_VARIANT_HELPER_H
19 changes: 6 additions & 13 deletions include/flashinfer/attention/hopper/variants.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@

#include "../../math.cuh"
#include "attention_updater.cuh"
#include "variant_helper.cuh"

namespace flashinfer {

Expand All @@ -36,12 +37,8 @@ struct StandardAttention {
return OnlineSoftmax<NUM_ROWS_PER_THREAD, /*WITH_SCALE=*/true>(sm_scale_log2);
}

template <typename MainloopParams, typename T>
__device__ __forceinline__ T LogitsTransform(const MainloopParams& params, T logits,
uint32_t batch_idx, uint32_t qo_idx, uint32_t kv_idx,
uint32_t qo_head_idx, uint32_t kv_head_idx) {
return logits;
}
REGISTER_LOGITS_TRANSFORM(params, logits, batch_idx, qo_idx, kv_idx, qo_head_idx, kv_head_idx,
{ return logits; })
};

struct LogitsSoftCap {
Expand All @@ -57,15 +54,11 @@ struct LogitsSoftCap {

template <int NUM_ROWS_PER_THREAD>
__device__ auto GetAttentionUpdater() {
return OnlineSoftmax<NUM_ROWS_PER_THREAD, /*WITH_SCALE=*/false>(0.);
return OnlineSoftmax<NUM_ROWS_PER_THREAD, /*WITH_SCALE=*/true>(post_tanh_scale);
}

template <typename MainloopParams, typename T>
__device__ __forceinline__ T LogitsTransform(const MainloopParams& params, T logits,
uint32_t batch_idx, uint32_t qo_idx, uint32_t kv_idx,
uint32_t qo_head_idx, uint32_t kv_head_idx) {
return math::tanh(logits * pre_tanh_scale) * post_tanh_scale;
}
REGISTER_LOGITS_TRANSFORM(params, logits, batch_idx, qo_idx, kv_idx, qo_head_idx, kv_head_idx,
{ return math::tanh(logits * pre_tanh_scale); })
};

template <bool use_logits_soft_cap>
Expand Down
Loading