Skip to content

Commit 2d86424

Browse files
authored
Add complex dtype support to mul
Differential Revision: D73877440 Pull Request resolved: #10560
1 parent b8b43f6 commit 2d86424

File tree

3 files changed

+146
-45
lines changed

3 files changed

+146
-45
lines changed

kernels/optimized/cpu/op_mul.cpp

+74-32
Original file line numberDiff line numberDiff line change
@@ -120,21 +120,47 @@ Tensor& opt_mul_out(
120120
out,
121121
"Failed to resize output tensor.");
122122

123-
ET_SWITCH_REALB_TYPES(out_type, ctx, "mul.out", CTYPE, [&]() {
124-
using Vec = executorch::vec::Vectorized<CTYPE>;
125-
executorch::vec::map2<CTYPE>(
126-
[](Vec x, Vec y) { return x * y; },
127-
out.mutable_data_ptr<CTYPE>(),
128-
a.const_data_ptr<CTYPE>(),
129-
b.const_data_ptr<CTYPE>(),
130-
out.numel());
131-
});
123+
if (executorch::runtime::isComplexType(out_type)) {
124+
ET_KERNEL_CHECK(
125+
ctx, a_type == b_type && a_type == out_type, InvalidArgument, out);
126+
127+
ET_SWITCH_COMPLEXH_TYPES(out_type, ctx, "mul.out", CTYPE, [&]() {
128+
using Vec = executorch::vec::Vectorized<CTYPE>;
129+
executorch::vec::map2<CTYPE>(
130+
[](Vec x, Vec y) { return x * y; },
131+
out.mutable_data_ptr<CTYPE>(),
132+
a.const_data_ptr<CTYPE>(),
133+
b.const_data_ptr<CTYPE>(),
134+
out.numel());
135+
});
136+
} else {
137+
ET_SWITCH_REALB_TYPES(out_type, ctx, "mul.out", CTYPE, [&]() {
138+
using Vec = executorch::vec::Vectorized<CTYPE>;
139+
executorch::vec::map2<CTYPE>(
140+
[](Vec x, Vec y) { return x * y; },
141+
out.mutable_data_ptr<CTYPE>(),
142+
a.const_data_ptr<CTYPE>(),
143+
b.const_data_ptr<CTYPE>(),
144+
out.numel());
145+
});
146+
}
132147
} else if (selected_optimized_path != ElementwiseOptimizedPath::kNone) {
133-
ET_SWITCH_REALB_TYPES(out_type, ctx, "mul.out", CTYPE, [&]() {
134-
auto mul_lambda = [](auto x, auto y) { return x * y; };
135-
return torch::executor::handle_broadcast_elementwise<CTYPE>(
136-
ctx, mul_lambda, a, b, out, selected_optimized_path);
137-
});
148+
if (executorch::runtime::isComplexType(out_type)) {
149+
ET_KERNEL_CHECK(
150+
ctx, a_type == b_type && a_type == out_type, InvalidArgument, out);
151+
152+
ET_SWITCH_COMPLEXH_TYPES(out_type, ctx, "mul.out", CTYPE, [&]() {
153+
auto mul_lambda = [](auto x, auto y) { return x * y; };
154+
return torch::executor::handle_broadcast_elementwise<CTYPE>(
155+
ctx, mul_lambda, a, b, out, selected_optimized_path);
156+
});
157+
} else {
158+
ET_SWITCH_REALB_TYPES(out_type, ctx, "mul.out", CTYPE, [&]() {
159+
auto mul_lambda = [](auto x, auto y) { return x * y; };
160+
return torch::executor::handle_broadcast_elementwise<CTYPE>(
161+
ctx, mul_lambda, a, b, out, selected_optimized_path);
162+
});
163+
}
138164
} else {
139165
ScalarType common_type =
140166
promoteTypes(a_type, b_type, /*half_to_float*/ true);
@@ -146,26 +172,42 @@ Tensor& opt_mul_out(
146172
InvalidArgument,
147173
out);
148174

149-
ET_SWITCH_REALHBBF16_TYPES(a_type, ctx, "mul.out", CTYPE_A, [&]() {
150-
ET_SWITCH_REALHBBF16_TYPES(b_type, ctx, "mul.out", CTYPE_B, [&]() {
151-
using CTYPE_IN = typename torch::executor::
152-
promote_types<CTYPE_A, CTYPE_B, /*half_to_float*/ true>::type;
153-
ET_DCHECK(CppTypeToScalarType<CTYPE_IN>::value == common_type);
154-
ET_SWITCH_REALHBBF16_TYPES(out_type, ctx, "mul.out", CTYPE_OUT, [&]() {
155-
apply_binary_elementwise_fn<CTYPE_A, CTYPE_B, CTYPE_OUT>(
156-
[](const CTYPE_A val_a, const CTYPE_B val_b) {
157-
CTYPE_IN a_casted = static_cast<CTYPE_IN>(val_a);
158-
CTYPE_IN b_casted = static_cast<CTYPE_IN>(val_b);
159-
CTYPE_IN value = a_casted * b_casted;
160-
161-
return static_cast<CTYPE_OUT>(value);
162-
},
163-
a,
164-
b,
165-
out);
175+
if (executorch::runtime::isComplexType(a_type) ||
176+
executorch::runtime::isComplexType(b_type) ||
177+
executorch::runtime::isComplexType(out_type)) {
178+
ET_KERNEL_CHECK(
179+
ctx, a_type == b_type && a_type == out_type, InvalidArgument, out);
180+
181+
ET_SWITCH_COMPLEXH_TYPES(out_type, ctx, "mul.out", CTYPE, [&]() {
182+
apply_binary_elementwise_fn<CTYPE, CTYPE, CTYPE>(
183+
[](const CTYPE val_a, const CTYPE val_b) { return val_a * val_b; },
184+
a,
185+
b,
186+
out);
187+
});
188+
} else {
189+
ET_SWITCH_REALHBBF16_TYPES(a_type, ctx, "mul.out", CTYPE_A, [&]() {
190+
ET_SWITCH_REALHBBF16_TYPES(b_type, ctx, "mul.out", CTYPE_B, [&]() {
191+
using CTYPE_IN = typename torch::executor::
192+
promote_types<CTYPE_A, CTYPE_B, /*half_to_float*/ true>::type;
193+
ET_DCHECK(CppTypeToScalarType<CTYPE_IN>::value == common_type);
194+
ET_SWITCH_REALHBBF16_TYPES(
195+
out_type, ctx, "mul.out", CTYPE_OUT, [&]() {
196+
apply_binary_elementwise_fn<CTYPE_A, CTYPE_B, CTYPE_OUT>(
197+
[](const CTYPE_A val_a, const CTYPE_B val_b) {
198+
CTYPE_IN a_casted = static_cast<CTYPE_IN>(val_a);
199+
CTYPE_IN b_casted = static_cast<CTYPE_IN>(val_b);
200+
CTYPE_IN value = a_casted * b_casted;
201+
202+
return static_cast<CTYPE_OUT>(value);
203+
},
204+
a,
205+
b,
206+
out);
207+
});
166208
});
167209
});
168-
});
210+
}
169211
}
170212

171213
return out;

kernels/portable/cpu/op_mul.cpp

+30-13
Original file line numberDiff line numberDiff line change
@@ -47,25 +47,42 @@ Tensor& mul_out(
4747
ET_KERNEL_CHECK(
4848
ctx,
4949
(executorch::runtime::isRealType(compute_type) ||
50+
executorch::runtime::isComplexType(compute_type) ||
5051
compute_type == ScalarType::Bool),
5152
InvalidArgument,
5253
out);
5354

54-
ET_SWITCH_REALB_TYPES(compute_type, ctx, op_name, CTYPE_COMPUTE, [&]() {
55-
utils::apply_bitensor_elementwise_fn<
56-
CTYPE_COMPUTE,
57-
op_name,
58-
utils::SupportedTensorDtypes::REALHBBF16>(
59-
[](const CTYPE_COMPUTE val_a, const CTYPE_COMPUTE val_b) {
60-
return val_a * val_b;
61-
},
55+
if (executorch::runtime::isComplexType(compute_type)) {
56+
ET_KERNEL_CHECK(
6257
ctx,
63-
a,
64-
utils::SupportedTensorDtypes::REALHBBF16,
65-
b,
66-
utils::SupportedTensorDtypes::REALHBBF16,
58+
a.scalar_type() == b.scalar_type() &&
59+
a.scalar_type() == out.scalar_type(),
60+
InvalidArgument,
6761
out);
68-
});
62+
ET_SWITCH_COMPLEXH_TYPES(out.scalar_type(), ctx, "mul.out", CTYPE, [&]() {
63+
apply_binary_elementwise_fn<CTYPE, CTYPE, CTYPE>(
64+
[](const CTYPE val_a, const CTYPE val_b) { return val_a * val_b; },
65+
a,
66+
b,
67+
out);
68+
});
69+
} else {
70+
ET_SWITCH_REALB_TYPES(compute_type, ctx, op_name, CTYPE_COMPUTE, [&]() {
71+
utils::apply_bitensor_elementwise_fn<
72+
CTYPE_COMPUTE,
73+
op_name,
74+
utils::SupportedTensorDtypes::REALHBBF16>(
75+
[](const CTYPE_COMPUTE val_a, const CTYPE_COMPUTE val_b) {
76+
return val_a * val_b;
77+
},
78+
ctx,
79+
a,
80+
utils::SupportedTensorDtypes::REALHBBF16,
81+
b,
82+
utils::SupportedTensorDtypes::REALHBBF16,
83+
out);
84+
});
85+
}
6986

7087
return out;
7188
}

kernels/test/op_mul_test.cpp

+42
Original file line numberDiff line numberDiff line change
@@ -322,6 +322,38 @@ class OpMulOutTest : public OperatorTest {
322322
EXPECT_TENSOR_CLOSE(op_mul_out(a, b, out), expected);
323323
EXPECT_TENSOR_CLOSE(op_mul_out(b, a, out), expected);
324324
}
325+
326+
template <typename CTYPE, ScalarType DTYPE>
327+
void test_complex_dtype() {
328+
TensorFactory<DTYPE> tf;
329+
const std::vector<int32_t> sizes = {2, 2};
330+
331+
// Create complex tensors with real and imaginary parts
332+
Tensor x =
333+
tf.make(sizes, {CTYPE(1, 2), CTYPE(3, 4), CTYPE(5, 6), CTYPE(7, 8)});
334+
335+
Tensor y =
336+
tf.make(sizes, {CTYPE(2, 3), CTYPE(4, 5), CTYPE(6, 7), CTYPE(8, 9)});
337+
338+
// Expected result: (a+bi) * (c+di) = (ac-bd) + (ad+bc)i
339+
// (1+2i) * (2+3i) = (1*2-2*3) + (1*3+2*2)i = -4 + 7i
340+
// (3+4i) * (4+5i) = (3*4-4*5) + (3*5+4*4)i = -8 + 31i
341+
// (5+6i) * (6+7i) = (5*6-6*7) + (5*7+6*6)i = -12 + 71i
342+
// (7+8i) * (8+9i) = (7*8-8*9) + (7*9+8*8)i = -16 + 127i
343+
Tensor expected = tf.make(
344+
sizes, {CTYPE(-4, 7), CTYPE(-8, 31), CTYPE(-12, 71), CTYPE(-16, 127)});
345+
346+
Tensor out = tf.make(
347+
{2, 2},
348+
{
349+
CTYPE(0, 0),
350+
CTYPE(0, 0),
351+
CTYPE(0, 0),
352+
CTYPE(0, 0),
353+
});
354+
op_mul_out(x, y, out);
355+
EXPECT_TENSOR_CLOSE(out, expected);
356+
}
325357
};
326358

327359
class OpMulScalarOutTest : public OperatorTest {
@@ -472,6 +504,16 @@ TEST_F(OpMulOutTest, BothScalarInputBroadcastTest) {
472504
test_both_scalar_input_broadcast<ScalarType::BFloat16>();
473505
}
474506

507+
TEST_F(OpMulOutTest, AllComplexDtypesSupported) {
508+
#define TEST_ENTRY(ctype, dtype) test_complex_dtype<ctype, ScalarType::dtype>();
509+
if (torch::executor::testing::SupportedFeatures::get()->is_aten) {
510+
ET_FORALL_COMPLEX_TYPES(TEST_ENTRY);
511+
} else {
512+
ET_FORALL_COMPLEXH_TYPES(TEST_ENTRY);
513+
}
514+
#undef TEST_ENTRY
515+
}
516+
475517
TEST_F(OpMulOutTest, MismatchedOutputShapesDies) {
476518
if (SupportedFeatures::get()->is_aten) {
477519
GTEST_SKIP() << "ATen currently supports mismatched shapes";

0 commit comments

Comments
 (0)