Skip to content

Commit 208a341

Browse files
authored
Specialize for non-mixed-dtype in elementwise_util (#9388)
Mixed dtype should be uncommon. Here is how we can specialize for the common case. Prepares us to tackle #9241 . Test Plan: automated tests on this PR verify we didn't break the now-deprecated runtime_out_dtypes mode; tests on the next PR will verify that everything works after migration. Also included migration for exactly one operator, op_mul, to verify that the new code compiles. To check performance, I edited examples/models/toy_model/model.py so that MulModule used inputs of size 3000, 2000 instead of 3, 2. I exported it with `python3 -m examples.portable.scripts.export --model_name mul` and saved the resulting `mul.pte`. Then I built in release mode with optimized kernels on, but with mul.out removed from kernels/optimized/optimized.yaml, so that we would use the optimized_portable_kernels build of kernels/portable/op_mul.cpp. Finally, I ran 3 trials on my M1 Macbook Pro using `cmake-out/executor_runner --model_path mul3kby2k.pte --num_executions 1000 --cpu_threads 2`. Resulting times for 1000 iterations in ms: Previous diff: 8295, 8187, 8139 This diff: 2953, 2806, 2861 (For comparison, the actual optimized mul kernel took around 1000 ms to run 1000 iterations, and #9432 later in the stack arrived at similar numbers.)
1 parent 59870c5 commit 208a341

File tree

3 files changed

+127
-19
lines changed

3 files changed

+127
-19
lines changed

kernels/portable/cpu/op_mul.cpp

+5-3
Original file line numberDiff line numberDiff line change
@@ -52,7 +52,10 @@ Tensor& mul_out(
5252
out);
5353

5454
ET_SWITCH_REALB_TYPES(compute_type, ctx, op_name, CTYPE_COMPUTE, [&]() {
55-
utils::apply_bitensor_elementwise_fn<CTYPE_COMPUTE, op_name>(
55+
utils::apply_bitensor_elementwise_fn<
56+
CTYPE_COMPUTE,
57+
op_name,
58+
utils::SupportedTensorDtypes::REALHBBF16>(
5659
[](const CTYPE_COMPUTE val_a, const CTYPE_COMPUTE val_b) {
5760
return val_a * val_b;
5861
},
@@ -61,8 +64,7 @@ Tensor& mul_out(
6164
utils::SupportedTensorDtypes::REALHBBF16,
6265
b,
6366
utils::SupportedTensorDtypes::REALHBBF16,
64-
out,
65-
utils::SupportedTensorDtypes::REALHBBF16);
67+
out);
6668
});
6769

6870
return out;

kernels/portable/cpu/util/dtype_util.h

+19
Original file line numberDiff line numberDiff line change
@@ -290,6 +290,25 @@ bool check_tensor_dtype(
290290
SupportedTensorDtypes dtypes,
291291
const ScalarType compute_type);
292292

293+
/// Return the one output type we are willing to emit specialized code
294+
/// to handle, given a compute type of CTYPE_COMMON and supported
295+
/// output types of out_dtypes.
296+
template <typename CTYPE_COMPUTE>
297+
inline constexpr ScalarType specialized_output_scalar_type(
298+
SupportedTensorDtypes out_dtypes) {
299+
switch (out_dtypes) {
300+
case SupportedTensorDtypes::BOOL_OR_BYTE:
301+
return ScalarType::Bool;
302+
case SupportedTensorDtypes::REALHBBF16:
303+
case SupportedTensorDtypes::REALHBF16:
304+
case SupportedTensorDtypes::FLOATHBF16:
305+
case SupportedTensorDtypes::INTB:
306+
case SupportedTensorDtypes::SAME_AS_COMPUTE:
307+
case SupportedTensorDtypes::SAME_AS_COMMON:
308+
return CppTypeToScalarType<CTYPE_COMPUTE>::value;
309+
}
310+
}
311+
293312
} // namespace internal
294313
} // namespace utils
295314
} // namespace native

kernels/portable/cpu/util/elementwise_util.h

+103-16
Original file line numberDiff line numberDiff line change
@@ -51,6 +51,44 @@ inline int64_t scalar_to<int64_t>(const Scalar& s) {
5151
}
5252

5353
namespace internal {
54+
template <
55+
typename CTYPE_COMPUTE,
56+
typename CTYPE_OUT,
57+
typename Op,
58+
typename... Args>
59+
inline void dtype_specialized_elementwise_fn_impl(
60+
const Op& compute_fun,
61+
KernelRuntimeContext& ctx,
62+
const Tensor& out,
63+
Args... inputs) {
64+
constexpr auto kNumInputs = sizeof...(inputs);
65+
ET_DCHECK(((inputs.first->element_size() == sizeof(CTYPE_COMPUTE)) && ...));
66+
67+
::executorch::extension::parallel_for(
68+
0,
69+
out.numel(),
70+
::executorch::extension::internal::GRAIN_SIZE,
71+
[&](const auto begin, const auto end) {
72+
std::array<const CTYPE_COMPUTE*, kNumInputs> inputs_data_ptrs = {
73+
inputs.first->template const_data_ptr<CTYPE_COMPUTE>()...};
74+
75+
CTYPE_OUT* const data_out = out.mutable_data_ptr<CTYPE_OUT>();
76+
77+
const auto range =
78+
BroadcastIndexesRange<kNumInputs>(out, (*inputs.first)...);
79+
auto begin_it = range.begin();
80+
begin_it += begin;
81+
for (; (*begin_it)[0] < end; ++begin_it) {
82+
const auto& indexes = *begin_it;
83+
std::array<CTYPE_COMPUTE, kNumInputs> loaded_inputs;
84+
for (const auto idx : c10::irange(kNumInputs)) {
85+
loaded_inputs[idx] = inputs_data_ptrs[idx][indexes[idx + 1]];
86+
}
87+
data_out[indexes[0]] = std::apply(compute_fun, loaded_inputs);
88+
}
89+
});
90+
}
91+
5492
template <typename CTYPE_COMPUTE, typename Op, typename... Args>
5593
inline bool validate_elementwise_fn_inputs(
5694
const Op& compute_fun,
@@ -81,18 +119,12 @@ template <
81119
const char* op_name,
82120
typename Op,
83121
typename... Args>
84-
inline void apply_elementwise_fn(
122+
inline void apply_elementwise_fn_generic_impl(
85123
const Op& compute_fun,
86124
KernelRuntimeContext& ctx,
87125
const Tensor& out,
88126
SupportedTensorDtypes out_dtypes,
89127
Args... inputs) {
90-
const bool inputs_valid = validate_elementwise_fn_inputs<CTYPE_COMPUTE>(
91-
compute_fun, ctx, out, out_dtypes, inputs...);
92-
if (!inputs_valid) {
93-
return;
94-
}
95-
96128
constexpr auto kNumInputs = sizeof...(inputs);
97129

98130
struct InputInfo {
@@ -138,6 +170,63 @@ inline void apply_elementwise_fn(
138170
});
139171
}
140172

173+
template <
174+
typename CTYPE_COMPUTE,
175+
const char* op_name,
176+
typename Op,
177+
typename... Args>
178+
inline void apply_elementwise_fn_runtime_out_dtypes(
179+
const Op& compute_fun,
180+
KernelRuntimeContext& ctx,
181+
const Tensor& out,
182+
SupportedTensorDtypes out_dtypes,
183+
Args... inputs) {
184+
const bool inputs_valid = validate_elementwise_fn_inputs<CTYPE_COMPUTE>(
185+
compute_fun, ctx, out, out_dtypes, inputs...);
186+
if (!inputs_valid) {
187+
return;
188+
}
189+
190+
apply_elementwise_fn_generic_impl<CTYPE_COMPUTE, op_name>(
191+
compute_fun, ctx, out, out_dtypes, inputs...);
192+
}
193+
194+
template <
195+
typename CTYPE_COMPUTE,
196+
const char* op_name,
197+
SupportedTensorDtypes out_dtypes,
198+
typename Op,
199+
typename... Args>
200+
inline void apply_elementwise_fn(
201+
const Op& compute_fun,
202+
KernelRuntimeContext& ctx,
203+
const Tensor& out,
204+
Args... inputs) {
205+
const bool inputs_valid = validate_elementwise_fn_inputs<CTYPE_COMPUTE>(
206+
compute_fun, ctx, out, out_dtypes, inputs...);
207+
if (!inputs_valid) {
208+
return;
209+
}
210+
211+
constexpr auto compute_type = CppTypeToScalarType<CTYPE_COMPUTE>::value;
212+
const bool all_inputs_compute_dtype =
213+
((inputs.first->scalar_type() == compute_type) && ...);
214+
215+
constexpr ScalarType out_specialized_scalar_type =
216+
specialized_output_scalar_type<CTYPE_COMPUTE>(out_dtypes);
217+
if (all_inputs_compute_dtype &&
218+
out.scalar_type() == out_specialized_scalar_type) {
219+
using CTYPE_OUT =
220+
typename ScalarTypeToCppType<out_specialized_scalar_type>::type;
221+
dtype_specialized_elementwise_fn_impl<CTYPE_COMPUTE, CTYPE_OUT>(
222+
compute_fun, ctx, out, inputs...);
223+
return;
224+
}
225+
226+
apply_elementwise_fn_generic_impl<CTYPE_COMPUTE, op_name>(
227+
compute_fun, ctx, out, out_dtypes, inputs...);
228+
}
229+
141230
/// DEPRECATED: prefer the variant with out_dtypes in the template argument.
142231
template <typename CTYPE_COMPUTE, const char* op_name, typename Op>
143232
inline void apply_unitensor_elementwise_fn(
@@ -147,7 +236,7 @@ inline void apply_unitensor_elementwise_fn(
147236
SupportedTensorDtypes a_dtypes,
148237
const Tensor& out,
149238
SupportedTensorDtypes out_dtypes) {
150-
internal::apply_elementwise_fn<CTYPE_COMPUTE, op_name>(
239+
internal::apply_elementwise_fn_runtime_out_dtypes<CTYPE_COMPUTE, op_name>(
151240
compute_fun, ctx, out, out_dtypes, std::make_pair(&a, a_dtypes));
152241
}
153242

@@ -162,8 +251,8 @@ inline void apply_unitensor_elementwise_fn(
162251
const Tensor& a,
163252
SupportedTensorDtypes a_dtypes,
164253
const Tensor& out) {
165-
internal::apply_elementwise_fn<CTYPE_COMPUTE, op_name>(
166-
compute_fun, ctx, out, out_dtypes, std::make_pair(&a, a_dtypes));
254+
internal::apply_elementwise_fn<CTYPE_COMPUTE, op_name, out_dtypes>(
255+
compute_fun, ctx, out, std::make_pair(&a, a_dtypes));
167256
}
168257

169258
/**
@@ -179,7 +268,7 @@ inline void apply_bitensor_elementwise_fn(
179268
SupportedTensorDtypes b_dtypes,
180269
const Tensor& out,
181270
SupportedTensorDtypes out_dtypes) {
182-
internal::apply_elementwise_fn<CTYPE_COMPUTE, op_name>(
271+
internal::apply_elementwise_fn_runtime_out_dtypes<CTYPE_COMPUTE, op_name>(
183272
compute_fun,
184273
ctx,
185274
out,
@@ -206,11 +295,10 @@ inline void apply_bitensor_elementwise_fn(
206295
const Tensor& b,
207296
SupportedTensorDtypes b_dtypes,
208297
const Tensor& out) {
209-
internal::apply_elementwise_fn<CTYPE_COMPUTE, op_name>(
298+
internal::apply_elementwise_fn<CTYPE_COMPUTE, op_name, out_dtypes>(
210299
compute_fun,
211300
ctx,
212301
out,
213-
out_dtypes,
214302
std::make_pair(&a, a_dtypes),
215303
std::make_pair(&b, b_dtypes));
216304
}
@@ -230,7 +318,7 @@ inline void apply_tritensor_elementwise_fn(
230318
SupportedTensorDtypes c_dtypes,
231319
const Tensor& out,
232320
SupportedTensorDtypes out_dtypes) {
233-
internal::apply_elementwise_fn<CTYPE_COMPUTE, op_name>(
321+
internal::apply_elementwise_fn_runtime_out_dtypes<CTYPE_COMPUTE, op_name>(
234322
compute_fun,
235323
ctx,
236324
out,
@@ -275,11 +363,10 @@ inline void apply_tritensor_elementwise_fn(
275363
const Tensor& c,
276364
SupportedTensorDtypes c_dtypes,
277365
const Tensor& out) {
278-
internal::apply_elementwise_fn<CTYPE_COMPUTE, op_name>(
366+
internal::apply_elementwise_fn<CTYPE_COMPUTE, op_name, out_dtypes>(
279367
compute_fun,
280368
ctx,
281369
out,
282-
out_dtypes,
283370
std::make_pair(&a, a_dtypes),
284371
std::make_pair(&b, b_dtypes),
285372
std::make_pair(&c, c_dtypes));

0 commit comments

Comments
 (0)