@@ -267,6 +267,8 @@ kernel void kernel_mul_mat_q4_0_f32(
267
267
uint2 tptg[[threads_per_threadgroup]]) {
268
268
const int nb = ne00/QK4_0;
269
269
270
+ const int8_t m8 = 8 ;
271
+
270
272
const int64_t r0 = tgpig.x ;
271
273
const int64_t r1 = tgpig.y ;
272
274
@@ -276,33 +278,34 @@ kernel void kernel_mul_mat_q4_0_f32(
276
278
const uint nth = tptg.x *tptg.y ;
277
279
const uint ith = tptg.y *tpitg.x + tpitg.y ;
278
280
279
- sum[ith] = 0 .0f ;
281
+ const int ix = tpitg.y /4 ; // 0 or 1
282
+ const int iy = tpitg.y - 4 *ix; // 0...3
280
283
281
- for ( int i = tpitg. x ; i < nb; i += tptg. x ) {
282
- device const uchar4 * x0p = (device const uchar4 *) (x + i)-> qs ;
283
- device const float4 * y0p = (device const float4 *) (y + i*QK4_0) ;
284
+ const int first = 4 * iy;
285
+
286
+ float sumf = 0 ;
284
287
285
- const float d = ( float )((x + i)-> d );
288
+ for ( int i = 2 *tpitg. x + ix; i < nb; i += 2 *tptg. x ) {
286
289
287
- const uchar4 x0v = *(x0p + tpitg.y );
288
- const float4 y0v = *(y0p + tpitg.y + 0 );
289
- const float4 y1v = *(y0p + tpitg.y + 4 );
290
+ const float d = (float )x[i].d ;
290
291
291
- float acc = 0 .0f ;
292
+ device const uint8_t * xl = x[i].qs + first;
293
+ device const float * yl = y + i * QK4_0 + first;
294
+
295
+ float2 acc = {0 .0f , 0 .0f };
292
296
293
297
for (int j = 0 ; j < 4 ; ++j) {
294
- const int x0 = x0v[j] & 0x0F ;
295
- const int x1 = x0v[j] >> 4 ;
296
298
297
- const float y0 = y0v [j];
298
- const float y1 = y1v [j];
299
+ acc[ 0 ] += yl[j+ 0 ] * (( int8_t )(xl [j] & 0xF ) - m8) ;
300
+ acc[ 1 ] += yl[j+ 16 ] * (( int8_t )(xl [j] >> 4 ) - m8) ;
299
301
300
- acc += (x0 - 8 )*y0 + (x1 - 8 )*y1 ;
301
302
}
302
303
303
- sum[ith ] += acc*d ;
304
+ sumf += d * (acc[ 0 ] + acc[ 1 ]) ;
304
305
}
305
306
307
+ sum[ith] = sumf;
308
+
306
309
//
307
310
// Accumulate the sum from all threads in the threadgroup
308
311
// This version is slightly faster than the commented out one below,
@@ -357,6 +360,7 @@ kernel void kernel_mul_mat_f16_f32(
357
360
uint3 tpig[[thread_position_in_grid]],
358
361
uint3 tpitg[[thread_position_in_threadgroup]],
359
362
uint3 tptg[[threads_per_threadgroup]]) {
363
+
360
364
const int64_t r0 = tgpig.x ;
361
365
const int64_t r1 = tgpig.y ;
362
366
const int64_t im = tgpig.z ;
0 commit comments