@@ -82,6 +82,51 @@ bool validate_flash_attention_args(
82
82
return true ;
83
83
}
84
84
85
+ bool validate_cache_quant_params_args (
86
+ const Tensor& t,
87
+ const Tensor& t_zero_points,
88
+ const Tensor& t_scales) {
89
+ ET_CHECK_OR_RETURN_FALSE (
90
+ t.dim () == t_scales.dim (),
91
+ " Quantized tensor and scales must have the same number of dimensions" );
92
+ ET_CHECK_OR_RETURN_FALSE (
93
+ t.dim () == t_zero_points.dim (),
94
+ " Quantized tensor and scales must have the same number of dimensions" );
95
+
96
+ ET_CHECK_OR_RETURN_FALSE (
97
+ (t.scalar_type () == ScalarType::Char), " Tensor must be of int8_t type" );
98
+
99
+ ET_CHECK_OR_RETURN_FALSE (
100
+ (t_scales.scalar_type () == ScalarType::Float),
101
+ " Scales tensor must be of float type" );
102
+
103
+ ET_CHECK_OR_RETURN_FALSE (
104
+ (t_zero_points.scalar_type () == ScalarType::Char),
105
+ " Zero points tensor must be of int8_t type" );
106
+
107
+ // Sizes
108
+ for (int64_t i = 0 ; i < t.dim () - 1 ; i++) {
109
+ ET_CHECK_OR_RETURN_FALSE (
110
+ (t.size (i) == t_scales.size (i)),
111
+ " Quantized tensor and scales have different shape"
112
+ " at dim: %" PRId64 " , t: %zd, t_scales: %zd" ,
113
+ i,
114
+ t.size (i),
115
+ t_scales.size (i));
116
+ ;
117
+ ET_CHECK_OR_RETURN_FALSE (
118
+ (t.size (i) == t_zero_points.size (i)),
119
+ " Quantized tensor and zero points have different shape"
120
+ " at dim: %" PRId64 " , t: %zd, t_scales: %zd" ,
121
+ i,
122
+ t.size (i),
123
+ t_zero_points.size (i));
124
+ ;
125
+ }
126
+
127
+ return true ;
128
+ }
129
+
85
130
bool validate_cache_params (
86
131
const Tensor& k_cache,
87
132
const Tensor& v_cache,
@@ -233,7 +278,13 @@ Tensor& flash_attention_kernel_out(
233
278
dropout_p,
234
279
is_causal,
235
280
attn_mask,
236
- scale);
281
+ scale,
282
+ nullopt,
283
+ nullopt,
284
+ nullopt,
285
+ nullopt,
286
+ nullopt,
287
+ nullopt);
237
288
} else if (q_seq_len >= 192 ) {
238
289
sdpa::impl::cpu_flash_attention<CTYPE, 64 , 512 >(
239
290
output,
@@ -243,7 +294,13 @@ Tensor& flash_attention_kernel_out(
243
294
dropout_p,
244
295
is_causal,
245
296
attn_mask,
246
- scale);
297
+ scale,
298
+ nullopt,
299
+ nullopt,
300
+ nullopt,
301
+ nullopt,
302
+ nullopt,
303
+ nullopt);
247
304
} else {
248
305
sdpa::impl::cpu_flash_attention<CTYPE, 32 , 512 >(
249
306
output,
@@ -253,28 +310,19 @@ Tensor& flash_attention_kernel_out(
253
310
dropout_p,
254
311
is_causal,
255
312
attn_mask,
256
- scale);
313
+ scale,
314
+ nullopt,
315
+ nullopt,
316
+ nullopt,
317
+ nullopt,
318
+ nullopt,
319
+ nullopt);
257
320
}
258
321
});
259
322
return output;
260
323
}
261
324
262
- /*
263
- Input params
264
- @param[in] q_projected Projected query with query weights.
265
- Format [n_layers, batch size, seq_len, num heads, head dim]
266
- @param[in] k_projected Projected query with key weights.
267
- Format [n_layers, batch size, seq_len, num heads, head dim]
268
- @param[in] v_projected Projected query with value weights.
269
- Format [n_layers, batch size, seq_len, num heads, head dim]
270
- @param[in] key_cache Cache of previous k_projected.
271
- Format [n_layers, batch size, max_seq_len, num heads, head dim]
272
- @param[in] key_cache Cache of previous v_projected.
273
- Format [n_layers, batch size, max_seq_len, num heads, head dim]
274
- ....
275
- @param[in] start_pos: sequence position
276
- */
277
- Tensor& custom_sdpa_out (
325
+ Tensor& custom_sdpa_out_impl (
278
326
RuntimeContext& ctx,
279
327
const Tensor& q,
280
328
const Tensor& k,
@@ -285,7 +333,13 @@ Tensor& custom_sdpa_out(
285
333
const bool is_causal,
286
334
// @lint-ignore CLANGTIDY facebook-hte-ParameterMightThrowOnCopy
287
335
const optional<double > scale,
288
- Tensor& output) {
336
+ Tensor& output,
337
+ const optional<Tensor>& q_zero_points = nullopt,
338
+ const optional<Tensor>& q_scales = nullopt,
339
+ const optional<Tensor>& k_zero_points = nullopt,
340
+ const optional<Tensor>& k_scales = nullopt,
341
+ const optional<Tensor>& v_zero_points = nullopt,
342
+ const optional<Tensor>& v_scales = nullopt) {
289
343
ET_KERNEL_CHECK_MSG (
290
344
ctx,
291
345
!attn_mask.has_value () || !is_causal,
@@ -300,6 +354,40 @@ Tensor& custom_sdpa_out(
300
354
output,
301
355
" Invalid arguments" );
302
356
357
+ bool is_seq_at_dim_1{true };
358
+ if (q.scalar_type () == ScalarType::Char) {
359
+ is_seq_at_dim_1 = false ;
360
+ ET_KERNEL_CHECK_MSG (
361
+ ctx,
362
+ q_scales.has_value () && q_zero_points.has_value () &&
363
+ k_scales.has_value () && k_zero_points.has_value () &&
364
+ q_scales.has_value () && q_zero_points.has_value (),
365
+ InvalidArgument,
366
+ output,
367
+ " If q is quantized, k and v must be quantized as well" );
368
+ ET_KERNEL_CHECK_MSG (
369
+ ctx,
370
+ validate_cache_quant_params_args (
371
+ q, q_zero_points.value (), q_scales.value ()),
372
+ InvalidArgument,
373
+ output,
374
+ " Invalid arguments for quantized query" );
375
+ ET_KERNEL_CHECK_MSG (
376
+ ctx,
377
+ validate_cache_quant_params_args (
378
+ k, k_zero_points.value (), k_scales.value ()),
379
+ InvalidArgument,
380
+ output,
381
+ " Invalid arguments for quantized key" );
382
+ ET_KERNEL_CHECK_MSG (
383
+ ctx,
384
+ validate_cache_quant_params_args (
385
+ v, v_zero_points.value (), v_scales.value ()),
386
+ InvalidArgument,
387
+ output,
388
+ " Invalid arguments for quantized value" );
389
+ }
390
+
303
391
ET_CHECK_MSG (q.dim () == 4 , " query must be a 4D tensor" );
304
392
305
393
const int64_t seq_len = q.size (1 );
@@ -315,53 +403,103 @@ Tensor& custom_sdpa_out(
315
403
316
404
// TODO(task): replace the template param selection logic
317
405
// with whatever apprpriately makes more sense for
318
- ET_SWITCH_FLOAT_TYPES (q.scalar_type (), ctx, " flash_attention" , CTYPE, [&] {
319
- // TODO we need to re-evaluate this for ARM CPUs
320
- // And there can be many so instead of templatizing
321
- // we might consider another appraoch
322
- if (q_seq_len >= 768 ) {
323
- sdpa::impl::cpu_flash_attention<CTYPE, 256 , 512 >(
324
- output,
325
- q,
326
- k,
327
- v,
328
- dropout_p,
329
- is_causal,
330
- attn_mask,
331
- scale,
332
- true , /* is_seq_at_dim_1 */
333
- start_pos,
334
- num_keys_for_causal_attention);
335
- } else if (q_seq_len >= 192 ) {
336
- sdpa::impl::cpu_flash_attention<CTYPE, 64 , 512 >(
337
- output,
338
- q,
339
- k,
340
- v,
341
- dropout_p,
342
- is_causal,
343
- attn_mask,
344
- scale,
345
- true , /* is_seq_at_dim_1 */
346
- start_pos,
347
- num_keys_for_causal_attention);
348
- } else {
349
- sdpa::impl::cpu_flash_attention<CTYPE, 32 , 512 >(
350
- output,
351
- q,
352
- k,
353
- v,
354
- dropout_p,
355
- is_causal,
356
- attn_mask,
357
- scale,
358
- true , /* is_seq_at_dim_1 */
359
- start_pos,
360
- num_keys_for_causal_attention);
361
- }
362
- });
406
+ ET_SWITCH_FLOAT_TYPES (
407
+ output.scalar_type (), ctx, " flash_attention" , CTYPE, [&] {
408
+ // TODO we need to re-evaluate this for ARM CPUs
409
+ // And there can be many so instead of templatizing
410
+ // we might consider another appraoch
411
+ if (q_seq_len >= 768 ) {
412
+ sdpa::impl::cpu_flash_attention<CTYPE, 256 , 512 >(
413
+ output,
414
+ q,
415
+ k,
416
+ v,
417
+ dropout_p,
418
+ is_causal,
419
+ attn_mask,
420
+ scale,
421
+ nullopt, // q_zero_points
422
+ nullopt, // q_scales
423
+ nullopt, // k_zero_points
424
+ nullopt, // k_scales
425
+ nullopt, // v_zero_points
426
+ nullopt, // v_scales
427
+ is_seq_at_dim_1, /* is_seq_at_dim_1 */
428
+ start_pos,
429
+ num_keys_for_causal_attention);
430
+ } else if (q_seq_len >= 192 ) {
431
+ sdpa::impl::cpu_flash_attention<CTYPE, 64 , 512 >(
432
+ output,
433
+ q,
434
+ k,
435
+ v,
436
+ dropout_p,
437
+ is_causal,
438
+ attn_mask,
439
+ scale,
440
+ nullopt, // q_zero_points
441
+ nullopt, // q_scales
442
+ nullopt, // k_zero_points
443
+ nullopt, // k_scales
444
+ nullopt, // v_zero_points
445
+ nullopt, // v_scales
446
+ is_seq_at_dim_1, /* is_seq_at_dim_1 */
447
+ start_pos,
448
+ num_keys_for_causal_attention);
449
+ } else {
450
+ sdpa::impl::cpu_flash_attention<CTYPE, 32 , 512 >(
451
+ output,
452
+ q,
453
+ k,
454
+ v,
455
+ dropout_p,
456
+ is_causal,
457
+ attn_mask,
458
+ scale,
459
+ nullopt, // q_zero_points
460
+ nullopt, // q_scales
461
+ nullopt, // k_zero_points
462
+ nullopt, // k_scales
463
+ nullopt, // v_zero_points
464
+ nullopt, // v_scales
465
+ is_seq_at_dim_1, /* is_seq_at_dim_1 */
466
+ start_pos,
467
+ num_keys_for_causal_attention);
468
+ }
469
+ });
363
470
return output;
364
471
}
472
+
473
+ /*
474
+ Input params
475
+ @param[in] q_projected Projected query with query weights.
476
+ Format [n_layers, batch size, seq_len, num heads, head dim]
477
+ @param[in] k_projected Projected query with key weights.
478
+ Format [n_layers, batch size, seq_len, num heads, head dim]
479
+ @param[in] v_projected Projected query with value weights.
480
+ Format [n_layers, batch size, seq_len, num heads, head dim]
481
+ @param[in] key_cache Cache of previous k_projected.
482
+ Format [n_layers, batch size, max_seq_len, num heads, head dim]
483
+ @param[in] key_cache Cache of previous v_projected.
484
+ Format [n_layers, batch size, max_seq_len, num heads, head dim]
485
+ ....
486
+ @param[in] start_pos: sequence position
487
+ */
488
+ Tensor& custom_sdpa_out (
489
+ RuntimeContext& ctx,
490
+ const Tensor& q,
491
+ const Tensor& k,
492
+ const Tensor& v,
493
+ const int64_t start_pos,
494
+ const optional<Tensor>& attn_mask,
495
+ const double dropout_p,
496
+ const bool is_causal,
497
+ // @lint-ignore CLANGTIDY facebook-hte-ParameterMightThrowOnCopy
498
+ const optional<double > scale,
499
+ Tensor& output) {
500
+ return custom_sdpa_out_impl (
501
+ ctx, q, k, v, start_pos, attn_mask, dropout_p, is_causal, scale, output);
502
+ }
365
503
/*
366
504
Input params
367
505
@param[in] q_projected Projected query with query weights.
0 commit comments