Skip to content

Commit ed7cbb1

Browse files
ikawrakowKawrakow
authored andcommitted
IQ1_M: 1.75 bpw quantization (ggml-org#6302)
* iq1_m: basics * iq1_m: basics-2 * iq1_m: CUDA dequantize works Very 1st shot I get PPL = 9.76 for LLaMA-v2-7B. * iq1_m: separate shifts for each group of 8 in a block We get PPL(LLaMA-v2-7B ) = 9.2810 PPL(LLaMA-v2-13B) = 6.8105 Not bad, but slightly higher than sqrt(PPL(IQ1_S) * PPL(IQ2_XXS)) which is the expected outcome given that IQ1_M is halfway between IQ1_S and IQ2_XXS in terms of bpw. From this, we would expect PPL = 9.14 for LLaMA-v2-7B PPL = 6.63 for LLaMA-v2-13B * iq1_m: go to 3-bit scales There is slight increase in PPL, but the 0.0625 bpw reduction in size is totally worth it. We now have PPL(LLaMA-v2-7B ) = 9.4469 at 1.96 bpw PPL(LLaMA-v2-13B) = 6.8717 at 1.93 bpw PPL(LLaMA-v2-70B) = 4.8568 at 1.85 bpw * iq1_m: scalar dot product * iq1_m: AVX2 dot product * iq1_m: very slightly faster AVX2 dot product * iq1_m: ARM_NEON dot product Works, but very slow (10.5 t/s) * iq1_m: Metal - dequantize works, dot product does not * iq1_m: Metal now works About the same performance as iq1_s. * iq1_m: minor * iq1_m: checking pure iq1_m quantization It is pretty bad: PPL(LLaMA-v2-7B) = 34 if we quantize output.weight with Q4_K. * iiq1_m: slightly faster ARM_NEON dot product 10.5 t/s -> 11.65 t/s * iq1_m: faster ARM_NEON dot product 11.65 t/s -> 14.9 t/s * iq1_m: another minor ARM_NEON dot product improvement 14.9 -> 15.0 t/s * iq1_m: small PPL improvement via super-block scale adjustment After quantizing block scales redo the super-block scale fit. PPL(LLaMA-v2-7B ) = 9.3346 PPL(LLaMA-v2-13B) = 6.8419 PPL(LLaMA-v2-70B) = 4.8294 PPL(Mistral-7B ) = 8.1624 * iq1_m: adapt to CUDA refactoring * iq1_m: remove unused variable We have progressed to warnings being errors. * iq1_m: add to backend-ops tests * iq1_m: fix Windows ARM * iq1_m: use common definition of iq1m_scale_t * cuda: assert -> NO_DEVICE_CODE * iq1_M: PR comments --------- Co-authored-by: Iwan Kawrakow <[email protected]>
1 parent 63687c1 commit ed7cbb1

16 files changed

+1006
-125
lines changed

examples/quantize/quantize.cpp

+7-4
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,7 @@ static const std::vector<struct quant_option> QUANT_OPTIONS = {
2626
{ "IQ2_S", LLAMA_FTYPE_MOSTLY_IQ2_S, " 2.5 bpw quantization", },
2727
{ "IQ2_M", LLAMA_FTYPE_MOSTLY_IQ2_M, " 2.7 bpw quantization", },
2828
{ "IQ1_S", LLAMA_FTYPE_MOSTLY_IQ1_S, " 1.56 bpw quantization", },
29+
{ "IQ1_M", LLAMA_FTYPE_MOSTLY_IQ1_M, " 1.75 bpw quantization", },
2930
{ "Q2_K", LLAMA_FTYPE_MOSTLY_Q2_K, " 2.63G, +0.6717 ppl @ LLaMA-v1-7B", },
3031
{ "Q2_K_S", LLAMA_FTYPE_MOSTLY_Q2_K_S, " 2.16G, +9.0634 ppl @ LLaMA-v1-7B", },
3132
{ "IQ3_XXS",LLAMA_FTYPE_MOSTLY_IQ3_XXS," 3.06 bpw quantization", },
@@ -370,10 +371,12 @@ int main(int argc, char ** argv) {
370371

371372
if ((params.ftype == LLAMA_FTYPE_MOSTLY_IQ2_XS || params.ftype == LLAMA_FTYPE_MOSTLY_IQ2_XXS ||
372373
params.ftype == LLAMA_FTYPE_MOSTLY_IQ2_S ||
373-
params.ftype == LLAMA_FTYPE_MOSTLY_Q2_K_S || params.ftype == LLAMA_FTYPE_MOSTLY_IQ1_S) && imatrix_data.empty()) {
374-
fprintf(stderr, "\n===============================================================================================\n");
375-
fprintf(stderr, "Please do not use IQ1_S, IQ2_XXS, IQ2_XS or Q2_K_S quantization without an importance matrix\n");
376-
fprintf(stderr, "===============================================================================================\n\n\n");
374+
params.ftype == LLAMA_FTYPE_MOSTLY_Q2_K_S ||
375+
params.ftype == LLAMA_FTYPE_MOSTLY_IQ1_S ||
376+
params.ftype == LLAMA_FTYPE_MOSTLY_IQ1_M) && imatrix_data.empty()) {
377+
fprintf(stderr, "\n==========================================================================================================\n");
378+
fprintf(stderr, "Please do not use IQ1_S, IQ1_M, IQ2_S, IQ2_XXS, IQ2_XS or Q2_K_S quantization without an importance matrix\n");
379+
fprintf(stderr, "==========================================================================================================\n\n\n");
377380
return 1;
378381
}
379382

ggml-common.h

+15
Original file line numberDiff line numberDiff line change
@@ -377,6 +377,20 @@ typedef struct {
377377
} block_iq1_s;
378378
static_assert(sizeof(block_iq1_s) == sizeof(ggml_half) + QK_K/8 + QK_K/16, "wrong iq1_s block size/padding");
379379

380+
// 1.8125 bpw
381+
typedef struct {
382+
uint8_t qs[QK_K/8]; // grid index, low 8 bits
383+
uint8_t qh[QK_K/16]; // grid index, high 3 bits + grid shift bit (for two groups of 8)
384+
uint8_t scales[QK_K/32]; // 4-bit block scales
385+
} block_iq1_m;
386+
static_assert(sizeof(block_iq1_m) == QK_K/8 + QK_K/16 + QK_K/32, "wrong iq1_m block size/padding");
387+
388+
// Used by IQ1_M quants
389+
typedef union {
390+
ggml_half f16;
391+
uint16_t u16;
392+
} iq1m_scale_t;
393+
380394
// Non-linear quants
381395
#define QK4_NL 32
382396
typedef struct {
@@ -1050,6 +1064,7 @@ GGML_TABLE_END()
10501064

10511065
#define NGRID_IQ1S 2048
10521066
#define IQ1S_DELTA 0.125f
1067+
#define IQ1M_DELTA 0.125f
10531068
#if defined(GGML_COMMON_IMPL_C)
10541069
GGML_TABLE_BEGIN(uint64_t, iq1s_grid, NGRID_IQ1S)
10551070
0xffffffffffffffff, 0xffffffffffffff01, 0xffffffffffff0000, 0xffffffffffff01ff,

ggml-cuda.cu

+3-1
Original file line numberDiff line numberDiff line change
@@ -615,6 +615,7 @@ static int64_t get_row_rounding(ggml_type type, const std::array<float, GGML_CUD
615615
case GGML_TYPE_IQ2_S:
616616
case GGML_TYPE_IQ3_XXS:
617617
case GGML_TYPE_IQ1_S:
618+
case GGML_TYPE_IQ1_M:
618619
case GGML_TYPE_IQ4_NL:
619620
case GGML_TYPE_IQ4_XS:
620621
case GGML_TYPE_IQ3_S:
@@ -643,6 +644,7 @@ static int64_t get_row_rounding(ggml_type type, const std::array<float, GGML_CUD
643644
case GGML_TYPE_IQ2_S:
644645
case GGML_TYPE_IQ3_XXS:
645646
case GGML_TYPE_IQ1_S:
647+
case GGML_TYPE_IQ1_M:
646648
case GGML_TYPE_IQ4_NL:
647649
case GGML_TYPE_IQ4_XS:
648650
case GGML_TYPE_IQ3_S:
@@ -2560,7 +2562,7 @@ GGML_CALL static bool ggml_backend_cuda_supports_op(ggml_backend_t backend, cons
25602562
ggml_type a_type = a->type;
25612563
if (a_type == GGML_TYPE_IQ2_XXS || a_type == GGML_TYPE_IQ2_XS || a_type == GGML_TYPE_IQ3_XXS ||
25622564
a_type == GGML_TYPE_IQ1_S || a_type == GGML_TYPE_IQ4_NL || a_type == GGML_TYPE_IQ3_S ||
2563-
a_type == GGML_TYPE_IQ2_S || a_type == GGML_TYPE_IQ4_XS) {
2565+
a_type == GGML_TYPE_IQ1_M || a_type == GGML_TYPE_IQ2_S || a_type == GGML_TYPE_IQ4_XS) {
25642566
if (b->ne[1] == 1 && ggml_nrows(b) > 1) {
25652567
return false;
25662568
}

ggml-cuda/convert.cu

+47-6
Original file line numberDiff line numberDiff line change
@@ -373,7 +373,7 @@ static __global__ void dequantize_block_iq2_xxs(const void * __restrict__ vx, ds
373373
const uint8_t signs = ksigns_iq2xs[(aux32 >> 7*il) & 127];
374374
for (int j = 0; j < 8; ++j) y[j] = d * grid[j] * (signs & kmask_iq2xs[j] ? -1.f : 1.f);
375375
#else
376-
assert(false);
376+
NO_DEVICE_CODE;
377377
#endif
378378

379379
}
@@ -395,7 +395,7 @@ static __global__ void dequantize_block_iq2_xs(const void * __restrict__ vx, dst
395395
const uint8_t signs = ksigns_iq2xs[q2[il] >> 9];
396396
for (int j = 0; j < 8; ++j) y[j] = d * grid[j] * (signs & kmask_iq2xs[j] ? -1.f : 1.f);
397397
#else
398-
assert(false);
398+
NO_DEVICE_CODE;
399399
#endif
400400

401401
}
@@ -416,7 +416,7 @@ static __global__ void dequantize_block_iq2_s(const void * __restrict__ vx, dst_
416416
const uint8_t signs = x[i].qs[QK_K/8+4*ib+il];
417417
for (int j = 0; j < 8; ++j) y[j] = d * grid[j] * (signs & kmask_iq2xs[j] ? -1.f : 1.f);
418418
#else
419-
assert(false);
419+
NO_DEVICE_CODE;
420420
#endif
421421

422422
}
@@ -444,7 +444,7 @@ static __global__ void dequantize_block_iq3_xxs(const void * __restrict__ vx, ds
444444
y[j+4] = d * grid2[j] * (signs & kmask_iq2xs[j+4] ? -1.f : 1.f);
445445
}
446446
#else
447-
assert(false);
447+
NO_DEVICE_CODE;
448448
#endif
449449

450450
}
@@ -470,7 +470,7 @@ static __global__ void dequantize_block_iq3_s(const void * __restrict__ vx, dst_
470470
y[j+4] = d * grid2[j] * (signs & kmask_iq2xs[j+4] ? -1.f : 1.f);
471471
}
472472
#else
473-
assert(false);
473+
NO_DEVICE_CODE;
474474
#endif
475475

476476
}
@@ -496,11 +496,42 @@ static __global__ void dequantize_block_iq1_s(const void * __restrict__ vx, dst_
496496
y[j] = d * (q[j] + delta);
497497
}
498498
#else
499-
assert(false);
499+
NO_DEVICE_CODE;
500+
#endif
501+
502+
}
503+
504+
template<typename dst_t>
505+
static __global__ void dequantize_block_iq1_m(const void * __restrict__ vx, dst_t * __restrict__ yy) {
506+
507+
const int i = blockIdx.x;
508+
const block_iq1_m * x = (const block_iq1_m *) vx;
509+
510+
const int tid = threadIdx.x;
511+
#if QK_K == 256
512+
const int il = tid/8; // 0...3
513+
const int ib = tid%8; // 0...7
514+
dst_t * y = yy + i*QK_K + 32*ib + 8*il;
515+
const uint16_t * sc = (const uint16_t *)x[i].scales;
516+
iq1m_scale_t scale;
517+
scale.u16 = (sc[0] >> 12) | ((sc[1] >> 8) & 0x00f0) | ((sc[2] >> 4) & 0x0f00) | (sc[3] & 0xf000);
518+
const int ib16 = 2*ib + il/2; // sc[ib16/4] >> 3*(ib16%4) -> sc[ib/2] >> 3*((2*ib+il/2)%4);
519+
const float d = (float)scale.f16 * (2*((sc[ib16/4] >> 3*(ib16%4)) & 0x7) + 1);
520+
const float delta = x[i].qh[2*ib+il/2] & (0x08 << 4*(il%2)) ? -1 - IQ1M_DELTA : -1 + IQ1M_DELTA;
521+
uint32_t grid32[2]; const int8_t * q = (const int8_t *)grid32;
522+
grid32[0] = iq1s_grid_gpu[x[i].qs[4*ib+il] | (((x[i].qh[2*ib+il/2] >> 4*(il%2)) & 7) << 8)];
523+
grid32[1] = (grid32[0] >> 4) & 0x0f0f0f0f;
524+
grid32[0] &= 0x0f0f0f0f;
525+
for (int j = 0; j < 8; ++j) {
526+
y[j] = d * (q[j] + delta);
527+
}
528+
#else
529+
NO_DEVICE_CODE;
500530
#endif
501531

502532
}
503533

534+
504535
template<typename dst_t>
505536
static __global__ void dequantize_block_iq4_nl(const void * __restrict__ vx, dst_t * __restrict__ yy) {
506537

@@ -658,6 +689,12 @@ static void dequantize_row_iq4_nl_cuda(const void * vx, dst_t * y, const int k,
658689
dequantize_block_iq4_nl<<<nb, 32, 0, stream>>>(vx, y);
659690
}
660691

692+
template<typename dst_t>
693+
static void dequantize_row_iq1_m_cuda(const void * vx, dst_t * y, const int k, cudaStream_t stream) {
694+
const int nb = k / QK_K;
695+
dequantize_block_iq1_m<<<nb, 32, 0, stream>>>(vx, y);
696+
}
697+
661698
template<typename dst_t>
662699
static void dequantize_row_iq4_xs_cuda(const void * vx, dst_t * y, const int k, cudaStream_t stream) {
663700
const int nb = (k + QK_K - 1) / QK_K;
@@ -724,6 +761,8 @@ to_fp16_cuda_t ggml_get_to_fp16_cuda(ggml_type type) {
724761
return dequantize_row_iq3_xxs_cuda;
725762
case GGML_TYPE_IQ1_S:
726763
return dequantize_row_iq1_s_cuda;
764+
case GGML_TYPE_IQ1_M:
765+
return dequantize_row_iq1_m_cuda;
727766
case GGML_TYPE_IQ4_NL:
728767
return dequantize_row_iq4_nl_cuda;
729768
case GGML_TYPE_IQ4_XS:
@@ -769,6 +808,8 @@ to_fp32_cuda_t ggml_get_to_fp32_cuda(ggml_type type) {
769808
return dequantize_row_iq3_xxs_cuda;
770809
case GGML_TYPE_IQ1_S:
771810
return dequantize_row_iq1_s_cuda;
811+
case GGML_TYPE_IQ1_M:
812+
return dequantize_row_iq1_m_cuda;
772813
case GGML_TYPE_IQ4_NL:
773814
return dequantize_row_iq4_nl_cuda;
774815
case GGML_TYPE_IQ4_XS:

ggml-cuda/mmvq.cu

+11
Original file line numberDiff line numberDiff line change
@@ -282,6 +282,14 @@ static void mul_mat_vec_iq1_s_q8_1_cuda(
282282
(vx, vy, dst, ncols_x, nrows_x, nrows_y, ncols_y, nrows_dst, stream);
283283
}
284284

285+
static void mul_mat_vec_iq1_m_q8_1_cuda(
286+
const void * vx, const void * vy, float * dst,
287+
const int ncols_x, const int nrows_x, const int nrows_y, const int ncols_y, const int nrows_dst, cudaStream_t stream) {
288+
289+
mul_mat_vec_q_cuda<QK_K, QI1_S, block_iq1_m, 1, vec_dot_iq1_m_q8_1>
290+
(vx, vy, dst, ncols_x, nrows_x, nrows_y, ncols_y, nrows_dst, stream);
291+
}
292+
285293
static void mul_mat_vec_iq4_nl_q8_1_cuda(
286294
const void * vx, const void * vy, float * dst,
287295
const int ncols_x, const int nrows_x, const int nrows_y, const int ncols_y, const int nrows_dst, cudaStream_t stream) {
@@ -373,6 +381,9 @@ void ggml_cuda_op_mul_mat_vec_q(
373381
case GGML_TYPE_IQ1_S:
374382
mul_mat_vec_iq1_s_q8_1_cuda(src0_dd_i, src1_ddq_i, dst_dd_i, ne00, row_diff, src1_padded_row_size, src1_ncols, nrows_dst, stream);
375383
break;
384+
case GGML_TYPE_IQ1_M:
385+
mul_mat_vec_iq1_m_q8_1_cuda(src0_dd_i, src1_ddq_i, dst_dd_i, ne00, row_diff, src1_padded_row_size, src1_ncols, nrows_dst, stream);
386+
break;
376387
case GGML_TYPE_IQ4_NL:
377388
mul_mat_vec_iq4_nl_q8_1_cuda(src0_dd_i, src1_ddq_i, dst_dd_i, ne00, row_diff, src1_padded_row_size, src1_ncols, nrows_dst, stream);
378389
break;

0 commit comments

Comments
 (0)