Skip to content

Commit 0c5b95f

Browse files
authored
Cherry-pick LLaMA GQA mask to rel-1.16.2 (round 4) (#18350)
Cherry-pick LLaMA GQA attention mask and script changes to 1.16.2 release branch. --------- Co-authored-by: aciddelgado <[email protected]> Co-authored-by: Yufeng Li <[email protected]> Co-authored-by: kunal-vaishnavi <[email protected]>
1 parent 8f06330 commit 0c5b95f

22 files changed

+1306
-463
lines changed

docs/ContribOperators.md

+12-14
Original file line numberDiff line numberDiff line change
@@ -2236,19 +2236,15 @@ This version of the operator has been available since version 1 of the 'com.micr
22362236
#### Attributes
22372237

22382238
<dl>
2239-
<dt><tt>is_past_bsnh</tt> : int</dt>
2240-
<dd>Whether past kv uses BSNH, otherwise BNSH. Default value is 1 (BSNH).</dd>
22412239
<dt><tt>kv_num_heads</tt> : int (required)</dt>
22422240
<dd>Number of attention heads for k and v</dd>
22432241
<dt><tt>num_heads</tt> : int (required)</dt>
22442242
<dd>Number of attention heads for q</dd>
22452243
<dt><tt>scale</tt> : float</dt>
22462244
<dd>Custom scale will be used if specified. Default value is 1/sqrt(head_size)</dd>
2247-
<dt><tt>unidirectional</tt> : int</dt>
2248-
<dd>Whether every token can only attend to previous tokens. Default value is 1.</dd>
22492245
</dl>
22502246

2251-
#### Inputs (3 - 6)
2247+
#### Inputs
22522248

22532249
<dl>
22542250
<dt><tt>query</tt> : T</dt>
@@ -2258,11 +2254,13 @@ This version of the operator has been available since version 1 of the 'com.micr
22582254
<dt><tt>value</tt> : T</dt>
22592255
<dd>Value with shape (batch_size, kv_sequence_length, kv_hidden_size)</dd>
22602256
<dt><tt>past_key</tt> (optional) : T</dt>
2261-
<dd>past state key with support for format BSNH or BNSH. When past_key uses same tensor as present_key(k-v cache), it is of length max_sequence_length... otherwise of length past_sequence_length.</dd>
2257+
<dd>past state key with support for format BNSH. When past_key uses same tensor as present_key(k-v cache), it is of length max_sequence_length... otherwise of length past_sequence_length.</dd>
22622258
<dt><tt>past_value</tt> (optional) : T</dt>
2263-
<dd>past state value with support for format BSNH or BNSH. When past_value uses same tensor as present_value(k-v cache), it is of length max_sequence_length... otherwise of length past_sequence_length.</dd>
2264-
<dt><tt>past_sequence_length</tt> (optional) : M</dt>
2265-
<dd>When buffered past_key and past_value is used (present_key uses same tensor as past_key), requiredto specify past_sequence_length (could be 0). Otherwise, past_sequence_length inferred from past_key.</dd>
2259+
<dd>past state value with support for format BNSH. When past_value uses same tensor as present_value(k-v cache), it is of length max_sequence_length... otherwise of length past_sequence_length.</dd>
2260+
<dt><tt>seqlens_k</tt> : M</dt>
2261+
<dd>1d Tensor of shape (batch_size). Indicates past sequence lengths for token generation case.</dd>
2262+
<dt><tt>total_sequence_length</tt> : M</dt>
2263+
<dd>Scalar tensor of total sequence length (past + new).</dd>
22662264
</dl>
22672265

22682266
#### Outputs
@@ -2271,18 +2269,18 @@ This version of the operator has been available since version 1 of the 'com.micr
22712269
<dt><tt>output</tt> : T</dt>
22722270
<dd>3D output tensor with shape (batch_size, sequence_length, hidden_size)</dd>
22732271
<dt><tt>present_key</tt> : T</dt>
2274-
<dd>present state key with support for format BSNH or BNSH. When past_key uses same tensor as present_key(k-v buffer), it is of length max_sequence_length... otherwise of length past_sequence_length +kv_sequence_length.</dd>
2272+
<dd>present state key with support for format BNSH. When past_key uses same tensor as present_key(k-v buffer), it is of length max_sequence_length... otherwise of length past_sequence_length +kv_sequence_length.</dd>
22752273
<dt><tt>present_value</tt> : T</dt>
2276-
<dd>present state value with support for format BSNH or BNSH. When past_value uses same tensor as present_value(k-v buffer), it is of length max_sequence_length... otherwise of length past_sequence_length +kv_sequence_length.</dd>
2274+
<dd>present state value with support for format BNSH. When past_value uses same tensor as present_value(k-v buffer), it is of length max_sequence_length... otherwise of length past_sequence_length +kv_sequence_length.</dd>
22772275
</dl>
22782276

22792277
#### Type Constraints
22802278

22812279
<dl>
22822280
<dt><tt>T</tt> : tensor(float16)</dt>
22832281
<dd>Constrain input and output to float tensors.</dd>
2284-
<dt><tt>M</tt> : tensor(int32), tensor(int64)</dt>
2285-
<dd>Constrain past sequence length to int tensor.</dd>
2282+
<dt><tt>M</tt> : tensor(int32)</dt>
2283+
<dd>Constrain mask to int tensor.</dd>
22862284
</dl>
22872285

22882286

@@ -4766,7 +4764,7 @@ This version of the operator has been available since version 1 of the 'com.micr
47664764

47674765
### <a name="com.microsoft.RotaryEmbedding"></a><a name="com.microsoft.rotaryembedding">**com.microsoft.RotaryEmbedding**</a>
47684766

4769-
RotaryEmbedding is the implementation of rotary positional embeddings (RoPE). The positions are represented as rotation matrices
4767+
RotaryEmbedding is the implementation of rotary positional embeddings (RoPE). The positions are represented as rotation matrices
47704768
that are multiplied to query and key before the inner product of query and key is taken.
47714769

47724770
#### Version

docs/OperatorKernels.md

+1-1
Original file line numberDiff line numberDiff line change
@@ -843,7 +843,7 @@ Do not modify directly.*
843843
|GreedySearch|*in* input_ids:**I**<br> *in* max_length:**I**<br> *in* min_length:**I**<br> *in* repetition_penalty:**T**<br> *in* vocab_mask:**I**<br> *in* prefix_vocab_mask:**I**<br> *in* attention_mask:**I**<br> *out* sequences:**I**|1+|**T** = tensor(float), tensor(float16)|
844844
|GridSample|*in* X:**T1**<br> *in* Grid:**T1**<br> *out* Y:**T2**|1+|**T1** = tensor(float)<br/> **T2** = tensor(float)|
845845
|GroupNorm|*in* X:**T**<br> *in* gamma:**M**<br> *in* beta:**M**<br> *out* Y:**T**|1+|**T** = tensor(float), tensor(float16)|
846-
|GroupQueryAttention|*in* query:**T**<br> *in* key:**T**<br> *in* value:**T**<br> *in* past_key:**T**<br> *in* past_value:**T**<br> *in* past_sequence_length:**M**<br> *out* output:**T**<br> *out* present_key:**T**<br> *out* present_value:**T**|1+|**M** = tensor(int32), tensor(int64)<br/> **T** = tensor(float16)|
846+
|GroupQueryAttention|*in* query:**T**<br> *in* key:**T**<br> *in* value:**T**<br> *in* past_key:**T**<br> *in* past_value:**T**<br> *in* seqlens_k:**M**<br> *in* total_sequence_length:**M**<br> *out* output:**T**<br> *out* present_key:**T**<br> *out* present_value:**T**|1+|**M** = tensor(int32)<br/> **T** = tensor(float16)|
847847
|Inverse|*in* X:**T**<br> *out* Y:**T**|1+|**T** = tensor(double), tensor(float), tensor(float16)|
848848
|Irfft|*in* X:**T**<br> *out* Y:**T**|1+|**T** = tensor(double), tensor(float), tensor(float16)|
849849
|LongformerAttention|*in* input:**T**<br> *in* weight:**T**<br> *in* bias:**T**<br> *in* mask:**T**<br> *in* global_weight:**T**<br> *in* global_bias:**T**<br> *in* global:**G**<br> *out* output:**T**|1+|**T** = tensor(float), tensor(float16)|

onnxruntime/contrib_ops/cpu/bert/attention_common.h

+6-5
Original file line numberDiff line numberDiff line change
@@ -86,18 +86,19 @@ struct PackedAttentionParameters {
8686
// Parameters deduced from node attributes and inputs/outputs.
8787
struct GroupQueryAttentionParameters {
8888
int batch_size;
89-
int sequence_length;
90-
int past_sequence_length; // actual sequence length of past_key and past_value
91-
int kv_sequence_length; // sequence length of key and value (or new_k and new_v when past is present)
92-
int present_sequence_length; // past_sequence_length + kv_sequence_length
93-
int max_sequence_length; // allocated length of past_key and past_value
89+
int sequence_length; // sequence length of input query, key, value
90+
int seqlen_past_kv_cache; // sequence length of past kv tensor
91+
int seqlen_present_kv_cache; // sequence length of present kv tensor
9492
int hidden_size;
9593
int num_heads;
9694
int head_size;
9795
int kv_hidden_size;
9896
int kv_num_heads;
9997
int num_splits; // number of splits for splitkv
10098
bool is_unidirectional; // causal
99+
bool kv_share_buffer;
100+
bool is_prompt; // determines if seqlens_k is past or kv sequence length tensor
101+
bool left_padding; // copies last token to last index if true
101102
float scale;
102103
AttentionQkvFormat qkv_format;
103104
AttentionQkvFormat past_kv_format;

onnxruntime/contrib_ops/cuda/bert/attention_impl.cu

+1
Original file line numberDiff line numberDiff line change
@@ -401,6 +401,7 @@ Status EfficientAttention(
401401
? data.scratch
402402
: nullptr;
403403
p.stream = stream;
404+
p.has_custom_right_padding = false;
404405
run_memory_efficient_attention(p);
405406
DUMP_TENSOR("efficient attention output", data.output,
406407
parameters.batch_size, parameters.sequence_length, parameters.num_heads, parameters.v_head_size);

onnxruntime/contrib_ops/cuda/bert/cutlass_fmha/fmha_launch_template.h

+132-1
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,133 @@ namespace onnxruntime {
1616
namespace contrib {
1717
namespace cuda {
1818

19+
template <typename AttentionKernel, int kQueriesPerBlock>
20+
struct RightPaddingBatchHook {
21+
using scalar_t = typename AttentionKernel::scalar_t;
22+
using accum_t = typename AttentionKernel::accum_t;
23+
using lse_scalar_t = typename AttentionKernel::lse_scalar_t;
24+
using output_t = typename AttentionKernel::output_t;
25+
using output_accum_t = typename AttentionKernel::output_accum_t;
26+
27+
static constexpr bool kSupportsDropout = AttentionKernel::kSupportsDropout;
28+
static constexpr bool kSupportsBias = AttentionKernel::kSupportsBias;
29+
static constexpr int kKeysPerBlock = AttentionKernel::kKeysPerBlock;
30+
static constexpr bool kIsAligned = AttentionKernel::kIsAligned;
31+
static constexpr bool kSingleValueIteration = AttentionKernel::kSingleValueIteration;
32+
static constexpr int32_t kAlignLSE = AttentionKernel::kAlignLSE; // block size of backward
33+
static constexpr bool kPreloadV = AttentionKernel::kPreloadV;
34+
static constexpr bool kKeepOutputInRF = AttentionKernel::kKeepOutputInRF;
35+
static constexpr bool kNeedsOutputAccumulatorBuffer = AttentionKernel::kNeedsOutputAccumulatorBuffer;
36+
37+
template <typename Params>
38+
static CUTLASS_DEVICE bool AdvanceToBlockForGQA(Params& p) {
39+
auto batch_id = blockIdx.z;
40+
auto head_id = blockIdx.y;
41+
auto query_start = blockIdx.x * kQueriesPerBlock;
42+
43+
auto lse_dim = ceil_div((int32_t)(p.num_queries), kAlignLSE) * kAlignLSE;
44+
45+
// Advance to current batch - in case of different sequence lengths
46+
if (p.seqlen_k_ptr) {
47+
p.num_keys = p.seqlen_k_ptr[batch_id];
48+
}
49+
50+
if (query_start >= p.num_queries) {
51+
return false;
52+
}
53+
54+
// Advance to the current batch / head / query_start
55+
p.query_ptr += batch_id * p.q_strideB + query_start * p.q_strideM + head_id * p.q_strideH;
56+
p.key_ptr += batch_id * p.k_strideB + head_id * p.k_strideH;
57+
p.value_ptr += batch_id * p.v_strideB + head_id * p.v_strideH;
58+
p.output_ptr += int64_t(batch_id * p.num_queries) * p.o_strideM + int64_t(query_start) * p.o_strideM + head_id * p.head_dim_value;
59+
60+
if (kSupportsBias && p.attn_bias_ptr != nullptr) {
61+
p.attn_bias_ptr += (batch_id * p.bias_strideB) + (head_id * p.bias_strideH);
62+
}
63+
if (p.output_accum_ptr != nullptr) {
64+
p.output_accum_ptr += int64_t(batch_id * p.num_queries) * (p.head_dim_value * p.num_heads) +
65+
int64_t(query_start) * (p.head_dim_value * p.num_heads) +
66+
head_id * p.head_dim_value;
67+
} else {
68+
// Accumulate directly in the destination buffer (eg for f32)
69+
p.output_accum_ptr = (accum_t*)(p.output_ptr);
70+
}
71+
72+
if (p.logsumexp_ptr != nullptr) {
73+
// lse[batch_id, head_id, query_start]
74+
p.logsumexp_ptr +=
75+
batch_id * lse_dim * p.num_heads + head_id * lse_dim + query_start;
76+
}
77+
78+
// Custom masking
79+
if (p.causal_diagonal_ptr) {
80+
p.causal_diagonal_offset = p.causal_diagonal_ptr[batch_id];
81+
}
82+
if (p.custom_mask_type == AttentionKernel::CausalFromBottomRight) {
83+
p.causal_diagonal_offset += p.num_keys - p.num_queries;
84+
}
85+
if (p.custom_mask_type == AttentionKernel::CausalFromTopLeft ||
86+
p.custom_mask_type == AttentionKernel::CausalFromBottomRight) {
87+
// the bottom row of the current block is query_start + kQueriesPerBlock
88+
// the last active key is then query_start + causal_diagonal_offset +
89+
// kQueriesPerBlock so num_keys is the min between actual num_keys and
90+
// this to avoid extra computations
91+
p.num_keys = cutlass::fast_min(
92+
int32_t(query_start + p.causal_diagonal_offset + kQueriesPerBlock),
93+
p.num_keys);
94+
}
95+
96+
p.num_queries -= query_start;
97+
p.num_batches = 0; // no longer used after
98+
99+
// If num_queries == 1, and there is only one key head we're wasting
100+
// 15/16th of tensor core compute In that case :
101+
// - we only launch kernels for head_id % kQueriesPerBlock == 0
102+
// - we iterate over heads instead of queries (strideM = strideH)
103+
if (p.num_queries == 1 && p.k_strideH == 0 && p.v_strideH == 0) {
104+
if (head_id % kQueriesPerBlock != 0)
105+
return false;
106+
p.q_strideM = p.q_strideH;
107+
p.num_queries = p.num_heads;
108+
p.num_heads = 1; // unused but here for intent
109+
// remove causal since n_query = 1
110+
// otherwise, offset would change with head !
111+
p.custom_mask_type = AttentionKernel::NoCustomMask;
112+
p.o_strideM = p.head_dim_value;
113+
}
114+
115+
// Make sure the compiler knows these variables are the same on all
116+
// the threads of the warp.
117+
p.query_ptr = warp_uniform(p.query_ptr);
118+
p.key_ptr = warp_uniform(p.key_ptr);
119+
p.value_ptr = warp_uniform(p.value_ptr);
120+
if (kSupportsBias) {
121+
p.attn_bias_ptr = warp_uniform(p.attn_bias_ptr);
122+
}
123+
p.output_ptr = warp_uniform(p.output_ptr);
124+
p.output_accum_ptr = warp_uniform(p.output_accum_ptr);
125+
p.logsumexp_ptr = warp_uniform(p.logsumexp_ptr);
126+
p.num_queries = warp_uniform(p.num_queries);
127+
p.num_keys = warp_uniform(p.num_keys);
128+
p.num_heads = warp_uniform(p.num_heads);
129+
p.head_dim = warp_uniform(p.head_dim);
130+
p.head_dim_value = warp_uniform(p.head_dim_value);
131+
p.o_strideM = warp_uniform(p.o_strideM);
132+
p.custom_mask_type = warp_uniform(p.custom_mask_type);
133+
return true;
134+
}
135+
};
136+
137+
template <typename AK, int kQueriesPerBlock>
138+
__global__ void __launch_bounds__(AK::kNumThreads, AK::kMinBlocksPerSm)
139+
attention_kernel_batched_impl_right_padding(typename AK::Params p) {
140+
if (!RightPaddingBatchHook<AK, kQueriesPerBlock>::AdvanceToBlockForGQA(p)) {
141+
return;
142+
}
143+
AK::attention_kernel(p);
144+
}
145+
19146
template <typename T, typename ArchTag, bool is_aligned, int queries_per_block, int keys_per_block, bool single_value_iteration>
20147
void LaunchCutlassFmha(const MemoryEfficientAttentionParams& params) {
21148
using Attention = AttentionKernel<T, ArchTag, is_aligned, queries_per_block, keys_per_block, single_value_iteration>;
@@ -92,7 +219,11 @@ void LaunchCutlassFmha(const MemoryEfficientAttentionParams& params) {
92219
}
93220
}
94221

95-
constexpr auto kernel_fn = attention_kernel_batched_impl<Attention>;
222+
auto kernel_fn = attention_kernel_batched_impl<Attention>;
223+
if (params.has_custom_right_padding) {
224+
kernel_fn = attention_kernel_batched_impl_right_padding<Attention, queries_per_block>;
225+
}
226+
96227
int smem_bytes = sizeof(typename Attention::SharedStorage);
97228
if (smem_bytes > 0xc000) {
98229
ORT_ENFORCE(params.sm >= 70, "This kernel requires too much shared memory on this machine!");

onnxruntime/contrib_ops/cuda/bert/cutlass_fmha/memory_efficient_attention.h

+2
Original file line numberDiff line numberDiff line change
@@ -43,6 +43,8 @@ struct MemoryEfficientAttentionParams {
4343
static bool need_workspace(size_t v_head_size, bool is_float) {
4444
return (v_head_size > 128 && !is_float);
4545
}
46+
47+
bool has_custom_right_padding = false;
4648
};
4749

4850
void run_memory_efficient_attention(const MemoryEfficientAttentionParams& params);

0 commit comments

Comments
 (0)