Skip to content

Commit e0c9312

Browse files
Dark Knightfacebook-github-bot
Dark Knight
authored andcommitted
Revert D62466496: Multisect successfully blamed "D62466496: [ExecuTorch] Build optimized kernels with bf16 support and gate usage at runtime" for one build failure
Summary: This diff reverts D62466496 D62466496: [ExecuTorch] Build optimized kernels with bf16 support and gate usage at runtime by swolchok causes the following build failure: Tests affected: - [playground_mwa_all_for_perftest](https://www.internalfb.com/intern/test/844425060509872/) Here's the Multisect link: https://www.internalfb.com/multisect/10105407 Here are the tasks that are relevant to this breakage: T191385168: 100+ CI signals unhealthy for mwa_import_android The backout may land if someone accepts it. If this diff has been generated in error, you can Commandeer and Abandon it. Reviewed By: sheepsword Differential Revision: D62678457 fbshipit-source-id: 8a06dc283aa0ecabb1c75114166fa7b7184df989
1 parent 768f5c9 commit e0c9312

File tree

4 files changed

+28
-55
lines changed

4 files changed

+28
-55
lines changed

kernels/optimized/blas/BlasKernel.cpp

+23-38
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,6 @@
1010

1111
#ifdef __aarch64__
1212
#include <arm_neon.h>
13-
#include <cpuinfo.h>
1413
#endif
1514

1615
using torch::executor::BFloat16;
@@ -81,37 +80,32 @@ f32_dot_bf16(float32x4_t a, bfloat16x8_t b, bfloat16x8_t c) {
8180
}
8281
#endif
8382

84-
template <bool useBfloat16Dot>
8583
static ET_INLINE void dot_with_fp32_arith_main_inner_loop(
8684
const BFloat16* vec1,
8785
const BFloat16* vec2,
8886
float32x4_t sum[kF32RegistersPerIteration],
8987
int registerPairIndex) {
9088
#ifdef __ARM_FEATURE_BF16
91-
if (useBfloat16Dot) {
92-
const bfloat16x8_t temp_vec1 = vld1q_bf16(reinterpret_cast<const __bf16*>(
93-
&vec1[registerPairIndex * 2 * kF32ElementsPerRegister]));
94-
const bfloat16x8_t temp_vec2 = vld1q_bf16(reinterpret_cast<const __bf16*>(
95-
&vec2[registerPairIndex * 2 * kF32ElementsPerRegister]));
96-
sum[registerPairIndex] =
97-
f32_dot_bf16(sum[registerPairIndex], temp_vec1, temp_vec2);
98-
} else {
99-
#endif
100-
const uint16x8_t temp_vec1 = vld1q_u16(reinterpret_cast<const uint16_t*>(
101-
&vec1[registerPairIndex * 2 * kF32ElementsPerRegister]));
102-
const uint16x8_t temp_vec2 = vld1q_u16(reinterpret_cast<const uint16_t*>(
103-
&vec2[registerPairIndex * 2 * kF32ElementsPerRegister]));
104-
105-
sum[2 * registerPairIndex] = f32_fma_bf16(
106-
sum[2 * registerPairIndex],
107-
vget_low_u16(temp_vec1),
108-
vget_low_u16(temp_vec2));
109-
sum[2 * registerPairIndex + 1] = f32_fma_bf16(
110-
sum[2 * registerPairIndex + 1],
111-
vget_high_u16(temp_vec1),
112-
vget_high_u16(temp_vec2));
113-
#ifdef __ARM_FEATURE_BF16
114-
}
89+
const bfloat16x8_t temp_vec1 = vld1q_bf16(reinterpret_cast<const __bf16*>(
90+
&vec1[registerPairIndex * 2 * kF32ElementsPerRegister]));
91+
const bfloat16x8_t temp_vec2 = vld1q_bf16(reinterpret_cast<const __bf16*>(
92+
&vec2[registerPairIndex * 2 * kF32ElementsPerRegister]));
93+
sum[registerPairIndex] =
94+
f32_dot_bf16(sum[registerPairIndex], temp_vec1, temp_vec2);
95+
#else
96+
const uint16x8_t temp_vec1 = vld1q_u16(reinterpret_cast<const uint16_t*>(
97+
&vec1[registerPairIndex * 2 * kF32ElementsPerRegister]));
98+
const uint16x8_t temp_vec2 = vld1q_u16(reinterpret_cast<const uint16_t*>(
99+
&vec2[registerPairIndex * 2 * kF32ElementsPerRegister]));
100+
101+
sum[2 * registerPairIndex] = f32_fma_bf16(
102+
sum[2 * registerPairIndex],
103+
vget_low_u16(temp_vec1),
104+
vget_low_u16(temp_vec2));
105+
sum[2 * registerPairIndex + 1] = f32_fma_bf16(
106+
sum[2 * registerPairIndex + 1],
107+
vget_high_u16(temp_vec1),
108+
vget_high_u16(temp_vec2));
115109
#endif
116110
}
117111

@@ -127,7 +121,7 @@ static ET_INLINE void dot_with_fp32_arith_vectorized_tail_inner_loop(
127121
*tailSum = f32_fma_bf16(*tailSum, temp_vec1, temp_vec2);
128122
}
129123

130-
template <typename T, bool useBfloat16Dot>
124+
template <typename T>
131125
float dot_with_fp32_arith(const T* vec1, const T* vec2, int64_t len) {
132126
float32x4_t sum[kF32RegistersPerIteration] = {vdupq_n_f32(0)};
133127
const auto len_aligned = len & ~(kF32ElementsPerIteration - 1);
@@ -136,8 +130,7 @@ float dot_with_fp32_arith(const T* vec1, const T* vec2, int64_t len) {
136130
const auto* vec2_ = vec2 + j;
137131
utils::ForcedUnroll<kF32RegisterPairsPerIteration>{}(
138132
[vec1_, vec2_, &sum](auto k) ET_INLINE_ATTRIBUTE {
139-
dot_with_fp32_arith_main_inner_loop<useBfloat16Dot>(
140-
vec1_, vec2_, sum, k);
133+
dot_with_fp32_arith_main_inner_loop(vec1_, vec2_, sum, k);
141134
});
142135
}
143136
auto reducedSum = reduce(sum);
@@ -164,15 +157,7 @@ float bf16_dot_with_fp32_arith(
164157
const BFloat16* vec1,
165158
const BFloat16* vec2,
166159
int64_t len) {
167-
#ifdef __ARM_FEATURE_BF16
168-
if (cpuinfo_has_arm_bf16()) {
169-
return dot_with_fp32_arith<BFloat16, true>(vec1, vec2, len);
170-
} else {
171-
#endif
172-
return dot_with_fp32_arith<BFloat16, false>(vec1, vec2, len);
173-
#ifdef __ARM_FEATURE_BF16
174-
}
175-
#endif
160+
return dot_with_fp32_arith(vec1, vec2, len);
176161
}
177162
#endif
178163
} // namespace internal

kernels/optimized/lib_defs.bzl

-11
Original file line numberDiff line numberDiff line change
@@ -129,14 +129,6 @@ def define_libs():
129129
] if not runtime.is_oss else [],
130130
"DEFAULT": [],
131131
}),
132-
fbandroid_platform_compiler_flags = [
133-
(
134-
"^android-arm64.*$",
135-
[
136-
"-march=armv8+bf16",
137-
],
138-
),
139-
],
140132
fbandroid_platform_preprocessor_flags = [
141133
(
142134
"^android-arm64.*$",
@@ -153,9 +145,6 @@ def define_libs():
153145
],
154146
),
155147
],
156-
fbobjc_compiler_flags = [
157-
"-march=armv8+bf16",
158-
],
159148
fbobjc_exported_preprocessor_flags = [
160149
"-DET_BUILD_WITH_BLAS",
161150
"-DET_BUILD_FOR_APPLE",

kernels/test/op_linear_test.cpp

+4-4
Original file line numberDiff line numberDiff line change
@@ -43,16 +43,16 @@ class OpLinearOutTest : public OperatorTest {
4343
}
4444
}
4545

46-
// matmul gives 32 * 2 * 3 = 192
47-
Tensor x = tf.full({3, 32}, 2);
48-
Tensor y = tf.full({5, 32}, 3);
46+
// matmul gives 4 * 2 * 3 = 24
47+
Tensor x = tf.full({3, 4}, 2);
48+
Tensor y = tf.full({5, 4}, 3);
4949

5050
// Output shape should be (3, 5)
5151
Tensor out = tf.zeros({3, 5});
5252

5353
op_linear_out(x, y, out);
5454

55-
Tensor expected = tf.full({3, 5}, 192);
55+
Tensor expected = tf.full({3, 5}, 24);
5656

5757
EXPECT_TENSOR_EQ(out, expected);
5858
}

shim/xplat/executorch/build/env_interface.bzl

+1-2
Original file line numberDiff line numberDiff line change
@@ -118,8 +118,7 @@ def _remove_platform_specific_args(kwargs):
118118
"""
119119
keys = []
120120
for key in kwargs:
121-
if (key.endswith("_platform_preprocessor_flags") or key.endswith("_platform_deps") or
122-
key.startswith("fbobjc") or key.endswith("_platform_compiler_flags")):
121+
if key.endswith("_platform_preprocessor_flags") or key.endswith("_platform_deps") or key.startswith("fbobjc"):
123122
keys.append(key)
124123
for key in keys:
125124
kwargs.pop(key)

0 commit comments

Comments
 (0)