You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
When do_rotary is True, cos_cache and sin_cache are required. Note that the maximum sequence length supported by cos
5575
+
or sin cache can be different from the maximum sequence length used by kv cache.
5559
5576
5560
5577
Only supports unidirectional attention with cache of past key and value in linear buffers.
5578
+
5561
5579
For performance, past_key and present_key share same memory buffer, and past_value and present_value too.
5562
5580
5563
5581
#### Version
@@ -5581,7 +5599,7 @@ This version of the operator has been available since version 1 of the 'com.micr
5581
5599
<dd>Number of tokens per sparse block. Choices: 16, 32, 64, 128</dd>
5582
5600
</dl>
5583
5601
5584
-
#### Inputs (8 - 10)
5602
+
#### Inputs (9 - 11)
5585
5603
5586
5604
<dl>
5587
5605
<dt><tt>query</tt> : T</dt>
@@ -5590,20 +5608,22 @@ This version of the operator has been available since version 1 of the 'com.micr
5590
5608
<dd>Key with shape (batch_size, sequence_length, kv_num_heads * head_size)</dd>
5591
5609
<dt><tt>value</tt> (optional) : T</dt>
5592
5610
<dd>Value with shape (batch_size, sequence_length, kv_num_heads * head_size)</dd>
5593
-
<dt><tt>past_key</tt> (optional) : T</dt>
5594
-
<dd>Key cache with shape (batch_size, kv_num_heads, max_sequence_length, head_size)</dd>
5595
-
<dt><tt>past_value</tt> (optional) : T</dt>
5596
-
<dd>Value cache with shape (batch_size, kv_num_heads, max_sequence_length, head_size)</dd>
5597
-
<dt><tt>block_mask</tt> : M</dt>
5598
-
<dd>block mask. 1 indicates attention and 0 no attention. Its shape is (num_layout, max_blocks, max_blocks), where num_heads is divisible by num_layout, and max_blocks is max_sequence_length / sparse_block_size.</dd>
5611
+
<dt><tt>past_key</tt> : T</dt>
5612
+
<dd>Key cache with shape (batch_size, kv_num_heads, max_cache_sequence_length, head_size)</dd>
5613
+
<dt><tt>past_value</tt> : T</dt>
5614
+
<dd>Value cache with shape (batch_size, kv_num_heads, max_cache_sequence_length, head_size)</dd>
5615
+
<dt><tt>block_row_indices</tt> : M</dt>
5616
+
<dd>The row indices of CSR format of block mask with shape (num_layout, max_blocks + 1).The num_heads is divisible by num_layout, and max_blocks is max_sequence_length / sparse_block_size.</dd>
5617
+
<dt><tt>block_col_indices</tt> : M</dt>
5618
+
<dd>The col indices of CSR format of block mask with shape (num_layout, max_nnz_blocks).The max_nnz_blocks is the maximum number of non-zeros per layout in block mask.</dd>
5599
5619
<dt><tt>total_sequence_length</tt> : M</dt>
5600
5620
<dd>Scalar tensor of maximum total sequence length (past_sequence_length + sequence_length) among keys.</dd>
5601
5621
<dt><tt>key_total_sequence_lengths</tt> : M</dt>
5602
5622
<dd>1D tensor with shape (batch_size) where each value is total sequence length of key excluding paddings.</dd>
5603
5623
<dt><tt>cos_cache</tt> (optional) : T</dt>
5604
-
<dd>Cos cache of rotary with shape (max_sequence_length, head_size / 2).</dd>
5624
+
<dd>Cos cache of rotary with shape (max_rotary_sequence_length, head_size / 2).</dd>
5605
5625
<dt><tt>sin_cache</tt> (optional) : T</dt>
5606
-
<dd>Sin cache of rotary with shape (max_sequence_length, head_size / 2).</dd>
5626
+
<dd>Sin cache of rotary with shape (max_rotary_sequence_length, head_size / 2).</dd>
5607
5627
</dl>
5608
5628
5609
5629
#### Outputs
@@ -5612,9 +5632,9 @@ This version of the operator has been available since version 1 of the 'com.micr
5612
5632
<dt><tt>output</tt> : T</dt>
5613
5633
<dd>3D output tensor with shape (batch_size, sequence_length, num_heads * head_size)</dd>
5614
5634
<dt><tt>present_key</tt> : T</dt>
5615
-
<dd>Updated key cache with shape (batch_size, kv_num_heads, max_sequence_length, head_size).</dd>
5635
+
<dd>Updated key cache with shape (batch_size, kv_num_heads, max_cache_sequence_length, head_size).</dd>
5616
5636
<dt><tt>present_value</tt> : T</dt>
5617
-
<dd>Updated value cache with shape (batch_size, kv_num_heads, max_sequence_length, head_size).</dd>
5637
+
<dd>Updated value cache with shape (batch_size, kv_num_heads, max_cache_sequence_length, head_size).</dd>
0 commit comments