Skip to content

Commit 5dee6be

Browse files
support half and bf16 in to_dim_order_copy (#7689) (#7713)
Differential Revision: D68245619 Pull Request resolved: #7693 (cherry picked from commit 9c04329) Co-authored-by: Gasoonjia <[email protected]>
1 parent a0a4684 commit 5dee6be

File tree

2 files changed

+21
-6
lines changed

2 files changed

+21
-6
lines changed

kernels/portable/cpu/op__to_dim_order_copy.cpp

+6-2
Original file line numberDiff line numberDiff line change
@@ -96,13 +96,17 @@ Tensor& _to_dim_order_copy_out(
9696
InvalidArgument,
9797
out);
9898

99-
ET_SWITCH_REALHB_TYPES(
99+
if (self.numel() == 0) {
100+
return out;
101+
}
102+
103+
ET_SWITCH_REALHBBF16_TYPES(
100104
self.scalar_type(),
101105
ctx,
102106
"dim_order_ops::_to_dim_order_copy.out",
103107
CTYPE_IN,
104108
[&] {
105-
ET_SWITCH_REALHB_TYPES(
109+
ET_SWITCH_REALHBBF16_TYPES(
106110
out.scalar_type(),
107111
ctx,
108112
"dim_order_ops::_to_dim_order_copy.out",

kernels/test/op__to_dim_order_copy_test.cpp

+15-4
Original file line numberDiff line numberDiff line change
@@ -36,7 +36,9 @@ typedef std::map<
3636
std::type_index,
3737
std::variant<
3838
std::vector<float>,
39-
std::vector<double>>>
39+
std::vector<double>,
40+
std::vector<exec_aten::Half>,
41+
std::vector<exec_aten::BFloat16>>>
4042
FloatingTypeToDataMap;
4143

4244
typedef std::map<
@@ -381,9 +383,9 @@ TEST_F(OpToDimOrderCopyTest, NanInfSupported) {
381383
ScalarType::OUTPUT_DTYPE>(test_cases);
382384

383385
#define TEST_ENTRY(INPUT_CTYPE, INPUT_DTYPE) \
384-
ET_FORALL_FLOAT_TYPES_WITH2(INPUT_CTYPE, INPUT_DTYPE, TEST_KERNEL);
386+
ET_FORALL_FLOATHBF16_TYPES_WITH2(INPUT_CTYPE, INPUT_DTYPE, TEST_KERNEL);
385387

386-
ET_FORALL_FLOAT_TYPES(TEST_ENTRY);
388+
ET_FORALL_FLOATHBF16_TYPES(TEST_ENTRY);
387389

388390
#undef TEST_ENTRY
389391
#undef TEST_KERNEL
@@ -413,6 +415,13 @@ TEST_F(OpToDimOrderCopyTest, HardcodeFloatConvertInt) {
413415
-0.30919688936285893988};
414416
// clang-format on
415417

418+
std::vector<exec_aten::Half> half_data;
419+
std::vector<exec_aten::BFloat16> bf16_data;
420+
for (auto d : double_data) {
421+
half_data.emplace_back(d);
422+
bf16_data.emplace_back(d);
423+
}
424+
416425
std::vector<int64_t> int64_data = {
417426
-1, -4, 2, -2, 3, 3, -3, -4, 3, 3, 0, 2, 0, -1, 0};
418427
std::vector<int32_t> int32_data = {
@@ -426,6 +435,8 @@ TEST_F(OpToDimOrderCopyTest, HardcodeFloatConvertInt) {
426435
FloatingTypeToDataMap floating_point_data;
427436
floating_point_data[typeid(float)] = float_data;
428437
floating_point_data[typeid(double)] = double_data;
438+
floating_point_data[typeid(exec_aten::Half)] = half_data;
439+
floating_point_data[typeid(exec_aten::BFloat16)] = bf16_data;
429440

430441
// Gathering all int data together for better traversial
431442
IntTypeToDataMap int_data;
@@ -444,7 +455,7 @@ TEST_F(OpToDimOrderCopyTest, HardcodeFloatConvertInt) {
444455
#define TEST_ENTRY(INPUT_CTYPE, INPUT_DTYPE) \
445456
ET_FORALL_INT_TYPES_WITH2(INPUT_CTYPE, INPUT_DTYPE, TEST_KERNEL);
446457

447-
ET_FORALL_FLOAT_TYPES(TEST_ENTRY);
458+
ET_FORALL_FLOATHBF16_TYPES(TEST_ENTRY);
448459
}
449460

450461
TEST_F(OpToDimOrderCopyTest, MismatchedSizesDie) {

0 commit comments

Comments
 (0)