Skip to content

Commit 68d785c

Browse files
authored
Write weight packing/unpacking functions for universal kernels
Differential Revision: D71163696 Pull Request resolved: #1921
1 parent ab3792e commit 68d785c

7 files changed

+605
-372
lines changed

torchao/experimental/kernels/cpu/aarch64/linear/channelwise_8bit_activation_groupwise_lowbit_weight_1x1x32_f32_neondot-impl.h

Lines changed: 22 additions & 75 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@
1010

1111
#include <torchao/experimental/kernels/cpu/aarch64/bitpacking/bitpack.h>
1212
#include <torchao/experimental/kernels/cpu/aarch64/linear/channelwise_8bit_activation_prepare_activation_data_1xk_f32-impl.h>
13+
#include <torchao/experimental/kernels/cpu/aarch64/linear/pack_weights.h>
1314
#include <torchao/experimental/kernels/cpu/aarch64/quantization/quantize.h>
1415
#include <torchao/experimental/kernels/cpu/aarch64/reduction/reduction.h>
1516
#include <cassert>
@@ -149,8 +150,8 @@ void kernel_impl(
149150
int32_t activation_qvals_sum = *((int32_t*)activation_ptr);
150151
activation_ptr += sizeof(int32_t);
151152

152-
int8_t weight_zero = *((int8_t*)weight_data_byte_ptr);
153-
weight_data_byte_ptr += sizeof(int8_t);
153+
int32_t weight_zero = *((int32_t*)weight_data_byte_ptr);
154+
weight_data_byte_ptr += sizeof(int32_t);
154155

155156
res += (weight_scale * activation_scale) *
156157
(qval_dot - (activation_zero * weight_qvals_sum) -
@@ -190,31 +191,14 @@ size_t inline weight_data_size_impl(
190191
int weight_nbit,
191192
bool has_weight_zeros,
192193
bool has_bias) {
193-
assert(k % group_size == 0);
194-
assert(k % 32 == 0);
195-
int groups_per_col = k / group_size;
196-
int col_size = 0;
197-
198-
// qvals
199-
// (k * weight_bit) bits -> ((k / 8) * weight_bit) bytes
200-
col_size += (k / 8) * weight_nbit;
201-
202-
// scales
203-
col_size += sizeof(float) * groups_per_col;
204-
205-
// qvals_sum
206-
col_size += sizeof(int32_t) * groups_per_col;
207-
208-
// zeros
209-
if (has_weight_zeros) {
210-
col_size += sizeof(int8_t) * groups_per_col;
211-
}
212-
213-
if (has_bias) {
214-
col_size += sizeof(float);
215-
}
216-
217-
return col_size * n;
194+
return torchao::kernels::cpu::aarch64::linear::packing::packed_weights_size(
195+
n,
196+
k,
197+
group_size,
198+
weight_nbit,
199+
has_weight_zeros,
200+
has_bias,
201+
/*nr*/ 1);
218202
}
219203

220204
template <int weight_nbit>
@@ -227,56 +211,19 @@ void prepare_weight_data_impl(
227211
int group_size,
228212
const int8_t* weight_qvals,
229213
const float* weight_scales,
214+
// Ignored if has_weight_zeros = false
230215
const int8_t* weight_zeros,
231216
const float* bias) {
232-
assert(k % group_size == 0);
233-
assert(group_size % 32 == 0);
234-
235-
bool has_weight_zeros = (weight_zeros != nullptr);
236-
bool has_bias = (bias != nullptr);
237-
238-
auto weight_data_byte_ptr = (char*)weight_data;
239-
constexpr int bytes_per_32_weight_values = 4 * weight_nbit;
240-
241-
int8x16_t wq0, wq1;
242-
243-
const int8_t* qvals_ptr = weight_qvals;
244-
const float* scales_ptr = weight_scales;
245-
const int8_t* zeros_ptr = weight_zeros;
246-
const float* bias_ptr = bias;
247-
248-
for (int n_idx = 0; n_idx < n; n_idx++) {
249-
for (int k_idx = 0; k_idx < k; k_idx += group_size) {
250-
int32_t group_qvals_sum = 0;
251-
for (int i = 0; i < group_size; i += 32) {
252-
wq0 = vld1q_s8(qvals_ptr);
253-
wq1 = vld1q_s8(qvals_ptr + 16);
254-
qvals_ptr += 32;
255-
256-
group_qvals_sum += vaddlvq_s8(wq0) + vaddlvq_s8(wq1);
257-
258-
torchao::bitpacking::vec_pack_32_lowbit_values<weight_nbit>(
259-
/*packed=*/(uint8_t*)weight_data_byte_ptr,
260-
/*unpacked0=*/wq0,
261-
/*unpacked1=*/wq1);
262-
weight_data_byte_ptr += bytes_per_32_weight_values;
263-
}
264-
*((float*)weight_data_byte_ptr) = *scales_ptr++;
265-
weight_data_byte_ptr += sizeof(float);
266-
267-
*((int32_t*)weight_data_byte_ptr) = group_qvals_sum;
268-
weight_data_byte_ptr += sizeof(int32_t);
269-
270-
if (has_weight_zeros) {
271-
*((int8_t*)weight_data_byte_ptr) = *zeros_ptr++;
272-
weight_data_byte_ptr += sizeof(int8_t);
273-
}
274-
}
275-
if (has_bias) {
276-
*((float*)weight_data_byte_ptr) = *bias_ptr++;
277-
weight_data_byte_ptr += sizeof(float);
278-
}
279-
}
217+
torchao::kernels::cpu::aarch64::linear::packing::
218+
pack_weights<weight_nbit, /*nr*/ 1, /*kr*/ 32, /*sr*/ 2>(
219+
weight_data,
220+
n,
221+
k,
222+
group_size,
223+
weight_qvals,
224+
weight_scales,
225+
weight_zeros,
226+
bias);
280227
}
281228

282229
} // namespace

torchao/experimental/kernels/cpu/aarch64/linear/channelwise_8bit_activation_groupwise_lowbit_weight_1x4x16_f32_neondot-impl.h

Lines changed: 19 additions & 148 deletions
Original file line numberDiff line numberDiff line change
@@ -10,8 +10,7 @@
1010

1111
#include <torchao/experimental/kernels/cpu/aarch64/bitpacking/bitpack.h>
1212
#include <torchao/experimental/kernels/cpu/aarch64/linear/channelwise_8bit_activation_prepare_activation_data_1xk_f32-impl.h>
13-
#include <torchao/experimental/kernels/cpu/aarch64/reduction/reduction.h>
14-
#include <torchao/experimental/kernels/cpu/aarch64/valpacking/valpack.h>
13+
#include <torchao/experimental/kernels/cpu/aarch64/linear/pack_weights.h>
1514
#include <cassert>
1615
#include <cstring>
1716

@@ -257,33 +256,14 @@ size_t inline weight_data_size_impl(
257256
int weight_nbit,
258257
bool has_weight_zeros,
259258
bool has_bias) {
260-
assert(k % group_size == 0);
261-
int groups_per_col = k / group_size;
262-
int col_size = 0;
263-
264-
// qvals
265-
col_size += (k / 8) * weight_nbit;
266-
267-
// scales
268-
col_size += sizeof(float) * groups_per_col;
269-
270-
// qvals_sum
271-
col_size += sizeof(int32_t) * groups_per_col;
272-
273-
// zeros
274-
if (has_weight_zeros) {
275-
col_size += sizeof(int32_t) * groups_per_col;
276-
}
277-
278-
// bias
279-
if (has_bias) {
280-
col_size += sizeof(float);
281-
}
282-
283-
// Replace n with next multiple of 4 >= n
284-
n = ((n + 3) / 4) * 4;
285-
286-
return col_size * n;
259+
return torchao::kernels::cpu::aarch64::linear::packing::packed_weights_size(
260+
n,
261+
k,
262+
group_size,
263+
weight_nbit,
264+
has_weight_zeros,
265+
has_bias,
266+
/*nr*/ 4);
287267
}
288268

289269
template <int weight_nbit>
@@ -299,125 +279,16 @@ void prepare_weight_data_impl(
299279
// Ignored if has_weight_zeros = false
300280
const int8_t* weight_zeros,
301281
const float* bias) {
302-
assert(k % group_size == 0);
303-
assert(group_size % 16 == 0);
304-
305-
bool has_weight_zeros = (weight_zeros != nullptr);
306-
bool has_bias = (bias != nullptr);
307-
308-
int groups_per_k = k / group_size;
309-
constexpr int bytes_per_64_weight_values = 8 * weight_nbit;
310-
311-
auto weight_data_byte_ptr = (char*)weight_data;
312-
const int8_t* qvals_ptr = weight_qvals;
313-
const float* scales_ptr = weight_scales;
314-
const int8_t* zeros_ptr = weight_zeros;
315-
const float* bias_ptr = bias;
316-
317-
int8_t interleaved_buffer[64];
318-
int8_t buffer[64];
319-
320-
for (int n_idx = 0; n_idx < n; n_idx += 4) {
321-
for (int k_idx = 0; k_idx < k; k_idx += group_size) {
322-
// Loop over group in chunks of 16, processing 4 columns at at time
323-
int qvals_sum[4] = {0, 0, 0, 0};
324-
for (int i = 0; i < group_size; i += 16) {
325-
std::memset(buffer, 0, 64);
326-
// Loop over 4 cols
327-
#pragma unroll(4)
328-
for (int j = 0; j < 4; j++) {
329-
if (n_idx + j < n) {
330-
// If qvals_ptr are pre-packed in a naive way, this is where
331-
// unpacking can occur
332-
std::memcpy(buffer + 16 * j, qvals_ptr + k * j, 16);
333-
qvals_sum[j] +=
334-
torchao::kernels::cpu::aarch64::reduction::compute_sum(
335-
buffer + 16 * j, 16);
336-
}
337-
}
338-
torchao::kernels::cpu::valpacking::interleave_data(
339-
/*data_interleaved=*/interleaved_buffer,
340-
/*data=*/buffer,
341-
/*bytes_per_val=*/1,
342-
/*vals_per_channel=*/16,
343-
/*vals_per_group=*/16,
344-
/*vals_per_chunk=*/8,
345-
/*channels=*/4,
346-
/*channel_stride_in_vals=*/16);
347-
torchao::bitpacking::vec_pack_64_lowbit_values<weight_nbit>(
348-
(uint8_t*)weight_data_byte_ptr,
349-
vld1q_s8(interleaved_buffer),
350-
vld1q_s8(interleaved_buffer + 16),
351-
vld1q_s8(interleaved_buffer + 32),
352-
vld1q_s8(interleaved_buffer + 48));
353-
qvals_ptr += 16;
354-
weight_data_byte_ptr += bytes_per_64_weight_values;
355-
} // loop over group
356-
357-
// Store weight scales
358-
#pragma unroll(4)
359-
for (int j = 0; j < 4; j++) {
360-
float32_t scale = 0.0;
361-
if (n_idx + j < n) {
362-
scale = *(scales_ptr + j * groups_per_k);
363-
}
364-
*((float*)weight_data_byte_ptr) = scale;
365-
weight_data_byte_ptr += sizeof(float);
366-
}
367-
scales_ptr += 1;
368-
369-
// Store weight qvals_sum
370-
#pragma unroll(4)
371-
for (int j = 0; j < 4; j++) {
372-
*((int*)weight_data_byte_ptr) = qvals_sum[j];
373-
weight_data_byte_ptr += sizeof(int);
374-
}
375-
376-
// Store weight zeros
377-
// I went back and forth on how to store weight_zero.
378-
// Kernel computation is done in int32, so I'm converting these to
379-
// int32 before storing (load 4 int32s in kernel).
380-
// In the 1x8 kernel, we may want to store as int16_t, which reduces
381-
// a load in the kernel (load 8 int16_ts in kernel, instead of 2
382-
// load 4 int32_ts), but adds 2 moves (int16 to int32).
383-
if (has_weight_zeros) {
384-
#pragma unroll(4)
385-
for (int j = 0; j < 4; j++) {
386-
int32_t zero = 0;
387-
if (n_idx + j < n) {
388-
zero = (int)(*(zeros_ptr + j * groups_per_k));
389-
}
390-
*((int32_t*)weight_data_byte_ptr) = zero;
391-
weight_data_byte_ptr += sizeof(int32_t);
392-
}
393-
zeros_ptr += 1;
394-
}
395-
} // k_idx
396-
if (has_bias) {
397-
#pragma unroll(4)
398-
for (int j = 0; j < 4; j++) {
399-
float bias_ = 0.0;
400-
if (n_idx + j < n) {
401-
bias_ = *(bias_ptr + j);
402-
}
403-
*((float*)weight_data_byte_ptr) = bias_;
404-
weight_data_byte_ptr += sizeof(float);
405-
}
406-
bias_ptr += 1;
407-
}
408-
409-
// In the previous loop over k, we processed 4 columns at a time,
410-
// but only advanced our pointers over the first column.
411-
// So we advance over the other 3 columns here.
412-
qvals_ptr += 3 * k;
413-
scales_ptr += 3 * groups_per_k;
414-
if (has_weight_zeros) {
415-
zeros_ptr += 3 * groups_per_k;
416-
}
417-
if (has_bias) {
418-
bias_ptr += 3;
419-
}
420-
} // n_idx
282+
torchao::kernels::cpu::aarch64::linear::packing::
283+
pack_weights<weight_nbit, /*nr*/ 4, /*kr*/ 16, /*sr*/ 2>(
284+
weight_data,
285+
n,
286+
k,
287+
group_size,
288+
weight_qvals,
289+
weight_scales,
290+
weight_zeros,
291+
bias);
421292
}
422293

423294
} // namespace

0 commit comments

Comments
 (0)