Skip to content

Revert "Revert "Extract parallel_for_each_reduce_over_dim_output_index from argmin parallelization PoC (#9139)"" #9274

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 5 commits into from
Mar 14, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 2 additions & 12 deletions kernels/portable/cpu/op_argmin.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,6 @@

#include <executorch/kernels/portable/cpu/util/reduce_util.h>
#include <executorch/runtime/kernel/kernel_includes.h>
#include <executorch/runtime/kernel/thread_parallel_interface.h>
#include <executorch/runtime/platform/assert.h>

namespace torch {
Expand Down Expand Up @@ -48,17 +47,8 @@ Tensor& argmin_out(
ET_SWITCH_REALHBF16_TYPES(in.scalar_type(), ctx, "argmin.out", CTYPE, [&] {
long* out_data = out.mutable_data_ptr<long>();

// REVIEW: this is the parallelization strategy ATen uses
// specifically when the reduction is along the last dimension and
// that dimension is contiguous. Is there any particular reason we
// shouldn't just always use this strategy since we aren't
// otherwise capable of parallelizing reductions?
const int64_t reduction_size = get_reduced_dim_product(in, dim);
const auto grain_size = std::max(
static_cast<int64_t>(1),
executorch::extension::internal::GRAIN_SIZE / reduction_size);
const bool success = executorch::extension::parallel_for(
0, out.numel(), grain_size, [&](const auto begin, const auto end) {
const bool success = parallel_for_each_reduce_over_dim_output_index(
in, dim, out, [&](const auto begin, const auto end) {
for (const auto out_ix : c10::irange(begin, end)) {
std::tuple<CTYPE, long> acc = reduce_over_dim<CTYPE>(
[](CTYPE v, long ix, CTYPE acc_val, long acc_ix) {
Expand Down
19 changes: 19 additions & 0 deletions kernels/portable/cpu/util/reduce_util.h
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@

#include <executorch/runtime/core/exec_aten/exec_aten.h>
#include <executorch/runtime/core/exec_aten/util/tensor_util.h>
#include <executorch/runtime/kernel/thread_parallel_interface.h>
#include <cstring>
#include <tuple>

Expand Down Expand Up @@ -811,5 +812,23 @@ bool check_prod_out_args(

#endif

/**
* parallel_for wrapper for reductions that call reduce_over_dim or
* map_reduce_over_dim for each output element. Automatically
* calculates appropriate grain size.
*/
template <typename Func>
[[nodiscard]] bool parallel_for_each_reduce_over_dim_output_index(
const Tensor& in,
executorch::aten::optional<int64_t> dim,
const Tensor& out,
const Func& func) {
const int64_t reduction_size = get_reduced_dim_product(in, dim);
const auto grain_size = std::max(
static_cast<int64_t>(1),
executorch::extension::internal::GRAIN_SIZE / reduction_size);
return executorch::extension::parallel_for(0, out.numel(), grain_size, func);
}

} // namespace executor
} // namespace torch
3 changes: 3 additions & 0 deletions kernels/portable/cpu/util/targets.bzl
Original file line number Diff line number Diff line change
Expand Up @@ -314,6 +314,9 @@ def define_common_targets():
"//executorch/runtime/kernel:kernel_includes{}".format(suffix),
"//executorch/runtime/core/exec_aten/util:tensor_util{}".format(suffix),
],
exported_deps = [
"//executorch/runtime/kernel:thread_parallel_interface",
],
exported_preprocessor_flags = ["-DUSE_ATEN_LIB"] if aten_mode else [],
visibility = [
"//executorch/extension/llm/custom_ops/...",
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -284,7 +284,6 @@ ATEN_OPS = (
name = "op_argmin",
deps = [
"//executorch/kernels/portable/cpu/util:reduce_util",
"//executorch/runtime/kernel:thread_parallel_interface",
],
),
op_target(
Expand Down
Loading