Skip to content

Commit 7375cf5

Browse files
Fix _native_batch_norm_legit_no_stats_out
Differential Revision: D66104138 Pull Request resolved: #6929
1 parent 0070680 commit 7375cf5

File tree

5 files changed

+280
-28
lines changed

5 files changed

+280
-28
lines changed

kernels/portable/cpu/op_native_batch_norm.cpp

+119-13
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@
1010
#include <tuple>
1111

1212
#include <executorch/kernels/portable/cpu/util/normalization_ops_util.h>
13+
#include <executorch/kernels/portable/cpu/vec_ops.h>
1314
#include <executorch/runtime/kernel/kernel_includes.h>
1415
#include <executorch/runtime/platform/assert.h>
1516

@@ -18,6 +19,7 @@ namespace executor {
1819
namespace native {
1920

2021
using Tensor = exec_aten::Tensor;
22+
using SizesType = exec_aten::SizesType;
2123

2224
std::tuple<Tensor&, Tensor&, Tensor&> _native_batch_norm_legit_no_training_out(
2325
KernelRuntimeContext& ctx,
@@ -184,27 +186,131 @@ std::tuple<Tensor&, Tensor&, Tensor&> _native_batch_norm_legit_no_stats_out(
184186
Tensor& mean_out,
185187
Tensor& invstd_out) {
186188
(void)ctx;
187-
(void)in;
188-
(void)weight;
189-
(void)bias;
190-
(void)momentum;
191-
(void)eps;
189+
(void)training;
192190

193191
std::tuple<Tensor&, Tensor&, Tensor&> ret_val(out, mean_out, invstd_out);
194192

195-
ET_KERNEL_CHECK_MSG(
193+
ET_KERNEL_CHECK(
196194
ctx,
197-
training == false,
195+
check_batch_norm_args(
196+
in,
197+
weight,
198+
bias,
199+
exec_aten::optional<Tensor>(),
200+
exec_aten::optional<Tensor>(),
201+
momentum,
202+
eps,
203+
out,
204+
mean_out,
205+
invstd_out),
198206
InvalidArgument,
199-
ret_val,
200-
"Portable kernels only support inference mode!");
207+
ret_val);
201208

202-
ET_KERNEL_CHECK_MSG(
209+
ET_KERNEL_CHECK(
203210
ctx,
204-
training == true,
211+
is_contiguous_dim_order(in.dim_order().data(), in.dim_order().size()),
205212
InvalidArgument,
206-
ret_val,
207-
"running_mean & running_var must be provided during inference!");
213+
ret_val);
214+
215+
ET_KERNEL_CHECK(
216+
ctx,
217+
tensors_have_same_dim_order(in, out, mean_out, invstd_out),
218+
InvalidArgument,
219+
ret_val);
220+
221+
if (weight.has_value()) {
222+
ET_KERNEL_CHECK(
223+
ctx,
224+
tensors_have_same_dim_order(in, weight.value()),
225+
InvalidArgument,
226+
ret_val);
227+
}
228+
229+
if (bias.has_value()) {
230+
ET_KERNEL_CHECK(
231+
ctx,
232+
tensors_have_same_dim_order(in, bias.value()),
233+
InvalidArgument,
234+
ret_val);
235+
}
236+
237+
ET_KERNEL_CHECK(ctx, in.dim() >= 2, InvalidArgument, ret_val);
238+
239+
size_t N = in.size(0);
240+
size_t C = in.size(1);
241+
size_t inner = getTrailingDims(in, 1);
242+
size_t elements_per_channel = N * inner;
243+
244+
ET_KERNEL_CHECK(
245+
ctx,
246+
resize_tensor(out, in.sizes()) == Error::Ok,
247+
InvalidArgument,
248+
ret_val);
249+
250+
ET_KERNEL_CHECK(
251+
ctx,
252+
resize_tensor(mean_out, {static_cast<SizesType>(C)}) == Error::Ok,
253+
InvalidArgument,
254+
ret_val);
255+
256+
ET_KERNEL_CHECK(
257+
ctx,
258+
resize_tensor(invstd_out, {static_cast<SizesType>(C)}) == Error::Ok,
259+
InvalidArgument,
260+
ret_val);
261+
262+
constexpr auto name = "_native_batch_norm_legit.no_stats_out";
263+
264+
ET_SWITCH_FLOAT_TYPES(in.scalar_type(), ctx, name, CTYPE, [&] {
265+
const CTYPE* in_data = in.const_data_ptr<CTYPE>();
266+
CTYPE* out_data = out.mutable_data_ptr<CTYPE>();
267+
CTYPE* mean_data = mean_out.mutable_data_ptr<CTYPE>();
268+
CTYPE* invstd_data = invstd_out.mutable_data_ptr<CTYPE>();
269+
270+
// Compute sum and sum of squares for each channel
271+
for (size_t b = 0; b < N; ++b) {
272+
const CTYPE* b_in_data = in_data + b * C * inner;
273+
for (size_t c = 0; c < C; ++c) {
274+
const CTYPE* x = b_in_data + c * inner;
275+
276+
CTYPE sum = reduce_add(x, inner);
277+
CTYPE sq_sum = vec_powerf(x, inner);
278+
279+
mean_data[c] += sum;
280+
invstd_data[c] += sq_sum;
281+
}
282+
}
283+
284+
// Compute mean and invstd for each channel
285+
for (size_t c = 0; c < C; ++c) {
286+
CTYPE mean = mean_data[c] / elements_per_channel;
287+
// Var[x] = E[x^2] - E[x]^2
288+
CTYPE var = invstd_data[c] / elements_per_channel - mean * mean;
289+
CTYPE invstd = 1.0 / std::sqrt(var + eps);
290+
mean_data[c] = mean;
291+
invstd_data[c] = invstd;
292+
}
293+
294+
for (size_t i = 0; i < N; ++i) {
295+
for (size_t c = 0; c < C; ++c) {
296+
CTYPE mean = mean_data[c];
297+
CTYPE invstd = invstd_data[c];
298+
CTYPE weight_val = 1;
299+
if (weight.has_value()) {
300+
weight_val = weight.value().const_data_ptr<CTYPE>()[c];
301+
}
302+
CTYPE bias_val = 0;
303+
if (bias.has_value()) {
304+
bias_val = bias.value().const_data_ptr<CTYPE>()[c];
305+
}
306+
for (size_t j = 0; j < inner; ++j) {
307+
*out_data = (*in_data - mean) * invstd * weight_val + bias_val;
308+
out_data++;
309+
in_data++;
310+
}
311+
}
312+
}
313+
});
208314

209315
return ret_val;
210316
}

kernels/portable/cpu/util/normalization_ops_util.cpp

+23-13
Original file line numberDiff line numberDiff line change
@@ -19,35 +19,35 @@ bool check_batch_norm_args(
1919
const Tensor& in,
2020
const exec_aten::optional<Tensor>& weight,
2121
const exec_aten::optional<Tensor>& bias,
22-
const Tensor& running_mean,
23-
const Tensor& running_var,
22+
const exec_aten::optional<Tensor>& running_mean,
23+
const exec_aten::optional<Tensor>& running_var,
2424
double momentum,
2525
double eps,
2626
Tensor& out,
2727
Tensor& mean_out,
2828
Tensor& var_out) {
2929
// All tensors must be the same dtype
30-
ET_LOG_AND_RETURN_IF_FALSE(
31-
tensors_have_same_dtype(in, running_mean, running_var));
32-
ET_LOG_AND_RETURN_IF_FALSE(tensors_have_same_dtype(in, out));
33-
ET_LOG_AND_RETURN_IF_FALSE(tensors_have_same_dtype(in, mean_out));
34-
ET_LOG_AND_RETURN_IF_FALSE(tensors_have_same_dtype(in, var_out));
3530
if (weight.has_value()) {
3631
ET_LOG_AND_RETURN_IF_FALSE(tensors_have_same_dtype(in, weight.value()));
3732
}
3833
if (bias.has_value()) {
3934
ET_LOG_AND_RETURN_IF_FALSE(tensors_have_same_dtype(in, bias.value()));
4035
}
36+
if (running_mean.has_value()) {
37+
ET_LOG_AND_RETURN_IF_FALSE(
38+
tensors_have_same_dtype(in, running_mean.value()));
39+
}
40+
if (running_mean.has_value()) {
41+
ET_LOG_AND_RETURN_IF_FALSE(
42+
tensors_have_same_dtype(in, running_var.value()));
43+
}
44+
ET_LOG_AND_RETURN_IF_FALSE(tensors_have_same_dtype(in, out));
45+
ET_LOG_AND_RETURN_IF_FALSE(tensors_have_same_dtype(in, mean_out));
46+
ET_LOG_AND_RETURN_IF_FALSE(tensors_have_same_dtype(in, var_out));
4147

4248
size_t C_dim = in.dim() >= 1 ? 1 : 0;
4349
// All parameter tensors must be of dim 1 and have length equal to the
4450
// channels dim of in
45-
ET_LOG_AND_RETURN_IF_FALSE(tensor_is_rank(running_mean, 1));
46-
ET_LOG_AND_RETURN_IF_FALSE(
47-
tensors_have_same_size_at_dims(running_mean, 0, in, C_dim));
48-
ET_LOG_AND_RETURN_IF_FALSE(tensor_is_rank(running_var, 1));
49-
ET_LOG_AND_RETURN_IF_FALSE(
50-
tensors_have_same_size_at_dims(running_var, 0, in, C_dim));
5151
if (weight.has_value()) {
5252
ET_LOG_AND_RETURN_IF_FALSE(tensor_is_rank(weight.value(), 1));
5353
ET_LOG_AND_RETURN_IF_FALSE(
@@ -58,6 +58,16 @@ bool check_batch_norm_args(
5858
ET_LOG_AND_RETURN_IF_FALSE(
5959
tensors_have_same_size_at_dims(bias.value(), 0, in, C_dim));
6060
}
61+
if (running_mean.has_value()) {
62+
ET_LOG_AND_RETURN_IF_FALSE(tensor_is_rank(running_mean.value(), 1));
63+
ET_LOG_AND_RETURN_IF_FALSE(
64+
tensors_have_same_size_at_dims(running_mean.value(), 0, in, C_dim));
65+
}
66+
if (running_var.has_value()) {
67+
ET_LOG_AND_RETURN_IF_FALSE(tensor_is_rank(running_var.value(), 1));
68+
ET_LOG_AND_RETURN_IF_FALSE(
69+
tensors_have_same_size_at_dims(running_var.value(), 0, in, C_dim));
70+
}
6171

6272
return true;
6373
}

kernels/portable/cpu/util/normalization_ops_util.h

+2-2
Original file line numberDiff line numberDiff line change
@@ -17,8 +17,8 @@ bool check_batch_norm_args(
1717
const Tensor& in,
1818
const exec_aten::optional<Tensor>& weight,
1919
const exec_aten::optional<Tensor>& bias,
20-
const Tensor& running_mean,
21-
const Tensor& running_var,
20+
const exec_aten::optional<Tensor>& running_mean,
21+
const exec_aten::optional<Tensor>& running_var,
2222
double momentum,
2323
double eps,
2424
Tensor& out,

kernels/test/op_native_batch_norm_test.cpp

+135
Original file line numberDiff line numberDiff line change
@@ -78,6 +78,33 @@ class OpNativeBatchNormLegitOutTest : public OperatorTest {
7878
}
7979
};
8080

81+
class OpNativeBatchNormLegitNoStatsOutTest : public OperatorTest {
82+
protected:
83+
::std::tuple<exec_aten::Tensor&, exec_aten::Tensor&, exec_aten::Tensor&>
84+
op_native_batch_norm_legit_no_stats_out(
85+
const exec_aten::Tensor& input,
86+
const exec_aten::optional<exec_aten::Tensor>& weight,
87+
const exec_aten::optional<exec_aten::Tensor>& bias,
88+
bool training,
89+
double momentum,
90+
double eps,
91+
exec_aten::Tensor& out0,
92+
exec_aten::Tensor& out1,
93+
exec_aten::Tensor& out2) {
94+
return torch::executor::aten::_native_batch_norm_legit_outf(
95+
context_,
96+
input,
97+
weight,
98+
bias,
99+
training,
100+
momentum,
101+
eps,
102+
out0,
103+
out1,
104+
out2);
105+
}
106+
};
107+
81108
TEST_F(OpNativeBatchNormLegitNoTrainingOutTest, SampleAtomicTest2D) {
82109
torch::executor::testing::TensorFactory<exec_aten::ScalarType::Float> tfFloat;
83110

@@ -949,3 +976,111 @@ TEST_F(OpNativeBatchNormLegitOutTest, SampleAtomicTest2D) {
949976
EXPECT_TENSOR_CLOSE(out1, out1_expected);
950977
EXPECT_TENSOR_CLOSE(out2, out2_expected);
951978
}
979+
980+
TEST_F(OpNativeBatchNormLegitNoStatsOutTest, SampleAtomicTest2D) {
981+
torch::executor::testing::TensorFactory<exec_aten::ScalarType::Float> tfFloat;
982+
983+
exec_aten::Tensor input =
984+
tfFloat.make({3, 4}, {0, 1, 4, 9, 16, 25, 36, 49, 64, 81, 100, 121});
985+
exec_aten::optional<exec_aten::Tensor> weight =
986+
exec_aten::optional<exec_aten::Tensor>();
987+
exec_aten::optional<exec_aten::Tensor> bias =
988+
exec_aten::optional<exec_aten::Tensor>();
989+
bool training = true;
990+
double momentum = 1e-3;
991+
double eps = 1e-5;
992+
exec_aten::Tensor out0 = tfFloat.zeros({3, 4});
993+
exec_aten::Tensor out1 = tfFloat.zeros({4});
994+
exec_aten::Tensor out2 = tfFloat.zeros({4});
995+
exec_aten::Tensor out0_expected = tfFloat.make(
996+
{3, 4},
997+
{-0.98058063,
998+
-1.03422451,
999+
-1.06904495,
1000+
-1.09332705,
1001+
-0.39223224,
1002+
-0.31822300,
1003+
-0.26726127,
1004+
-0.23017406,
1005+
1.37281299,
1006+
1.35244739,
1007+
1.33630610,
1008+
1.32350123});
1009+
exec_aten::Tensor out1_expected =
1010+
tfFloat.make({4}, {26.66666603, 35.66666794, 46.66666794, 59.66666794});
1011+
exec_aten::Tensor out2_expected =
1012+
tfFloat.make({4}, {0.03677177, 0.02983340, 0.02505574, 0.02157882});
1013+
op_native_batch_norm_legit_no_stats_out(
1014+
input, weight, bias, training, momentum, eps, out0, out1, out2);
1015+
EXPECT_TENSOR_CLOSE(out0, out0_expected);
1016+
EXPECT_TENSOR_CLOSE(out1, out1_expected);
1017+
EXPECT_TENSOR_CLOSE(out2, out2_expected);
1018+
}
1019+
1020+
TEST_F(OpNativeBatchNormLegitNoStatsOutTest, SampleAtomicTest3D) {
1021+
torch::executor::testing::TensorFactory<exec_aten::ScalarType::Float> tfFloat;
1022+
1023+
exec_aten::Tensor input = tfFloat.make(
1024+
{2, 3, 4}, {0, 1, 4, 9, 16, 25, 36, 49, 64, 81, 100, 121,
1025+
144, 169, 196, 225, 256, 289, 324, 361, 400, 441, 484, 529});
1026+
exec_aten::optional<exec_aten::Tensor> weight =
1027+
exec_aten::optional<exec_aten::Tensor>();
1028+
exec_aten::optional<exec_aten::Tensor> bias =
1029+
exec_aten::optional<exec_aten::Tensor>();
1030+
bool training = true;
1031+
double momentum = 1e-3;
1032+
double eps = 1e-5;
1033+
exec_aten::Tensor out0 = tfFloat.zeros({2, 3, 4});
1034+
exec_aten::Tensor out1 = tfFloat.zeros({3});
1035+
exec_aten::Tensor out2 = tfFloat.zeros({3});
1036+
exec_aten::Tensor out0_expected = tfFloat.make(
1037+
{2, 3, 4},
1038+
{-1.01045656, -0.99964952, -0.96722847, -0.91319335, -1.08850884,
1039+
-1.02468753, -0.94668359, -0.85449719, -1.12558389, -1.03595889,
1040+
-0.93578988, -0.82507670, 0.54575467, 0.81593025, 1.10771990,
1041+
1.42112350, 0.61339414, 0.84740579, 1.09560001, 1.35797679,
1042+
0.64582670, 0.86198103, 1.08867943, 1.32592189});
1043+
exec_aten::Tensor out1_expected = tfFloat.make({3}, {93.5, 169.5, 277.5});
1044+
exec_aten::Tensor out2_expected =
1045+
tfFloat.make({3}, {0.01080702, 0.00709126, 0.00527206});
1046+
op_native_batch_norm_legit_no_stats_out(
1047+
input, weight, bias, training, momentum, eps, out0, out1, out2);
1048+
EXPECT_TENSOR_CLOSE(out0, out0_expected);
1049+
EXPECT_TENSOR_CLOSE(out1, out1_expected);
1050+
EXPECT_TENSOR_CLOSE(out2, out2_expected);
1051+
}
1052+
1053+
TEST_F(OpNativeBatchNormLegitNoStatsOutTest, SampleAtomicTest4D) {
1054+
torch::executor::testing::TensorFactory<exec_aten::ScalarType::Float> tfFloat;
1055+
1056+
exec_aten::Tensor input =
1057+
tfFloat.make({2, 3, 2, 2}, {0, 1, 4, 9, 16, 25, 36, 49,
1058+
64, 81, 100, 121, 144, 169, 196, 225,
1059+
256, 289, 324, 361, 400, 441, 484, 529});
1060+
exec_aten::optional<exec_aten::Tensor> weight =
1061+
exec_aten::optional<exec_aten::Tensor>(
1062+
tfFloat.make({3}, {1.1, 0.7, 0.3}));
1063+
exec_aten::optional<exec_aten::Tensor> bias =
1064+
exec_aten::optional<exec_aten::Tensor>(
1065+
tfFloat.make({3}, {1.7, 2.2, 3.3}));
1066+
bool training = true;
1067+
double momentum = 1e-3;
1068+
double eps = 1e-5;
1069+
exec_aten::Tensor out0 = tfFloat.zeros({2, 3, 2, 2});
1070+
exec_aten::Tensor out1 = tfFloat.zeros({3});
1071+
exec_aten::Tensor out2 = tfFloat.zeros({3});
1072+
exec_aten::Tensor out0_expected = tfFloat.make(
1073+
{2, 3, 2, 2},
1074+
{0.58849782, 0.60038555, 0.63604873, 0.69548732, 1.43804383, 1.48271883,
1075+
1.53732157, 1.60185206, 2.96232486, 2.98921227, 3.01926303, 3.05247688,
1076+
2.30033016, 2.59752321, 2.91849184, 3.26323581, 2.62937593, 2.79318404,
1077+
2.96691990, 3.15058374, 3.49374819, 3.55859423, 3.62660384, 3.69777656});
1078+
exec_aten::Tensor out1_expected = tfFloat.make({3}, {93.5, 169.5, 277.5});
1079+
exec_aten::Tensor out2_expected =
1080+
tfFloat.make({3}, {0.01080702, 0.00709126, 0.00527206});
1081+
op_native_batch_norm_legit_no_stats_out(
1082+
input, weight, bias, training, momentum, eps, out0, out1, out2);
1083+
EXPECT_TENSOR_CLOSE(out0, out0_expected);
1084+
EXPECT_TENSOR_CLOSE(out1, out1_expected);
1085+
EXPECT_TENSOR_CLOSE(out2, out2_expected);
1086+
}

shim/xplat/executorch/kernels/portable/op_registration_util.bzl

+1
Original file line numberDiff line numberDiff line change
@@ -867,6 +867,7 @@ ATEN_OPS = (
867867
op_target(
868868
name = "op_native_batch_norm",
869869
deps = [
870+
":vec_ops",
870871
"//executorch/kernels/portable/cpu/util:normalization_ops_util",
871872
],
872873
),

0 commit comments

Comments
 (0)