@@ -328,6 +328,7 @@ static ggml_fp16_t table_exp_f16[1 << 16];
328
328
// precomputed f32 table for f16 (256 KB)
329
329
static float table_f32_f16 [1 << 16 ];
330
330
331
+ #if defined(__ARM_NEON )
331
332
#define B1 (c ,s ,n ) 0x ## n ## c , 0x ## n ## s
332
333
#define B2 (c ,s ,n ) B1(c,s,n ## c), B1(c,s,n ## s)
333
334
#define B3 (c ,s ,n ) B2(c,s,n ## c), B2(c,s,n ## s)
@@ -339,7 +340,7 @@ static float table_f32_f16[1 << 16];
339
340
340
341
// precomputed tables for expanding 8bits to 8 bytes (shl 4)
341
342
static const uint64_t table_b2b_u [1 << 8 ] = { B8 (00 , 10 ) };
342
- static const uint64_t table_b2b_i [ 1 << 8 ] = { B8 ( F0 , 00 ) };
343
+ #endif
343
344
344
345
// On ARM NEON, it's quicker to directly convert x -> x instead of calling into ggml_lookup_fp16_to_fp32,
345
346
// so we define GGML_FP16_TO_FP32 and GGML_FP32_TO_FP16 elsewhere for NEON.
@@ -490,6 +491,19 @@ static inline int hsum_i32_4(const __m128i a) {
490
491
}
491
492
492
493
#if __AVX2__ || __AVX512F__
494
+ // spread 32 bits to 32 bytes { 0x00, 0xFF }
495
+ static inline __m256i bytes_from_bits_32 (const uint8_t * x ) {
496
+ uint32_t x32 ;
497
+ memcpy (& x32 , x , sizeof (uint32_t ));
498
+ const __m256i shuf_mask = _mm256_set_epi64x (
499
+ 0x0303030303030303 , 0x0202020202020202 ,
500
+ 0x0101010101010101 , 0x0000000000000000 );
501
+ __m256i bytes = _mm256_shuffle_epi8 (_mm256_set1_epi32 (x32 ), shuf_mask );
502
+ const __m256i bit_mask = _mm256_set1_epi64x (0x7fbfdfeff7fbfdfe );
503
+ bytes = _mm256_or_si256 (bytes , bit_mask );
504
+ return _mm256_cmpeq_epi8 (bytes , _mm256_set1_epi64x (-1 ));
505
+ }
506
+
493
507
// Unpack 32 4-bit fields into 32 bytes
494
508
// The output vector contains 32 bytes, each one in [ 0 .. 15 ] interval
495
509
static inline __m256i bytes_from_nibbles_32 (const uint8_t * rsi )
@@ -3367,9 +3381,8 @@ static void ggml_vec_dot_q5_0_q8_0(const int n, float * restrict s, const void *
3367
3381
const __m256 d = _mm256_mul_ps (_mm256_set1_ps (GGML_FP16_TO_FP32 (x [i ].d )), _mm256_broadcast_ss (& y [i ].d ));
3368
3382
3369
3383
__m256i bx = bytes_from_nibbles_32 (x [i ].qs );
3370
- const __m256i bxhi = _mm256_set_epi64x (
3371
- table_b2b_i [x [i ].qh [3 ]], table_b2b_i [x [i ].qh [2 ]],
3372
- table_b2b_i [x [i ].qh [1 ]], table_b2b_i [x [i ].qh [0 ]]);
3384
+ __m256i bxhi = bytes_from_bits_32 (x [i ].qh );
3385
+ bxhi = _mm256_andnot_si256 (bxhi , _mm256_set1_epi8 ((char )0xF0 ));
3373
3386
bx = _mm256_or_si256 (bx , bxhi );
3374
3387
3375
3388
__m256i by = _mm256_loadu_si256 ((const __m256i * )y [i ].qs );
@@ -3501,9 +3514,8 @@ static void ggml_vec_dot_q5_1_q8_1(const int n, float * restrict s, const void *
3501
3514
summs += GGML_FP16_TO_FP32 (x [i ].m ) * (y [i ].s0 + y [i ].s1 );
3502
3515
3503
3516
__m256i bx = bytes_from_nibbles_32 (x [i ].qs );
3504
- const __m256i bxhi = _mm256_set_epi64x (
3505
- table_b2b_u [x [i ].qh [3 ]], table_b2b_u [x [i ].qh [2 ]],
3506
- table_b2b_u [x [i ].qh [1 ]], table_b2b_u [x [i ].qh [0 ]]);
3517
+ __m256i bxhi = bytes_from_bits_32 (x [i ].qh );
3518
+ bxhi = _mm256_and_si256 (bxhi , _mm256_set1_epi8 (0x10 ));
3507
3519
bx = _mm256_or_si256 (bx , bxhi );
3508
3520
3509
3521
const __m256 dy = _mm256_broadcast_ss (& y [i ].d );
0 commit comments