@@ -120,21 +120,47 @@ Tensor& opt_mul_out(
120
120
out,
121
121
" Failed to resize output tensor." );
122
122
123
- ET_SWITCH_REALB_TYPES (out_type, ctx, " mul.out" , CTYPE, [&]() {
124
- using Vec = executorch::vec::Vectorized<CTYPE>;
125
- executorch::vec::map2<CTYPE>(
126
- [](Vec x, Vec y) { return x * y; },
127
- out.mutable_data_ptr <CTYPE>(),
128
- a.const_data_ptr <CTYPE>(),
129
- b.const_data_ptr <CTYPE>(),
130
- out.numel ());
131
- });
123
+ if (executorch::runtime::isComplexType (out_type)) {
124
+ ET_KERNEL_CHECK (
125
+ ctx, a_type == b_type && a_type == out_type, InvalidArgument, out);
126
+
127
+ ET_SWITCH_COMPLEXH_TYPES (out_type, ctx, " mul.out" , CTYPE, [&]() {
128
+ using Vec = executorch::vec::Vectorized<CTYPE>;
129
+ executorch::vec::map2<CTYPE>(
130
+ [](Vec x, Vec y) { return x * y; },
131
+ out.mutable_data_ptr <CTYPE>(),
132
+ a.const_data_ptr <CTYPE>(),
133
+ b.const_data_ptr <CTYPE>(),
134
+ out.numel ());
135
+ });
136
+ } else {
137
+ ET_SWITCH_REALB_TYPES (out_type, ctx, " mul.out" , CTYPE, [&]() {
138
+ using Vec = executorch::vec::Vectorized<CTYPE>;
139
+ executorch::vec::map2<CTYPE>(
140
+ [](Vec x, Vec y) { return x * y; },
141
+ out.mutable_data_ptr <CTYPE>(),
142
+ a.const_data_ptr <CTYPE>(),
143
+ b.const_data_ptr <CTYPE>(),
144
+ out.numel ());
145
+ });
146
+ }
132
147
} else if (selected_optimized_path != ElementwiseOptimizedPath::kNone ) {
133
- ET_SWITCH_REALB_TYPES (out_type, ctx, " mul.out" , CTYPE, [&]() {
134
- auto mul_lambda = [](auto x, auto y) { return x * y; };
135
- return torch::executor::handle_broadcast_elementwise<CTYPE>(
136
- ctx, mul_lambda, a, b, out, selected_optimized_path);
137
- });
148
+ if (executorch::runtime::isComplexType (out_type)) {
149
+ ET_KERNEL_CHECK (
150
+ ctx, a_type == b_type && a_type == out_type, InvalidArgument, out);
151
+
152
+ ET_SWITCH_COMPLEXH_TYPES (out_type, ctx, " mul.out" , CTYPE, [&]() {
153
+ auto mul_lambda = [](auto x, auto y) { return x * y; };
154
+ return torch::executor::handle_broadcast_elementwise<CTYPE>(
155
+ ctx, mul_lambda, a, b, out, selected_optimized_path);
156
+ });
157
+ } else {
158
+ ET_SWITCH_REALB_TYPES (out_type, ctx, " mul.out" , CTYPE, [&]() {
159
+ auto mul_lambda = [](auto x, auto y) { return x * y; };
160
+ return torch::executor::handle_broadcast_elementwise<CTYPE>(
161
+ ctx, mul_lambda, a, b, out, selected_optimized_path);
162
+ });
163
+ }
138
164
} else {
139
165
ScalarType common_type =
140
166
promoteTypes (a_type, b_type, /* half_to_float*/ true );
@@ -146,26 +172,42 @@ Tensor& opt_mul_out(
146
172
InvalidArgument,
147
173
out);
148
174
149
- ET_SWITCH_REALHBBF16_TYPES (a_type, ctx, " mul.out" , CTYPE_A, [&]() {
150
- ET_SWITCH_REALHBBF16_TYPES (b_type, ctx, " mul.out" , CTYPE_B, [&]() {
151
- using CTYPE_IN = typename torch::executor::
152
- promote_types<CTYPE_A, CTYPE_B, /* half_to_float*/ true >::type;
153
- ET_DCHECK (CppTypeToScalarType<CTYPE_IN>::value == common_type);
154
- ET_SWITCH_REALHBBF16_TYPES (out_type, ctx, " mul.out" , CTYPE_OUT, [&]() {
155
- apply_binary_elementwise_fn<CTYPE_A, CTYPE_B, CTYPE_OUT>(
156
- [](const CTYPE_A val_a, const CTYPE_B val_b) {
157
- CTYPE_IN a_casted = static_cast <CTYPE_IN>(val_a);
158
- CTYPE_IN b_casted = static_cast <CTYPE_IN>(val_b);
159
- CTYPE_IN value = a_casted * b_casted;
160
-
161
- return static_cast <CTYPE_OUT>(value);
162
- },
163
- a,
164
- b,
165
- out);
175
+ if (executorch::runtime::isComplexType (a_type) ||
176
+ executorch::runtime::isComplexType (b_type) ||
177
+ executorch::runtime::isComplexType (out_type)) {
178
+ ET_KERNEL_CHECK (
179
+ ctx, a_type == b_type && a_type == out_type, InvalidArgument, out);
180
+
181
+ ET_SWITCH_COMPLEXH_TYPES (out_type, ctx, " mul.out" , CTYPE, [&]() {
182
+ apply_binary_elementwise_fn<CTYPE, CTYPE, CTYPE>(
183
+ [](const CTYPE val_a, const CTYPE val_b) { return val_a * val_b; },
184
+ a,
185
+ b,
186
+ out);
187
+ });
188
+ } else {
189
+ ET_SWITCH_REALHBBF16_TYPES (a_type, ctx, " mul.out" , CTYPE_A, [&]() {
190
+ ET_SWITCH_REALHBBF16_TYPES (b_type, ctx, " mul.out" , CTYPE_B, [&]() {
191
+ using CTYPE_IN = typename torch::executor::
192
+ promote_types<CTYPE_A, CTYPE_B, /* half_to_float*/ true >::type;
193
+ ET_DCHECK (CppTypeToScalarType<CTYPE_IN>::value == common_type);
194
+ ET_SWITCH_REALHBBF16_TYPES (
195
+ out_type, ctx, " mul.out" , CTYPE_OUT, [&]() {
196
+ apply_binary_elementwise_fn<CTYPE_A, CTYPE_B, CTYPE_OUT>(
197
+ [](const CTYPE_A val_a, const CTYPE_B val_b) {
198
+ CTYPE_IN a_casted = static_cast <CTYPE_IN>(val_a);
199
+ CTYPE_IN b_casted = static_cast <CTYPE_IN>(val_b);
200
+ CTYPE_IN value = a_casted * b_casted;
201
+
202
+ return static_cast <CTYPE_OUT>(value);
203
+ },
204
+ a,
205
+ b,
206
+ out);
207
+ });
166
208
});
167
209
});
168
- });
210
+ }
169
211
}
170
212
171
213
return out;
0 commit comments