Skip to content

Commit 3c7edf1

Browse files
leslie-fang-intelpytorchmergebot
authored andcommitted
[Inductor][CPP] Fix int8 cvt half (pytorch#136353)
Fix the correctness issue of pytorch/ao#884. The current implementation for converting between `Half/BFloat16` and `int8/uint8` incorrectly assumes that 1/4 of the int8/uint8 vector lane maps to 1/2 of the Half/BFloat16 vector lane. This assumption leads to accuracy issues after the full bit-width vectorization of the Half data type was introduced. When converting between int8 weights and the half data type, the generated code is as the following: ``` #include "/tmp/torchinductor_leslie/xw/cxww3s7wxrujoyxna7mlcjktid2uu6nntixqwm542xfkd756gl3x.h" extern "C" void kernel(const int8_t* in_ptr0, half* out_ptr0) { { for(int64_t x0=static_cast<int64_t>(0L); x0<static_cast<int64_t>(2048L); x0+=static_cast<int64_t>(32L)) { auto tmp0 = at::vec::Vectorized<int8_t>::loadu(in_ptr0 + static_cast<int64_t>(x0), static_cast<int64_t>(32)); auto tmp1 = at::vec::convert<half>(tmp0); tmp1.store(out_ptr0 + static_cast<int64_t>(x0), static_cast<int64_t>(32)); } } } ``` In this PR, we address the issue by changing the implementation to convert 1/2 of the int8/uint8 vector lane into a full vector lane of Half/BFloat16. **TestPlan** * AO: `python test/integration/test_integration.py -k test_int8_weight_only_quant_subclass_api` * `python -u -m pytest -s -v test/inductor/test_cpu_repro.py -k test_convert_int8_to_half_vec` * Due to the CPP backend legalization pass, we are unable to create a unit test to simulate the conversion from `Half` to `int8`. Instead, we rely on a C++ test case. * `./build/bin/vec_test_all_types_AVX512 --gtest_filter="VecConvertTestsReducedFloat/*.ConvertReduced"` * `./build/bin/vec_test_all_types_AVX2 --gtest_filter="VecConvertTestsReducedFloat/*.ConvertReduced"` Pull Request resolved: pytorch#136353 Approved by: https://github.com/jgong5, https://github.com/jerryzh168
1 parent 8225e77 commit 3c7edf1

File tree

4 files changed

+142
-4
lines changed

4 files changed

+142
-4
lines changed

aten/src/ATen/cpu/vec/vec256/vec256_convert.h

+40-2
Original file line numberDiff line numberDiff line change
@@ -208,8 +208,27 @@ struct VecConvert<
208208
(is_reduced_floating_point_v<src_t> && is_8bit_integer_v<dst_t>),
209209
void>> {
210210
static inline VectorizedN<dst_t, 1> apply(const VectorizedN<src_t, 1>& src) {
211-
VectorizedN<float, 1> tmp_fp32 = VecConvert<float, 1, src_t, 1>::apply(src);
212-
return VecConvert<dst_t, 1, float, 1>::apply(tmp_fp32);
211+
VectorizedN<float, 2> tmp_fp32 = VecConvert<float, 2, src_t, 1>::apply(src);
212+
return VecConvert<dst_t, 1, float, 2>::apply(tmp_fp32);
213+
}
214+
};
215+
216+
template <typename dst_t>
217+
struct VecConvert<
218+
dst_t,
219+
1,
220+
float,
221+
2,
222+
typename std::enable_if_t<is_8bit_integer_v<dst_t>,
223+
void>> {
224+
static inline VectorizedN<dst_t, 1> apply(const VectorizedN<float, 2>& src) {
225+
at::vec::Vectorized<dst_t> vec1 = convert_float_to_int8<dst_t>(src[0]);
226+
at::vec::Vectorized<dst_t> vec2 = convert_float_to_int8<dst_t>(src[1]);
227+
__m128 lane2 = _mm256_castps256_ps128(_mm256_castsi256_ps(vec2));
228+
__m256 combined = _mm256_insertf128_ps(_mm256_castsi256_ps(vec1), lane2, 1);
229+
// Shuffle [191:128] bit from combined in to [127:64] bit of result
230+
__m256i result = _mm256_permute4x64_epi64(_mm256_castps_si256(combined), 0b11011000);
231+
return at::vec::Vectorized<dst_t>(result);
213232
}
214233
};
215234

@@ -226,6 +245,25 @@ struct VecConvert<
226245
}
227246
};
228247

248+
template <typename src_t>
249+
struct VecConvert<
250+
float,
251+
2,
252+
src_t,
253+
1,
254+
typename std::enable_if_t<is_8bit_integer_v<src_t>,
255+
void>> {
256+
static inline VectorizedN<float, 2> apply(const VectorizedN<src_t, 1>& src) {
257+
// Shuffle [127:64] bit from src[0] in to [191:128] bit of shuffled
258+
__m256i shuffled = _mm256_permute4x64_epi64(src[0], 0b11011000);
259+
__m256i src2 = _mm256_castsi128_si256(
260+
_mm_castps_si128(
261+
_mm256_extractf128_ps(_mm256_castsi256_ps(shuffled), 1) // Extract the second 128-bit lane
262+
)
263+
);
264+
return VectorizedN<float, 2>(convert_int8_to_float<src_t>(src[0]), convert_int8_to_float<src_t>(src2));
265+
}
266+
};
229267

230268
template <typename dst_t>
231269
struct VecConvert<

aten/src/ATen/cpu/vec/vec512/vec512_convert.h

+37-2
Original file line numberDiff line numberDiff line change
@@ -209,8 +209,25 @@ struct VecConvert<
209209
(is_reduced_floating_point_v<src_t> && is_8bit_integer_v<dst_t>),
210210
void>> {
211211
static inline VectorizedN<dst_t, 1> apply(const VectorizedN<src_t, 1>& src) {
212-
VectorizedN<float, 1> tmp_fp32 = VecConvert<float, 1, src_t, 1>::apply(src);
213-
return VecConvert<dst_t, 1, float, 1>::apply(tmp_fp32);
212+
VectorizedN<float, 2> tmp_fp32 = VecConvert<float, 2, src_t, 1>::apply(src);
213+
return VecConvert<dst_t, 1, float, 2>::apply(tmp_fp32);
214+
}
215+
};
216+
217+
template <typename dst_t>
218+
struct VecConvert<
219+
dst_t,
220+
1,
221+
float,
222+
2,
223+
typename std::enable_if_t<is_8bit_integer_v<dst_t>,
224+
void>> {
225+
static inline VectorizedN<dst_t, 1> apply(const VectorizedN<float, 2>& src) {
226+
at::vec::Vectorized<dst_t> vec1 = convert_float_to_int8<dst_t>(src[0]);
227+
at::vec::Vectorized<dst_t> vec2 = convert_float_to_int8<dst_t>(src[1]);
228+
__m128 lane2 = _mm512_castps512_ps128(_mm512_castsi512_ps(vec2));
229+
__m512 result = _mm512_insertf32x4(_mm512_castsi512_ps(vec1), lane2, 1); // Insert lane2 into the second 128-bit lane
230+
return at::vec::Vectorized<dst_t>(_mm512_castps_si512(result));
214231
}
215232
};
216233

@@ -227,6 +244,24 @@ struct VecConvert<
227244
}
228245
};
229246

247+
template <typename src_t>
248+
struct VecConvert<
249+
float,
250+
2,
251+
src_t,
252+
1,
253+
typename std::enable_if_t<is_8bit_integer_v<src_t>,
254+
void>> {
255+
static inline VectorizedN<float, 2> apply(const VectorizedN<src_t, 1>& src) {
256+
__m512i src2 = _mm512_castsi128_si512(
257+
_mm_castps_si128(
258+
_mm512_extractf32x4_ps(_mm512_castsi512_ps(src[0]), 1) // Extract the second 128-bit lane
259+
)
260+
);
261+
return VectorizedN<float, 2>(convert_int8_to_float<src_t>(src[0]), convert_int8_to_float<src_t>(src2));
262+
}
263+
};
264+
230265
template <typename src_t>
231266
struct VecConvert<
232267
float,

aten/src/ATen/test/vec_test_all_types.cpp

+45
Original file line numberDiff line numberDiff line change
@@ -71,6 +71,8 @@ namespace {
7171
template <typename T>
7272
class VecConvertTests : public ::testing::Test {};
7373
template <typename T>
74+
class VecConvertTestsReducedFloat : public ::testing::Test {};
75+
template <typename T>
7476
class VecMaskTests : public ::testing::Test {};
7577
using RealFloatTestedTypes = ::testing::Types<vfloat, vdouble>;
7678
using FloatTestedTypes = ::testing::Types<vfloat, vdouble, vcomplex, vcomplexDbl>;
@@ -121,6 +123,7 @@ namespace {
121123
TYPED_TEST_SUITE(FunctionalTests, RealFloatIntTestedTypes);
122124
TYPED_TEST_SUITE(FunctionalTestsReducedFloat, ReducedFloatTestedTypes);
123125
TYPED_TEST_SUITE(VecConvertTests, RealFloatIntTestedTypes);
126+
TYPED_TEST_SUITE(VecConvertTestsReducedFloat, ReducedFloatTestedTypes);
124127
TYPED_TEST_SUITE(VecMaskTests, RealFloatIntTestedTypes);
125128
TYPED_TEST(Memory, UnAlignedLoadStore) {
126129
using vec = TypeParam;
@@ -1663,6 +1666,48 @@ namespace {
16631666
TEST_CONVERT_TO(double);
16641667
#undef TEST_CONVERT_TO
16651668
}
1669+
TYPED_TEST(VecConvertTestsReducedFloat, ConvertReduced) {
1670+
using vec = TypeParam;
1671+
using src_t = UholdType<TypeParam>;
1672+
constexpr auto N = vec::size();
1673+
#define TEST_CONVERT_TO(dst_t) \
1674+
do { \
1675+
CACHE_ALIGN src_t x[N]; \
1676+
CACHE_ALIGN dst_t y[N]; \
1677+
CACHE_ALIGN dst_t ref[N]; \
1678+
auto seed = TestSeed(); \
1679+
auto low = std::is_signed_v<dst_t> ? src_t(-100.0) : src_t(0); \
1680+
ValueGen<src_t> generator(low, src_t(100), seed); \
1681+
for (const auto i : c10::irange(N)) { \
1682+
x[i] = generator.get(); \
1683+
} \
1684+
for (const auto i : c10::irange(N)) { \
1685+
ref[i] = static_cast<dst_t>(x[i]); \
1686+
} \
1687+
auto x_vec = vec::loadu(x); \
1688+
auto y_vec = at::vec::convert<dst_t>(x_vec); \
1689+
constexpr int num_dst_elements = \
1690+
std::min(N, at::vec::Vectorized<dst_t>::size()); \
1691+
y_vec.store(y, num_dst_elements); \
1692+
for (const auto i : c10::irange(num_dst_elements)) { \
1693+
ASSERT_EQ(y[i], ref[i]) \
1694+
<< "Failure Details:\nTest Seed to reproduce: " << seed \
1695+
<< " x[" << i << "]=" << x[i] << " dst_t=" #dst_t; \
1696+
} \
1697+
constexpr int dst_n = N / num_dst_elements; \
1698+
auto y_vec_n = at::vec::convert<dst_t, dst_n, src_t, 1>( \
1699+
at::vec::VectorizedN<src_t, 1>(x_vec)); \
1700+
y_vec_n.store(y, N); \
1701+
for (const auto i : c10::irange(N)) { \
1702+
ASSERT_EQ(y[i], ref[i]) \
1703+
<< "Failure Details:\nTest Seed to reproduce: " << seed \
1704+
<< " x[" << i << "]=" << x[i] << " dst_t=" #dst_t; \
1705+
} \
1706+
} while (0)
1707+
TEST_CONVERT_TO(int8_t);
1708+
TEST_CONVERT_TO(uint8_t);
1709+
#undef TEST_CONVERT_TO
1710+
}
16661711
#endif
16671712
TYPED_TEST(VecMaskTests, MaskedLoad) {
16681713
using vec = TypeParam;

test/inductor/test_cpu_repro.py

+20
Original file line numberDiff line numberDiff line change
@@ -3734,6 +3734,26 @@ def fn(x):
37343734
# TODO(jgong5): change to 1 with vectorized uint64 load
37353735
assert metrics.generated_cpp_vec_kernel_count == 0
37363736

3737+
def test_convert_int8_to_half_vec(self):
3738+
src_dtypes = [torch.int8, torch.uint8]
3739+
dst_dtypes = [torch.bfloat16, torch.half]
3740+
_simd_lens = [isa._bit_width for isa in cpu_vec_isa.valid_vec_isa_list()]
3741+
for src_dtype, dst_dtype, _simd_len in itertools.product(
3742+
src_dtypes, dst_dtypes, _simd_lens
3743+
):
3744+
3745+
def fn(x):
3746+
return x.to(dst_dtype)
3747+
3748+
low = 0 if src_dtype == torch.uint8 else -100
3749+
3750+
x = torch.randint(low, 100, (32, 32), dtype=src_dtype)
3751+
with config.patch({"cpp.simdlen": _simd_len}):
3752+
torch._dynamo.reset()
3753+
metrics.reset()
3754+
self.common(fn, (x,))
3755+
check_metrics_vec_kernel_count(1)
3756+
37373757
def test_convert_int32_to_int64_vec(self):
37383758
def fn(x):
37393759
return x.to(torch.int64)

0 commit comments

Comments
 (0)