Skip to content

Commit 2bee76e

Browse files
authored
Revert "Revert "Extract parallel_for_each_reduce_over_dim_output_index from argmin parallelization PoC (#9139)"" (#9274)
This reverts commit a70d1d5d4f57c2ce9474bc914c8a7a1bbb73885b. Added missing namespacing on optional argument to new function in reduce_util.h.
1 parent 09a3a5a commit 2bee76e

File tree

4 files changed

+24
-13
lines changed

4 files changed

+24
-13
lines changed

kernels/portable/cpu/op_argmin.cpp

+2-12
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,6 @@
1212

1313
#include <executorch/kernels/portable/cpu/util/reduce_util.h>
1414
#include <executorch/runtime/kernel/kernel_includes.h>
15-
#include <executorch/runtime/kernel/thread_parallel_interface.h>
1615
#include <executorch/runtime/platform/assert.h>
1716

1817
namespace torch {
@@ -48,17 +47,8 @@ Tensor& argmin_out(
4847
ET_SWITCH_REALHBF16_TYPES(in.scalar_type(), ctx, "argmin.out", CTYPE, [&] {
4948
long* out_data = out.mutable_data_ptr<long>();
5049

51-
// REVIEW: this is the parallelization strategy ATen uses
52-
// specifically when the reduction is along the last dimension and
53-
// that dimension is contiguous. Is there any particular reason we
54-
// shouldn't just always use this strategy since we aren't
55-
// otherwise capable of parallelizing reductions?
56-
const int64_t reduction_size = get_reduced_dim_product(in, dim);
57-
const auto grain_size = std::max(
58-
static_cast<int64_t>(1),
59-
executorch::extension::internal::GRAIN_SIZE / reduction_size);
60-
const bool success = executorch::extension::parallel_for(
61-
0, out.numel(), grain_size, [&](const auto begin, const auto end) {
50+
const bool success = parallel_for_each_reduce_over_dim_output_index(
51+
in, dim, out, [&](const auto begin, const auto end) {
6252
for (const auto out_ix : c10::irange(begin, end)) {
6353
std::tuple<CTYPE, long> acc = reduce_over_dim<CTYPE>(
6454
[](CTYPE v, long ix, CTYPE acc_val, long acc_ix) {

kernels/portable/cpu/util/reduce_util.h

+19
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@
1010

1111
#include <executorch/runtime/core/exec_aten/exec_aten.h>
1212
#include <executorch/runtime/core/exec_aten/util/tensor_util.h>
13+
#include <executorch/runtime/kernel/thread_parallel_interface.h>
1314
#include <cstring>
1415
#include <tuple>
1516

@@ -811,5 +812,23 @@ bool check_prod_out_args(
811812

812813
#endif
813814

815+
/**
816+
* parallel_for wrapper for reductions that call reduce_over_dim or
817+
* map_reduce_over_dim for each output element. Automatically
818+
* calculates appropriate grain size.
819+
*/
820+
template <typename Func>
821+
[[nodiscard]] bool parallel_for_each_reduce_over_dim_output_index(
822+
const Tensor& in,
823+
executorch::aten::optional<int64_t> dim,
824+
const Tensor& out,
825+
const Func& func) {
826+
const int64_t reduction_size = get_reduced_dim_product(in, dim);
827+
const auto grain_size = std::max(
828+
static_cast<int64_t>(1),
829+
executorch::extension::internal::GRAIN_SIZE / reduction_size);
830+
return executorch::extension::parallel_for(0, out.numel(), grain_size, func);
831+
}
832+
814833
} // namespace executor
815834
} // namespace torch

kernels/portable/cpu/util/targets.bzl

+3
Original file line numberDiff line numberDiff line change
@@ -314,6 +314,9 @@ def define_common_targets():
314314
"//executorch/runtime/kernel:kernel_includes{}".format(suffix),
315315
"//executorch/runtime/core/exec_aten/util:tensor_util{}".format(suffix),
316316
],
317+
exported_deps = [
318+
"//executorch/runtime/kernel:thread_parallel_interface",
319+
],
317320
exported_preprocessor_flags = ["-DUSE_ATEN_LIB"] if aten_mode else [],
318321
visibility = [
319322
"//executorch/extension/llm/custom_ops/...",

shim_et/xplat/executorch/kernels/portable/op_registration_util.bzl

-1
Original file line numberDiff line numberDiff line change
@@ -284,7 +284,6 @@ ATEN_OPS = (
284284
name = "op_argmin",
285285
deps = [
286286
"//executorch/kernels/portable/cpu/util:reduce_util",
287-
"//executorch/runtime/kernel:thread_parallel_interface",
288287
],
289288
),
290289
op_target(

0 commit comments

Comments
 (0)