Skip to content

Commit a9f302e

Browse files
ikawrakowKawrakow
andauthored
Adding IQ2_TN for use with ternary models (#13)
* iq2_tn: TriLM specific 2.0625 bpw quantization Quantize/dequantize/scale dot product. I get 46 t/s for the TriLM-3.9B with any SIMD! Finally a compiler doing a decent job auto-vectorizing the scalar implementation. * iq2_tn: AVX512 Just reusing the k-quants template gets us to PP-512 = 376 t/s, TG-128 = 47.6 t/s for TriLM-3.9B. * iq2_tn: AVX512 With this tweak we get to PP-512 = 431 t/s. * iq2_tn: AVX512 With this tweak we get TG-128 = 19.58 / 35.18 t/s for 1 / 2 threads. At 4 threads we saturate at 48.41 t/s, and then performance slowly degrades with increasing number of threads. * iq2_tn: AVX2 PP512 = 440 t/s on the Ryzen-5975WX. We should be able to do better. * iq2_tn: initial NEON version * iq2_tn: NEON For TriLM-3.9B running on the M2-Max we get PP-512 = 193.5 t/s, TG-128 = 75.5 t/s. This is in line with what we have for iq2_bn ant 3.3B Bitnet. * iq2_tn: Metal For TriLM-3.9B on a 30-core M2-Max we get PP-512 = 890 t/s, TG-128 = 98.5 t/s. * iq2_tn: CUDA For TriLM-3.9B running on RTX-4080 we get PP-512 = 9936 t/s, TG-128 = 299.2 t/s. * iq2_tn: AVX2 PP improvement We now get PP-512 = 490.73 t/s for TriLM-3.9B on the Ryzen-5975WX. We have PP-512 = 636.61 t/s for Bintnet-3B quantized with iq2_bn. Bintnet-3B is actually 3.4B, TriLM-3.9B is 3.99B, so we would expect 3.43/3.99 * 636 = 546 t/s, so it seems we still have something that is not quite optimal in iq2_tn. * iq2_tn: small NEON improvement For TriLM-3.9B we now get PP-512 = 206.6 t/s and TG-128 = 76.4 t/s. --------- Co-authored-by: Iwan Kawrakow <[email protected]>
1 parent b409c15 commit a9f302e

File tree

18 files changed

+718
-20
lines changed

18 files changed

+718
-20
lines changed

examples/quantize/quantize.cpp

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,7 @@ static const std::vector<struct quant_option> QUANT_OPTIONS = {
2828
{ "IQ1_M", LLAMA_FTYPE_MOSTLY_IQ1_M, " 1.75 bpw quantization", },
2929
{ "IQ1_BN", LLAMA_FTYPE_MOSTLY_IQ1_BN, " 1.62 bpw quantization (Bitnet)", },
3030
{ "IQ2_BN", LLAMA_FTYPE_MOSTLY_IQ2_BN, " 2.00 bpw quantization (Bitnet)", },
31+
{ "IQ2_TN", LLAMA_FTYPE_MOSTLY_IQ2_TN, " 2.06 bpw quantization (TriLM)", },
3132
{ "Q2_K", LLAMA_FTYPE_MOSTLY_Q2_K, " 2.63G, +0.6717 ppl @ LLaMA-v1-7B", },
3233
{ "Q2_K_S", LLAMA_FTYPE_MOSTLY_Q2_K_S, " 2.16G, +9.0634 ppl @ LLaMA-v1-7B", },
3334
{ "IQ3_XXS", LLAMA_FTYPE_MOSTLY_IQ3_XXS, " 3.06 bpw quantization", },

ggml/include/ggml.h

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -393,6 +393,7 @@ extern "C" {
393393
GGML_TYPE_IQ3_K = 38,
394394
GGML_TYPE_IQ4_K = 39,
395395
GGML_TYPE_IQ5_K = 40,
396+
GGML_TYPE_IQ2_TN = 41,
396397
GGML_TYPE_COUNT,
397398
};
398399

@@ -443,6 +444,7 @@ extern "C" {
443444
GGML_FTYPE_MOSTLY_IQ3_K = 31, // except 1d tensors
444445
GGML_FTYPE_MOSTLY_IQ4_K = 32, // except 1d tensors
445446
GGML_FTYPE_MOSTLY_IQ5_K = 33, // except 1d tensors
447+
GGML_FTYPE_MOSTLY_IQ2_TN = 34, // except 1d tensors
446448
};
447449

448450
// available tensor operations:

ggml/src/ggml-common.h

Lines changed: 10 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -407,7 +407,7 @@ typedef struct {
407407
static_assert(sizeof(block_iq1_m) == QK_K/8 + QK_K/16 + QK_K/32, "wrong iq1_m block size/padding");
408408

409409
//
410-
// Bitnet - implemented as 1.75 bpw
410+
// Bitnet - implemented as 1.625 bpw
411411
// The block scale is a waste, but it allows us to plug it in without any additional
412412
// changes to ggml.
413413
//
@@ -418,13 +418,21 @@ typedef struct {
418418
} block_iq1_bn;
419419
static_assert(sizeof(block_iq1_bn) == 13, "wrong iq1_bn block size/padding");
420420
//
421-
// Bitnet - implemented as 2.25 bpw
421+
// Bitnet - implemented as 2.0 bpw
422422
//
423423
#define QK_IQ2BN 64
424424
typedef struct {
425425
uint8_t qs[QK_IQ2BN/4];
426426
} block_iq2_bn;
427427
static_assert(sizeof(block_iq2_bn) == QK_IQ2BN/4, "wrong iq2_bn block size/padding");
428+
//
429+
// TriLM - implemented as 2.0625 bpw
430+
//
431+
typedef struct {
432+
ggml_half d;
433+
uint8_t qs[QK_K/4];
434+
} block_iq2_tn;
435+
static_assert(sizeof(block_iq2_tn) == sizeof(ggml_half) + QK_K/4, "wrong iqt_bn block size/padding");
428436

429437
// Used by IQ1_M quants
430438
typedef union {

ggml/src/ggml-cuda.cu

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2759,6 +2759,7 @@ GGML_CALL static bool ggml_backend_cuda_supports_op(ggml_backend_t backend, cons
27592759
case GGML_TYPE_IQ5_K:
27602760
case GGML_TYPE_IQ1_BN:
27612761
case GGML_TYPE_IQ2_BN:
2762+
case GGML_TYPE_IQ2_TN:
27622763
return true;
27632764
default:
27642765
return false;

ggml/src/ggml-cuda/common.cuh

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -655,6 +655,13 @@ struct ggml_cuda_type_traits<GGML_TYPE_IQ2_BN> {
655655
static constexpr int qi = QI1_BN;
656656
};
657657

658+
template<>
659+
struct ggml_cuda_type_traits<GGML_TYPE_IQ2_TN> {
660+
static constexpr int qk = QK_K;
661+
static constexpr int qr = QR2_K;
662+
static constexpr int qi = QI2_K;
663+
};
664+
658665
template<>
659666
struct ggml_cuda_type_traits<GGML_TYPE_IQ4_NL> {
660667
static constexpr int qk = QK4_NL;

ggml/src/ggml-cuda/convert.cu

Lines changed: 31 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -153,6 +153,27 @@ static __global__ void dequantize_block_q2_K(const void * __restrict__ vx, dst_t
153153
y[l+96] = dall * (x[i].scales[is+6] & 0xF) * ((q >> 6) & 3) - dmin * (x[i].scales[is+6] >> 4);
154154
}
155155

156+
template<typename dst_t>
157+
static __global__ void dequantize_block_iq2_tn(const void * __restrict__ vx, dst_t * __restrict__ yy) {
158+
159+
const int64_t i = blockIdx.x;
160+
const block_iq2_tn * x = (const block_iq2_tn *) vx;
161+
162+
const int64_t tid = threadIdx.x;
163+
const int64_t n = tid/32;
164+
const int64_t l = tid - 32*n;
165+
const int64_t is = 8*n + l/16;
166+
167+
const uint8_t q = x[i].qs[32*n + l];
168+
dst_t * y = yy + i*QK_K + 128*n;
169+
170+
float d = __half2float(x[i].d);
171+
y[l+ 0] = d * ((q >> 0) & 3) - d;
172+
y[l+32] = d * ((q >> 2) & 3) - d;
173+
y[l+64] = d * ((q >> 4) & 3) - d;
174+
y[l+96] = d * ((q >> 6) & 3) - d;
175+
}
176+
156177
template<typename dst_t>
157178
static __global__ void dequantize_block_q3_K(const void * __restrict__ vx, dst_t * __restrict__ yy) {
158179

@@ -646,6 +667,12 @@ static void dequantize_row_q2_K_cuda(const void * vx, dst_t * y, const int64_t k
646667
dequantize_block_q2_K<<<nb, 64, 0, stream>>>(vx, y);
647668
}
648669

670+
template<typename dst_t>
671+
static void dequantize_row_iq2_tn_cuda(const void * vx, dst_t * y, const int64_t k, cudaStream_t stream) {
672+
const int nb = k / QK_K;
673+
dequantize_block_iq2_tn<<<nb, 64, 0, stream>>>(vx, y);
674+
}
675+
649676
template<typename dst_t>
650677
static void dequantize_row_q3_K_cuda(const void * vx, dst_t * y, const int64_t k, cudaStream_t stream) {
651678
const int nb = k / QK_K;
@@ -812,6 +839,8 @@ to_fp16_cuda_t ggml_get_to_fp16_cuda(ggml_type type) {
812839
return dequantize_block_cuda<QK8_0, QR8_0, dequantize_q8_0>;
813840
case GGML_TYPE_Q2_K:
814841
return dequantize_row_q2_K_cuda;
842+
case GGML_TYPE_IQ2_TN:
843+
return dequantize_row_iq2_tn_cuda;
815844
case GGML_TYPE_Q3_K:
816845
return dequantize_row_q3_K_cuda;
817846
case GGML_TYPE_Q4_K:
@@ -871,6 +900,8 @@ to_fp32_cuda_t ggml_get_to_fp32_cuda(ggml_type type) {
871900
return dequantize_block_cuda<QK8_0, QR8_0, dequantize_q8_0>;
872901
case GGML_TYPE_Q2_K:
873902
return dequantize_row_q2_K_cuda;
903+
case GGML_TYPE_IQ2_TN:
904+
return dequantize_row_iq2_tn_cuda;
874905
case GGML_TYPE_Q3_K:
875906
return dequantize_row_q3_K_cuda;
876907
case GGML_TYPE_Q4_K:

ggml/src/ggml-cuda/iqk_mmvq.cu

Lines changed: 42 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -469,6 +469,41 @@ __device__ __forceinline__ float vec_dot_iq3_k_q8_1(
469469

470470
}
471471

472+
#define VDR_IQ2_TN_Q8_1_MMVQ 1
473+
#define VDR_IQ2_TN_Q8_1_MMQ 4
474+
475+
static __device__ __forceinline__ float vec_dot_iq2_tn_q8_1(
476+
const void * __restrict__ vbq, const block_q8_1 * __restrict__ bq8_1, const int & kbx, const int & iqs) {
477+
478+
const block_iq2_tn * bq2 = (const block_iq2_tn *) vbq + kbx;
479+
480+
const int bq8_offset = QR2_K * (iqs / QI8_1);
481+
482+
const uint16_t * q16 = (const uint16_t *)bq2->qs + 2*iqs;
483+
int v = q16[0] | (q16[1] << 16);
484+
485+
float sumf = 0;
486+
for (int i = 0; i < QR2_K; ++ i) {
487+
int u = *((const int *)bq8_1[bq8_offset + i].qs + iqs % QI8_1);
488+
float d8 = __low2float(bq8_1[bq8_offset + i].ds);
489+
sumf += d8 * (ggml_cuda_dp4a(v & 0x03030303, u, 0) - ggml_cuda_dp4a(0x01010101, u, 0));
490+
v >>= 2;
491+
}
492+
return __half2float(bq2->d) * sumf;
493+
494+
//float sumf_d = 0;
495+
//float sumf_m = 0;
496+
//for (int i = 0; i < QR2_K; ++ i) {
497+
// int u = *((const int *)bq8_1[bq8_offset + i].qs + iqs % QI8_1);
498+
// float2 d8 = __half22float2(bq8_1[bq8_offset + i].ds);
499+
// sumf_d += d8.x * ggml_cuda_dp4a(v & 0x03030303, u, 0);
500+
// sumf_m += d8.y;
501+
// v >>= 2;
502+
//}
503+
//return __half2float(bq2->d) * (sumf_d - 0.125f * sumf_m);
504+
505+
}
506+
472507
} // namespace
473508

474509
void mul_mat_vec_iq2_k_q8_1_cuda(
@@ -499,3 +534,10 @@ void mul_mat_vec_iq5_k_q8_1_cuda(
499534
iqk_mul_mat_vec_q_cuda<GGML_TYPE_IQ5_K, VDR_IQ5_K_Q8_1_MMVQ, vec_dot_iq5_k_q8_1>(vx, vy, dst, ncols_x, nrows_x, nrows_y, ncols_y, nrows_dst, stream);
500535
}
501536

537+
void mul_mat_vec_iq2_tn_q8_1_cuda(
538+
const void * vx, const void * vy, float * dst,
539+
const int ncols_x, const int nrows_x, const int nrows_y, const int ncols_y, const int nrows_dst, cudaStream_t stream) {
540+
541+
iqk_mul_mat_vec_q_cuda<GGML_TYPE_IQ2_TN, VDR_IQ2_TN_Q8_1_MMVQ, vec_dot_iq2_tn_q8_1>(vx, vy, dst, ncols_x, nrows_x, nrows_y, ncols_y, nrows_dst, stream);
542+
}
543+

ggml/src/ggml-cuda/iqk_mmvq.cuh

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,3 +16,7 @@ void mul_mat_vec_iq5_k_q8_1_cuda(
1616
const void * vx, const void * vy, float * dst,
1717
const int ncols_x, const int nrows_x, const int nrows_y, const int ncols_y, const int nrows_dst, cudaStream_t stream);
1818

19+
void mul_mat_vec_iq2_tn_q8_1_cuda(
20+
const void * vx, const void * vy, float * dst,
21+
const int ncols_x, const int nrows_x, const int nrows_y, const int ncols_y, const int nrows_dst, cudaStream_t stream);
22+

ggml/src/ggml-cuda/mmvq.cu

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -426,6 +426,9 @@ void ggml_cuda_op_mul_mat_vec_q(
426426
case GGML_TYPE_IQ2_BN:
427427
mul_mat_vec_iq2_bn_q8_1_cuda(src0_dd_i, src1_ddq_i, dst_dd_i, ne00, row_diff, src1_padded_row_size, src1_ncols, nrows_dst, stream);
428428
break;
429+
case GGML_TYPE_IQ2_TN:
430+
mul_mat_vec_iq2_tn_q8_1_cuda(src0_dd_i, src1_ddq_i, dst_dd_i, ne00, row_diff, src1_padded_row_size, src1_ncols, nrows_dst, stream);
431+
break;
429432
case GGML_TYPE_IQ4_NL:
430433
mul_mat_vec_iq4_nl_q8_1_cuda(src0_dd_i, src1_ddq_i, dst_dd_i, ne00, row_diff, src1_padded_row_size, src1_ncols, nrows_dst, stream);
431434
break;

0 commit comments

Comments
 (0)