Skip to content

Commit a1e95d1

Browse files
committed
updated per reviews
1 parent 4327151 commit a1e95d1

File tree

3 files changed

+21
-13
lines changed

3 files changed

+21
-13
lines changed

src/plugins/intel_gpu/src/graph/fully_connected.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -66,7 +66,7 @@ format::type get_preferred_format(fully_connected_node const& node, const kernel
6666
}
6767

6868
if (input_layout.data_type == data_types::f32 &&
69-
(input_layout.format == format::bfyx || input_layout.format == format::bfzyx || input_layout.format == format::bfwzyx) &&
69+
one_of<cldnn::format>(input_layout.format, {format::bfyx, format::bfzyx, format::bfwzyx}) &&
7070
no_spatial_padding &&
7171
input_layout.batch() != 8)
7272
return input_layout.format;

src/plugins/intel_gpu/src/graph/impls/onednn/fully_connected_onednn.cpp

Lines changed: 18 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -268,15 +268,18 @@ struct fully_connected_onednn : typed_primitive_onednn_impl<fully_connected> {
268268
auto prim = impl_params->typed_desc<fully_connected>();
269269
auto weights_layout = impl_params->get_input_layout(1);
270270
bool is_four_bit_weight = weights_layout.data_type == data_types::u4 || weights_layout.data_type == data_types::i4;
271+
auto shift_size = std::max<size_t>(prim->input_size - 2, 0);
272+
int per_oc = PER_OC << shift_size;
273+
int grouped = GROUPED << shift_size;
271274

272275
bool has_decompression_scale = !prim->decompression_scale.empty();
273276
if (has_decompression_scale) {
274277
ib >> _ds_group_size;
275278
ib >> make_data(&_ds_data_type, sizeof(dnnl::memory::data_type));
276279
if (!is_four_bit_weight)
277-
_attrs->set_scales(DNNL_ARG_WEIGHTS, PER_OC, dnnl::memory::dims{}, _ds_data_type);
280+
_attrs->set_scales(DNNL_ARG_WEIGHTS, per_oc, dnnl::memory::dims{}, _ds_data_type);
278281
else
279-
_attrs->set_scales(DNNL_ARG_WEIGHTS, GROUPED, {_ds_group_size, 1}, _ds_data_type);
282+
_attrs->set_scales(DNNL_ARG_WEIGHTS, grouped, {_ds_group_size, 1}, _ds_data_type);
280283
}
281284

282285
bool has_decompression_zp = !prim->decompression_zero_point.empty() || prim->decompression_zero_point_scalar.has_value();
@@ -293,9 +296,9 @@ struct fully_connected_onednn : typed_primitive_onednn_impl<fully_connected> {
293296
} else {
294297
auto ngroups = dzp_layout.get_dim(1);
295298
if (ngroups == 1) {
296-
_attrs->set_zero_points(DNNL_ARG_WEIGHTS, PER_OC, dnnl::memory::dims{}, _dzp_data_type);
299+
_attrs->set_zero_points(DNNL_ARG_WEIGHTS, per_oc, dnnl::memory::dims{}, _dzp_data_type);
297300
} else {
298-
_attrs->set_zero_points(DNNL_ARG_WEIGHTS, GROUPED, {_ds_group_size, 1}, _dzp_data_type);
301+
_attrs->set_zero_points(DNNL_ARG_WEIGHTS, grouped, {_ds_group_size, 1}, _dzp_data_type);
299302
}
300303
}
301304
}
@@ -310,9 +313,9 @@ struct fully_connected_onednn : typed_primitive_onednn_impl<fully_connected> {
310313
int src_group_size = innermost_len / src_scale_ngroups;
311314

312315
auto act_scale_data_type = convert_data_type(impl_params->get_input_layout(src_scale_idx).data_type);
313-
_attrs->set_scales(DNNL_ARG_SRC, GROUPED, dnnl::memory::dims{1, src_group_size}, act_scale_data_type);
316+
_attrs->set_scales(DNNL_ARG_SRC, grouped, dnnl::memory::dims{1, src_group_size}, act_scale_data_type);
314317
if (dynamic_quantized_activation_zp)
315-
_attrs->set_zero_points(DNNL_ARG_SRC, GROUPED, dnnl::memory::dims{1, src_group_size}, dnnl::memory::data_type::u8);
318+
_attrs->set_zero_points(DNNL_ARG_SRC, grouped, dnnl::memory::dims{1, src_group_size}, dnnl::memory::data_type::u8);
316319
}
317320

318321
auto prim_desc = get_matmul_primitive_descriptor(*impl_params, ib.get_engine(), input_size, weights_rank, has_bias, *_attrs);
@@ -349,6 +352,9 @@ struct fully_connected_onednn : typed_primitive_onednn_impl<fully_connected> {
349352
auto weights_layout = impl_params.get_input_layout(1);
350353
is_four_bit_weight = weights_layout.data_type == data_types::u4 || weights_layout.data_type == data_types::i4;
351354
auto shift_size = std::max<size_t>(prim->input_size - 2, 0);
355+
int per_oc = PER_OC << shift_size;
356+
int grouped = GROUPED << shift_size;
357+
352358
if (!prim->decompression_scale.empty()) {
353359
auto decompression_scale_idx = ++idx;
354360
auto scale_layout = arg.get_dependency(decompression_scale_idx).get_output_layout();
@@ -358,10 +364,10 @@ struct fully_connected_onednn : typed_primitive_onednn_impl<fully_connected> {
358364
group_size = ifm / ngroups;
359365
if (!is_four_bit_weight) {
360366
// 8-bit quantized weight
361-
attr->set_scales(DNNL_ARG_WEIGHTS, (PER_OC << shift_size), dnnl::memory::dims{}, ds_data_type);
367+
attr->set_scales(DNNL_ARG_WEIGHTS, per_oc, dnnl::memory::dims{}, ds_data_type);
362368
} else {
363369
// OneDNN does not support scalar zero-point for s4 and u8 type. Need to broadcast it.
364-
attr->set_scales(DNNL_ARG_WEIGHTS, (GROUPED << shift_size), {group_size, 1}, ds_data_type);
370+
attr->set_scales(DNNL_ARG_WEIGHTS, grouped, {group_size, 1}, ds_data_type);
365371
}
366372
}
367373

@@ -375,9 +381,9 @@ struct fully_connected_onednn : typed_primitive_onednn_impl<fully_connected> {
375381
} else {
376382
auto ngroups = dzp_layout.get_dim(1);
377383
if (ngroups == 1) {
378-
attr->set_zero_points(DNNL_ARG_WEIGHTS, (PER_OC << shift_size), dnnl::memory::dims{}, dzp_data_type);
384+
attr->set_zero_points(DNNL_ARG_WEIGHTS, per_oc, dnnl::memory::dims{}, dzp_data_type);
379385
} else {
380-
attr->set_zero_points(DNNL_ARG_WEIGHTS, (GROUPED << shift_size), {group_size, 1}, dzp_data_type);
386+
attr->set_zero_points(DNNL_ARG_WEIGHTS, grouped, {group_size, 1}, dzp_data_type);
381387
}
382388
}
383389
}
@@ -391,10 +397,10 @@ struct fully_connected_onednn : typed_primitive_onednn_impl<fully_connected> {
391397
int src_group_size = innermost_len / src_scale_ngroups;
392398

393399
auto act_scale_data_type = convert_data_type(impl_params.input_layouts[src_scale_idx].data_type);
394-
attr->set_scales(DNNL_ARG_SRC, (GROUPED << shift_size), dnnl::memory::dims{1, src_group_size}, act_scale_data_type);
400+
attr->set_scales(DNNL_ARG_SRC, grouped, dnnl::memory::dims{1, src_group_size}, act_scale_data_type);
395401

396402
if (prim->activation_zero_point.is_valid())
397-
attr->set_zero_points(DNNL_ARG_SRC, (GROUPED << shift_size), dnnl::memory::dims{1, src_group_size}, dnnl::memory::data_type::u8);
403+
attr->set_zero_points(DNNL_ARG_SRC, grouped, dnnl::memory::dims{1, src_group_size}, dnnl::memory::data_type::u8);
398404
}
399405

400406

src/plugins/intel_gpu/tests/unit/test_cases/fully_connected_gpu_test.cpp

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3334,6 +3334,8 @@ void test_compressed_int4_scale_dynamic_batch_gemv(bool is_caching_test,
33343334
ASSERT_EQ(out_l.feature(), 3);
33353335
ASSERT_EQ(out_l.spatial(0), 2);
33363336
ASSERT_EQ(out_l.spatial(1), 1);
3337+
ASSERT_EQ(out_l.spatial(2), 1);
3338+
ASSERT_EQ(out_l.spatial(3), 2);
33373339
} else {
33383340
ASSERT_EQ(output_prim_mem->get_layout().batch(), 6);
33393341
ASSERT_EQ(out_l.batch(), 6);

0 commit comments

Comments
 (0)