Skip to content

Commit 77a4fc6

Browse files
committed
Update
[ghstack-poisoned]
1 parent 85451ea commit 77a4fc6

21 files changed

+201
-118
lines changed

Diff for: kernels/portable/cpu/op_add.cpp

+12-8
Original file line numberDiff line numberDiff line change
@@ -52,17 +52,19 @@ Tensor& add_out(
5252

5353
ET_SWITCH_REALB_TYPES(compute_type, ctx, op_name, CTYPE_COMPUTE, [&]() {
5454
const CTYPE_COMPUTE val_alpha = utils::scalar_to<CTYPE_COMPUTE>(alpha);
55-
utils::apply_bitensor_elementwise_fn<CTYPE_COMPUTE, op_name>(
56-
[val_alpha](const CTYPE_COMPUTE val_a, const CTYPE_COMPUTE val_b) {
55+
utils::apply_bitensor_elementwise_fn<
56+
CTYPE_COMPUTE,
57+
op_name,
58+
utils::SupportedTensorDtypes::REALHBBF16>(
59+
[val_alpha](const auto val_a, const auto val_b) {
5760
return val_a + val_alpha * val_b;
5861
},
5962
ctx,
6063
a,
6164
utils::SupportedTensorDtypes::REALHBBF16,
6265
b,
6366
utils::SupportedTensorDtypes::REALHBBF16,
64-
out,
65-
utils::SupportedTensorDtypes::REALHBBF16);
67+
out);
6668
});
6769

6870
return out;
@@ -100,17 +102,19 @@ Tensor& add_scalar_out(
100102
static constexpr const char op_name[] = "add.Scalar_out";
101103

102104
ET_SWITCH_REALB_TYPES(compute_type, ctx, op_name, CTYPE_COMPUTE, [&]() {
103-
utils::apply_unitensor_elementwise_fn<CTYPE_COMPUTE, op_name>(
104-
[b, alpha](const CTYPE_COMPUTE val_a) {
105+
utils::apply_unitensor_elementwise_fn<
106+
CTYPE_COMPUTE,
107+
op_name,
108+
utils::SupportedTensorDtypes::SAME_AS_COMMON>(
109+
[b, alpha](const auto val_a) {
105110
CTYPE_COMPUTE val_b = utils::scalar_to<CTYPE_COMPUTE>(b);
106111
CTYPE_COMPUTE val_alpha = utils::scalar_to<CTYPE_COMPUTE>(alpha);
107112
return val_a + val_alpha * val_b;
108113
},
109114
ctx,
110115
a,
111116
utils::SupportedTensorDtypes::REALHBBF16,
112-
out,
113-
utils::SupportedTensorDtypes::SAME_AS_COMMON);
117+
out);
114118
});
115119

116120
return out;

Diff for: kernels/portable/cpu/op_addmm.cpp

+6-4
Original file line numberDiff line numberDiff line change
@@ -88,17 +88,19 @@ Tensor& addmm_out(
8888
n,
8989
p);
9090

91-
utils::apply_bitensor_elementwise_fn<CTYPE, op_name>(
92-
[alpha_val, beta_val](const CTYPE val_a, const CTYPE val_b) {
91+
utils::apply_bitensor_elementwise_fn<
92+
CTYPE,
93+
op_name,
94+
utils::SupportedTensorDtypes::REALHBF16>(
95+
[alpha_val, beta_val](const auto val_a, const auto val_b) {
9396
return val_a * alpha_val + val_b * beta_val;
9497
},
9598
ctx,
9699
out,
97100
utils::SupportedTensorDtypes::REALHBF16,
98101
in,
99102
utils::SupportedTensorDtypes::REALHBF16,
100-
out,
101-
utils::SupportedTensorDtypes::REALHBF16);
103+
out);
102104
}
103105
});
104106

Diff for: kernels/portable/cpu/op_atan2.cpp

+6-4
Original file line numberDiff line numberDiff line change
@@ -55,17 +55,19 @@ Tensor& atan2_out(
5555
static constexpr const char op_name[] = "atan2.out";
5656

5757
ET_SWITCH_FLOAT_TYPES(compute_type, ctx, op_name, CTYPE_COMPUTE, [&]() {
58-
utils::apply_bitensor_elementwise_fn<CTYPE_COMPUTE, op_name>(
59-
[](const CTYPE_COMPUTE val_a, const CTYPE_COMPUTE val_b) {
58+
utils::apply_bitensor_elementwise_fn<
59+
CTYPE_COMPUTE,
60+
op_name,
61+
utils::SupportedTensorDtypes::FLOATHBF16>(
62+
[](const auto val_a, const auto val_b) {
6063
return std::atan2(val_a, val_b);
6164
},
6265
ctx,
6366
a,
6467
utils::SupportedTensorDtypes::REALHBBF16,
6568
b,
6669
utils::SupportedTensorDtypes::REALHBBF16,
67-
out,
68-
utils::SupportedTensorDtypes::FLOATHBF16);
70+
out);
6971
});
7072

7173
return out;

Diff for: kernels/portable/cpu/op_clamp.cpp

+12-6
Original file line numberDiff line numberDiff line change
@@ -134,8 +134,12 @@ Tensor& clamp_out(
134134
static constexpr const char op_name[] = "clamp.out";
135135

136136
ET_SWITCH_REALB_TYPES(compute_type, ctx, op_name, CTYPE_COMPUTE, [&]() {
137-
utils::apply_unitensor_elementwise_fn<CTYPE_COMPUTE, op_name>(
137+
utils::apply_unitensor_elementwise_fn<
138+
CTYPE_COMPUTE,
139+
op_name,
140+
utils::SupportedTensorDtypes::SAME_AS_COMMON>(
138141
[has_min, min_opt, has_max, max_opt](const CTYPE_COMPUTE val_in) {
142+
// TODO: rewrite this to be vectorization-capable.
139143
CTYPE_COMPUTE val_out = val_in;
140144
if (has_min) {
141145
val_out = utils::max_override(
@@ -150,8 +154,7 @@ Tensor& clamp_out(
150154
ctx,
151155
in,
152156
utils::SupportedTensorDtypes::REALHBBF16,
153-
out,
154-
utils::SupportedTensorDtypes::SAME_AS_COMMON);
157+
out);
155158
});
156159

157160
return out;
@@ -210,11 +213,15 @@ Tensor& clamp_tensor_out(
210213
static constexpr const char op_name[] = "clamp.Tensor_out";
211214

212215
ET_SWITCH_REALB_TYPES(compute_type, ctx, op_name, CTYPE_COMPUTE, [&]() {
213-
utils::apply_tritensor_elementwise_fn<CTYPE_COMPUTE, op_name>(
216+
utils::apply_tritensor_elementwise_fn<
217+
CTYPE_COMPUTE,
218+
op_name,
219+
utils::SupportedTensorDtypes::REALHBBF16>(
214220
[has_min, has_max](
215221
const CTYPE_COMPUTE val_in,
216222
const CTYPE_COMPUTE val_min,
217223
const CTYPE_COMPUTE val_max) {
224+
// TODO: rewrite this to be vectorization-capable.
218225
CTYPE_COMPUTE val_out = val_in;
219226
if (has_min) {
220227
val_out = utils::max_override(val_out, val_min);
@@ -231,8 +238,7 @@ Tensor& clamp_tensor_out(
231238
utils::SupportedTensorDtypes::REALHBBF16,
232239
max,
233240
utils::SupportedTensorDtypes::REALHBBF16,
234-
out,
235-
utils::SupportedTensorDtypes::REALHBBF16);
241+
out);
236242
});
237243

238244
return out;

Diff for: kernels/portable/cpu/op_copy.cpp

+12-8
Original file line numberDiff line numberDiff line change
@@ -47,15 +47,17 @@ Tensor& copy_out(
4747
static constexpr const char op_name[] = "copy.out";
4848

4949
ET_SWITCH_REALHBBF16_TYPES(in.scalar_type(), ctx, "copy.out", CTYPE, [&]() {
50-
utils::apply_bitensor_elementwise_fn<CTYPE, op_name>(
51-
[](ET_UNUSED const CTYPE _, const CTYPE val_src) { return val_src; },
50+
utils::apply_bitensor_elementwise_fn<
51+
CTYPE,
52+
op_name,
53+
utils::SupportedTensorDtypes::REALHBBF16>(
54+
[](ET_UNUSED const auto _, const auto val_src) { return val_src; },
5255
ctx,
5356
in,
5457
utils::SupportedTensorDtypes::REALHBBF16,
5558
src,
5659
utils::SupportedTensorDtypes::REALHBBF16,
57-
out,
58-
utils::SupportedTensorDtypes::REALHBBF16);
60+
out);
5961
});
6062

6163
return out;
@@ -80,15 +82,17 @@ Tensor& copy_(
8082
static constexpr const char op_name[] = "copy_";
8183

8284
ET_SWITCH_REALHBBF16_TYPES(in.scalar_type(), ctx, "copy_", CTYPE, [&]() {
83-
utils::apply_bitensor_elementwise_fn<CTYPE, op_name>(
84-
[](ET_UNUSED const CTYPE _, const CTYPE val_src) { return val_src; },
85+
utils::apply_bitensor_elementwise_fn<
86+
CTYPE,
87+
op_name,
88+
utils::SupportedTensorDtypes::REALHBBF16>(
89+
[](ET_UNUSED const auto _, const auto val_src) { return val_src; },
8590
ctx,
8691
in,
8792
utils::SupportedTensorDtypes::REALHBBF16,
8893
src,
8994
utils::SupportedTensorDtypes::REALHBBF16,
90-
in,
91-
utils::SupportedTensorDtypes::REALHBBF16);
95+
in);
9296
});
9397

9498
return in;

Diff for: kernels/portable/cpu/op_div.cpp

+18-13
Original file line numberDiff line numberDiff line change
@@ -58,17 +58,17 @@ Tensor& div_out(
5858
static constexpr const char op_name[] = "div.out";
5959

6060
ET_SWITCH_FLOAT_TYPES(compute_type, ctx, op_name, CTYPE_COMPUTE, [&]() {
61-
utils::apply_bitensor_elementwise_fn<CTYPE_COMPUTE, op_name>(
62-
[](const CTYPE_COMPUTE val_a, const CTYPE_COMPUTE val_b) {
63-
return val_a / val_b;
64-
},
61+
utils::apply_bitensor_elementwise_fn<
62+
CTYPE_COMPUTE,
63+
op_name,
64+
utils::SupportedTensorDtypes::FLOATHBF16>(
65+
[](const auto val_a, const auto val_b) { return val_a / val_b; },
6566
ctx,
6667
a,
6768
utils::SupportedTensorDtypes::REALHBBF16,
6869
b,
6970
utils::SupportedTensorDtypes::REALHBBF16,
70-
out,
71-
utils::SupportedTensorDtypes::FLOATHBF16);
71+
out);
7272
});
7373

7474
return out;
@@ -122,9 +122,13 @@ Tensor& div_out_mode(
122122
bool div_by_zero_error = false;
123123

124124
ET_SWITCH_REAL_TYPES(compute_type, ctx, op_name, CTYPE_COMPUTE, [&]() {
125-
utils::apply_bitensor_elementwise_fn<CTYPE_COMPUTE, op_name>(
125+
utils::apply_bitensor_elementwise_fn<
126+
CTYPE_COMPUTE,
127+
op_name,
128+
utils::SupportedTensorDtypes::REALHBF16>(
126129
[mode_is_trunc, &div_by_zero_error](
127130
const CTYPE_COMPUTE val_a, const CTYPE_COMPUTE val_b) {
131+
// TODO: rewrite this to be vectorization-capable.
128132
if (is_integral_type<CTYPE_COMPUTE, /*includeBool=*/true>::value) {
129133
if (val_b == 0) {
130134
div_by_zero_error = true;
@@ -146,8 +150,7 @@ Tensor& div_out_mode(
146150
utils::SupportedTensorDtypes::REALHBBF16,
147151
b,
148152
utils::SupportedTensorDtypes::REALHBBF16,
149-
out,
150-
utils::SupportedTensorDtypes::REALHBF16);
153+
out);
151154
});
152155

153156
ET_KERNEL_CHECK_MSG(
@@ -188,13 +191,15 @@ Tensor& div_scalar_out(
188191

189192
ET_SWITCH_FLOAT_TYPES(compute_type, ctx, op_name, CTYPE_COMPUTE, [&]() {
190193
const CTYPE_COMPUTE val_b = utils::scalar_to<CTYPE_COMPUTE>(b);
191-
utils::apply_unitensor_elementwise_fn<CTYPE_COMPUTE, op_name>(
192-
[val_b](const CTYPE_COMPUTE val_a) { return val_a / val_b; },
194+
utils::apply_unitensor_elementwise_fn<
195+
CTYPE_COMPUTE,
196+
op_name,
197+
utils::SupportedTensorDtypes::SAME_AS_COMMON>(
198+
[val_b](const auto val_a) { return val_a / val_b; },
193199
ctx,
194200
a,
195201
utils::SupportedTensorDtypes::REALHBBF16,
196-
out,
197-
utils::SupportedTensorDtypes::SAME_AS_COMMON);
202+
out);
198203
});
199204

200205
return out;

Diff for: kernels/portable/cpu/op_elu.cpp

+7-4
Original file line numberDiff line numberDiff line change
@@ -44,17 +44,20 @@ Tensor& elu_out(
4444
ET_EXTRACT_SCALAR(scale, math_scale);
4545
ET_EXTRACT_SCALAR(input_scale, math_input_scale);
4646
const auto negcoef = math_alpha * math_scale;
47-
utils::apply_unitensor_elementwise_fn<CTYPE, op_name>(
48-
[negcoef, math_scale, math_input_scale](auto x) {
47+
utils::apply_unitensor_elementwise_fn<
48+
CTYPE,
49+
op_name,
50+
utils::SupportedTensorDtypes::SAME_AS_COMMON>(
51+
[negcoef, math_scale, math_input_scale](const auto x) {
52+
// TODO: rewrite this to be vectorization-capable.
4953
return MathT(x) <= MathT(0)
5054
? std::expm1(MathT(x) * math_input_scale) * negcoef
5155
: MathT(x) * math_scale;
5256
},
5357
ctx,
5458
in,
5559
utils::SupportedTensorDtypes::FLOATHBF16,
56-
out,
57-
utils::SupportedTensorDtypes::SAME_AS_COMMON);
60+
out);
5861
});
5962
return out;
6063
}

Diff for: kernels/portable/cpu/op_floor_divide.cpp

+6-3
Original file line numberDiff line numberDiff line change
@@ -53,9 +53,13 @@ Tensor& floor_divide_out(
5353
bool div_by_zero_error = false;
5454

5555
ET_SWITCH_REAL_TYPES(compute_type, ctx, op_name, CTYPE_COMPUTE, [&]() {
56-
utils::apply_bitensor_elementwise_fn<CTYPE_COMPUTE, op_name>(
56+
utils::apply_bitensor_elementwise_fn<
57+
CTYPE_COMPUTE,
58+
op_name,
59+
utils::SupportedTensorDtypes::REALHBF16>(
5760
[&div_by_zero_error](
5861
const CTYPE_COMPUTE val_a, const CTYPE_COMPUTE val_b) {
62+
// TODO: rewrite this to be vectorization-capable.
5963
if (is_integral_type<CTYPE_COMPUTE, /*includeBool=*/true>::value) {
6064
if (val_b == 0) {
6165
div_by_zero_error = true;
@@ -69,8 +73,7 @@ Tensor& floor_divide_out(
6973
utils::SupportedTensorDtypes::REALHBBF16,
7074
b,
7175
utils::SupportedTensorDtypes::REALHBBF16,
72-
out,
73-
utils::SupportedTensorDtypes::REALHBF16);
76+
out);
7477
});
7578

7679
ET_KERNEL_CHECK_MSG(

Diff for: kernels/portable/cpu/op_fmod.cpp

+12-6
Original file line numberDiff line numberDiff line change
@@ -55,9 +55,13 @@ Tensor& fmod_Tensor_out(
5555
bool div_by_zero_error = false;
5656

5757
ET_SWITCH_FLOAT_TYPES(compute_type, ctx, op_name, CTYPE_COMPUTE, [&]() {
58-
utils::apply_bitensor_elementwise_fn<CTYPE_COMPUTE, op_name>(
58+
utils::apply_bitensor_elementwise_fn<
59+
CTYPE_COMPUTE,
60+
op_name,
61+
utils::SupportedTensorDtypes::REALHBF16>(
5962
[&div_by_zero_error](
6063
const CTYPE_COMPUTE val_a, const CTYPE_COMPUTE val_b) {
64+
// TODO: rewrite this to be vectorization-capable.
6165
CTYPE_COMPUTE value = 0;
6266
if (is_integral_type<CTYPE_COMPUTE, /*includeBool=*/true>::value) {
6367
if (val_b == 0) {
@@ -73,8 +77,7 @@ Tensor& fmod_Tensor_out(
7377
utils::SupportedTensorDtypes::REALHBBF16,
7478
b,
7579
utils::SupportedTensorDtypes::REALHBBF16,
76-
out,
77-
utils::SupportedTensorDtypes::REALHBF16);
80+
out);
7881
});
7982

8083
ET_KERNEL_CHECK_MSG(
@@ -131,16 +134,19 @@ Tensor& fmod_Scalar_out(
131134

132135
ET_SWITCH_FLOAT_TYPES(compute_type, ctx, op_name, CTYPE_COMPUTE, [&]() {
133136
const CTYPE_COMPUTE val_b = utils::scalar_to<CTYPE_COMPUTE>(b);
134-
utils::apply_unitensor_elementwise_fn<CTYPE_COMPUTE, op_name>(
137+
utils::apply_unitensor_elementwise_fn<
138+
CTYPE_COMPUTE,
139+
op_name,
140+
utils::SupportedTensorDtypes::REALHBF16>(
135141
[val_b](const CTYPE_COMPUTE val_a) {
142+
// TODO: rewrite this to be vectorization-capable.
136143
CTYPE_COMPUTE value = std::fmod(val_a, val_b);
137144
return value;
138145
},
139146
ctx,
140147
a,
141148
utils::SupportedTensorDtypes::REALHBBF16,
142-
out,
143-
utils::SupportedTensorDtypes::REALHBF16);
149+
out);
144150
});
145151

146152
return out;

Diff for: kernels/portable/cpu/op_maximum.cpp

+5-3
Original file line numberDiff line numberDiff line change
@@ -45,7 +45,10 @@ Tensor& maximum_out(
4545
static constexpr const char op_name[] = "maximum.out";
4646

4747
ET_SWITCH_REALB_TYPES(compute_type, ctx, op_name, CTYPE_COMPUTE, [&]() {
48-
utils::apply_bitensor_elementwise_fn<CTYPE_COMPUTE, op_name>(
48+
utils::apply_bitensor_elementwise_fn<
49+
CTYPE_COMPUTE,
50+
op_name,
51+
utils::SupportedTensorDtypes::REALHBBF16>(
4952
[](const CTYPE_COMPUTE val_a, const CTYPE_COMPUTE val_b) {
5053
return utils::max_override(val_a, val_b);
5154
},
@@ -54,8 +57,7 @@ Tensor& maximum_out(
5457
utils::SupportedTensorDtypes::REALHBBF16,
5558
b,
5659
utils::SupportedTensorDtypes::REALHBBF16,
57-
out,
58-
utils::SupportedTensorDtypes::REALHBBF16);
60+
out);
5961
});
6062

6163
return out;

0 commit comments

Comments
 (0)