Skip to content

Commit a318396

Browse files
swolchokDannyYuyang-quic
authored andcommitted
Use parallel_for_each_reduce_over_dim_output_index for {map_,}reduce_over_dim ops (pytorch#9141)
1 parent 6915026 commit a318396

File tree

5 files changed

+101
-79
lines changed

5 files changed

+101
-79
lines changed

kernels/portable/cpu/op_any.cpp

+20-16
Original file line numberDiff line numberDiff line change
@@ -144,22 +144,26 @@ Tensor& any_out(
144144
ET_SWITCH_REALHBBF16_TYPES(in_type, ctx, name, CTYPE_IN, [&] {
145145
ET_SWITCH_TWO_TYPES(Bool, Byte, out_type, ctx, name, CTYPE_OUT, [&] {
146146
CTYPE_OUT* out_data = out.mutable_data_ptr<CTYPE_OUT>();
147-
for (const auto out_ix : c10::irange(out.numel())) {
148-
CTYPE_OUT any = false;
149-
if (in.numel() > 0) {
150-
std::tuple<CTYPE_OUT, long> acc =
151-
map_reduce_over_dim<CTYPE_IN, CTYPE_OUT>(
152-
[](CTYPE_IN v) { return static_cast<bool>(v); },
153-
[](bool outv, long, bool acc, long) {
154-
return std::tuple<bool, long>{acc || outv, 0};
155-
},
156-
in,
157-
dim,
158-
out_ix);
159-
any = std::get<0>(acc);
160-
}
161-
out_data[out_ix] = any;
162-
}
147+
const bool success = parallel_for_each_reduce_over_dim_output_index(
148+
in, dim, out, [&](const auto begin, const auto end) {
149+
for (const auto out_ix : c10::irange(begin, end)) {
150+
CTYPE_OUT any = false;
151+
if (in.numel() > 0) {
152+
std::tuple<CTYPE_OUT, long> acc =
153+
map_reduce_over_dim<CTYPE_IN, CTYPE_OUT>(
154+
[](CTYPE_IN v) { return static_cast<bool>(v); },
155+
[](bool outv, long, bool acc, long) {
156+
return std::tuple<bool, long>{acc || outv, 0};
157+
},
158+
in,
159+
dim,
160+
out_ix);
161+
any = std::get<0>(acc);
162+
}
163+
out_data[out_ix] = any;
164+
}
165+
});
166+
ET_KERNEL_CHECK_MSG(ctx, success, Internal, , "parallel_for failed");
163167
});
164168
});
165169

kernels/portable/cpu/op_argmax.cpp

+21-17
Original file line numberDiff line numberDiff line change
@@ -47,23 +47,27 @@ Tensor& argmax_out(
4747
ET_SWITCH_REALHBF16_TYPES(in.scalar_type(), ctx, "argmax.out", CTYPE, [&] {
4848
long* out_data = out.mutable_data_ptr<long>();
4949

50-
for (const auto out_ix : c10::irange(out.numel())) {
51-
std::tuple<CTYPE, long> acc = reduce_over_dim<CTYPE>(
52-
[](CTYPE v, long ix, CTYPE acc_val, long acc_ix) {
53-
// the below condition as written is equivalent to
54-
// !isnan(accval) && (isnan(v) || v > acc_val). See
55-
// argument in op_argmin.cpp.
56-
if (!std::isnan(acc_val) && !(v <= acc_val)) {
57-
acc_val = v;
58-
acc_ix = ix;
59-
}
60-
return std::tuple<CTYPE, long>{acc_val, acc_ix};
61-
},
62-
in,
63-
dim,
64-
out_ix);
65-
out_data[out_ix] = std::get<1>(acc);
66-
}
50+
const bool success = parallel_for_each_reduce_over_dim_output_index(
51+
in, dim, out, [&](const auto begin, const auto end) {
52+
for (const auto out_ix : c10::irange(begin, end)) {
53+
std::tuple<CTYPE, long> acc = reduce_over_dim<CTYPE>(
54+
[](CTYPE v, long ix, CTYPE acc_val, long acc_ix) {
55+
// the below condition as written is equivalent to
56+
// !isnan(accval) && (isnan(v) || v > acc_val). See
57+
// argument in op_argmin.cpp.
58+
if (!std::isnan(acc_val) && !(v <= acc_val)) {
59+
acc_val = v;
60+
acc_ix = ix;
61+
}
62+
return std::tuple<CTYPE, long>{acc_val, acc_ix};
63+
},
64+
in,
65+
dim,
66+
out_ix);
67+
out_data[out_ix] = std::get<1>(acc);
68+
}
69+
});
70+
ET_KERNEL_CHECK_MSG(ctx, success, Internal, , "parallel_for failed");
6771
});
6872

6973
return out;

kernels/portable/cpu/op_max.cpp

+20-15
Original file line numberDiff line numberDiff line change
@@ -83,21 +83,26 @@ std::tuple<Tensor&, Tensor&> max_out(
8383
CTYPE* max_data = max.mutable_data_ptr<CTYPE>();
8484
long* max_indices_data = max_indices.mutable_data_ptr<long>();
8585

86-
for (const auto out_ix : c10::irange(max.numel())) {
87-
std::tuple<CTYPE, long> acc = reduce_over_dim<CTYPE>(
88-
[](CTYPE v, long ix, CTYPE acc_val, long acc_ix) {
89-
if (!std::isnan(acc_val) && (std::isnan(v) || v > acc_val)) {
90-
acc_val = v;
91-
acc_ix = ix;
92-
}
93-
return std::tuple<CTYPE, long>{acc_val, acc_ix};
94-
},
95-
in,
96-
dim,
97-
out_ix);
98-
max_data[out_ix] = std::get<0>(acc);
99-
max_indices_data[out_ix] = std::get<1>(acc);
100-
}
86+
const bool success = parallel_for_each_reduce_over_dim_output_index(
87+
in, dim, max, [&](const auto begin, const auto end) {
88+
for (const auto out_ix : c10::irange(begin, end)) {
89+
std::tuple<CTYPE, long> acc = reduce_over_dim<CTYPE>(
90+
[](CTYPE v, long ix, CTYPE acc_val, long acc_ix) {
91+
if (!std::isnan(acc_val) &&
92+
(std::isnan(v) || v > acc_val)) {
93+
acc_val = v;
94+
acc_ix = ix;
95+
}
96+
return std::tuple<CTYPE, long>{acc_val, acc_ix};
97+
},
98+
in,
99+
dim,
100+
out_ix);
101+
max_data[out_ix] = std::get<0>(acc);
102+
max_indices_data[out_ix] = std::get<1>(acc);
103+
}
104+
});
105+
ET_KERNEL_CHECK_MSG(ctx, success, Internal, , "parallel_for failed");
101106
});
102107

103108
return {max, max_indices};

kernels/portable/cpu/op_min.cpp

+20-15
Original file line numberDiff line numberDiff line change
@@ -83,21 +83,26 @@ std::tuple<Tensor&, Tensor&> min_out(
8383
CTYPE* min_data = min.mutable_data_ptr<CTYPE>();
8484
long* min_indices_data = min_indices.mutable_data_ptr<long>();
8585

86-
for (const auto out_ix : c10::irange(min.numel())) {
87-
std::tuple<CTYPE, long> acc = reduce_over_dim<CTYPE>(
88-
[](CTYPE v, long ix, CTYPE acc_val, long acc_ix) {
89-
if (!std::isnan(acc_val) && (std::isnan(v) || v < acc_val)) {
90-
acc_val = v;
91-
acc_ix = ix;
92-
}
93-
return std::tuple<CTYPE, long>{acc_val, acc_ix};
94-
},
95-
in,
96-
dim,
97-
out_ix);
98-
min_data[out_ix] = std::get<0>(acc);
99-
min_indices_data[out_ix] = std::get<1>(acc);
100-
}
86+
const bool success = parallel_for_each_reduce_over_dim_output_index(
87+
in, dim, min, [&](const auto begin, const auto end) {
88+
for (const auto out_ix : c10::irange(begin, end)) {
89+
std::tuple<CTYPE, long> acc = reduce_over_dim<CTYPE>(
90+
[](CTYPE v, long ix, CTYPE acc_val, long acc_ix) {
91+
if (!std::isnan(acc_val) &&
92+
(std::isnan(v) || v < acc_val)) {
93+
acc_val = v;
94+
acc_ix = ix;
95+
}
96+
return std::tuple<CTYPE, long>{acc_val, acc_ix};
97+
},
98+
in,
99+
dim,
100+
out_ix);
101+
min_data[out_ix] = std::get<0>(acc);
102+
min_indices_data[out_ix] = std::get<1>(acc);
103+
}
104+
});
105+
ET_KERNEL_CHECK_MSG(ctx, success, Internal, , "parallel_for failed");
101106
});
102107

103108
return {min, min_indices};

kernels/portable/cpu/op_prod.cpp

+20-16
Original file line numberDiff line numberDiff line change
@@ -77,22 +77,26 @@ Tensor& prod_int_out(
7777
ET_SWITCH_REALHBBF16_TYPES(in_type, ctx, name, CTYPE_IN, [&] {
7878
ET_SWITCH_REALHBBF16_TYPES(out_type, ctx, name, CTYPE_OUT, [&] {
7979
CTYPE_OUT* out_data = out.mutable_data_ptr<CTYPE_OUT>();
80-
for (const auto out_ix : c10::irange(out.numel())) {
81-
CTYPE_OUT prod = 1;
82-
if (in.numel() > 0) {
83-
std::tuple<CTYPE_OUT, long> acc =
84-
map_reduce_over_dim<CTYPE_IN, CTYPE_OUT>(
85-
[](CTYPE_IN v) { return static_cast<CTYPE_OUT>(v); },
86-
[](CTYPE_OUT outv, long, CTYPE_OUT acc, long) {
87-
return std::tuple<CTYPE_OUT, long>{acc * outv, 0};
88-
},
89-
in,
90-
dim,
91-
out_ix);
92-
prod = std::get<0>(acc);
93-
}
94-
out_data[out_ix] = prod;
95-
}
80+
const bool success = parallel_for_each_reduce_over_dim_output_index(
81+
in, dim, out, [&](const auto begin, const auto end) {
82+
for (const auto out_ix : c10::irange(begin, end)) {
83+
CTYPE_OUT prod = 1;
84+
if (in.numel() > 0) {
85+
std::tuple<CTYPE_OUT, long> acc =
86+
map_reduce_over_dim<CTYPE_IN, CTYPE_OUT>(
87+
[](CTYPE_IN v) { return static_cast<CTYPE_OUT>(v); },
88+
[](CTYPE_OUT outv, long, CTYPE_OUT acc, long) {
89+
return std::tuple<CTYPE_OUT, long>{acc * outv, 0};
90+
},
91+
in,
92+
dim,
93+
out_ix);
94+
prod = std::get<0>(acc);
95+
}
96+
out_data[out_ix] = prod;
97+
}
98+
});
99+
ET_KERNEL_CHECK_MSG(ctx, success, Internal, , "parallel_for failed");
96100
});
97101
});
98102

0 commit comments

Comments
 (0)