Skip to content

Commit 5e4d50c

Browse files
authored
Reintroduce has_weight_zeros as a template param
Differential Revision: D71503133 Pull Request resolved: #1991
1 parent 70fc520 commit 5e4d50c

File tree

8 files changed

+92
-64
lines changed

8 files changed

+92
-64
lines changed

torchao/experimental/kernels/cpu/aarch64/linear/channelwise_8bit_activation_groupwise_lowbit_weight/channelwise_8bit_activation_groupwise_lowbit_weight.h

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -245,7 +245,7 @@ void kernel_1x4x16_f32_neondot(
245245
has_clamp);
246246
}
247247

248-
template <int weight_nbit, bool has_lut>
248+
template <int weight_nbit, bool has_weight_zeros, bool has_lut>
249249
void kernel_1x8x16_f32_neondot(
250250
// Outputs
251251
float32_t* output,
@@ -260,10 +260,11 @@ void kernel_1x8x16_f32_neondot(
260260
// Ignored if has_clamp = false
261261
float clamp_min,
262262
float clamp_max,
263-
bool has_weight_zeros,
263+
bool has_weight_zeros_,
264264
bool has_bias,
265265
bool has_clamp) {
266-
kernel::kernel_1x8x16_f32_neondot<weight_nbit, has_lut>(
266+
(void)has_weight_zeros_; // unused
267+
kernel::kernel_1x8x16_f32_neondot<weight_nbit, has_weight_zeros, has_lut>(
267268
output,
268269
output_m_stride,
269270
m,
@@ -274,7 +275,6 @@ void kernel_1x8x16_f32_neondot(
274275
packed_activations,
275276
clamp_min,
276277
clamp_max,
277-
has_weight_zeros,
278278
has_bias,
279279
has_clamp);
280280
}

torchao/experimental/kernels/cpu/aarch64/linear/channelwise_8bit_activation_groupwise_lowbit_weight/kernel_1x8x16_f32_neondot-impl.h

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -58,7 +58,7 @@ vec_clamp(float32x4_t x, float32x4_t vec_min, float32x4_t vec_max) {
5858
// Roughly inspired by
5959
// https://gitlab.arm.com/kleidi/kleidiai/-/blob/main/kai/ukernels/matmul/matmul_clamp_f32_qai8dxp_qsi4cxp/kai_matmul_clamp_f32_qai8dxp1x8_qsi4cxp8x8_1x8x32_neon_dotprod.c?ref_type=heads
6060

61-
template <int weight_nbit, bool has_lut>
61+
template <int weight_nbit, bool has_weight_zeros, bool has_lut>
6262
void kernel_1x8x16_f32_neondot(
6363
// Outputs
6464
float32_t* output,
@@ -73,7 +73,6 @@ void kernel_1x8x16_f32_neondot(
7373
// Ignored if has_clamp is false
7474
float clamp_min,
7575
float clamp_max,
76-
bool has_weight_zeros,
7776
bool has_bias,
7877
bool has_clamp) {
7978
assert(k % group_size == 0);
@@ -267,7 +266,7 @@ void kernel_1x8x16_f32_neondot(
267266

268267
int32x4_t term1_4567 = vmulq_n_s32(weight_qvals_sum, activation_zero);
269268

270-
if (has_weight_zeros) {
269+
if constexpr (has_weight_zeros) {
271270
// Compute term2 and term3
272271

273272
int32_t activation_qvals_sum = *((int32_t*)activation_ptr);

torchao/experimental/kernels/cpu/aarch64/linear/linear.h

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -320,7 +320,7 @@ void prepare_weight_data(
320320
bias);
321321
}
322322

323-
template <int weight_nbit>
323+
template <int weight_nbit, bool has_weight_zeros>
324324
void kernel(
325325
// Outputs
326326
float32_t* output,
@@ -335,12 +335,13 @@ void kernel(
335335
// Ignored if has_clamp = false
336336
float clamp_min,
337337
float clamp_max,
338-
bool has_weight_zeros,
338+
bool has_weight_zeros_,
339339
bool has_bias,
340340
bool has_clamp) {
341+
(void)has_weight_zeros_; // unused
341342
torchao::kernels::cpu::aarch64::linear::
342343
channelwise_8bit_activation_groupwise_lowbit_weight::
343-
kernel_1x8x16_f32_neondot<weight_nbit, /*has_lut*/ false>(
344+
kernel_1x8x16_f32_neondot<weight_nbit, has_weight_zeros, /*has_lut*/ false>(
344345
output,
345346
output_m_stride,
346347
m,

torchao/experimental/kernels/cpu/aarch64/tests/test_linear.cpp

Lines changed: 24 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -311,7 +311,7 @@ void test_channelwise_8bit_activation_groupwise_lowbit_weight_1x8x16_f32_neondot
311311
bias_ptr);
312312

313313
std::vector<float> output(m * n);
314-
kernel<weight_nbit>(
314+
kernel<weight_nbit, has_weight_zeros>(
315315
output.data(),
316316
/*output_m_stride=*/n,
317317
m,
@@ -388,13 +388,12 @@ TEST(
388388
}
389389
}
390390

391-
template <int weight_nbit>
391+
template <int weight_nbit, bool has_weight_zeros>
392392
void test_channelwise_8bit_activation_groupwise_lowbit_weight_lut(
393393
int m,
394394
int k,
395395
int n,
396396
int group_size,
397-
bool has_weight_zeros,
398397
bool has_bias,
399398
bool has_clamp) {
400399
constexpr int mr = 1;
@@ -453,7 +452,7 @@ void test_channelwise_8bit_activation_groupwise_lowbit_weight_lut(
453452
has_bias ? test_case.bias.data() : nullptr);
454453

455454
std::vector<float> output(m * n);
456-
kernel_1x8x16_f32_neondot<weight_nbit, /*has_lut*/ true>(
455+
kernel_1x8x16_f32_neondot<weight_nbit, has_weight_zeros, /*has_lut*/ true>(
457456
output.data(),
458457
/*output_m_stride=*/n,
459458
m,
@@ -476,85 +475,90 @@ void test_channelwise_8bit_activation_groupwise_lowbit_weight_lut(
476475
TEST(test_channelwise_8bit_activation_groupwise_lowbit_weight, LUT) {
477476
constexpr int weight_nbit = 4;
478477

479-
test_channelwise_8bit_activation_groupwise_lowbit_weight_lut<weight_nbit>(
478+
test_channelwise_8bit_activation_groupwise_lowbit_weight_lut<
479+
weight_nbit,
480+
/*has_weight_zeros*/ false>(
480481
/*m=*/7,
481482
/*k=*/64,
482483
/*n=*/13,
483484
/*group_size=*/16,
484-
/*has_weight_zeros=*/false,
485485
/*has_bias=*/false,
486486
/*has_clamp=*/false);
487487

488488
// has_weight_zeros
489-
test_channelwise_8bit_activation_groupwise_lowbit_weight_lut<weight_nbit>(
489+
test_channelwise_8bit_activation_groupwise_lowbit_weight_lut<
490+
weight_nbit,
491+
/*has_weight_zeros*/ true>(
490492
/*m=*/7,
491493
/*k=*/64,
492494
/*n=*/13,
493495
/*group_size=*/16,
494-
/*has_weight_zeros=*/true,
495496
/*has_bias=*/false,
496497
/*has_clamp=*/false);
497498

498499
// has_bias
499-
test_channelwise_8bit_activation_groupwise_lowbit_weight_lut<weight_nbit>(
500+
test_channelwise_8bit_activation_groupwise_lowbit_weight_lut<
501+
weight_nbit,
502+
/*has_weight_zeros=*/false>(
500503
/*m=*/7,
501504
/*k=*/64,
502505
/*n=*/13,
503506
/*group_size=*/16,
504-
/*has_weight_zeros=*/false,
505507
/*has_bias=*/true,
506508
/*has_clamp=*/false);
507509

508510
// has_clamp
509-
test_channelwise_8bit_activation_groupwise_lowbit_weight_lut<weight_nbit>(
511+
test_channelwise_8bit_activation_groupwise_lowbit_weight_lut<
512+
weight_nbit,
513+
/*has_weight_zeros*/ false>(
510514
/*m=*/7,
511515
/*k=*/64,
512516
/*n=*/13,
513517
/*group_size=*/16,
514-
/*has_weight_zeros=*/false,
515518
/*has_bias=*/false,
516519
/*has_clamp=*/true);
517520

518521
// n less than 8 (nr)
519522
for (int n = 1; n < 8; n++) {
520-
test_channelwise_8bit_activation_groupwise_lowbit_weight_lut<weight_nbit>(
523+
test_channelwise_8bit_activation_groupwise_lowbit_weight_lut<
524+
weight_nbit,
525+
/*has_weight_zeros=*/false>(
521526
/*m=*/7,
522527
/*k=*/64,
523528
/*n=*/n,
524529
/*group_size=*/16,
525-
/*has_weight_zeros=*/false,
526530
/*has_bias=*/false,
527531
/*has_clamp=*/false);
528532
}
529533

530534
// Other bitwidths
531535
test_channelwise_8bit_activation_groupwise_lowbit_weight_lut<
532-
/*weight_nbit*/ 1>(
536+
/*weight_nbit*/ 1,
537+
/*has_weight_zeros=*/false>(
533538
/*m=*/7,
534539
/*k=*/64,
535540
/*n=*/13,
536541
/*group_size=*/16,
537-
/*has_weight_zeros=*/false,
538542
/*has_bias=*/false,
539543
/*has_clamp=*/false);
540544

541545
test_channelwise_8bit_activation_groupwise_lowbit_weight_lut<
542-
/*weight_nbit*/ 2>(
546+
/*weight_nbit*/ 2,
547+
/*has_weight_zeros=*/false>(
543548
/*m=*/7,
544549
/*k=*/64,
545550
/*n=*/13,
546551
/*group_size=*/16,
547-
/*has_weight_zeros=*/false,
548552
/*has_bias=*/false,
549553
/*has_clamp=*/false);
550554

551555
test_channelwise_8bit_activation_groupwise_lowbit_weight_lut<
552-
/*weight_nbit*/ 3>(
556+
/*weight_nbit*/ 3,
557+
/*has_weight_zeros=*/false>(
553558
/*m=*/7,
554559
/*k=*/64,
555560
/*n=*/13,
556561
/*group_size=*/16,
557-
/*has_weight_zeros=*/false,
558562
/*has_bias=*/false,
559563
/*has_clamp=*/false);
560564
}

torchao/experimental/ops/embedding_xbit/op_embedding_xbit-impl.h

Lines changed: 5 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -253,9 +253,11 @@ Tensor shared_embedding_out_cpu(
253253
torchao::ops::PackedWeightsHeader::read(packed_weights.const_data_ptr());
254254
auto format = torchao::ops::linear_8bit_act_xbit_weight::PackedWeightsFormat::
255255
from_packed_weights_header(header);
256-
torchao::ops::linear_8bit_act_xbit_weight::check_format<weight_nbit>(
256+
257+
torchao::ops::linear_8bit_act_xbit_weight::check_format(
257258
format,
258-
torchao::ops::PackedWeightsType::linear_8bit_act_xbit_weight_universal);
259+
torchao::ops::PackedWeightsType::linear_8bit_act_xbit_weight_universal,
260+
weight_nbit);
259261
constexpr int nr = 8;
260262
constexpr int kr = 16;
261263
constexpr int sr = 2;
@@ -316,12 +318,7 @@ Tensor shared_embedding_cpu(
316318
const Tensor& indices) {
317319
Tensor output_tensor = torch::empty({}, torch::kFloat32);
318320
shared_embedding_out_cpu<weight_nbit>(
319-
packed_weights,
320-
group_size,
321-
n,
322-
k,
323-
indices,
324-
output_tensor);
321+
packed_weights, group_size, n, k, indices, output_tensor);
325322
return output_tensor;
326323
}
327324
#endif // USE_ATEN

torchao/experimental/ops/linear_8bit_act_xbit_weight/kernel_selector.h

Lines changed: 49 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -89,35 +89,62 @@ void register_ukernel_config_universal(
8989
if (!cpuinfo_initialize()) {
9090
throw std::runtime_error("Failed to initialize cpuinfo!");
9191
}
92-
check_format<weight_nbit>(
92+
93+
check_format(
9394
format,
94-
torchao::ops::PackedWeightsType::linear_8bit_act_xbit_weight_universal);
95+
torchao::ops::PackedWeightsType::linear_8bit_act_xbit_weight_universal,
96+
weight_nbit);
9597

9698
if (format.nr == 8 && format.kr == 16 && format.sr == 2) {
9799
#if defined(TORCHAO_BUILD_CPU_AARCH64)
98100
if (cpuinfo_has_arm_neon_dot()) {
99101
log_registration(format, "universal");
100102
namespace kernel = torchao::kernels::cpu::aarch64::linear::
101103
channelwise_8bit_activation_groupwise_lowbit_weight_1x8x16_f32_neondot;
102-
table.register_ukernel_config(
103-
format,
104-
uarch,
105-
UKernelConfig{
106-
/*preferred_alignment*/ 16,
107-
/*nr*/ 8,
108-
/*weight_packing_config*/
109-
{/*weight_data_size_fn*/
110-
&kernel::weight_data_size<weight_nbit>,
111-
/*prepare_weight_data_fn*/
112-
&kernel::prepare_weight_data<weight_nbit>},
113-
/*linear_configs*/
114-
{{{/*mr*/ 1,
115-
/*activation_data_size_fn*/
116-
&kernel::activation_data_size,
117-
/*prepare_activation_data_fn*/
118-
&kernel::prepare_activation_data,
119-
/*kernel*/
120-
&kernel::kernel<weight_nbit>}}}});
104+
105+
if (format.has_weight_zeros) {
106+
constexpr bool has_weight_zeros = true;
107+
table.register_ukernel_config(
108+
format,
109+
uarch,
110+
UKernelConfig{
111+
/*preferred_alignment*/ 16,
112+
/*nr*/ 8,
113+
/*weight_packing_config*/
114+
{/*weight_data_size_fn*/
115+
&kernel::weight_data_size<weight_nbit>,
116+
/*prepare_weight_data_fn*/
117+
&kernel::prepare_weight_data<weight_nbit>},
118+
/*linear_configs*/
119+
{{{/*mr*/ 1,
120+
/*activation_data_size_fn*/
121+
&kernel::activation_data_size,
122+
/*prepare_activation_data_fn*/
123+
&kernel::prepare_activation_data,
124+
/*kernel*/
125+
&kernel::kernel<weight_nbit, has_weight_zeros>}}}});
126+
} else {
127+
constexpr bool has_weight_zeros = false;
128+
table.register_ukernel_config(
129+
format,
130+
uarch,
131+
UKernelConfig{
132+
/*preferred_alignment*/ 16,
133+
/*nr*/ 8,
134+
/*weight_packing_config*/
135+
{/*weight_data_size_fn*/
136+
&kernel::weight_data_size<weight_nbit>,
137+
/*prepare_weight_data_fn*/
138+
&kernel::prepare_weight_data<weight_nbit>},
139+
/*linear_configs*/
140+
{{{/*mr*/ 1,
141+
/*activation_data_size_fn*/
142+
&kernel::activation_data_size,
143+
/*prepare_activation_data_fn*/
144+
&kernel::prepare_activation_data,
145+
/*kernel*/
146+
&kernel::kernel<weight_nbit, has_weight_zeros>}}}});
147+
}
121148
return;
122149
}
123150
#endif // TORCHAO_BUILD_CPU_AARCH64
@@ -166,7 +193,7 @@ void register_ukernel_config_kleidi(
166193
if (!cpuinfo_initialize()) {
167194
throw std::runtime_error("Failed to initialize cpuinfo!");
168195
}
169-
check_format<weight_nbit>(format, torchao::ops::PackedWeightsType::kleidi_ai);
196+
check_format(format, torchao::ops::PackedWeightsType::kleidi_ai, weight_nbit);
170197
namespace op = torchao::kernels::cpu::aarch64::kleidi::
171198
kai_matmul_clamp_f32_qai8dxp_qsi4c32p;
172199

torchao/experimental/ops/linear_8bit_act_xbit_weight/packed_weights_format.h

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -53,10 +53,10 @@ struct PackedWeightsFormat {
5353
}
5454
};
5555

56-
template <int weight_nbit>
57-
void check_format(
56+
inline void check_format(
5857
PackedWeightsFormat format,
59-
torchao::ops::PackedWeightsType type) {
58+
torchao::ops::PackedWeightsType type,
59+
int weight_nbit) {
6060
if (format.type != type) {
6161
throw std::runtime_error(
6262
"Kernel expects packed_weights type=" +

torchao/experimental/ops/tests/test_linear_8bit_act_xbit_weight.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -42,7 +42,7 @@ UKernelConfig get_ukernel_config() {
4242
/*prepare_activation_data_fn*/
4343
&kernel::prepare_activation_data,
4444
/*kernel*/
45-
&kernel::kernel<weight_nbit>}}}};
45+
&kernel::kernel<weight_nbit, has_weight_zeros>}}}};
4646
}
4747

4848
template <

0 commit comments

Comments
 (0)