@@ -8525,17 +8525,36 @@ void ggml_vec_dot_iq2_xs_q8_K(const int n, float * restrict s, const void * rest
8525
8525
8526
8526
const __m128i m4 = _mm_set1_epi8 (0xf );
8527
8527
const __m128i m1 = _mm_set1_epi8 (1 );
8528
- const __m128i m511 = _mm_set1_epi16 (511 );
8529
- const __m128i m127 = _mm_set1_epi16 ( 127 );
8528
+ const __m256i m511 = _mm256_set1_epi16 (511 );
8529
+ const __m256i mone = _mm256_set1_epi8 ( 1 );
8530
8530
8531
- const uint64_t * signs64 = (const uint64_t * )keven_signs_q2xs ;
8531
+ static const uint8_t k_bit_helper [32 ] = {
8532
+ 0x00 , 0x80 , 0x80 , 0x00 , 0x80 , 0x00 , 0x00 , 0x80 , 0x80 , 0x00 , 0x00 , 0x80 , 0x00 , 0x80 , 0x80 , 0x00 ,
8533
+ 0x00 , 0x80 , 0x80 , 0x00 , 0x80 , 0x00 , 0x00 , 0x80 , 0x80 , 0x00 , 0x00 , 0x80 , 0x00 , 0x80 , 0x80 , 0x00 ,
8534
+ };
8535
+ static const char block_sign_shuffle_mask_1 [32 ] = {
8536
+ 0x00 , 0x00 , 0x00 , 0x00 , 0x00 , 0x00 , 0x00 , 0x00 , 0x02 , 0x02 , 0x02 , 0x02 , 0x02 , 0x02 , 0x02 , 0x02 ,
8537
+ 0x04 , 0x04 , 0x04 , 0x04 , 0x04 , 0x04 , 0x04 , 0x04 , 0x06 , 0x06 , 0x06 , 0x06 , 0x06 , 0x06 , 0x06 , 0x06 ,
8538
+ };
8539
+ static const char block_sign_shuffle_mask_2 [32 ] = {
8540
+ 0x08 , 0x08 , 0x08 , 0x08 , 0x08 , 0x08 , 0x08 , 0x08 , 0x0a , 0x0a , 0x0a , 0x0a , 0x0a , 0x0a , 0x0a , 0x0a ,
8541
+ 0x0c , 0x0c , 0x0c , 0x0c , 0x0c , 0x0c , 0x0c , 0x0c , 0x0e , 0x0e , 0x0e , 0x0e , 0x0e , 0x0e , 0x0e , 0x0e ,
8542
+ };
8543
+ static const uint8_t bit_selector_mask_bytes [32 ] = {
8544
+ 0x01 , 0x02 , 0x04 , 0x08 , 0x10 , 0x20 , 0x40 , 0x80 , 0x01 , 0x02 , 0x04 , 0x08 , 0x10 , 0x20 , 0x40 , 0x80 ,
8545
+ 0x01 , 0x02 , 0x04 , 0x08 , 0x10 , 0x20 , 0x40 , 0x80 , 0x01 , 0x02 , 0x04 , 0x08 , 0x10 , 0x20 , 0x40 , 0x80 ,
8546
+ };
8547
+
8548
+ const __m256i bit_helper = _mm256_loadu_si256 ((const __m256i * )k_bit_helper );
8549
+ const __m256i bit_selector_mask = _mm256_loadu_si256 ((const __m256i * )bit_selector_mask_bytes );
8550
+ const __m256i block_sign_shuffle_1 = _mm256_loadu_si256 ((const __m256i * )block_sign_shuffle_mask_1 );
8551
+ const __m256i block_sign_shuffle_2 = _mm256_loadu_si256 ((const __m256i * )block_sign_shuffle_mask_2 );
8532
8552
8533
8553
uint64_t aux64 ;
8534
8554
8535
8555
// somewhat hacky, but gives a significant boost in performance
8536
- __m128i aux_gindex , aux_sindex ;
8556
+ __m256i aux_gindex ;
8537
8557
const uint16_t * gindex = (const uint16_t * )& aux_gindex ;
8538
- const uint16_t * sindex = (const uint16_t * )& aux_sindex ;
8539
8558
8540
8559
__m256 accumf = _mm256_setzero_ps ();
8541
8560
for (int i = 0 ; i < nb ; ++ i ) {
@@ -8550,26 +8569,68 @@ void ggml_vec_dot_iq2_xs_q8_K(const int n, float * restrict s, const void * rest
8550
8569
8551
8570
__m256i sumi1 = _mm256_setzero_si256 ();
8552
8571
__m256i sumi2 = _mm256_setzero_si256 ();
8553
- for (int ib32 = 0 ; ib32 < QK_K /32 ; ib32 += 2 ) {
8572
+ for (int ib32 = 0 ; ib32 < QK_K /32 ; ib32 += 4 ) {
8573
+
8574
+ const __m256i q2_data = _mm256_loadu_si256 ((const __m256i * )q2 ); q2 += 16 ;
8575
+ aux_gindex = _mm256_and_si256 (q2_data , m511 );
8576
+
8577
+ const __m256i partial_sign_bits = _mm256_srli_epi16 (q2_data , 9 );
8578
+ const __m256i partial_sign_bits_upper = _mm256_srli_epi16 (q2_data , 13 );
8579
+ const __m256i partial_sign_bits_for_counting = _mm256_xor_si256 (partial_sign_bits , partial_sign_bits_upper );
8580
+
8581
+ const __m256i odd_bits = _mm256_shuffle_epi8 (bit_helper , partial_sign_bits_for_counting );
8582
+ const __m256i full_sign_bits = _mm256_or_si256 (partial_sign_bits , odd_bits );
8583
+
8554
8584
const __m256i q8_1 = _mm256_loadu_si256 ((const __m256i * )q8 ); q8 += 32 ;
8555
8585
const __m256i q8_2 = _mm256_loadu_si256 ((const __m256i * )q8 ); q8 += 32 ;
8556
- const __m128i q2_data = _mm_loadu_si128 ((const __m128i * )q2 ); q2 += 8 ;
8557
- aux_gindex = _mm_and_si128 (q2_data , m511 );
8558
- aux_sindex = _mm_and_si128 (_mm_srli_epi16 (q2_data , 9 ), m127 );
8559
- const __m256i q2_1 = _mm256_set_epi64x (iq2xs_grid [gindex [3 ]], iq2xs_grid [gindex [2 ]], iq2xs_grid [gindex [1 ]], iq2xs_grid [gindex [0 ]]);
8560
- const __m256i q2_2 = _mm256_set_epi64x (iq2xs_grid [gindex [7 ]], iq2xs_grid [gindex [6 ]], iq2xs_grid [gindex [5 ]], iq2xs_grid [gindex [4 ]]);
8561
- const __m256i s2_1 = _mm256_set_epi64x (signs64 [sindex [3 ]], signs64 [sindex [2 ]], signs64 [sindex [1 ]], signs64 [sindex [0 ]]);
8562
- const __m256i s2_2 = _mm256_set_epi64x (signs64 [sindex [7 ]], signs64 [sindex [6 ]], signs64 [sindex [5 ]], signs64 [sindex [4 ]]);
8563
- const __m256i q8s_1 = _mm256_sign_epi8 (q8_1 , s2_1 );
8564
- const __m256i q8s_2 = _mm256_sign_epi8 (q8_2 , s2_2 );
8586
+ const __m256i q8_3 = _mm256_loadu_si256 ((const __m256i * )q8 ); q8 += 32 ;
8587
+ const __m256i q8_4 = _mm256_loadu_si256 ((const __m256i * )q8 ); q8 += 32 ;
8588
+
8589
+ const __m256i q2_1 = _mm256_set_epi64x (iq2xs_grid [gindex [ 3 ]], iq2xs_grid [gindex [ 2 ]],
8590
+ iq2xs_grid [gindex [ 1 ]], iq2xs_grid [gindex [ 0 ]]);
8591
+ const __m256i q2_2 = _mm256_set_epi64x (iq2xs_grid [gindex [ 7 ]], iq2xs_grid [gindex [ 6 ]],
8592
+ iq2xs_grid [gindex [ 5 ]], iq2xs_grid [gindex [ 4 ]]);
8593
+ const __m256i q2_3 = _mm256_set_epi64x (iq2xs_grid [gindex [11 ]], iq2xs_grid [gindex [10 ]],
8594
+ iq2xs_grid [gindex [ 9 ]], iq2xs_grid [gindex [ 8 ]]);
8595
+ const __m256i q2_4 = _mm256_set_epi64x (iq2xs_grid [gindex [15 ]], iq2xs_grid [gindex [14 ]],
8596
+ iq2xs_grid [gindex [13 ]], iq2xs_grid [gindex [12 ]]);
8597
+
8598
+ const __m128i full_signs_l = _mm256_castsi256_si128 (full_sign_bits );
8599
+ const __m128i full_signs_h = _mm256_extractf128_si256 (full_sign_bits , 1 );
8600
+ const __m256i full_signs_1 = _mm256_set_m128i (full_signs_l , full_signs_l );
8601
+ const __m256i full_signs_2 = _mm256_set_m128i (full_signs_h , full_signs_h );
8602
+
8603
+ __m256i signs ;
8604
+ signs = _mm256_shuffle_epi8 (full_signs_1 , block_sign_shuffle_1 );
8605
+ signs = _mm256_cmpeq_epi8 (_mm256_and_si256 (signs , bit_selector_mask ), bit_selector_mask );
8606
+ const __m256i q8s_1 = _mm256_sign_epi8 (q8_1 , _mm256_or_si256 (signs , mone ));
8607
+
8608
+ signs = _mm256_shuffle_epi8 (full_signs_1 , block_sign_shuffle_2 );
8609
+ signs = _mm256_cmpeq_epi8 (_mm256_and_si256 (signs , bit_selector_mask ), bit_selector_mask );
8610
+ const __m256i q8s_2 = _mm256_sign_epi8 (q8_2 , _mm256_or_si256 (signs , mone ));
8611
+
8612
+ signs = _mm256_shuffle_epi8 (full_signs_2 , block_sign_shuffle_1 );
8613
+ signs = _mm256_cmpeq_epi8 (_mm256_and_si256 (signs , bit_selector_mask ), bit_selector_mask );
8614
+ const __m256i q8s_3 = _mm256_sign_epi8 (q8_3 , _mm256_or_si256 (signs , mone ));
8615
+
8616
+ signs = _mm256_shuffle_epi8 (full_signs_2 , block_sign_shuffle_2 );
8617
+ signs = _mm256_cmpeq_epi8 (_mm256_and_si256 (signs , bit_selector_mask ), bit_selector_mask );
8618
+ const __m256i q8s_4 = _mm256_sign_epi8 (q8_4 , _mm256_or_si256 (signs , mone ));
8619
+
8565
8620
const __m256i dot1 = _mm256_maddubs_epi16 (q2_1 , q8s_1 );
8566
8621
const __m256i dot2 = _mm256_maddubs_epi16 (q2_2 , q8s_2 );
8622
+ const __m256i dot3 = _mm256_maddubs_epi16 (q2_3 , q8s_3 );
8623
+ const __m256i dot4 = _mm256_maddubs_epi16 (q2_4 , q8s_4 );
8567
8624
8568
8625
const __m256i sc1 = _mm256_cvtepi8_epi16 (_mm_shuffle_epi8 (scales , get_scale_shuffle (ib32 + 0 )));
8569
8626
const __m256i sc2 = _mm256_cvtepi8_epi16 (_mm_shuffle_epi8 (scales , get_scale_shuffle (ib32 + 1 )));
8627
+ const __m256i sc3 = _mm256_cvtepi8_epi16 (_mm_shuffle_epi8 (scales , get_scale_shuffle (ib32 + 2 )));
8628
+ const __m256i sc4 = _mm256_cvtepi8_epi16 (_mm_shuffle_epi8 (scales , get_scale_shuffle (ib32 + 3 )));
8570
8629
8571
8630
sumi1 = _mm256_add_epi32 (sumi1 , _mm256_madd_epi16 (dot1 , sc1 ));
8572
8631
sumi2 = _mm256_add_epi32 (sumi2 , _mm256_madd_epi16 (dot2 , sc2 ));
8632
+ sumi1 = _mm256_add_epi32 (sumi1 , _mm256_madd_epi16 (dot3 , sc3 ));
8633
+ sumi2 = _mm256_add_epi32 (sumi2 , _mm256_madd_epi16 (dot4 , sc4 ));
8573
8634
}
8574
8635
8575
8636
accumf = _mm256_fmadd_ps (_mm256_set1_ps (d ), _mm256_cvtepi32_ps (_mm256_add_epi32 (sumi1 , sumi2 )), accumf );
0 commit comments