-
Notifications
You must be signed in to change notification settings - Fork 11.5k
ggml-cuda : add TQ2_0 kernels, for ternary inference on GPU #11183
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: master
Are you sure you want to change the base?
Changes from all commits
970b5ab
fb43d5e
983aa09
f5fddb6
946796f
b6fc9f0
fbddb26
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,5 @@ | ||
// This file has been autogenerated by generate_cu_files.py, do not edit manually. | ||
|
||
#include "../mmq.cuh" | ||
|
||
DECL_MMQ_CASE(GGML_TYPE_TQ2_0); |
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -524,6 +524,32 @@ static __device__ __forceinline__ float vec_dot_q6_K_q8_1_impl_mmq( | |
return d6 * sumf_d; | ||
} | ||
|
||
#define VDR_TQ2_0_Q8_1_MMVQ 2 | ||
#define VDR_TQ2_0_Q8_1_MMQ 8 | ||
|
||
// Can use the same for both mmvq and mmq, because there are no sub-scales in a TQ2_0 block | ||
template <int vdr> static __device__ __forceinline__ float vec_dot_tq2_0_q8_1_impl( | ||
const int * __restrict__ v, const int * __restrict__ u, const float & d2, const float * __restrict__ d8) { | ||
|
||
float sumf = 0.0f; | ||
|
||
#pragma unroll | ||
for (int i0 = 0; i0 < QR2_0; ++i0) { | ||
int sumi = 0; | ||
|
||
#pragma unroll | ||
for (int i = 0; i < vdr; ++i) { | ||
Comment on lines
+537
to
+541
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I think There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Right. I think I tried to use a similar nomenclature as some of the other functions in this file. But I agree, |
||
const int vi = (v[i] >> (2*i0)) & 0x03030303; | ||
|
||
sumi = ggml_cuda_dp4a(__vsub4(vi, 0x01010101), u[vdr*i0 + i], sumi); // SIMD dot product | ||
} | ||
|
||
sumf += d8[i0] * sumi; | ||
} | ||
|
||
return d2 * sumf; | ||
} | ||
|
||
static __device__ __forceinline__ float vec_dot_q4_0_q8_1( | ||
const void * __restrict__ vbq, const block_q8_1 * __restrict__ bq8_1, const int & kbx, const int & iqs) { | ||
|
||
|
@@ -786,6 +812,37 @@ static __device__ __forceinline__ float vec_dot_q6_K_q8_1( | |
return vec_dot_q6_K_q8_1_impl_mmvq(vl, vh, u, scales, bq6_K->d, d8); | ||
} | ||
|
||
static __device__ __forceinline__ float vec_dot_tq2_0_q8_1( | ||
const void * __restrict__ vbq, const block_q8_1 * __restrict__ bq8_1, const int & kbx, const int & iqs) { | ||
|
||
const block_tq2_0 * btq2_0 = (const block_tq2_0 *) vbq + kbx; | ||
|
||
// iqs 0..7 all need bq8_offset 0, 1, 2, 3 | ||
// iqs 8..15 all need bq8_offset 4, 5, 6, 7 | ||
const int bq8_offset = QR2_0 * (iqs / 8); | ||
|
||
int v[VDR_TQ2_0_Q8_1_MMVQ]; | ||
int u[QR2_0*VDR_TQ2_0_Q8_1_MMVQ]; | ||
float d8[QR2_0]; | ||
|
||
#pragma unroll | ||
for (int i = 0; i < VDR_TQ2_0_Q8_1_MMVQ; ++i) { | ||
v[i] = get_int_b2(btq2_0->qs, iqs + i); | ||
} | ||
|
||
#pragma unroll | ||
for (int i = 0; i < QR2_0; ++i) { | ||
const block_q8_1 * bq8i = bq8_1 + bq8_offset + i; | ||
|
||
for (int j = 0; j < VDR_TQ2_0_Q8_1_MMVQ; ++j) { | ||
u[VDR_TQ2_0_Q8_1_MMVQ*i + j] = get_int_b4(bq8i->qs, (iqs % QI8_1) + j); | ||
} | ||
d8[i] = __low2float(bq8i->ds); | ||
} | ||
|
||
return vec_dot_tq2_0_q8_1_impl<VDR_TQ2_0_Q8_1_MMVQ>(v, u, btq2_0->d, d8); | ||
} | ||
|
||
#define VDR_IQ2_XXS_Q8_1_MMVQ 2 | ||
#define VDR_IQ2_XXS_Q8_1_MMQ 2 | ||
|
||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -3375,7 +3375,8 @@ static const ggml_type all_types[] = { | |
GGML_TYPE_Q2_K, GGML_TYPE_Q3_K, | ||
GGML_TYPE_Q4_K, GGML_TYPE_Q5_K, | ||
GGML_TYPE_Q6_K, | ||
// GGML_TYPE_TQ1_0, GGML_TYPE_TQ2_0, // TODO: implement for all backends | ||
// GGML_TYPE_TQ1_0, | ||
GGML_TYPE_TQ2_0, | ||
Comment on lines
-3378
to
+3379
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. An unintended side effect of un-commenting Some solutions are:
Most of these solutions (apart from hiding the problem) are out of scope of this PR which focuses on the CUDA implementation of There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. The correct fix would be to modify There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Done in b6fc9f0. |
||
GGML_TYPE_IQ2_XXS, GGML_TYPE_IQ2_XS, GGML_TYPE_IQ2_S, | ||
GGML_TYPE_IQ3_XXS, GGML_TYPE_IQ1_S, GGML_TYPE_IQ1_M, | ||
GGML_TYPE_IQ4_NL, GGML_TYPE_IQ3_S, GGML_TYPE_IQ4_XS, | ||
|
@@ -3387,6 +3388,7 @@ static const ggml_type base_types[] = { | |
GGML_TYPE_Q4_0, | ||
GGML_TYPE_Q4_1, // for I8MM tests | ||
GGML_TYPE_Q4_K, | ||
GGML_TYPE_TQ2_0, | ||
GGML_TYPE_IQ2_XXS | ||
}; | ||
|
||
|
@@ -3397,7 +3399,8 @@ static const ggml_type other_types[] = { | |
GGML_TYPE_Q2_K, GGML_TYPE_Q3_K, | ||
GGML_TYPE_Q5_K, | ||
GGML_TYPE_Q6_K, | ||
// GGML_TYPE_TQ1_0, GGML_TYPE_TQ2_0, // TODO: implement for all backends | ||
// GGML_TYPE_TQ1_0, | ||
GGML_TYPE_TQ2_0, | ||
GGML_TYPE_IQ2_XS, GGML_TYPE_IQ2_S, | ||
GGML_TYPE_IQ3_XXS, GGML_TYPE_IQ1_S, GGML_TYPE_IQ1_M, | ||
GGML_TYPE_IQ4_NL, GGML_TYPE_IQ3_S, GGML_TYPE_IQ4_XS, | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This should be faster but since this kernel is going to be I/O bound anyways I doubt it will make a measurable difference.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Right, the indices calculation shouldn't really be a bottleneck here.
Is there a particular reason why
tid
isn't anint
everywhere in that file when it corresponds tothreadIdx.x
?There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
If you mean my comment that was just me being a bit inconsistent and not looking ahead how the values are being used, sorry. Generally speaking the issue with
int
vs.int64_t
is just potential overflows for very large tensors. So for kernels where the performance is not relevant anyways it's a lot of the time preferable to just useint64_t
.