@@ -36,7 +36,9 @@ typedef std::map<
36
36
std::type_index,
37
37
std::variant<
38
38
std::vector<float >,
39
- std::vector<double >>>
39
+ std::vector<double >,
40
+ std::vector<exec_aten::Half>,
41
+ std::vector<exec_aten::BFloat16>>>
40
42
FloatingTypeToDataMap;
41
43
42
44
typedef std::map<
@@ -381,9 +383,9 @@ TEST_F(OpToDimOrderCopyTest, NanInfSupported) {
381
383
ScalarType::OUTPUT_DTYPE>(test_cases);
382
384
383
385
#define TEST_ENTRY (INPUT_CTYPE, INPUT_DTYPE ) \
384
- ET_FORALL_FLOAT_TYPES_WITH2 (INPUT_CTYPE, INPUT_DTYPE, TEST_KERNEL);
386
+ ET_FORALL_FLOATHBF16_TYPES_WITH2 (INPUT_CTYPE, INPUT_DTYPE, TEST_KERNEL);
385
387
386
- ET_FORALL_FLOAT_TYPES (TEST_ENTRY);
388
+ ET_FORALL_FLOATHBF16_TYPES (TEST_ENTRY);
387
389
388
390
#undef TEST_ENTRY
389
391
#undef TEST_KERNEL
@@ -413,6 +415,13 @@ TEST_F(OpToDimOrderCopyTest, HardcodeFloatConvertInt) {
413
415
-0.30919688936285893988 };
414
416
// clang-format on
415
417
418
+ std::vector<exec_aten::Half> half_data;
419
+ std::vector<exec_aten::BFloat16> bf16_data;
420
+ for (auto d : double_data) {
421
+ half_data.emplace_back (d);
422
+ bf16_data.emplace_back (d);
423
+ }
424
+
416
425
std::vector<int64_t > int64_data = {
417
426
-1 , -4 , 2 , -2 , 3 , 3 , -3 , -4 , 3 , 3 , 0 , 2 , 0 , -1 , 0 };
418
427
std::vector<int32_t > int32_data = {
@@ -426,6 +435,8 @@ TEST_F(OpToDimOrderCopyTest, HardcodeFloatConvertInt) {
426
435
FloatingTypeToDataMap floating_point_data;
427
436
floating_point_data[typeid (float )] = float_data;
428
437
floating_point_data[typeid (double )] = double_data;
438
+ floating_point_data[typeid (exec_aten::Half)] = half_data;
439
+ floating_point_data[typeid (exec_aten::BFloat16)] = bf16_data;
429
440
430
441
// Gathering all int data together for better traversial
431
442
IntTypeToDataMap int_data;
@@ -444,7 +455,7 @@ TEST_F(OpToDimOrderCopyTest, HardcodeFloatConvertInt) {
444
455
#define TEST_ENTRY (INPUT_CTYPE, INPUT_DTYPE ) \
445
456
ET_FORALL_INT_TYPES_WITH2 (INPUT_CTYPE, INPUT_DTYPE, TEST_KERNEL);
446
457
447
- ET_FORALL_FLOAT_TYPES (TEST_ENTRY);
458
+ ET_FORALL_FLOATHBF16_TYPES (TEST_ENTRY);
448
459
}
449
460
450
461
TEST_F (OpToDimOrderCopyTest, MismatchedSizesDie) {
0 commit comments