Skip to content

Commit d580455

Browse files
Srihari-mcwNexesenex
authored andcommitted
Extend sgemm.cpp support for Q5_0
1 parent 98e4186 commit d580455

File tree

1 file changed

+57
-0
lines changed

1 file changed

+57
-0
lines changed

ggml/src/llamafile/sgemm.cpp

+57
Original file line numberDiff line numberDiff line change
@@ -942,6 +942,36 @@ class tinyBLAS_Q0_AVX {
942942
return _mm_sub_epi8(_mm_and_si128(_mm_set1_epi8(15), _mm_srli_epi16(x, 4)), _mm_set1_epi8(8));
943943
}
944944

945+
inline __m256i load(const block_q5_0 *b) {
946+
return _mm256_or_si256(denibble(b->qs), bittobyte(b->qh));
947+
}
948+
949+
inline __m128i load0(const block_q5_0* b) {
950+
const __m128i x = _mm_loadu_si128((const __m128i *)(b->qs));
951+
uint32_t x32;
952+
memcpy(&x32, b->qh, sizeof(uint32_t));
953+
__m128i qxl = _mm_and_si128(_mm_set1_epi8(15), x);
954+
__m128i bytesl = _mm_cmpeq_epi8(_mm_set1_epi64x(-1),
955+
_mm_or_si128(_mm_set1_epi64x(0x7fbfdfeff7fbfdfe),
956+
_mm_shuffle_epi8(_mm_set1_epi32(x32),
957+
_mm_set_epi64x(0x0101010101010101, 0x0000000000000000))));
958+
bytesl = _mm_andnot_si128(bytesl, _mm_set1_epi8((char)0xF0));
959+
return _mm_or_si128(qxl, bytesl);
960+
}
961+
962+
inline __m128i load1(const block_q5_0* b) {
963+
const __m128i x = _mm_loadu_si128((const __m128i *)(b->qs));
964+
uint32_t x32;
965+
memcpy(&x32, b->qh, sizeof(uint32_t));
966+
__m128i qxh = _mm_and_si128(_mm_set1_epi8(15), _mm_srli_epi16(x, 4));
967+
__m128i bytesh = _mm_cmpeq_epi8(_mm_set1_epi64x(-1),
968+
_mm_or_si128(_mm_set1_epi64x(0x7fbfdfeff7fbfdfe),
969+
_mm_shuffle_epi8(_mm_set1_epi32(x32),
970+
_mm_set_epi64x(0x0303030303030303, 0x0202020202020202))));
971+
bytesh = _mm_andnot_si128(bytesh, _mm_set1_epi8((char)0xF0));
972+
return _mm_or_si128(qxh, bytesh);
973+
}
974+
945975
inline __m256i load(const block_iq4_nl *b) {
946976
return MM256_SET_M128I(load1(b), load0(b));
947977
}
@@ -973,6 +1003,17 @@ class tinyBLAS_Q0_AVX {
9731003
_mm_srli_epi16(x, 4), 1));
9741004
}
9751005

1006+
static inline __m256i bittobyte(const uint8_t *p) {
1007+
uint32_t x32;
1008+
memcpy(&x32, p, sizeof(uint32_t));
1009+
__m256i bytes = _mm256_cmpeq_epi8(_mm256_set1_epi64x(-1),
1010+
_mm256_or_si256(_mm256_set1_epi64x(0x7fbfdfeff7fbfdfe),
1011+
_mm256_shuffle_epi8(_mm256_set1_epi32(x32),
1012+
_mm256_set_epi64x(0x0303030303030303, 0x0202020202020202,
1013+
0x0101010101010101, 0x0000000000000000))));
1014+
return _mm256_andnot_si256(bytes, _mm256_set1_epi8((char)0xF0));
1015+
}
1016+
9761017
const TA *const A;
9771018
const TB *const B;
9781019
TC *const C;
@@ -1182,6 +1223,22 @@ bool llamafile_sgemm(int64_t m, int64_t n, int64_t k, const void *A, int64_t lda
11821223
#endif
11831224
}
11841225

1226+
case GGML_TYPE_Q5_0: {
1227+
if (Btype != GGML_TYPE_Q8_0)
1228+
return false;
1229+
#if defined(__AVX2__) || defined(__AVX512F__) || defined(__AVX__)
1230+
tinyBLAS_Q0_AVX<block_q5_0, block_q8_0, float> tb{
1231+
k, (const block_q5_0 *)A, lda,
1232+
(const block_q8_0 *)B, ldb,
1233+
(float *)C, ldc,
1234+
ith, nth};
1235+
tb.matmul(m, n);
1236+
return true;
1237+
#else
1238+
return false;
1239+
#endif
1240+
}
1241+
11851242
case GGML_TYPE_IQ4_NL: {
11861243
if (Btype != GGML_TYPE_Q8_0)
11871244
return false;

0 commit comments

Comments
 (0)