Skip to content

Commit 4573740

Browse files
yihonglyuYUNQIUGUOrachguorachguotianleiwu
authored
[ORT 1.18.0 Release] Cherry pick 3rd/Final round (#20677)
Co-authored-by: Rachel Guo <[email protected]> Co-authored-by: rachguo <[email protected]> Co-authored-by: rachguo <[email protected]> Co-authored-by: Tianlei Wu <[email protected]> Co-authored-by: George Wu <[email protected]> Co-authored-by: Edward Chen <[email protected]> Co-authored-by: Jian Chen <[email protected]>
1 parent ed349b9 commit 4573740

28 files changed

+892
-1290
lines changed

csharp/OnnxRuntime.CSharp.proj

+4
Original file line numberDiff line numberDiff line change
@@ -50,8 +50,12 @@ CMake creates a target to this project
5050
<PropertyGroup>
5151
<!-- If we create multiple nuget packages in one job, major package and dependent packages version should be the same-->
5252
<!-- CurrentDate and CurrentTime are only used for dev packages-->
53+
<CurrentDate Condition=" '$(BuildDate)'!='' ">$(BuildDate)</CurrentDate>
54+
<CurrentTime Condition=" '$(BuildTime)'!='' ">$(BuildTime)</CurrentTime>
5355
<CurrentDate Condition="'$(CurrentDate)'==''">$([System.DateTime]::UtcNow.ToString(yyyyMMdd))</CurrentDate>
5456
<CurrentTime Condition="'$(CurrentTime)'==''">$([System.DateTime]::UtcNow.ToString(hhmm))</CurrentTime>
57+
58+
5559
</PropertyGroup>
5660

5761
<Import Project="$(MSBuildToolsPath)\Microsoft.CSharp.targets" />

docs/ContribOperators.md

+33-13
Original file line numberDiff line numberDiff line change
@@ -5553,11 +5553,29 @@ This version of the operator has been available since version 1 of the 'com.micr
55535553
When number of sparse layout is 1, all heads have same sparse layout. Otherwise, different layouts are used cyclically.
55545554
For example, given 4 layouts (S0, S1, S2, S3), 8 heads will have layouts like (S0, S1, S2, S3, S0, S1, S2, S3).
55555555

5556-
Padding shall be on the right side.
5556+
The block_row_indices and block_col_indices are the CSR representation of block mask. The block_col_indices might contain
5557+
paddings at the right side when different layout has different number of non-zeros in block mask.
55575558

5558-
When do_rotary is True, cos_cache and sin_cache are required.
5559+
An example of block mask with 2 layouts where each layout is 4 x 4 blocks:
5560+
[[[1, 0, 0, 0],
5561+
[1, 1, 0, 0],
5562+
[0, 1, 1, 0],
5563+
[0, 1, 1, 1]],
5564+
5565+
[[1, 0, 0, 0],
5566+
[1, 1, 0, 0],
5567+
[1, 1, 1, 0],
5568+
[1, 0, 1, 1]]]
5569+
5570+
The corresponding CSR format:
5571+
block_col_indices = [[0, 0, 1, 1, 2, 1, 2, 3, -1], [0, 0, 1, 0, 1, 2, 0, 2, 3]]
5572+
block_row_indices = [[0, 1, 3, 5, 8], [0, 1, 3, 6, 9]]
5573+
5574+
When do_rotary is True, cos_cache and sin_cache are required. Note that the maximum sequence length supported by cos
5575+
or sin cache can be different from the maximum sequence length used by kv cache.
55595576

55605577
Only supports unidirectional attention with cache of past key and value in linear buffers.
5578+
55615579
For performance, past_key and present_key share same memory buffer, and past_value and present_value too.
55625580

55635581
#### Version
@@ -5581,7 +5599,7 @@ This version of the operator has been available since version 1 of the 'com.micr
55815599
<dd>Number of tokens per sparse block. Choices: 16, 32, 64, 128</dd>
55825600
</dl>
55835601

5584-
#### Inputs (8 - 10)
5602+
#### Inputs (9 - 11)
55855603

55865604
<dl>
55875605
<dt><tt>query</tt> : T</dt>
@@ -5590,20 +5608,22 @@ This version of the operator has been available since version 1 of the 'com.micr
55905608
<dd>Key with shape (batch_size, sequence_length, kv_num_heads * head_size)</dd>
55915609
<dt><tt>value</tt> (optional) : T</dt>
55925610
<dd>Value with shape (batch_size, sequence_length, kv_num_heads * head_size)</dd>
5593-
<dt><tt>past_key</tt> (optional) : T</dt>
5594-
<dd>Key cache with shape (batch_size, kv_num_heads, max_sequence_length, head_size)</dd>
5595-
<dt><tt>past_value</tt> (optional) : T</dt>
5596-
<dd>Value cache with shape (batch_size, kv_num_heads, max_sequence_length, head_size)</dd>
5597-
<dt><tt>block_mask</tt> : M</dt>
5598-
<dd>block mask. 1 indicates attention and 0 no attention. Its shape is (num_layout, max_blocks, max_blocks), where num_heads is divisible by num_layout, and max_blocks is max_sequence_length / sparse_block_size.</dd>
5611+
<dt><tt>past_key</tt> : T</dt>
5612+
<dd>Key cache with shape (batch_size, kv_num_heads, max_cache_sequence_length, head_size)</dd>
5613+
<dt><tt>past_value</tt> : T</dt>
5614+
<dd>Value cache with shape (batch_size, kv_num_heads, max_cache_sequence_length, head_size)</dd>
5615+
<dt><tt>block_row_indices</tt> : M</dt>
5616+
<dd>The row indices of CSR format of block mask with shape (num_layout, max_blocks + 1).The num_heads is divisible by num_layout, and max_blocks is max_sequence_length / sparse_block_size.</dd>
5617+
<dt><tt>block_col_indices</tt> : M</dt>
5618+
<dd>The col indices of CSR format of block mask with shape (num_layout, max_nnz_blocks).The max_nnz_blocks is the maximum number of non-zeros per layout in block mask.</dd>
55995619
<dt><tt>total_sequence_length</tt> : M</dt>
56005620
<dd>Scalar tensor of maximum total sequence length (past_sequence_length + sequence_length) among keys.</dd>
56015621
<dt><tt>key_total_sequence_lengths</tt> : M</dt>
56025622
<dd>1D tensor with shape (batch_size) where each value is total sequence length of key excluding paddings.</dd>
56035623
<dt><tt>cos_cache</tt> (optional) : T</dt>
5604-
<dd>Cos cache of rotary with shape (max_sequence_length, head_size / 2).</dd>
5624+
<dd>Cos cache of rotary with shape (max_rotary_sequence_length, head_size / 2).</dd>
56055625
<dt><tt>sin_cache</tt> (optional) : T</dt>
5606-
<dd>Sin cache of rotary with shape (max_sequence_length, head_size / 2).</dd>
5626+
<dd>Sin cache of rotary with shape (max_rotary_sequence_length, head_size / 2).</dd>
56075627
</dl>
56085628

56095629
#### Outputs
@@ -5612,9 +5632,9 @@ This version of the operator has been available since version 1 of the 'com.micr
56125632
<dt><tt>output</tt> : T</dt>
56135633
<dd>3D output tensor with shape (batch_size, sequence_length, num_heads * head_size)</dd>
56145634
<dt><tt>present_key</tt> : T</dt>
5615-
<dd>Updated key cache with shape (batch_size, kv_num_heads, max_sequence_length, head_size).</dd>
5635+
<dd>Updated key cache with shape (batch_size, kv_num_heads, max_cache_sequence_length, head_size).</dd>
56165636
<dt><tt>present_value</tt> : T</dt>
5617-
<dd>Updated value cache with shape (batch_size, kv_num_heads, max_sequence_length, head_size).</dd>
5637+
<dd>Updated value cache with shape (batch_size, kv_num_heads, max_cache_sequence_length, head_size).</dd>
56185638
</dl>
56195639

56205640
#### Type Constraints

docs/OperatorKernels.md

+1-1
Original file line numberDiff line numberDiff line change
@@ -906,7 +906,7 @@ Do not modify directly.*
906906
|SkipGroupNorm|*in* X:**T**<br> *in* gamma:**M**<br> *in* beta:**M**<br> *in* skip:**T**<br> *in* bias:**T**<br> *out* Y:**T**<br> *out* S:**T**|1+|**T** = tensor(float), tensor(float16)|
907907
|SkipLayerNormalization|*in* input:**T**<br> *in* skip:**T**<br> *in* gamma:**T**<br> *in* beta:**T**<br> *in* bias:**T**<br> *out* output:**T**<br> *out* mean:**U**<br> *out* inv_std_var:**U**<br> *out* input_skip_bias_sum:**T**|1+|**T** = tensor(float), tensor(float16)|
908908
|SkipSimplifiedLayerNormalization|*in* input:**T**<br> *in* skip:**T**<br> *in* gamma:**T**<br> *in* bias:**T**<br> *out* output:**T**<br> *out* mean:**U**<br> *out* inv_std_var:**U**<br> *out* input_skip_bias_sum:**T**|1+|**T** = tensor(float), tensor(float16)|
909-
|SparseAttention|*in* query:**T**<br> *in* key:**T**<br> *in* value:**T**<br> *in* past_key:**T**<br> *in* past_value:**T**<br> *in* block_mask:**M**<br> *in* total_sequence_length:**M**<br> *in* key_total_sequence_lengths:**M**<br> *in* cos_cache:**T**<br> *in* sin_cache:**T**<br> *out* output:**T**<br> *out* present_key:**T**<br> *out* present_value:**T**|1+|**M** = tensor(int32)<br/> **T** = tensor(bfloat16), tensor(float16)|
909+
|SparseAttention|*in* query:**T**<br> *in* key:**T**<br> *in* value:**T**<br> *in* past_key:**T**<br> *in* past_value:**T**<br> *in* block_row_indices:**M**<br> *in* block_col_indices:**M**<br> *in* total_sequence_length:**M**<br> *in* key_total_sequence_lengths:**M**<br> *in* cos_cache:**T**<br> *in* sin_cache:**T**<br> *out* output:**T**<br> *out* present_key:**T**<br> *out* present_value:**T**|1+|**M** = tensor(int32)<br/> **T** = tensor(bfloat16), tensor(float16)|
910910
|TransposeMatMul|*in* A:**T**<br> *in* B:**T**<br> *out* Y:**T**|1+|**T** = tensor(bfloat16), tensor(double), tensor(float), tensor(float16)|
911911
|Trilu|*in* X:**T**<br> *in* k:**tensor(int64)**<br> *out* Y:**T**|1+|**T** = tensor(bfloat16), tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)|
912912
|UnfoldTensor|*in* input:**T**<br> *out* output:**T**|1+|**T** = tensor(bfloat16), tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(string), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)|

onnxruntime/contrib_ops/cpu/bert/attention_common.h

+6-2
Original file line numberDiff line numberDiff line change
@@ -126,11 +126,15 @@ struct SparseAttentionParameters {
126126
bool rotary_interleaved; // whether to use interleaved rotary embedding
127127
int rotary_dim; // rotary embedding dimension
128128
int sparse_block_size; // block size for sparse attention
129-
int num_sparse_layout; // number of sparse layout, or the first dimension of block_mask
129+
int num_sparse_layout; // number of sparse layout
130+
int stride_col_indices; // shape of block_col_indices is [num_sparse_layout, stride_col_indices]
131+
int stride_row_indices; // shape of block_row_indices is [num_sparse_layout, stride_row_indices]
130132
float scale; // scaling factor applied prior to softmax
131133
bool is_packed_qkv; // whether qkv is packed
132134
int total_sequence_length; // maximum total sequence length (past_sequence_length + sequence_length) among keys
133-
int max_sequence_length; // max sequence length allowed
135+
int max_sequence_length; // max sequence length for sparse layout
136+
int max_rotary_sequence_length; // max sequence length for rotary cos/sin cache
137+
int max_cache_sequence_length; // max sequence length for kv cache buffer
134138
bool past_present_share_buffer; // whether past_key and present_key share buffer, so is past_value and present_value
135139
};
136140

onnxruntime/contrib_ops/cuda/sparse/sparse_attention.cc

+22-41
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,6 @@
44
#include "contrib_ops/cuda/sparse/sparse_attention_impl.h"
55
#include "contrib_ops/cuda/sparse/sparse_attention.h"
66
#include "contrib_ops/cuda/sparse/sparse_attention_helper.h"
7-
#include "contrib_ops/cuda/sparse/block_mask.h"
87
#include "contrib_ops/cuda/sparse/sparse_attention_v1/sparse_attention_v1_api.h"
98
#include "contrib_ops/cuda/sparse/sparse_attention_v2/sparse_attention_v2_api.h"
109
#include "core/platform/env_var_utils.h"
@@ -26,7 +25,7 @@ namespace cuda {
2625
.TypeConstraint("M", DataTypeImpl::GetTensorType<int32_t>()) \
2726
.MayInplace(3, 1) \
2827
.MayInplace(4, 2) \
29-
.InputMemoryType(OrtMemTypeCPUInput, 6), \
28+
.InputMemoryType(OrtMemTypeCPUInput, 7), \
3029
SparseAttention<T>);
3130

3231
REGISTER_KERNEL_TYPED(MLFloat16)
@@ -77,15 +76,16 @@ Status SparseAttention<T>::ComputeInternal(OpKernelContext* context) const {
7776
const Tensor* value = context->Input<Tensor>(2);
7877
const Tensor* past_key = context->Input<Tensor>(3);
7978
const Tensor* past_value = context->Input<Tensor>(4);
80-
const Tensor* block_mask = context->Input<Tensor>(5);
81-
const Tensor* total_seq_len = context->Input<Tensor>(6);
82-
const Tensor* seqlens_k_total = context->Input<Tensor>(7);
83-
const Tensor* cos_cache = context->Input<Tensor>(8);
84-
const Tensor* sin_cache = context->Input<Tensor>(9);
79+
const Tensor* block_row_indices = context->Input<Tensor>(5);
80+
const Tensor* block_col_indices = context->Input<Tensor>(6);
81+
const Tensor* total_seq_len = context->Input<Tensor>(7);
82+
const Tensor* seqlens_k_total = context->Input<Tensor>(8);
83+
const Tensor* cos_cache = context->Input<Tensor>(9);
84+
const Tensor* sin_cache = context->Input<Tensor>(10);
8585

8686
SparseAttentionParameters parameters;
8787

88-
// Parameters from node attribute
88+
// Parameters from node attribute shall be set before calling CheckInputs
8989
parameters.sparse_block_size = sparse_block_size_;
9090
parameters.num_heads = num_heads_;
9191
parameters.kv_num_heads = kv_num_heads_;
@@ -101,7 +101,8 @@ Status SparseAttention<T>::ComputeInternal(OpKernelContext* context) const {
101101
past_value,
102102
cos_cache,
103103
sin_cache,
104-
block_mask,
104+
block_row_indices,
105+
block_col_indices,
105106
seqlens_k_total,
106107
total_seq_len));
107108
// Some limitations of CUDA kernels
@@ -177,7 +178,7 @@ Status SparseAttention<T>::ComputeInternal(OpKernelContext* context) const {
177178
Tensor* output = context->Output(0, output_shape);
178179

179180
std::vector<int64_t> present_dims = {
180-
parameters.batch_size, parameters.kv_num_heads, parameters.max_sequence_length, parameters.head_size};
181+
parameters.batch_size, parameters.kv_num_heads, parameters.max_cache_sequence_length, parameters.head_size};
181182
TensorShape present_shape(present_dims);
182183
Tensor* present_key = context->Output(1, present_shape);
183184
Tensor* present_value = context->Output(2, present_shape);
@@ -188,13 +189,12 @@ Status SparseAttention<T>::ComputeInternal(OpKernelContext* context) const {
188189
data.query = reinterpret_cast<const CudaT*>(query->Data<T>());
189190
data.key = key == nullptr ? nullptr : reinterpret_cast<const CudaT*>(key->Data<T>());
190191
data.value = value == nullptr ? nullptr : reinterpret_cast<const CudaT*>(value->Data<T>());
191-
data.past_key = (nullptr == past_key) ? nullptr : reinterpret_cast<const CudaT*>(past_key->Data<T>());
192-
data.past_value = (nullptr == past_value) ? nullptr : reinterpret_cast<const CudaT*>(past_value->Data<T>());
193-
data.block_mask = block_mask->Data<int32_t>();
192+
data.past_key = reinterpret_cast<const CudaT*>(past_key->Data<T>());
193+
data.past_value = reinterpret_cast<const CudaT*>(past_value->Data<T>());
194194
data.seqlens_k_total = (nullptr == seqlens_k_total) ? nullptr : seqlens_k_total->Data<int32_t>();
195195
data.output = reinterpret_cast<CudaT*>(output->MutableData<T>());
196-
data.present_key = (nullptr == present_key) ? nullptr : reinterpret_cast<CudaT*>(present_key->MutableData<T>());
197-
data.present_value = (nullptr == present_value) ? nullptr : reinterpret_cast<CudaT*>(present_value->MutableData<T>());
196+
data.present_key = reinterpret_cast<CudaT*>(present_key->MutableData<T>());
197+
data.present_value = reinterpret_cast<CudaT*>(present_value->MutableData<T>());
198198

199199
// Check past and present share buffer.
200200
parameters.past_present_share_buffer = (data.past_key != nullptr && data.past_key == data.present_key);
@@ -214,45 +214,26 @@ Status SparseAttention<T>::ComputeInternal(OpKernelContext* context) const {
214214
// Currently, we use same block size in kernel.
215215
// TODO: support kernel block size that is smaller than sparse_block_size in tunable (need expand block mask).
216216
data.kernel_layout.block_size = parameters.sparse_block_size;
217-
data.kernel_layout.mask = data.block_mask;
218217
data.kernel_layout.num_layout = parameters.num_sparse_layout;
219-
data.kernel_layout.num_cols = parameters.max_sequence_length / data.kernel_layout.block_size;
220-
data.kernel_layout.num_rows = parameters.max_sequence_length / data.kernel_layout.block_size;
221-
222-
// Allocate buffer for CSR col and row indices.
223-
onnxruntime::Stream* stream = context->GetComputeStream();
224-
int dense_blocks = data.kernel_layout.num_layout * data.kernel_layout.num_cols * data.kernel_layout.num_rows;
225-
auto csr_col_indices_buffer = GetScratchBuffer<int>(static_cast<size_t>(dense_blocks), stream);
226-
auto csr_row_indices_buffer = GetScratchBuffer<int>(
227-
static_cast<size_t>(data.kernel_layout.num_layout * (data.kernel_layout.num_rows + 1)), stream);
228-
229-
data.kernel_layout.csr_col_indices = reinterpret_cast<const int*>(csr_col_indices_buffer.get());
230-
data.kernel_layout.csr_row_indices = reinterpret_cast<const int*>(csr_row_indices_buffer.get());
231-
232-
ConvertMaskToCSR(cuda_stream,
233-
data.kernel_layout.mask,
234-
data.kernel_layout.num_layout,
235-
data.kernel_layout.num_rows,
236-
data.kernel_layout.num_cols,
237-
csr_row_indices_buffer.get(),
238-
csr_col_indices_buffer.get(),
239-
device_prop.maxThreadsPerBlock);
218+
data.kernel_layout.csr_col_indices = block_col_indices->Data<int32_t>();
219+
data.kernel_layout.csr_row_indices = block_row_indices->Data<int32_t>();
240220

241221
size_t rotary_buffer_bytes = 0;
242222
if (do_rotary_) {
243223
rotary_buffer_bytes = 2 * sizeof(T) * parameters.batch_size * parameters.num_heads *
244224
parameters.sequence_length * parameters.head_size;
245225
rotary_buffer_bytes += sizeof(int64_t) * parameters.batch_size * parameters.sequence_length;
246226
}
247-
auto rotary_buffer = GetScratchBuffer<void>(rotary_buffer_bytes, context->GetComputeStream());
227+
onnxruntime::Stream* stream = context->GetComputeStream();
228+
auto rotary_buffer = GetScratchBuffer<void>(rotary_buffer_bytes, stream);
248229
data.rotary_buffer = reinterpret_cast<CudaT*>(rotary_buffer.get());
249230

250231
size_t transposed_q_bytes = 0;
251232
if (!parameters.is_packed_qkv) {
252233
transposed_q_bytes = parameters.batch_size * parameters.sequence_length *
253234
parameters.num_heads * parameters.head_size * sizeof(T);
254235
}
255-
auto transposed_q_buffer = GetScratchBuffer<void>(transposed_q_bytes, context->GetComputeStream());
236+
auto transposed_q_buffer = GetScratchBuffer<void>(transposed_q_bytes, stream);
256237
if (transposed_q_buffer) {
257238
data.transposed_q_buffer = reinterpret_cast<CudaT*>(transposed_q_buffer.get());
258239
}
@@ -263,7 +244,7 @@ Status SparseAttention<T>::ComputeInternal(OpKernelContext* context) const {
263244
(parameters.num_heads + 2 * parameters.kv_num_heads) *
264245
parameters.head_size * sizeof(T));
265246
}
266-
auto unpacked_qkv_buffer = GetScratchBuffer<void>(unpacked_qkv_bytes, context->GetComputeStream());
247+
auto unpacked_qkv_buffer = GetScratchBuffer<void>(unpacked_qkv_bytes, stream);
267248
if (unpacked_qkv_buffer) {
268249
data.unpacked_qkv_buffer = reinterpret_cast<CudaT*>(unpacked_qkv_buffer.get());
269250
}
@@ -327,7 +308,7 @@ Status SparseAttention<T>::ComputeInternal(OpKernelContext* context) const {
327308
}
328309
}
329310

330-
v2_kernel_buffer = GetScratchBuffer<int>(v2_kernel_buffer_size, context->GetComputeStream());
311+
v2_kernel_buffer = GetScratchBuffer<int>(v2_kernel_buffer_size, stream);
331312
CUDA_RETURN_IF_ERROR(cudaMemcpyAsync(v2_kernel_buffer.get(), v2_kernel_inputs_pinned,
332313
sizeof(int32_t) * v2_kernel_buffer_size,
333314
cudaMemcpyHostToDevice, cuda_stream));

0 commit comments

Comments
 (0)