@@ -8986,6 +8986,14 @@ void ggml_vec_dot_iq3_xxs_q8_K(const int n, float * restrict s, const void * res
8986
8986
#endif
8987
8987
}
8988
8988
8989
+ #ifdef __AVX2__
8990
+ static inline __m256i mul_add_epi8 (const __m256i x , const __m256i y ) {
8991
+ const __m256i ax = _mm256_sign_epi8 (x , x );
8992
+ const __m256i sy = _mm256_sign_epi8 (y , x );
8993
+ return _mm256_maddubs_epi16 (ax , sy );
8994
+ }
8995
+ #endif
8996
+
8989
8997
void ggml_vec_dot_iq1_s_q8_K (const int n , float * restrict s , const void * restrict vx , const void * restrict vy ) {
8990
8998
assert (n % QK_K == 0 );
8991
8999
@@ -8994,6 +9002,59 @@ void ggml_vec_dot_iq1_s_q8_K(const int n, float * restrict s, const void * restr
8994
9002
8995
9003
const int nb = n / QK_K ;
8996
9004
9005
+ #if defined __AVX2__
9006
+
9007
+ const __m128i m8 = _mm_set1_epi8 (0x08 );
9008
+ const __m128i m7 = _mm_set1_epi8 (0x07 );
9009
+ const __m128i shuffle_h = _mm_set_epi8 (15 , 7 , 14 , 6 , 13 , 5 , 12 , 4 , 11 , 3 , 10 , 2 , 9 , 1 , 8 , 0 );
9010
+ const __m128i shuffle_s [4 ] = {
9011
+ _mm_set_epi32 (0x03030303 , 0x02020202 , 0x01010101 , 0x00000000 ),
9012
+ _mm_set_epi32 (0x07070707 , 0x06060606 , 0x05050505 , 0x04040404 ),
9013
+ _mm_set_epi32 (0x0b0b0b0b , 0x0a0a0a0a , 0x09090909 , 0x08080808 ),
9014
+ _mm_set_epi32 (0x0f0f0f0f , 0x0e0e0e0e , 0x0d0d0d0d , 0x0c0c0c0c )
9015
+ };
9016
+
9017
+ uint64_t aux64 ;
9018
+
9019
+ __m256i v_gindex ;
9020
+ const uint16_t * gindex = (const uint16_t * )& v_gindex ;
9021
+
9022
+ __m256 accum = _mm256_setzero_ps ();
9023
+ for (int i = 0 ; i < nb ; ++ i ) {
9024
+
9025
+ const int8_t * q8 = y [i ].qs ;
9026
+ const uint8_t * qs = x [i ].qs ;
9027
+ const uint8_t * sc = x [i ].scales ;
9028
+
9029
+ __m256i sumi = _mm256_setzero_si256 ();
9030
+ for (int i128 = 0 ; i128 < QK_K /128 ; ++ i128 ) {
9031
+ const __m128i ql = _mm_loadu_si128 ((const __m128i * )qs ); qs += 16 ;
9032
+ memcpy (& aux64 , sc , 8 ); sc += 8 ;
9033
+ const __m128i qh = _mm_shuffle_epi8 (_mm_set_epi64x (aux64 >> 4 , aux64 ), shuffle_h );
9034
+ const __m256i hbit = _mm256_cvtepi8_epi16 (_mm_and_si128 (qh , m8 ));
9035
+ v_gindex = _mm256_or_si256 (_mm256_cvtepi8_epi16 (ql ), _mm256_slli_epi16 (hbit , 5 ));
9036
+ const __m128i scales = _mm_and_si128 (qh , m7 );
9037
+
9038
+ for (int i32 = 0 ; i32 < 4 ; ++ i32 ) {
9039
+ const __m256i q8b = _mm256_loadu_si256 ((const __m256i * )q8 ); q8 += 32 ;
9040
+ const __m256i q1b = _mm256_set_epi64x (iq1s_grid [gindex [4 * i32 + 3 ]], iq1s_grid [gindex [4 * i32 + 2 ]],
9041
+ iq1s_grid [gindex [4 * i32 + 1 ]], iq1s_grid [gindex [4 * i32 + 0 ]]);
9042
+ const __m256i dot = mul_add_epi8 (q1b , q8b );
9043
+ const __m256i s16 = _mm256_cvtepi8_epi16 (_mm_shuffle_epi8 (scales , shuffle_s [i32 ]));
9044
+ const __m256i p = _mm256_madd_epi16 (s16 , dot );
9045
+ sumi = _mm256_add_epi32 (sumi , p );
9046
+ }
9047
+
9048
+ }
9049
+
9050
+ accum = _mm256_fmadd_ps (_mm256_set1_ps (y [i ].d * GGML_FP16_TO_FP32 (x [i ].d )), _mm256_cvtepi32_ps (sumi ), accum );
9051
+
9052
+ }
9053
+
9054
+ * s = hsum_float_8 (accum );
9055
+
9056
+ #else
9057
+
8997
9058
int db [4 ];
8998
9059
uint16_t idx [4 ];
8999
9060
@@ -9030,6 +9091,8 @@ void ggml_vec_dot_iq1_s_q8_K(const int n, float * restrict s, const void * restr
9030
9091
9031
9092
* s = sumf ;
9032
9093
9094
+ #endif
9095
+
9033
9096
}
9034
9097
9035
9098
// ================================ IQ2 quantization =============================================
0 commit comments