Skip to content

Commit dabb920

Browse files
Update
[ghstack-poisoned]
1 parent 5ccef96 commit dabb920

File tree

3 files changed

+6
-9
lines changed

3 files changed

+6
-9
lines changed

aten/src/ATen/cpu/vec/vec256/vec256_convert.h

+1-1
Original file line numberDiff line numberDiff line change
@@ -226,7 +226,7 @@ struct VecConvert<
226226
at::vec::Vectorized<dst_t> vec2 = convert_float_to_int8<dst_t>(src[1]);
227227
__m128 lane2 = _mm256_castps256_ps128(_mm256_castsi256_ps(vec2));
228228
__m256 combined = _mm256_insertf128_ps(_mm256_castsi256_ps(vec1), lane2, 1);
229-
// Shuffle [191:128] bit from combined to [127:64] bit of result
229+
// Shuffle [191:128] bit from combined in to [127:64] bit of result
230230
__m256i result = _mm256_permute4x64_epi64(_mm256_castps_si256(combined), 0b11011000);
231231
return at::vec::Vectorized<dst_t>(result);
232232
}

aten/src/ATen/cpu/vec/vec512/vec512_convert.h

+2-7
Original file line numberDiff line numberDiff line change
@@ -225,13 +225,8 @@ struct VecConvert<
225225
static inline VectorizedN<dst_t, 1> apply(const VectorizedN<float, 2>& src) {
226226
at::vec::Vectorized<dst_t> vec1 = convert_float_to_int8<dst_t>(src[0]);
227227
at::vec::Vectorized<dst_t> vec2 = convert_float_to_int8<dst_t>(src[1]);
228-
__m128 lane1 = _mm512_extractf32x4_ps(_mm512_castsi512_ps(vec1), 0);
229-
__m128 lane2 = _mm512_extractf32x4_ps(_mm512_castsi512_ps(vec2), 0);
230-
__m512 result = _mm512_setzero_ps();
231-
232-
// Insert the extracted lanes into the result vector
233-
result = _mm512_insertf32x4(result, lane1, 0); // Insert lane1 into the first 128-bit lane
234-
result = _mm512_insertf32x4(result, lane2, 1); // Insert lane2 into the second 128-bit lane
228+
__m128 lane2 = _mm512_castps512_ps128(_mm512_castsi512_ps(vec2));
229+
__m512 result = _mm512_insertf32x4(_mm512_castsi512_ps(vec1), lane2, 1); // Insert lane2 into the second 128-bit lane
235230
return at::vec::Vectorized<dst_t>(_mm512_castps_si512(result));
236231
}
237232
};

test/inductor/test_cpu_repro.py

+3-1
Original file line numberDiff line numberDiff line change
@@ -3747,7 +3747,9 @@ def test_convert_int8_to_half_vec(self):
37473747
def fn(x):
37483748
return x.to(dst_dtype)
37493749

3750-
x = torch.randint(0, 100, (32, 32), dtype=src_dtype)
3750+
low = 0 if src_dtype == torch.uint8 else -100
3751+
3752+
x = torch.randint(low, 100, (32, 32), dtype=src_dtype)
37513753
with config.patch({"cpp.simdlen": _simd_len}):
37523754
torch._dynamo.reset()
37533755
metrics.reset()

0 commit comments

Comments
 (0)