Skip to content

Commit 1607a5e

Browse files
chaxu01slaren
andauthored
backend cpu: add online flow for aarch64 Q4_0 GEMV/GEMM kernels (#9921)
* backend-cpu: add online flow for aarch64 Q4_0 GEMV/GEMM kernels --------- Co-authored-by: Diego Devesa <[email protected]>
1 parent ae8de6d commit 1607a5e

File tree

9 files changed

+271
-20
lines changed

9 files changed

+271
-20
lines changed

Makefile

+4
Original file line numberDiff line numberDiff line change
@@ -940,6 +940,10 @@ ggml/src/ggml-cuda/%.o: \
940940
$(MCC) $(CXXFLAGS) $(MUSAFLAGS) -x musa -mtgpu -c -o $@ $<
941941
endif # GGML_MUSA
942942

943+
ifndef GGML_NO_CPU_AARCH64
944+
MK_CPPFLAGS += -DGGML_USE_CPU_AARCH64
945+
endif
946+
943947
ifdef GGML_METAL
944948
MK_CPPFLAGS += -DGGML_USE_METAL
945949
MK_LDFLAGS += -framework Foundation -framework Metal -framework MetalKit

ggml/CMakeLists.txt

+1
Original file line numberDiff line numberDiff line change
@@ -92,6 +92,7 @@ else()
9292
endif()
9393

9494
option(GGML_CPU_HBM "ggml: use memkind for CPU HBM" OFF)
95+
option(GGML_CPU_AARCH64 "ggml: use runtime weight conversion of Q4_0 to Q4_X_X" ON)
9596

9697
option(GGML_AVX "ggml: enable AVX" ${INS_ENB})
9798
option(GGML_AVX2 "ggml: enable AVX2" ${INS_ENB})

ggml/include/ggml-cpu.h

+3
Original file line numberDiff line numberDiff line change
@@ -169,6 +169,9 @@ extern "C" {
169169
GGML_BACKEND_API ggml_backend_buffer_type_t ggml_backend_cpu_hbm_buffer_type(void);
170170
#endif
171171

172+
GGML_BACKEND_API ggml_backend_buffer_type_t ggml_backend_cpu_aarch64_buffer_type(void);
173+
GGML_BACKEND_API bool ggml_backend_cpu_buft_is_aarch64(ggml_backend_buffer_type_t buft);
174+
172175
#ifdef __cplusplus
173176
}
174177
#endif

ggml/src/ggml-cpu/CMakeLists.txt

+5
Original file line numberDiff line numberDiff line change
@@ -236,6 +236,11 @@ else()
236236
message(STATUS "Unknown architecture")
237237
endif()
238238

239+
if (GGML_CPU_AARCH64)
240+
message(STATUS "Using runtime weight conversion of Q4_0 to Q4_0_x_x to enable optimized GEMM/GEMV kernels")
241+
add_compile_definitions(GGML_USE_CPU_AARCH64)
242+
endif()
243+
239244
target_compile_options(ggml-cpu PRIVATE "$<$<COMPILE_LANGUAGE:CXX>:${ARCH_FLAGS}>")
240245
target_compile_options(ggml-cpu PRIVATE "$<$<COMPILE_LANGUAGE:C>:${ARCH_FLAGS}>")
241246

ggml/src/ggml-cpu/ggml-cpu-aarch64.c

+144
Original file line numberDiff line numberDiff line change
@@ -3385,3 +3385,147 @@ void ggml_gemm_q4_0_8x8_q8_0(int n, float * restrict s, size_t bs, const void *
33853385
}
33863386
}
33873387
}
3388+
3389+
// FIXME: this code is duplicated from ggml-aarch64.c
3390+
static block_q4_0x4 make_block_q4_0x4(block_q4_0 * in, unsigned int blck_size_interleave, unsigned int xor_mask) {
3391+
block_q4_0x4 out;
3392+
3393+
for (int i = 0; i < 4; i++) {
3394+
out.d[i] = in[i].d;
3395+
}
3396+
3397+
for (int i = 0; i < QK4_0 * 2; i++) {
3398+
int src_offset = (i / (4 * blck_size_interleave)) * blck_size_interleave;
3399+
int src_id = (i % (4 * blck_size_interleave)) / blck_size_interleave;
3400+
src_offset += (i % blck_size_interleave);
3401+
3402+
out.qs[i] = in[src_id].qs[src_offset] ^ xor_mask;
3403+
}
3404+
3405+
return out;
3406+
}
3407+
3408+
// interleave 8 block_q4_0s in blocks of blck_size_interleave
3409+
// returns an interleaved block_q4_0x8
3410+
// in the interleaved block_q4_0x8, place deltas for 8 block_q4_0 blocks
3411+
// first, then interleave quants from 8 block_q4_0s in blocks of blck_size_interleave
3412+
static block_q4_0x8 make_block_q4_0x8(block_q4_0 * in, unsigned int blck_size_interleave, unsigned int xor_mask) {
3413+
block_q4_0x8 out;
3414+
3415+
for (int i = 0; i < 8; i++) {
3416+
out.d[i] = in[i].d;
3417+
}
3418+
3419+
for (int i = 0; i < QK4_0 * 4; i++) {
3420+
int src_offset = (i / (8 * blck_size_interleave)) * blck_size_interleave;
3421+
int src_id = (i % (8 * blck_size_interleave)) / blck_size_interleave;
3422+
src_offset += (i % blck_size_interleave);
3423+
3424+
out.qs[i] = in[src_id].qs[src_offset] ^ xor_mask;
3425+
}
3426+
3427+
return out;
3428+
}
3429+
3430+
static int repack_q4_0_to_q4_0_4_bl(struct ggml_tensor * t, int interleave_block, const void * restrict data, size_t data_size) {
3431+
GGML_ASSERT(t->type == GGML_TYPE_Q4_0);
3432+
GGML_ASSERT(interleave_block == 4 || interleave_block == 8);
3433+
3434+
block_q4_0x4 * dst = (block_q4_0x4 *)t->data;
3435+
const block_q4_0 * src = (const block_q4_0 *)data;
3436+
block_q4_0 dst_tmp[4];
3437+
int nrow = t->ne[1]; // Number of rows
3438+
int nrows_interleaved = 4;
3439+
int nblocks = t->ne[0] / QK4_0;
3440+
3441+
GGML_ASSERT(data_size == nrow * nblocks * sizeof(block_q4_0));
3442+
3443+
if (nrow % nrows_interleaved != 0 || t->ne[0] % 8 != 0) {
3444+
return -1;
3445+
}
3446+
3447+
for (int b = 0; b < nrow; b += nrows_interleaved) {
3448+
for (int64_t x = 0; x < nblocks; x++) {
3449+
for (int i = 0; i < nrows_interleaved; i++) {
3450+
dst_tmp[i] = src[x + i * nblocks];
3451+
}
3452+
*dst++ = make_block_q4_0x4(dst_tmp, interleave_block, 0x88);
3453+
}
3454+
src += nrows_interleaved * nblocks;
3455+
}
3456+
return 0;
3457+
3458+
GGML_UNUSED(data_size);
3459+
}
3460+
3461+
static int repack_q4_0_to_q4_0_8_bl(struct ggml_tensor *t, int interleave_block, const void * restrict data, size_t data_size) {
3462+
GGML_ASSERT(t->type == GGML_TYPE_Q4_0);
3463+
GGML_ASSERT(interleave_block == 8);
3464+
3465+
block_q4_0x8 * dst = (block_q4_0x8*)t->data;
3466+
const block_q4_0 * src = (const block_q4_0*) data;
3467+
block_q4_0 dst_tmp[8];
3468+
int nrow = t->ne[1]; // Number of rows
3469+
int nrows_interleaved = 8;
3470+
int nblocks = t->ne[0] / QK4_0;
3471+
3472+
GGML_ASSERT(data_size == nrow * nblocks * sizeof(block_q4_0));
3473+
3474+
if (nrow % nrows_interleaved != 0 || t->ne[0] % 8 != 0) {
3475+
return -1;
3476+
}
3477+
3478+
for (int b = 0; b < nrow; b += nrows_interleaved) {
3479+
for (int64_t x = 0; x < nblocks; x++) {
3480+
for (int i = 0; i < nrows_interleaved; i++ ) {
3481+
dst_tmp[i] = src[x + i * nblocks];
3482+
}
3483+
*dst++ = make_block_q4_0x8(dst_tmp, interleave_block, 0x88);
3484+
}
3485+
src += nrows_interleaved * nblocks;
3486+
}
3487+
return 0;
3488+
3489+
GGML_UNUSED(data_size);
3490+
}
3491+
3492+
// Prepare for optimized kernels if applicable
3493+
void ggml_aarch64_repack_tensor(struct ggml_tensor * cur, enum ggml_type repack_type, const void * restrict data, size_t data_size) {
3494+
if (cur->type == repack_type) {
3495+
memcpy(cur->data, data, data_size);
3496+
return;
3497+
}
3498+
3499+
GGML_ASSERT(cur->type == GGML_TYPE_Q4_0);
3500+
3501+
switch (repack_type) {
3502+
case GGML_TYPE_Q4_0_8_8:
3503+
repack_q4_0_to_q4_0_8_bl(cur, 8, data, data_size);
3504+
break;
3505+
case GGML_TYPE_Q4_0_4_8:
3506+
repack_q4_0_to_q4_0_4_bl(cur, 8, data, data_size);
3507+
break;
3508+
case GGML_TYPE_Q4_0_4_4:
3509+
repack_q4_0_to_q4_0_4_bl(cur, 4, data, data_size);
3510+
break;
3511+
default:
3512+
GGML_ABORT("Unsupported type");
3513+
}
3514+
}
3515+
3516+
enum ggml_type ggml_aarch64_get_optimal_repack_type(const struct ggml_tensor * cur) {
3517+
if (cur->type == GGML_TYPE_Q4_0) {
3518+
// TODO: enable for AVX2 - currently disabled due to bad gemv performance
3519+
if (/* ggml_cpu_has_avx2() || */ (ggml_cpu_has_sve() && ggml_cpu_has_matmul_int8() && ggml_cpu_get_sve_cnt() == QK8_0)) {
3520+
return GGML_TYPE_Q4_0_8_8;
3521+
}
3522+
if (ggml_cpu_has_neon() && ggml_cpu_has_matmul_int8()) {
3523+
return GGML_TYPE_Q4_0_4_8;
3524+
}
3525+
if (ggml_cpu_has_neon()) {
3526+
return GGML_TYPE_Q4_0_4_4;
3527+
}
3528+
}
3529+
3530+
return cur->type;
3531+
}

ggml/src/ggml-cpu/ggml-cpu-aarch64.h

+3
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,9 @@ void ggml_gemm_q4_0_4x4_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const vo
2121
void ggml_gemm_q4_0_4x8_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc);
2222
void ggml_gemm_q4_0_8x8_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc);
2323

24+
void ggml_aarch64_repack_tensor(struct ggml_tensor * cur, enum ggml_type repack_type, const void * data, size_t data_size);
25+
enum ggml_type ggml_aarch64_get_optimal_repack_type(const struct ggml_tensor * cur);
26+
2427
#ifdef __cplusplus
2528
}
2629
#endif

ggml/src/ggml-cpu/ggml-cpu.c

+13-10
Original file line numberDiff line numberDiff line change
@@ -7330,6 +7330,7 @@ static void ggml_compute_forward_group_norm(
73307330
static void ggml_compute_forward_mul_mat_one_chunk(
73317331
const struct ggml_compute_params * params,
73327332
struct ggml_tensor * dst,
7333+
const enum ggml_type type,
73337334
const int64_t num_rows_per_vec_dot,
73347335
const int64_t ir0_start,
73357336
const int64_t ir0_end,
@@ -7341,8 +7342,6 @@ static void ggml_compute_forward_mul_mat_one_chunk(
73417342

73427343
GGML_TENSOR_BINARY_OP_LOCALS
73437344

7344-
const enum ggml_type type = src0->type;
7345-
73467345
const bool src1_cont = ggml_is_contiguous(src1);
73477346

73487347
ggml_vec_dot_t const vec_dot = type_traits_cpu[type].vec_dot;
@@ -7430,7 +7429,11 @@ static void ggml_compute_forward_mul_mat(
74307429
const int ith = params->ith;
74317430
const int nth = params->nth;
74327431

7433-
const enum ggml_type type = src0->type;
7432+
enum ggml_type type = src0->type;
7433+
7434+
if (src0->buffer && ggml_backend_cpu_buft_is_aarch64(src0->buffer->buft)) {
7435+
type = (enum ggml_type)(intptr_t)src0->extra;
7436+
}
74347437

74357438
enum ggml_type const vec_dot_type = type_traits_cpu[type].vec_dot_type;
74367439
ggml_from_float_t const from_float = type_traits_cpu[vec_dot_type].from_float;
@@ -7469,15 +7472,15 @@ static void ggml_compute_forward_mul_mat(
74697472
if (src1_cont) {
74707473
for (int64_t i13 = 0; i13 < ne13; i13++)
74717474
for (int64_t i12 = 0; i12 < ne12; i12++)
7472-
if (!llamafile_sgemm(ne01, ne11, ne00/ggml_blck_size(src0->type),
7475+
if (!llamafile_sgemm(ne01, ne11, ne00/ggml_blck_size(type),
74737476
(const char *)src0->data + i12/r2*nb02 + i13/r3*nb03,
7474-
nb01/ggml_type_size(src0->type),
7477+
nb01/ggml_type_size(type),
74757478
(const char *)src1->data + i12*nb12 + i13*nb13,
74767479
nb11/ggml_type_size(src1->type),
74777480
(char *)dst->data + i12*nb2 + i13*nb3,
74787481
nb1/ggml_type_size(dst->type),
74797482
ith, nth,
7480-
src0->type,
7483+
type,
74817484
src1->type,
74827485
dst->type))
74837486
goto UseGgmlGemm1;
@@ -7530,15 +7533,15 @@ UseGgmlGemm1:;
75307533

75317534
for (int64_t i13 = 0; i13 < ne13; i13++)
75327535
for (int64_t i12 = 0; i12 < ne12; i12++)
7533-
if (!llamafile_sgemm(ne01, ne11, ne00/ggml_blck_size(src0->type),
7536+
if (!llamafile_sgemm(ne01, ne11, ne00/ggml_blck_size(type),
75347537
(const char *)src0->data + i12/r2*nb02 + i13/r3*nb03,
7535-
nb01/ggml_type_size(src0->type),
7538+
nb01/ggml_type_size(type),
75367539
(const char *)wdata + (i12*ne11 + i13*ne12*ne11)*row_size,
75377540
row_size/ggml_type_size(vec_dot_type),
75387541
(char *)dst->data + i12*nb2 + i13*nb3,
75397542
nb1/ggml_type_size(dst->type),
75407543
ith, nth,
7541-
src0->type,
7544+
type,
75427545
vec_dot_type,
75437546
dst->type))
75447547
goto UseGgmlGemm2;
@@ -7623,7 +7626,7 @@ UseGgmlGemm2:;
76237626
const int64_t ir1_start = dr1 * ith1;
76247627
const int64_t ir1_end = MIN(ir1_start + dr1, nr1);
76257628

7626-
ggml_compute_forward_mul_mat_one_chunk(params, dst, num_rows_per_vec_dot, ir0_start, ir0_end, ir1_start, ir1_end);
7629+
ggml_compute_forward_mul_mat_one_chunk(params, dst, type, num_rows_per_vec_dot, ir0_start, ir0_end, ir1_start, ir1_end);
76277630

76287631
if (nth >= nchunk0 * nchunk1) {
76297632
break;

0 commit comments

Comments
 (0)