1
+ #pragma once
2
+
1
3
#include " common.cuh"
2
4
#include " vecdotq.cuh"
3
5
@@ -510,6 +512,48 @@ static __device__ __forceinline__ T dequantize_1_f16(const void * __restrict__ v
510
512
return x[i];
511
513
}
512
514
515
+ template <int D>
516
+ constexpr __device__ vec_dot_KQ_f16_t get_vec_dot_KQ_f16 (ggml_type type_K) {
517
+ return type_K == GGML_TYPE_Q4_0 ? vec_dot_fattn_vec_KQ_q4_0<half, D> :
518
+ type_K == GGML_TYPE_Q4_1 ? vec_dot_fattn_vec_KQ_q4_1<half, D> :
519
+ type_K == GGML_TYPE_Q5_0 ? vec_dot_fattn_vec_KQ_q5_0<half, D> :
520
+ type_K == GGML_TYPE_Q5_1 ? vec_dot_fattn_vec_KQ_q5_1<half, D> :
521
+ type_K == GGML_TYPE_Q8_0 ? vec_dot_fattn_vec_KQ_q8_0<half, D> :
522
+ type_K == GGML_TYPE_F16 ? vec_dot_fattn_vec_KQ_f16<half, D> :
523
+ nullptr ;
524
+ }
525
+
526
+ template <int D>
527
+ constexpr __device__ vec_dot_KQ_f32_t get_vec_dot_KQ_f32 (ggml_type type_K) {
528
+ return type_K == GGML_TYPE_Q4_0 ? vec_dot_fattn_vec_KQ_q4_0<float , D> :
529
+ type_K == GGML_TYPE_Q4_1 ? vec_dot_fattn_vec_KQ_q4_1<float , D> :
530
+ type_K == GGML_TYPE_Q5_0 ? vec_dot_fattn_vec_KQ_q5_0<float , D> :
531
+ type_K == GGML_TYPE_Q5_1 ? vec_dot_fattn_vec_KQ_q5_1<float , D> :
532
+ type_K == GGML_TYPE_Q8_0 ? vec_dot_fattn_vec_KQ_q8_0<float , D> :
533
+ type_K == GGML_TYPE_F16 ? vec_dot_fattn_vec_KQ_f16<float , D> :
534
+ nullptr ;
535
+ }
536
+
537
+ constexpr __device__ dequantize_1_f16_t get_dequantize_1_f16 (ggml_type type_V) {
538
+ return type_V == GGML_TYPE_Q4_0 ? dequantize_1_q4_0<half> :
539
+ type_V == GGML_TYPE_Q4_1 ? dequantize_1_q4_1<half> :
540
+ type_V == GGML_TYPE_Q5_0 ? dequantize_1_q5_0<half> :
541
+ type_V == GGML_TYPE_Q5_1 ? dequantize_1_q5_1<half> :
542
+ type_V == GGML_TYPE_Q8_0 ? dequantize_1_q8_0<half> :
543
+ type_V == GGML_TYPE_F16 ? dequantize_1_f16<half> :
544
+ nullptr ;
545
+ }
546
+
547
+ constexpr __device__ dequantize_1_f32_t get_dequantize_1_f32 (ggml_type type_V) {
548
+ return type_V == GGML_TYPE_Q4_0 ? dequantize_1_q4_0<float > :
549
+ type_V == GGML_TYPE_Q4_1 ? dequantize_1_q4_1<float > :
550
+ type_V == GGML_TYPE_Q5_0 ? dequantize_1_q5_0<float > :
551
+ type_V == GGML_TYPE_Q5_1 ? dequantize_1_q5_1<float > :
552
+ type_V == GGML_TYPE_Q8_0 ? dequantize_1_q8_0<float > :
553
+ type_V == GGML_TYPE_F16 ? dequantize_1_f16<float > :
554
+ nullptr ;
555
+ }
556
+
513
557
template <int D, int parallel_blocks> // D == head size
514
558
#if !(defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__))
515
559
__launch_bounds__ (D, 1 )
@@ -565,13 +609,12 @@ static constexpr ggml_type ggml_type_f16 = GGML_TYPE_F16;
565
609
typedef half f16;
566
610
typedef float f32;
567
611
568
- #define FATTN_VEC_CASE (type_VKQ, D, type_suffix_K, type_suffix_V ) \
569
- if (Q->ne[0 ] == (D) && K->type == ggml_type_##type_suffix_K && V->type == ggml_type_##type_suffix_V ) { \
612
+ #define FATTN_VEC_CASE (type_VKQ, D, type_K, type_V ) \
613
+ if (Q->ne[0 ] == (D) && K->type == type_K && V->type == type_V ) { \
570
614
constexpr int nwarps = (D)/WARP_SIZE; \
571
- constexpr bool Q_q8_1 = ggml_type_##type_suffix_K != GGML_TYPE_F16; \
572
615
fattn_kernel_t fattn_kernel = flash_attn_vec_ext_##type_VKQ< \
573
616
(D), cols_per_block, parallel_blocks, \
574
- vec_dot_fattn_vec_KQ_##type_suffix_K<type_VKQ, (D)>, Q_q8_1, dequantize_1_##type_suffix_V<type_VKQ> >; \
617
+ type_K, type_V >; \
575
618
launch_fattn<(D), parallel_blocks>(ctx, dst, fattn_kernel, nwarps, cols_per_block); \
576
619
return ; \
577
620
} \
@@ -582,14 +625,18 @@ static void on_no_fattn_vec_case(const int D) {
582
625
fprintf (stderr, " By default only f16 KV cache is supported.\n " );
583
626
fprintf (stderr, " Compile with LLAMA_CUDA_FA_ALL_QUANTS for V cache quantization support.\n " );
584
627
GGML_ASSERT (false );
585
- } else {
628
+ } else if (D == 128 ) {
586
629
fprintf (stderr, " Unsupported KV type combination for head_size 128.\n " );
587
630
fprintf (stderr, " Supported combinations:\n " );
588
631
fprintf (stderr, " - K == q4_0, V == q4_0, 4.50 BPV\n " );
589
632
fprintf (stderr, " - K == q8_0, V == q8_0, 8.50 BPV\n " );
590
633
fprintf (stderr, " - K == f16, V == f16, 16.00 BPV\n " );
591
634
fprintf (stderr, " Compile with LLAMA_CUDA_FA_ALL_QUANTS for all combinations of q4_0, q4_1, q5_0, q5_1, q8_0, and f16.\n " );
592
635
GGML_ASSERT (false );
636
+ } else {
637
+ fprintf (stderr, " Unsupported KV type combination for head_size 256.\n " );
638
+ fprintf (stderr, " Only f16 is supported.\n " );
639
+ GGML_ASSERT (false );
593
640
}
594
641
}
595
642
0 commit comments