@@ -51,6 +51,44 @@ inline int64_t scalar_to<int64_t>(const Scalar& s) {
51
51
}
52
52
53
53
namespace internal {
54
+ template <
55
+ typename CTYPE_COMPUTE,
56
+ typename CTYPE_OUT,
57
+ typename Op,
58
+ typename ... Args>
59
+ inline void dtype_specialized_elementwise_fn_impl (
60
+ const Op& compute_fun,
61
+ KernelRuntimeContext& ctx,
62
+ const Tensor& out,
63
+ Args... inputs) {
64
+ constexpr auto kNumInputs = sizeof ...(inputs);
65
+ ET_DCHECK (((inputs.first ->element_size () == sizeof (CTYPE_COMPUTE)) && ...));
66
+
67
+ ::executorch::extension::parallel_for (
68
+ 0 ,
69
+ out.numel(),
70
+ ::executorch::extension::internal::GRAIN_SIZE,
71
+ [&](const auto begin, const auto end) {
72
+ std::array<const CTYPE_COMPUTE*, kNumInputs > inputs_data_ptrs = {
73
+ inputs.first ->template const_data_ptr <CTYPE_COMPUTE>()...};
74
+
75
+ CTYPE_OUT* const data_out = out.mutable_data_ptr <CTYPE_OUT>();
76
+
77
+ const auto range =
78
+ BroadcastIndexesRange<kNumInputs >(out, (*inputs.first )...);
79
+ auto begin_it = range.begin ();
80
+ begin_it += begin;
81
+ for (; (*begin_it)[0 ] < end; ++begin_it) {
82
+ const auto & indexes = *begin_it;
83
+ std::array<CTYPE_COMPUTE, kNumInputs > loaded_inputs;
84
+ for (const auto idx : c10::irange (kNumInputs )) {
85
+ loaded_inputs[idx] = inputs_data_ptrs[idx][indexes[idx + 1 ]];
86
+ }
87
+ data_out[indexes[0 ]] = std::apply (compute_fun, loaded_inputs);
88
+ }
89
+ });
90
+ }
91
+
54
92
template <typename CTYPE_COMPUTE, typename Op, typename ... Args>
55
93
inline bool validate_elementwise_fn_inputs (
56
94
const Op& compute_fun,
@@ -81,18 +119,12 @@ template <
81
119
const char * op_name,
82
120
typename Op,
83
121
typename ... Args>
84
- inline void apply_elementwise_fn (
122
+ inline void apply_elementwise_fn_generic_impl (
85
123
const Op& compute_fun,
86
124
KernelRuntimeContext& ctx,
87
125
const Tensor& out,
88
126
SupportedTensorDtypes out_dtypes,
89
127
Args... inputs) {
90
- const bool inputs_valid = validate_elementwise_fn_inputs<CTYPE_COMPUTE>(
91
- compute_fun, ctx, out, out_dtypes, inputs...);
92
- if (!inputs_valid) {
93
- return ;
94
- }
95
-
96
128
constexpr auto kNumInputs = sizeof ...(inputs);
97
129
98
130
struct InputInfo {
@@ -138,6 +170,63 @@ inline void apply_elementwise_fn(
138
170
});
139
171
}
140
172
173
+ template <
174
+ typename CTYPE_COMPUTE,
175
+ const char * op_name,
176
+ typename Op,
177
+ typename ... Args>
178
+ inline void apply_elementwise_fn_runtime_out_dtypes (
179
+ const Op& compute_fun,
180
+ KernelRuntimeContext& ctx,
181
+ const Tensor& out,
182
+ SupportedTensorDtypes out_dtypes,
183
+ Args... inputs) {
184
+ const bool inputs_valid = validate_elementwise_fn_inputs<CTYPE_COMPUTE>(
185
+ compute_fun, ctx, out, out_dtypes, inputs...);
186
+ if (!inputs_valid) {
187
+ return ;
188
+ }
189
+
190
+ apply_elementwise_fn_generic_impl<CTYPE_COMPUTE, op_name>(
191
+ compute_fun, ctx, out, out_dtypes, inputs...);
192
+ }
193
+
194
+ template <
195
+ typename CTYPE_COMPUTE,
196
+ const char * op_name,
197
+ SupportedTensorDtypes out_dtypes,
198
+ typename Op,
199
+ typename ... Args>
200
+ inline void apply_elementwise_fn (
201
+ const Op& compute_fun,
202
+ KernelRuntimeContext& ctx,
203
+ const Tensor& out,
204
+ Args... inputs) {
205
+ const bool inputs_valid = validate_elementwise_fn_inputs<CTYPE_COMPUTE>(
206
+ compute_fun, ctx, out, out_dtypes, inputs...);
207
+ if (!inputs_valid) {
208
+ return ;
209
+ }
210
+
211
+ constexpr auto compute_type = CppTypeToScalarType<CTYPE_COMPUTE>::value;
212
+ const bool all_inputs_compute_dtype =
213
+ ((inputs.first ->scalar_type () == compute_type) && ...);
214
+
215
+ constexpr ScalarType out_specialized_scalar_type =
216
+ specialized_output_scalar_type<CTYPE_COMPUTE>(out_dtypes);
217
+ if (all_inputs_compute_dtype &&
218
+ out.scalar_type () == out_specialized_scalar_type) {
219
+ using CTYPE_OUT =
220
+ typename ScalarTypeToCppType<out_specialized_scalar_type>::type;
221
+ dtype_specialized_elementwise_fn_impl<CTYPE_COMPUTE, CTYPE_OUT>(
222
+ compute_fun, ctx, out, inputs...);
223
+ return ;
224
+ }
225
+
226
+ apply_elementwise_fn_generic_impl<CTYPE_COMPUTE, op_name>(
227
+ compute_fun, ctx, out, out_dtypes, inputs...);
228
+ }
229
+
141
230
// / DEPRECATED: prefer the variant with out_dtypes in the template argument.
142
231
template <typename CTYPE_COMPUTE, const char * op_name, typename Op>
143
232
inline void apply_unitensor_elementwise_fn (
@@ -147,7 +236,7 @@ inline void apply_unitensor_elementwise_fn(
147
236
SupportedTensorDtypes a_dtypes,
148
237
const Tensor& out,
149
238
SupportedTensorDtypes out_dtypes) {
150
- internal::apply_elementwise_fn <CTYPE_COMPUTE, op_name>(
239
+ internal::apply_elementwise_fn_runtime_out_dtypes <CTYPE_COMPUTE, op_name>(
151
240
compute_fun, ctx, out, out_dtypes, std::make_pair (&a, a_dtypes));
152
241
}
153
242
@@ -162,8 +251,8 @@ inline void apply_unitensor_elementwise_fn(
162
251
const Tensor& a,
163
252
SupportedTensorDtypes a_dtypes,
164
253
const Tensor& out) {
165
- internal::apply_elementwise_fn<CTYPE_COMPUTE, op_name>(
166
- compute_fun, ctx, out, out_dtypes, std::make_pair (&a, a_dtypes));
254
+ internal::apply_elementwise_fn<CTYPE_COMPUTE, op_name, out_dtypes >(
255
+ compute_fun, ctx, out, std::make_pair (&a, a_dtypes));
167
256
}
168
257
169
258
/* *
@@ -179,7 +268,7 @@ inline void apply_bitensor_elementwise_fn(
179
268
SupportedTensorDtypes b_dtypes,
180
269
const Tensor& out,
181
270
SupportedTensorDtypes out_dtypes) {
182
- internal::apply_elementwise_fn <CTYPE_COMPUTE, op_name>(
271
+ internal::apply_elementwise_fn_runtime_out_dtypes <CTYPE_COMPUTE, op_name>(
183
272
compute_fun,
184
273
ctx,
185
274
out,
@@ -206,11 +295,10 @@ inline void apply_bitensor_elementwise_fn(
206
295
const Tensor& b,
207
296
SupportedTensorDtypes b_dtypes,
208
297
const Tensor& out) {
209
- internal::apply_elementwise_fn<CTYPE_COMPUTE, op_name>(
298
+ internal::apply_elementwise_fn<CTYPE_COMPUTE, op_name, out_dtypes >(
210
299
compute_fun,
211
300
ctx,
212
301
out,
213
- out_dtypes,
214
302
std::make_pair (&a, a_dtypes),
215
303
std::make_pair (&b, b_dtypes));
216
304
}
@@ -230,7 +318,7 @@ inline void apply_tritensor_elementwise_fn(
230
318
SupportedTensorDtypes c_dtypes,
231
319
const Tensor& out,
232
320
SupportedTensorDtypes out_dtypes) {
233
- internal::apply_elementwise_fn <CTYPE_COMPUTE, op_name>(
321
+ internal::apply_elementwise_fn_runtime_out_dtypes <CTYPE_COMPUTE, op_name>(
234
322
compute_fun,
235
323
ctx,
236
324
out,
@@ -275,11 +363,10 @@ inline void apply_tritensor_elementwise_fn(
275
363
const Tensor& c,
276
364
SupportedTensorDtypes c_dtypes,
277
365
const Tensor& out) {
278
- internal::apply_elementwise_fn<CTYPE_COMPUTE, op_name>(
366
+ internal::apply_elementwise_fn<CTYPE_COMPUTE, op_name, out_dtypes >(
279
367
compute_fun,
280
368
ctx,
281
369
out,
282
- out_dtypes,
283
370
std::make_pair (&a, a_dtypes),
284
371
std::make_pair (&b, b_dtypes),
285
372
std::make_pair (&c, c_dtypes));
0 commit comments