1
1
#pragma once
2
2
3
3
#include " common.cuh"
4
+ #include " convert.cuh"
4
5
#include " vecdotq.cuh"
5
6
6
7
#include < cstdint>
@@ -53,7 +54,7 @@ typedef float (*vec_dot_KQ_f32_t)(
53
54
template <typename T, int D>
54
55
static __device__ __forceinline__ T vec_dot_fattn_vec_KQ_q4_0 (
55
56
const char * __restrict__ K_c, const void * __restrict__ Q_v, const int * __restrict__ Q_q8, const void * __restrict__ Q_ds_v) {
56
- #if __CUDA_ARCH__ > MIN_CC_DP4A
57
+ #if __CUDA_ARCH__ >= MIN_CC_DP4A
57
58
58
59
const block_q4_0 * K_q4_0 = (const block_q4_0 *) K_c;
59
60
GGML_UNUSED (Q_v);
@@ -95,13 +96,13 @@ static __device__ __forceinline__ T vec_dot_fattn_vec_KQ_q4_0(
95
96
GGML_UNUSED (Q_q8);
96
97
GGML_UNUSED (Q_ds_v);
97
98
NO_DEVICE_CODE;
98
- #endif // __CUDA_ARCH__ > MIN_CC_DP4A
99
+ #endif // __CUDA_ARCH__ >= MIN_CC_DP4A
99
100
}
100
101
101
102
template <typename T, int D>
102
103
static __device__ __forceinline__ T vec_dot_fattn_vec_KQ_q4_1 (
103
104
const char * __restrict__ K_c, const void * __restrict__ Q_v, const int * __restrict__ Q_q8, const void * __restrict__ Q_ds_v) {
104
- #if __CUDA_ARCH__ > MIN_CC_DP4A
105
+ #if __CUDA_ARCH__ >= MIN_CC_DP4A
105
106
106
107
const block_q4_1 * K_q4_1 = (const block_q4_1 *) K_c;
107
108
GGML_UNUSED (Q_v);
@@ -147,13 +148,13 @@ static __device__ __forceinline__ T vec_dot_fattn_vec_KQ_q4_1(
147
148
GGML_UNUSED (Q_q8);
148
149
GGML_UNUSED (Q_ds_v);
149
150
NO_DEVICE_CODE;
150
- #endif // __CUDA_ARCH__ > MIN_CC_DP4A
151
+ #endif // __CUDA_ARCH__ >= MIN_CC_DP4A
151
152
}
152
153
153
154
template <typename T, int D>
154
155
static __device__ __forceinline__ T vec_dot_fattn_vec_KQ_q5_0 (
155
156
const char * __restrict__ K_c, const void * __restrict__ Q_v, const int * __restrict__ Q_q8, const void * __restrict__ Q_ds_v) {
156
- #if __CUDA_ARCH__ > MIN_CC_DP4A
157
+ #if __CUDA_ARCH__ >= MIN_CC_DP4A
157
158
158
159
const block_q5_0 * K_q5_0 = (const block_q5_0 *) K_c;
159
160
GGML_UNUSED (Q_v);
@@ -202,13 +203,13 @@ static __device__ __forceinline__ T vec_dot_fattn_vec_KQ_q5_0(
202
203
GGML_UNUSED (Q_q8);
203
204
GGML_UNUSED (Q_ds_v);
204
205
NO_DEVICE_CODE;
205
- #endif // __CUDA_ARCH__ > MIN_CC_DP4A
206
+ #endif // __CUDA_ARCH__ >= MIN_CC_DP4A
206
207
}
207
208
208
209
template <typename T, int D>
209
210
static __device__ __forceinline__ T vec_dot_fattn_vec_KQ_q5_1 (
210
211
const char * __restrict__ K_c, const void * __restrict__ Q_v, const int * __restrict__ Q_q8, const void * __restrict__ Q_ds_v) {
211
- #if __CUDA_ARCH__ > MIN_CC_DP4A
212
+ #if __CUDA_ARCH__ >= MIN_CC_DP4A
212
213
213
214
const block_q5_1 * K_q5_1 = (const block_q5_1 *) K_c;
214
215
GGML_UNUSED (Q_v);
@@ -261,13 +262,13 @@ static __device__ __forceinline__ T vec_dot_fattn_vec_KQ_q5_1(
261
262
GGML_UNUSED (Q_q8);
262
263
GGML_UNUSED (Q_ds_v);
263
264
NO_DEVICE_CODE;
264
- #endif // __CUDA_ARCH__ > MIN_CC_DP4A
265
+ #endif // __CUDA_ARCH__ >= MIN_CC_DP4A
265
266
}
266
267
267
268
template <typename T, int D>
268
269
static __device__ __forceinline__ T vec_dot_fattn_vec_KQ_q8_0 (
269
270
const char * __restrict__ K_c, const void * __restrict__ Q_v, const int * __restrict__ Q_q8, const void * __restrict__ Q_ds_v) {
270
- #if __CUDA_ARCH__ > MIN_CC_DP4A
271
+ #if __CUDA_ARCH__ >= MIN_CC_DP4A
271
272
272
273
const block_q8_0 * K_q8_0 = (const block_q8_0 *) K_c;
273
274
GGML_UNUSED (Q_v);
@@ -302,7 +303,7 @@ static __device__ __forceinline__ T vec_dot_fattn_vec_KQ_q8_0(
302
303
GGML_UNUSED (Q_q8);
303
304
GGML_UNUSED (Q_ds_v);
304
305
NO_DEVICE_CODE;
305
- #endif // __CUDA_ARCH__ > MIN_CC_DP4A
306
+ #endif // __CUDA_ARCH__ >= MIN_CC_DP4A
306
307
}
307
308
308
309
template <typename T, int D>
@@ -620,7 +621,10 @@ static void on_no_fattn_vec_case(const int D) {
620
621
}
621
622
622
623
template <int D, int parallel_blocks>
623
- void launch_fattn (ggml_backend_cuda_context & ctx, ggml_tensor * dst, fattn_kernel_t fattn_kernel, int nwarps, int cols_per_block) {
624
+ void launch_fattn (
625
+ ggml_backend_cuda_context & ctx, ggml_tensor * dst, fattn_kernel_t fattn_kernel,
626
+ const int nwarps, const int cols_per_block, const bool need_f16_K, const bool need_f16_V
627
+ ) {
624
628
const ggml_tensor * Q = dst->src [0 ];
625
629
const ggml_tensor * K = dst->src [1 ];
626
630
const ggml_tensor * V = dst->src [2 ];
@@ -641,9 +645,49 @@ void launch_fattn(ggml_backend_cuda_context & ctx, ggml_tensor * dst, fattn_kern
641
645
ggml_cuda_pool & pool = ctx.pool ();
642
646
cudaStream_t main_stream = ctx.stream ();
643
647
648
+ ggml_cuda_pool_alloc<half> K_f16 (pool);
649
+ ggml_cuda_pool_alloc<half> V_f16 (pool);
644
650
ggml_cuda_pool_alloc<float > dst_tmp (pool);
645
651
ggml_cuda_pool_alloc<float2 > dst_tmp_meta (pool);
646
652
653
+ char * K_data = (char *) K->data ;
654
+ size_t nb11 = K->nb [1 ];
655
+ size_t nb12 = K->nb [2 ];
656
+ size_t nb13 = K->nb [3 ];
657
+
658
+ char * V_data = (char *) V->data ;
659
+ size_t nb21 = V->nb [1 ];
660
+ size_t nb22 = V->nb [2 ];
661
+ size_t nb23 = V->nb [3 ];
662
+
663
+ if (need_f16_K && K->type != GGML_TYPE_F16) {
664
+ K_f16.alloc (ggml_nelements (K));
665
+ to_fp16_cuda_t to_fp16 = ggml_get_to_fp16_cuda (K->type );
666
+ to_fp16 (K_data, K_f16.ptr , ggml_nelements (K), main_stream);
667
+ K_data = (char *) K_f16.ptr ;
668
+
669
+ const size_t bs = ggml_blck_size (K->type );
670
+ const size_t ts = ggml_type_size (K->type );
671
+
672
+ nb11 = nb11*bs*sizeof (half)/ts;
673
+ nb12 = nb12*bs*sizeof (half)/ts;
674
+ nb13 = nb13*bs*sizeof (half)/ts;
675
+ }
676
+
677
+ if (need_f16_V && V->type != GGML_TYPE_F16) {
678
+ V_f16.alloc (ggml_nelements (V));
679
+ to_fp16_cuda_t to_fp16 = ggml_get_to_fp16_cuda (V->type );
680
+ to_fp16 (V_data, V_f16.ptr , ggml_nelements (V), main_stream);
681
+ V_data = (char *) V_f16.ptr ;
682
+
683
+ const size_t bs = ggml_blck_size (V->type );
684
+ const size_t ts = ggml_type_size (V->type );
685
+
686
+ nb21 = nb21*bs*sizeof (half)/ts;
687
+ nb22 = nb22*bs*sizeof (half)/ts;
688
+ nb23 = nb23*bs*sizeof (half)/ts;
689
+ }
690
+
647
691
if (parallel_blocks > 1 ) {
648
692
dst_tmp.alloc (parallel_blocks*ggml_nelements (KQV));
649
693
dst_tmp_meta.alloc (parallel_blocks*ggml_nrows (KQV));
@@ -667,17 +711,17 @@ void launch_fattn(ggml_backend_cuda_context & ctx, ggml_tensor * dst, fattn_kern
667
711
668
712
fattn_kernel<<<blocks_num, block_dim, shmem, main_stream>>> (
669
713
(const char *) Q->data ,
670
- ( const char *) K-> data ,
671
- ( const char *) V-> data ,
714
+ K_data ,
715
+ V_data ,
672
716
mask ? ((const char *) mask->data ) : nullptr ,
673
717
(parallel_blocks) == 1 ? (float *) KQV->data : dst_tmp.ptr , dst_tmp_meta.ptr ,
674
718
scale, max_bias, m0, m1, n_head_log2,
675
719
Q->ne [0 ], Q->ne [1 ], Q->ne [2 ], Q->ne [3 ],
676
720
K->ne [0 ], K->ne [1 ], K->ne [2 ], K->ne [3 ],
677
721
mask ? mask->ne [1 ] : 0 , mask ? mask->nb [1 ] : 0 ,
678
722
Q->nb [1 ], Q->nb [2 ], Q->nb [3 ],
679
- K-> nb [ 1 ], K-> nb [ 2 ], K-> nb [ 3 ] ,
680
- V-> nb [ 1 ], V-> nb [ 2 ], V-> nb [ 3 ] ,
723
+ nb11, nb12, nb13 ,
724
+ nb21, nb22, nb23 ,
681
725
KQV->ne [0 ], KQV->ne [1 ], KQV->ne [2 ], KQV->ne [3 ]
682
726
);
683
727
CUDA_CHECK (cudaGetLastError ());
0 commit comments