You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
Copy file name to clipboardExpand all lines: docs/ContribOperators.md
+12-14
Original file line number
Diff line number
Diff line change
@@ -2236,19 +2236,15 @@ This version of the operator has been available since version 1 of the 'com.micr
2236
2236
#### Attributes
2237
2237
2238
2238
<dl>
2239
-
<dt><tt>is_past_bsnh</tt> : int</dt>
2240
-
<dd>Whether past kv uses BSNH, otherwise BNSH. Default value is 1 (BSNH).</dd>
2241
2239
<dt><tt>kv_num_heads</tt> : int (required)</dt>
2242
2240
<dd>Number of attention heads for k and v</dd>
2243
2241
<dt><tt>num_heads</tt> : int (required)</dt>
2244
2242
<dd>Number of attention heads for q</dd>
2245
2243
<dt><tt>scale</tt> : float</dt>
2246
2244
<dd>Custom scale will be used if specified. Default value is 1/sqrt(head_size)</dd>
2247
-
<dt><tt>unidirectional</tt> : int</dt>
2248
-
<dd>Whether every token can only attend to previous tokens. Default value is 1.</dd>
2249
2245
</dl>
2250
2246
2251
-
#### Inputs (3 - 6)
2247
+
#### Inputs
2252
2248
2253
2249
<dl>
2254
2250
<dt><tt>query</tt> : T</dt>
@@ -2258,11 +2254,13 @@ This version of the operator has been available since version 1 of the 'com.micr
2258
2254
<dt><tt>value</tt> : T</dt>
2259
2255
<dd>Value with shape (batch_size, kv_sequence_length, kv_hidden_size)</dd>
2260
2256
<dt><tt>past_key</tt> (optional) : T</dt>
2261
-
<dd>past state key with support for format BSNH or BNSH. When past_key uses same tensor as present_key(k-v cache), it is of length max_sequence_length... otherwise of length past_sequence_length.</dd>
2257
+
<dd>past state key with support for format BNSH. When past_key uses same tensor as present_key(k-v cache), it is of length max_sequence_length... otherwise of length past_sequence_length.</dd>
2262
2258
<dt><tt>past_value</tt> (optional) : T</dt>
2263
-
<dd>past state value with support for format BSNH or BNSH. When past_value uses same tensor as present_value(k-v cache), it is of length max_sequence_length... otherwise of length past_sequence_length.</dd>
<dd>When buffered past_key and past_value is used (present_key uses same tensor as past_key), requiredto specify past_sequence_length (could be 0). Otherwise, past_sequence_length inferred from past_key.</dd>
2259
+
<dd>past state value with support for format BNSH. When past_value uses same tensor as present_value(k-v cache), it is of length max_sequence_length... otherwise of length past_sequence_length.</dd>
2260
+
<dt><tt>seqlens_k</tt> : M</dt>
2261
+
<dd>1d Tensor of shape (batch_size). Indicates past sequence lengths for token generation case.</dd>
2262
+
<dt><tt>total_sequence_length</tt> : M</dt>
2263
+
<dd>Scalar tensor of total sequence length (past + new).</dd>
2266
2264
</dl>
2267
2265
2268
2266
#### Outputs
@@ -2271,18 +2269,18 @@ This version of the operator has been available since version 1 of the 'com.micr
2271
2269
<dt><tt>output</tt> : T</dt>
2272
2270
<dd>3D output tensor with shape (batch_size, sequence_length, hidden_size)</dd>
2273
2271
<dt><tt>present_key</tt> : T</dt>
2274
-
<dd>present state key with support for format BSNH or BNSH. When past_key uses same tensor as present_key(k-v buffer), it is of length max_sequence_length... otherwise of length past_sequence_length +kv_sequence_length.</dd>
2272
+
<dd>present state key with support for format BNSH. When past_key uses same tensor as present_key(k-v buffer), it is of length max_sequence_length... otherwise of length past_sequence_length +kv_sequence_length.</dd>
2275
2273
<dt><tt>present_value</tt> : T</dt>
2276
-
<dd>present state value with support for format BSNH or BNSH. When past_value uses same tensor as present_value(k-v buffer), it is of length max_sequence_length... otherwise of length past_sequence_length +kv_sequence_length.</dd>
2274
+
<dd>present state value with support for format BNSH. When past_value uses same tensor as present_value(k-v buffer), it is of length max_sequence_length... otherwise of length past_sequence_length +kv_sequence_length.</dd>
2277
2275
</dl>
2278
2276
2279
2277
#### Type Constraints
2280
2278
2281
2279
<dl>
2282
2280
<dt><tt>T</tt> : tensor(float16)</dt>
2283
2281
<dd>Constrain input and output to float tensors.</dd>
0 commit comments