10
10
11
11
#include < torchao/experimental/kernels/cpu/aarch64/bitpacking/bitpack.h>
12
12
#include < torchao/experimental/kernels/cpu/aarch64/linear/channelwise_8bit_activation_prepare_activation_data_1xk_f32-impl.h>
13
- #include < torchao/experimental/kernels/cpu/aarch64/reduction/reduction.h>
14
- #include < torchao/experimental/kernels/cpu/aarch64/valpacking/valpack.h>
13
+ #include < torchao/experimental/kernels/cpu/aarch64/linear/pack_weights.h>
15
14
#include < cassert>
16
15
#include < cstring>
17
16
@@ -257,33 +256,14 @@ size_t inline weight_data_size_impl(
257
256
int weight_nbit,
258
257
bool has_weight_zeros,
259
258
bool has_bias) {
260
- assert (k % group_size == 0 );
261
- int groups_per_col = k / group_size;
262
- int col_size = 0 ;
263
-
264
- // qvals
265
- col_size += (k / 8 ) * weight_nbit;
266
-
267
- // scales
268
- col_size += sizeof (float ) * groups_per_col;
269
-
270
- // qvals_sum
271
- col_size += sizeof (int32_t ) * groups_per_col;
272
-
273
- // zeros
274
- if (has_weight_zeros) {
275
- col_size += sizeof (int32_t ) * groups_per_col;
276
- }
277
-
278
- // bias
279
- if (has_bias) {
280
- col_size += sizeof (float );
281
- }
282
-
283
- // Replace n with next multiple of 4 >= n
284
- n = ((n + 3 ) / 4 ) * 4 ;
285
-
286
- return col_size * n;
259
+ return torchao::kernels::cpu::aarch64::linear::packing::packed_weights_size (
260
+ n,
261
+ k,
262
+ group_size,
263
+ weight_nbit,
264
+ has_weight_zeros,
265
+ has_bias,
266
+ /* nr*/ 4 );
287
267
}
288
268
289
269
template <int weight_nbit>
@@ -299,125 +279,16 @@ void prepare_weight_data_impl(
299
279
// Ignored if has_weight_zeros = false
300
280
const int8_t * weight_zeros,
301
281
const float * bias) {
302
- assert (k % group_size == 0 );
303
- assert (group_size % 16 == 0 );
304
-
305
- bool has_weight_zeros = (weight_zeros != nullptr );
306
- bool has_bias = (bias != nullptr );
307
-
308
- int groups_per_k = k / group_size;
309
- constexpr int bytes_per_64_weight_values = 8 * weight_nbit;
310
-
311
- auto weight_data_byte_ptr = (char *)weight_data;
312
- const int8_t * qvals_ptr = weight_qvals;
313
- const float * scales_ptr = weight_scales;
314
- const int8_t * zeros_ptr = weight_zeros;
315
- const float * bias_ptr = bias;
316
-
317
- int8_t interleaved_buffer[64 ];
318
- int8_t buffer[64 ];
319
-
320
- for (int n_idx = 0 ; n_idx < n; n_idx += 4 ) {
321
- for (int k_idx = 0 ; k_idx < k; k_idx += group_size) {
322
- // Loop over group in chunks of 16, processing 4 columns at at time
323
- int qvals_sum[4 ] = {0 , 0 , 0 , 0 };
324
- for (int i = 0 ; i < group_size; i += 16 ) {
325
- std::memset (buffer, 0 , 64 );
326
- // Loop over 4 cols
327
- #pragma unroll(4)
328
- for (int j = 0 ; j < 4 ; j++) {
329
- if (n_idx + j < n) {
330
- // If qvals_ptr are pre-packed in a naive way, this is where
331
- // unpacking can occur
332
- std::memcpy (buffer + 16 * j, qvals_ptr + k * j, 16 );
333
- qvals_sum[j] +=
334
- torchao::kernels::cpu::aarch64::reduction::compute_sum (
335
- buffer + 16 * j, 16 );
336
- }
337
- }
338
- torchao::kernels::cpu::valpacking::interleave_data (
339
- /* data_interleaved=*/ interleaved_buffer,
340
- /* data=*/ buffer,
341
- /* bytes_per_val=*/ 1 ,
342
- /* vals_per_channel=*/ 16 ,
343
- /* vals_per_group=*/ 16 ,
344
- /* vals_per_chunk=*/ 8 ,
345
- /* channels=*/ 4 ,
346
- /* channel_stride_in_vals=*/ 16 );
347
- torchao::bitpacking::vec_pack_64_lowbit_values<weight_nbit>(
348
- (uint8_t *)weight_data_byte_ptr,
349
- vld1q_s8 (interleaved_buffer),
350
- vld1q_s8 (interleaved_buffer + 16 ),
351
- vld1q_s8 (interleaved_buffer + 32 ),
352
- vld1q_s8 (interleaved_buffer + 48 ));
353
- qvals_ptr += 16 ;
354
- weight_data_byte_ptr += bytes_per_64_weight_values;
355
- } // loop over group
356
-
357
- // Store weight scales
358
- #pragma unroll(4)
359
- for (int j = 0 ; j < 4 ; j++) {
360
- float32_t scale = 0.0 ;
361
- if (n_idx + j < n) {
362
- scale = *(scales_ptr + j * groups_per_k);
363
- }
364
- *((float *)weight_data_byte_ptr) = scale;
365
- weight_data_byte_ptr += sizeof (float );
366
- }
367
- scales_ptr += 1 ;
368
-
369
- // Store weight qvals_sum
370
- #pragma unroll(4)
371
- for (int j = 0 ; j < 4 ; j++) {
372
- *((int *)weight_data_byte_ptr) = qvals_sum[j];
373
- weight_data_byte_ptr += sizeof (int );
374
- }
375
-
376
- // Store weight zeros
377
- // I went back and forth on how to store weight_zero.
378
- // Kernel computation is done in int32, so I'm converting these to
379
- // int32 before storing (load 4 int32s in kernel).
380
- // In the 1x8 kernel, we may want to store as int16_t, which reduces
381
- // a load in the kernel (load 8 int16_ts in kernel, instead of 2
382
- // load 4 int32_ts), but adds 2 moves (int16 to int32).
383
- if (has_weight_zeros) {
384
- #pragma unroll(4)
385
- for (int j = 0 ; j < 4 ; j++) {
386
- int32_t zero = 0 ;
387
- if (n_idx + j < n) {
388
- zero = (int )(*(zeros_ptr + j * groups_per_k));
389
- }
390
- *((int32_t *)weight_data_byte_ptr) = zero;
391
- weight_data_byte_ptr += sizeof (int32_t );
392
- }
393
- zeros_ptr += 1 ;
394
- }
395
- } // k_idx
396
- if (has_bias) {
397
- #pragma unroll(4)
398
- for (int j = 0 ; j < 4 ; j++) {
399
- float bias_ = 0.0 ;
400
- if (n_idx + j < n) {
401
- bias_ = *(bias_ptr + j);
402
- }
403
- *((float *)weight_data_byte_ptr) = bias_;
404
- weight_data_byte_ptr += sizeof (float );
405
- }
406
- bias_ptr += 1 ;
407
- }
408
-
409
- // In the previous loop over k, we processed 4 columns at a time,
410
- // but only advanced our pointers over the first column.
411
- // So we advance over the other 3 columns here.
412
- qvals_ptr += 3 * k;
413
- scales_ptr += 3 * groups_per_k;
414
- if (has_weight_zeros) {
415
- zeros_ptr += 3 * groups_per_k;
416
- }
417
- if (has_bias) {
418
- bias_ptr += 3 ;
419
- }
420
- } // n_idx
282
+ torchao::kernels::cpu::aarch64::linear::packing::
283
+ pack_weights<weight_nbit, /* nr*/ 4 , /* kr*/ 16 , /* sr*/ 2 >(
284
+ weight_data,
285
+ n,
286
+ k,
287
+ group_size,
288
+ weight_qvals,
289
+ weight_scales,
290
+ weight_zeros,
291
+ bias);
421
292
}
422
293
423
294
} // namespace
0 commit comments