Skip to content

Commit 77c4ef1

Browse files
metascroyjainapurva
authored andcommitted
Claen up op interface
Differential Revision: D72179480 Pull Request resolved: #1998
1 parent acc3c79 commit 77c4ef1

File tree

10 files changed

+1006
-1429
lines changed

10 files changed

+1006
-1429
lines changed

torchao/experimental/kernels/cpu/aarch64/kleidi/kai_matmul_clamp_f32_qai8dxp_qsi4c32p.h

Lines changed: 72 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -60,27 +60,47 @@ namespace kai_matmul_clamp_f32_qai8dxp_qsi4c32p {
6060

6161
using Ukernel = struct kai_matmul_clamp_f32_qai8dxp_qsi4c32p_ukernel;
6262

63-
template <int mr, int kr, int sr>
64-
size_t
65-
activation_data_size(int m, int k, int group_size, bool has_weight_zeros) {
63+
size_t packed_activations_size(
64+
int m,
65+
int k,
66+
int group_size,
67+
bool has_weight_zeros,
68+
int mr,
69+
int kr,
70+
int sr) {
6671
(void)group_size; // unused
6772
(void)has_weight_zeros; // unused
6873
auto lhs_packing = get_lhs_packing();
6974
return lhs_packing.get_lhs_packed_size(m, k, mr, kr, sr);
7075
}
7176

72-
template <int mr, int kr, int sr>
73-
void prepare_activation_data(
74-
void* activation_data,
77+
size_t packed_activations_offset(
78+
int m_idx,
79+
int k,
80+
int group_size,
81+
bool has_weight_zeros,
82+
int mr,
83+
int kr,
84+
int sr) {
85+
(void)group_size; // unused
86+
(void)has_weight_zeros; // unused
87+
auto lhs_pack = get_lhs_packing();
88+
return lhs_pack.get_lhs_packed_offset(m_idx, k, mr, kr, sr);
89+
}
90+
91+
void pack_activations(
92+
void* packed_activations,
7593
int m,
7694
int k,
7795
int group_size,
7896
const float* activations,
79-
bool has_weight_zeros) {
97+
bool has_weight_zeros,
98+
int mr,
99+
int kr,
100+
int sr) {
80101
(void)group_size; // unused
81102
(void)has_weight_zeros; // unused
82103
auto lhs_pack = get_lhs_packing();
83-
84104
lhs_pack.run_lhs_pack(
85105
m,
86106
k,
@@ -90,33 +110,62 @@ void prepare_activation_data(
90110
/*m_index_start=*/0,
91111
activations,
92112
/*lhs_stride=*/k * sizeof(float),
93-
activation_data);
113+
packed_activations);
94114
}
95115

96-
template <int nr, int kr, int sr>
97-
size_t weight_data_size(
116+
size_t packed_weights_size(
98117
int n,
99118
int k,
100119
int group_size,
120+
int weight_nbit,
101121
bool has_weight_zeros,
102-
bool has_bias) {
122+
bool has_bias,
123+
int nr,
124+
int kr,
125+
int sr) {
126+
(void)weight_nbit; // unused
103127
(void)has_weight_zeros; // unused
104128
(void)has_bias; // unused
105129
auto rhs_pack = get_rhs_packing();
106130
return rhs_pack.get_rhs_packed_size(
107-
n, k, nr, kr, sr, group_size, kai_datatype::kai_dt_bf16);
131+
internal::adjust_n(n),
132+
k,
133+
nr,
134+
kr,
135+
sr,
136+
group_size,
137+
kai_datatype::kai_dt_bf16);
138+
}
139+
140+
size_t packed_weights_offset(
141+
int n_idx,
142+
int k,
143+
int group_size,
144+
int weight_nbit,
145+
bool has_weight_zeros,
146+
bool has_bias,
147+
int nr,
148+
int kr,
149+
int sr) {
150+
(void)has_weight_zeros; // unused
151+
(void)has_bias; // unused
152+
auto rhs_pack = get_rhs_packing();
153+
return rhs_pack.get_rhs_packed_offset(
154+
n_idx, k, nr, kr, sr, group_size, kai_datatype::kai_dt_bf16);
108155
}
109156

110-
template <int nr, int kr, int sr>
111-
void prepare_weight_data(
112-
void* weight_data,
157+
void pack_weights(
158+
void* packed_weights,
113159
int n,
114160
int k,
115161
int group_size,
116162
const int8_t* weight_qvals,
117163
const float* weight_scales,
118164
const int8_t* weight_zeros,
119-
const float* bias) {
165+
const float* bias,
166+
int nr,
167+
int kr,
168+
int sr) {
120169
if (group_size % 32 != 0) {
121170
throw std::runtime_error(
122171
"Group size must be a multiple of 32, but got group_size=" +
@@ -187,7 +236,7 @@ void prepare_weight_data(
187236
reinterpret_cast<const uint16_t*>(weight_scales_bf16_padded.data()),
188237
/*scale_stride=*/sizeof(uint16_t) *
189238
(internal::roundup(k, group_size) / group_size),
190-
/*rhs_packed=*/weight_data,
239+
/*rhs_packed=*/packed_weights,
191240
/*extra_bytes=*/0,
192241
/*qparams=*/&qparams);
193242
}
@@ -220,8 +269,8 @@ size_t get_preferred_alignement() {
220269
int n, \
221270
int k, \
222271
int group_size, \
223-
const void* weight_data, \
224-
const void* activation_data, \
272+
const void* packed_weights, \
273+
const void* packed_activations, \
225274
float clamp_min, \
226275
float clamp_max, \
227276
bool has_weight_zeros, \
@@ -235,11 +284,11 @@ size_t get_preferred_alignement() {
235284
} \
236285
get_ukernel().run_matmul( \
237286
m, \
238-
internal::adjust_n(n), \
287+
n, \
239288
k, \
240289
group_size, \
241-
activation_data, \
242-
weight_data, \
290+
packed_activations, \
291+
packed_weights, \
243292
output, \
244293
/*dst_stride_row=*/output_m_stride * sizeof(float), \
245294
/*dst_stride_col=*/sizeof(float), \

torchao/experimental/kernels/cpu/aarch64/linear/channelwise_8bit_activation_groupwise_lowbit_weight/channelwise_8bit_activation_groupwise_lowbit_weight.h

Lines changed: 18 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -49,15 +49,21 @@ inline size_t packed_activations_offset(
4949
return (m_idx / mr) * packed_activations_size_mr_rows;
5050
}
5151

52-
template <int mr, int kr, int sr>
52+
template <int mr_, int kr_, int sr_>
5353
void pack_activations(
5454
void* packed_activations,
5555
int m,
5656
int k,
5757
int group_size,
5858
const float* activations,
59-
bool has_weight_zeros) {
60-
activation_packing::pack_activations<mr, kr, sr>(
59+
bool has_weight_zeros,
60+
int mr,
61+
int kr,
62+
int sr) {
63+
(void)mr; // unused
64+
(void)kr; // unused
65+
(void)sr; // unused
66+
activation_packing::pack_activations<mr_, kr_, sr_>(
6167
packed_activations, m, k, group_size, activations, has_weight_zeros);
6268
}
6369

@@ -93,7 +99,7 @@ inline size_t packed_weights_offset(
9399
return (n_idx / nr) * packed_weights_size_nr_cols;
94100
}
95101

96-
template <int weight_nbit, int nr, int kr, int sr>
102+
template <int weight_nbit, int nr_, int kr_, int sr_>
97103
void pack_weights(
98104
void* packed_weights,
99105
int n,
@@ -102,8 +108,14 @@ void pack_weights(
102108
const int8_t* weight_qvals,
103109
const float* weight_scales,
104110
const int8_t* weight_zeros,
105-
const float* bias) {
106-
weight_packing::pack_weights<weight_nbit, nr, kr, sr>(
111+
const float* bias,
112+
int nr,
113+
int kr,
114+
int sr) {
115+
(void)nr; // unused
116+
(void)kr; // unused
117+
(void)sr; // unused
118+
weight_packing::pack_weights<weight_nbit, nr_, kr_, sr_>(
107119
packed_weights,
108120
n,
109121
k,

0 commit comments

Comments
 (0)