@@ -268,15 +268,18 @@ struct fully_connected_onednn : typed_primitive_onednn_impl<fully_connected> {
268
268
auto prim = impl_params->typed_desc <fully_connected>();
269
269
auto weights_layout = impl_params->get_input_layout (1 );
270
270
bool is_four_bit_weight = weights_layout.data_type == data_types::u4 || weights_layout.data_type == data_types::i4;
271
+ auto shift_size = std::max<size_t >(prim->input_size - 2 , 0 );
272
+ int per_oc = PER_OC << shift_size;
273
+ int grouped = GROUPED << shift_size;
271
274
272
275
bool has_decompression_scale = !prim->decompression_scale .empty ();
273
276
if (has_decompression_scale) {
274
277
ib >> _ds_group_size;
275
278
ib >> make_data (&_ds_data_type, sizeof (dnnl::memory::data_type));
276
279
if (!is_four_bit_weight)
277
- _attrs->set_scales (DNNL_ARG_WEIGHTS, PER_OC , dnnl::memory::dims{}, _ds_data_type);
280
+ _attrs->set_scales (DNNL_ARG_WEIGHTS, per_oc , dnnl::memory::dims{}, _ds_data_type);
278
281
else
279
- _attrs->set_scales (DNNL_ARG_WEIGHTS, GROUPED , {_ds_group_size, 1 }, _ds_data_type);
282
+ _attrs->set_scales (DNNL_ARG_WEIGHTS, grouped , {_ds_group_size, 1 }, _ds_data_type);
280
283
}
281
284
282
285
bool has_decompression_zp = !prim->decompression_zero_point .empty () || prim->decompression_zero_point_scalar .has_value ();
@@ -293,9 +296,9 @@ struct fully_connected_onednn : typed_primitive_onednn_impl<fully_connected> {
293
296
} else {
294
297
auto ngroups = dzp_layout.get_dim (1 );
295
298
if (ngroups == 1 ) {
296
- _attrs->set_zero_points (DNNL_ARG_WEIGHTS, PER_OC , dnnl::memory::dims{}, _dzp_data_type);
299
+ _attrs->set_zero_points (DNNL_ARG_WEIGHTS, per_oc , dnnl::memory::dims{}, _dzp_data_type);
297
300
} else {
298
- _attrs->set_zero_points (DNNL_ARG_WEIGHTS, GROUPED , {_ds_group_size, 1 }, _dzp_data_type);
301
+ _attrs->set_zero_points (DNNL_ARG_WEIGHTS, grouped , {_ds_group_size, 1 }, _dzp_data_type);
299
302
}
300
303
}
301
304
}
@@ -310,9 +313,9 @@ struct fully_connected_onednn : typed_primitive_onednn_impl<fully_connected> {
310
313
int src_group_size = innermost_len / src_scale_ngroups;
311
314
312
315
auto act_scale_data_type = convert_data_type (impl_params->get_input_layout (src_scale_idx).data_type );
313
- _attrs->set_scales (DNNL_ARG_SRC, GROUPED , dnnl::memory::dims{1 , src_group_size}, act_scale_data_type);
316
+ _attrs->set_scales (DNNL_ARG_SRC, grouped , dnnl::memory::dims{1 , src_group_size}, act_scale_data_type);
314
317
if (dynamic_quantized_activation_zp)
315
- _attrs->set_zero_points (DNNL_ARG_SRC, GROUPED , dnnl::memory::dims{1 , src_group_size}, dnnl::memory::data_type::u8 );
318
+ _attrs->set_zero_points (DNNL_ARG_SRC, grouped , dnnl::memory::dims{1 , src_group_size}, dnnl::memory::data_type::u8 );
316
319
}
317
320
318
321
auto prim_desc = get_matmul_primitive_descriptor (*impl_params, ib.get_engine (), input_size, weights_rank, has_bias, *_attrs);
@@ -349,6 +352,9 @@ struct fully_connected_onednn : typed_primitive_onednn_impl<fully_connected> {
349
352
auto weights_layout = impl_params.get_input_layout (1 );
350
353
is_four_bit_weight = weights_layout.data_type == data_types::u4 || weights_layout.data_type == data_types::i4;
351
354
auto shift_size = std::max<size_t >(prim->input_size - 2 , 0 );
355
+ int per_oc = PER_OC << shift_size;
356
+ int grouped = GROUPED << shift_size;
357
+
352
358
if (!prim->decompression_scale .empty ()) {
353
359
auto decompression_scale_idx = ++idx;
354
360
auto scale_layout = arg.get_dependency (decompression_scale_idx).get_output_layout ();
@@ -358,10 +364,10 @@ struct fully_connected_onednn : typed_primitive_onednn_impl<fully_connected> {
358
364
group_size = ifm / ngroups;
359
365
if (!is_four_bit_weight) {
360
366
// 8-bit quantized weight
361
- attr->set_scales (DNNL_ARG_WEIGHTS, (PER_OC << shift_size) , dnnl::memory::dims{}, ds_data_type);
367
+ attr->set_scales (DNNL_ARG_WEIGHTS, per_oc , dnnl::memory::dims{}, ds_data_type);
362
368
} else {
363
369
// OneDNN does not support scalar zero-point for s4 and u8 type. Need to broadcast it.
364
- attr->set_scales (DNNL_ARG_WEIGHTS, (GROUPED << shift_size) , {group_size, 1 }, ds_data_type);
370
+ attr->set_scales (DNNL_ARG_WEIGHTS, grouped , {group_size, 1 }, ds_data_type);
365
371
}
366
372
}
367
373
@@ -375,9 +381,9 @@ struct fully_connected_onednn : typed_primitive_onednn_impl<fully_connected> {
375
381
} else {
376
382
auto ngroups = dzp_layout.get_dim (1 );
377
383
if (ngroups == 1 ) {
378
- attr->set_zero_points (DNNL_ARG_WEIGHTS, (PER_OC << shift_size) , dnnl::memory::dims{}, dzp_data_type);
384
+ attr->set_zero_points (DNNL_ARG_WEIGHTS, per_oc , dnnl::memory::dims{}, dzp_data_type);
379
385
} else {
380
- attr->set_zero_points (DNNL_ARG_WEIGHTS, (GROUPED << shift_size) , {group_size, 1 }, dzp_data_type);
386
+ attr->set_zero_points (DNNL_ARG_WEIGHTS, grouped , {group_size, 1 }, dzp_data_type);
381
387
}
382
388
}
383
389
}
@@ -391,10 +397,10 @@ struct fully_connected_onednn : typed_primitive_onednn_impl<fully_connected> {
391
397
int src_group_size = innermost_len / src_scale_ngroups;
392
398
393
399
auto act_scale_data_type = convert_data_type (impl_params.input_layouts [src_scale_idx].data_type );
394
- attr->set_scales (DNNL_ARG_SRC, (GROUPED << shift_size) , dnnl::memory::dims{1 , src_group_size}, act_scale_data_type);
400
+ attr->set_scales (DNNL_ARG_SRC, grouped , dnnl::memory::dims{1 , src_group_size}, act_scale_data_type);
395
401
396
402
if (prim->activation_zero_point .is_valid ())
397
- attr->set_zero_points (DNNL_ARG_SRC, (GROUPED << shift_size) , dnnl::memory::dims{1 , src_group_size}, dnnl::memory::data_type::u8 );
403
+ attr->set_zero_points (DNNL_ARG_SRC, grouped , dnnl::memory::dims{1 , src_group_size}, dnnl::memory::data_type::u8 );
398
404
}
399
405
400
406
0 commit comments