Skip to content

Commit 84d9277

Browse files
split fattn compile via extern templates
1 parent f4003cf commit 84d9277

9 files changed

+1609
-1327
lines changed

Makefile

+13-1
Original file line numberDiff line numberDiff line change
@@ -421,6 +421,15 @@ ifdef LLAMA_CUBLAS
421421
LLAMA_CUDA := 1
422422
endif
423423

424+
OBJS_CUDA_TEMP_INST = $(patsubst %.cu,%.o,$(wildcard ggml-cuda/template-instances/fattn-wmma*.cu))
425+
ifdef LLAMA_CUDA_FA_ALL_QUANTS
426+
OBJS_CUDA_TEMP_INST += $(patsubst %.cu,%.o,$(wildcard ggml-cuda/template-instances/fattn-vec*.cu))
427+
else
428+
OBJS_CUDA_TEMP_INST += $(patsubst %.cu,%.o,$(wildcard ggml-cuda/template-instances/fattn-vec*q4_0-q4_0.cu))
429+
OBJS_CUDA_TEMP_INST += $(patsubst %.cu,%.o,$(wildcard ggml-cuda/template-instances/fattn-vec*q8_0-q8_0.cu))
430+
OBJS_CUDA_TEMP_INST += $(patsubst %.cu,%.o,$(wildcard ggml-cuda/template-instances/fattn-vec*f16-f16.cu))
431+
endif # LLAMA_CUDA_FA_ALL_QUANTS
432+
424433
ifdef LLAMA_CUDA
425434
ifneq ('', '$(wildcard /opt/cuda)')
426435
CUDA_PATH ?= /opt/cuda
@@ -431,6 +440,7 @@ ifdef LLAMA_CUDA
431440
MK_LDFLAGS += -lcuda -lcublas -lculibos -lcudart -lcublasLt -lpthread -ldl -lrt -L$(CUDA_PATH)/lib64 -L/usr/lib64 -L$(CUDA_PATH)/targets/$(UNAME_M)-linux/lib -L/usr/lib/wsl/lib
432441
OBJS += ggml-cuda.o
433442
OBJS += $(patsubst %.cu,%.o,$(wildcard ggml-cuda/*.cu))
443+
OBJS += $(OBJS_CUDA_TEMP_INST)
434444
MK_NVCCFLAGS += -use_fast_math
435445
ifdef LLAMA_FATAL_WARNINGS
436446
MK_NVCCFLAGS += -Werror all-warnings
@@ -508,7 +518,7 @@ define NVCC_COMPILE
508518
endef # NVCC_COMPILE
509519
endif # JETSON_EOL_MODULE_DETECT
510520

511-
ggml-cuda/%.o: ggml-cuda/%.cu ggml-cuda/%.cuh ggml.h ggml-common.h ggml-cuda/common.cuh
521+
ggml-cuda/%.o: ggml-cuda/%.cu ggml.h ggml-common.h ggml-cuda/common.cuh
512522
$(NVCC_COMPILE)
513523

514524
ggml-cuda.o: ggml-cuda.cu ggml-cuda.h ggml.h ggml-backend.h ggml-backend-impl.h ggml-common.h $(wildcard ggml-cuda/*.cuh)
@@ -587,6 +597,7 @@ ifdef LLAMA_CUDA_NO_PEER_COPY
587597
endif # LLAMA_CUDA_NO_PEER_COPY
588598
OBJS += ggml-cuda.o
589599
OBJS += $(patsubst %.cu,%.o,$(wildcard ggml-cuda/*.cu))
600+
OBJS += $(OBJS_CUDA_TEMP_INST)
590601

591602
ggml-cuda.o: ggml-cuda.cu ggml-cuda.h ggml.h ggml-backend.h ggml-backend-impl.h ggml-common.h $(wildcard ggml-cuda/*.cuh)
592603
$(HIPCC) $(CXXFLAGS) $(HIPFLAGS) -x hip -c -o $@ $<
@@ -751,6 +762,7 @@ libllama.a: llama.o ggml.o $(OBJS) $(COMMON_DEPS)
751762
clean:
752763
rm -vrf *.o tests/*.o *.so *.a *.dll benchmark-matmult lookup-create lookup-merge lookup-stats common/build-info.cpp *.dot $(COV_TARGETS) $(BUILD_TARGETS) $(TEST_TARGETS)
753764
rm -vrf ggml-cuda/*.o
765+
rm -vrf ggml-cuda/template-instances/*.o
754766
find examples pocs -type f -name "*.o" -delete
755767

756768
#

ggml-cuda/fattn-common.cuh

+52-5
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,5 @@
1+
#pragma once
2+
13
#include "common.cuh"
24
#include "vecdotq.cuh"
35

@@ -510,6 +512,48 @@ static __device__ __forceinline__ T dequantize_1_f16(const void * __restrict__ v
510512
return x[i];
511513
}
512514

515+
template <int D>
516+
constexpr __device__ vec_dot_KQ_f16_t get_vec_dot_KQ_f16(ggml_type type_K) {
517+
return type_K == GGML_TYPE_Q4_0 ? vec_dot_fattn_vec_KQ_q4_0<half, D> :
518+
type_K == GGML_TYPE_Q4_1 ? vec_dot_fattn_vec_KQ_q4_1<half, D> :
519+
type_K == GGML_TYPE_Q5_0 ? vec_dot_fattn_vec_KQ_q5_0<half, D> :
520+
type_K == GGML_TYPE_Q5_1 ? vec_dot_fattn_vec_KQ_q5_1<half, D> :
521+
type_K == GGML_TYPE_Q8_0 ? vec_dot_fattn_vec_KQ_q8_0<half, D> :
522+
type_K == GGML_TYPE_F16 ? vec_dot_fattn_vec_KQ_f16<half, D> :
523+
nullptr;
524+
}
525+
526+
template <int D>
527+
constexpr __device__ vec_dot_KQ_f32_t get_vec_dot_KQ_f32(ggml_type type_K) {
528+
return type_K == GGML_TYPE_Q4_0 ? vec_dot_fattn_vec_KQ_q4_0<float, D> :
529+
type_K == GGML_TYPE_Q4_1 ? vec_dot_fattn_vec_KQ_q4_1<float, D> :
530+
type_K == GGML_TYPE_Q5_0 ? vec_dot_fattn_vec_KQ_q5_0<float, D> :
531+
type_K == GGML_TYPE_Q5_1 ? vec_dot_fattn_vec_KQ_q5_1<float, D> :
532+
type_K == GGML_TYPE_Q8_0 ? vec_dot_fattn_vec_KQ_q8_0<float, D> :
533+
type_K == GGML_TYPE_F16 ? vec_dot_fattn_vec_KQ_f16<float, D> :
534+
nullptr;
535+
}
536+
537+
constexpr __device__ dequantize_1_f16_t get_dequantize_1_f16(ggml_type type_V) {
538+
return type_V == GGML_TYPE_Q4_0 ? dequantize_1_q4_0<half> :
539+
type_V == GGML_TYPE_Q4_1 ? dequantize_1_q4_1<half> :
540+
type_V == GGML_TYPE_Q5_0 ? dequantize_1_q5_0<half> :
541+
type_V == GGML_TYPE_Q5_1 ? dequantize_1_q5_1<half> :
542+
type_V == GGML_TYPE_Q8_0 ? dequantize_1_q8_0<half> :
543+
type_V == GGML_TYPE_F16 ? dequantize_1_f16<half> :
544+
nullptr;
545+
}
546+
547+
constexpr __device__ dequantize_1_f32_t get_dequantize_1_f32(ggml_type type_V) {
548+
return type_V == GGML_TYPE_Q4_0 ? dequantize_1_q4_0<float> :
549+
type_V == GGML_TYPE_Q4_1 ? dequantize_1_q4_1<float> :
550+
type_V == GGML_TYPE_Q5_0 ? dequantize_1_q5_0<float> :
551+
type_V == GGML_TYPE_Q5_1 ? dequantize_1_q5_1<float> :
552+
type_V == GGML_TYPE_Q8_0 ? dequantize_1_q8_0<float> :
553+
type_V == GGML_TYPE_F16 ? dequantize_1_f16<float> :
554+
nullptr;
555+
}
556+
513557
template<int D, int parallel_blocks> // D == head size
514558
#if !(defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__))
515559
__launch_bounds__(D, 1)
@@ -565,13 +609,12 @@ static constexpr ggml_type ggml_type_f16 = GGML_TYPE_F16;
565609
typedef half f16;
566610
typedef float f32;
567611

568-
#define FATTN_VEC_CASE(type_VKQ, D, type_suffix_K, type_suffix_V) \
569-
if (Q->ne[0] == (D) && K->type == ggml_type_##type_suffix_K && V->type == ggml_type_##type_suffix_V) { \
612+
#define FATTN_VEC_CASE(type_VKQ, D, type_K, type_V) \
613+
if (Q->ne[0] == (D) && K->type == type_K && V->type == type_V) { \
570614
constexpr int nwarps = (D)/WARP_SIZE; \
571-
constexpr bool Q_q8_1 = ggml_type_##type_suffix_K != GGML_TYPE_F16; \
572615
fattn_kernel_t fattn_kernel = flash_attn_vec_ext_##type_VKQ< \
573616
(D), cols_per_block, parallel_blocks, \
574-
vec_dot_fattn_vec_KQ_##type_suffix_K<type_VKQ, (D)>, Q_q8_1, dequantize_1_##type_suffix_V<type_VKQ>>; \
617+
type_K, type_V>; \
575618
launch_fattn<(D), parallel_blocks>(ctx, dst, fattn_kernel, nwarps, cols_per_block); \
576619
return; \
577620
} \
@@ -582,14 +625,18 @@ static void on_no_fattn_vec_case(const int D) {
582625
fprintf(stderr, "By default only f16 KV cache is supported.\n");
583626
fprintf(stderr, "Compile with LLAMA_CUDA_FA_ALL_QUANTS for V cache quantization support.\n");
584627
GGML_ASSERT(false);
585-
} else {
628+
} else if (D == 128) {
586629
fprintf(stderr, "Unsupported KV type combination for head_size 128.\n");
587630
fprintf(stderr, "Supported combinations:\n");
588631
fprintf(stderr, " - K == q4_0, V == q4_0, 4.50 BPV\n");
589632
fprintf(stderr, " - K == q8_0, V == q8_0, 8.50 BPV\n");
590633
fprintf(stderr, " - K == f16, V == f16, 16.00 BPV\n");
591634
fprintf(stderr, "Compile with LLAMA_CUDA_FA_ALL_QUANTS for all combinations of q4_0, q4_1, q5_0, q5_1, q8_0, and f16.\n");
592635
GGML_ASSERT(false);
636+
} else {
637+
fprintf(stderr, "Unsupported KV type combination for head_size 256.\n");
638+
fprintf(stderr, "Only f16 is supported.\n");
639+
GGML_ASSERT(false);
593640
}
594641
}
595642

0 commit comments

Comments
 (0)