@@ -60,27 +60,47 @@ namespace kai_matmul_clamp_f32_qai8dxp_qsi4c32p {
60
60
61
61
using Ukernel = struct kai_matmul_clamp_f32_qai8dxp_qsi4c32p_ukernel ;
62
62
63
- template <int mr, int kr, int sr>
64
- size_t
65
- activation_data_size (int m, int k, int group_size, bool has_weight_zeros) {
63
+ size_t packed_activations_size (
64
+ int m,
65
+ int k,
66
+ int group_size,
67
+ bool has_weight_zeros,
68
+ int mr,
69
+ int kr,
70
+ int sr) {
66
71
(void )group_size; // unused
67
72
(void )has_weight_zeros; // unused
68
73
auto lhs_packing = get_lhs_packing ();
69
74
return lhs_packing.get_lhs_packed_size (m, k, mr, kr, sr);
70
75
}
71
76
72
- template <int mr, int kr, int sr>
73
- void prepare_activation_data (
74
- void * activation_data,
77
+ size_t packed_activations_offset (
78
+ int m_idx,
79
+ int k,
80
+ int group_size,
81
+ bool has_weight_zeros,
82
+ int mr,
83
+ int kr,
84
+ int sr) {
85
+ (void )group_size; // unused
86
+ (void )has_weight_zeros; // unused
87
+ auto lhs_pack = get_lhs_packing ();
88
+ return lhs_pack.get_lhs_packed_offset (m_idx, k, mr, kr, sr);
89
+ }
90
+
91
+ void pack_activations (
92
+ void * packed_activations,
75
93
int m,
76
94
int k,
77
95
int group_size,
78
96
const float * activations,
79
- bool has_weight_zeros) {
97
+ bool has_weight_zeros,
98
+ int mr,
99
+ int kr,
100
+ int sr) {
80
101
(void )group_size; // unused
81
102
(void )has_weight_zeros; // unused
82
103
auto lhs_pack = get_lhs_packing ();
83
-
84
104
lhs_pack.run_lhs_pack (
85
105
m,
86
106
k,
@@ -90,33 +110,62 @@ void prepare_activation_data(
90
110
/* m_index_start=*/ 0 ,
91
111
activations,
92
112
/* lhs_stride=*/ k * sizeof (float ),
93
- activation_data );
113
+ packed_activations );
94
114
}
95
115
96
- template <int nr, int kr, int sr>
97
- size_t weight_data_size (
116
+ size_t packed_weights_size (
98
117
int n,
99
118
int k,
100
119
int group_size,
120
+ int weight_nbit,
101
121
bool has_weight_zeros,
102
- bool has_bias) {
122
+ bool has_bias,
123
+ int nr,
124
+ int kr,
125
+ int sr) {
126
+ (void )weight_nbit; // unused
103
127
(void )has_weight_zeros; // unused
104
128
(void )has_bias; // unused
105
129
auto rhs_pack = get_rhs_packing ();
106
130
return rhs_pack.get_rhs_packed_size (
107
- n, k, nr, kr, sr, group_size, kai_datatype::kai_dt_bf16);
131
+ internal::adjust_n (n),
132
+ k,
133
+ nr,
134
+ kr,
135
+ sr,
136
+ group_size,
137
+ kai_datatype::kai_dt_bf16);
138
+ }
139
+
140
+ size_t packed_weights_offset (
141
+ int n_idx,
142
+ int k,
143
+ int group_size,
144
+ int weight_nbit,
145
+ bool has_weight_zeros,
146
+ bool has_bias,
147
+ int nr,
148
+ int kr,
149
+ int sr) {
150
+ (void )has_weight_zeros; // unused
151
+ (void )has_bias; // unused
152
+ auto rhs_pack = get_rhs_packing ();
153
+ return rhs_pack.get_rhs_packed_offset (
154
+ n_idx, k, nr, kr, sr, group_size, kai_datatype::kai_dt_bf16);
108
155
}
109
156
110
- template <int nr, int kr, int sr>
111
- void prepare_weight_data (
112
- void * weight_data,
157
+ void pack_weights (
158
+ void * packed_weights,
113
159
int n,
114
160
int k,
115
161
int group_size,
116
162
const int8_t * weight_qvals,
117
163
const float * weight_scales,
118
164
const int8_t * weight_zeros,
119
- const float * bias) {
165
+ const float * bias,
166
+ int nr,
167
+ int kr,
168
+ int sr) {
120
169
if (group_size % 32 != 0 ) {
121
170
throw std::runtime_error (
122
171
" Group size must be a multiple of 32, but got group_size=" +
@@ -187,7 +236,7 @@ void prepare_weight_data(
187
236
reinterpret_cast <const uint16_t *>(weight_scales_bf16_padded.data ()),
188
237
/* scale_stride=*/ sizeof (uint16_t ) *
189
238
(internal::roundup (k, group_size) / group_size),
190
- /* rhs_packed=*/ weight_data ,
239
+ /* rhs_packed=*/ packed_weights ,
191
240
/* extra_bytes=*/ 0 ,
192
241
/* qparams=*/ &qparams);
193
242
}
@@ -220,8 +269,8 @@ size_t get_preferred_alignement() {
220
269
int n, \
221
270
int k, \
222
271
int group_size, \
223
- const void * weight_data, \
224
- const void * activation_data, \
272
+ const void * packed_weights, \
273
+ const void * packed_activations, \
225
274
float clamp_min, \
226
275
float clamp_max, \
227
276
bool has_weight_zeros, \
@@ -235,11 +284,11 @@ size_t get_preferred_alignement() {
235
284
} \
236
285
get_ukernel ().run_matmul ( \
237
286
m, \
238
- internal::adjust_n (n), \
287
+ n, \
239
288
k, \
240
289
group_size, \
241
- activation_data, \
242
- weight_data, \
290
+ packed_activations, \
291
+ packed_weights, \
243
292
output, \
244
293
/* dst_stride_row=*/ output_m_stride * sizeof (float ), \
245
294
/* dst_stride_col=*/ sizeof (float ), \
0 commit comments