Skip to content

Commit e8a42b6

Browse files
EikanWangcyyever
authored andcommitted
[Inductor] Enable Inductor to support BF16 atomic_add (#96620)
Pull Request resolved: pytorch/pytorch#96620 Approved by: https://github.com/jansel, https://github.com/jgong5
1 parent 6db9be8 commit e8a42b6

File tree

2 files changed

+54
-1
lines changed

2 files changed

+54
-1
lines changed

test/inductor/test_torchinductor.py

Lines changed: 42 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6141,6 +6141,48 @@ def test_cpu_vec_cosim(self):
61416141
union = {*cpp_vec_op_list, *diff}
61426142
self.assertTrue(set(cpp_op_list).issubset(union))
61436143

6144+
def test_atomic_add_bf16(self):
6145+
def fn(test_args):
6146+
res = torch.gather(**test_args)
6147+
return res
6148+
6149+
input_tensor_for_ref = torch.tensor(
6150+
[[3.0, -5.0]], dtype=torch.bfloat16, requires_grad=True
6151+
)
6152+
input_tensor_for_opt = torch.tensor(
6153+
[[3.0, -5.0]], dtype=torch.bfloat16, requires_grad=True
6154+
)
6155+
6156+
test_args_for_ref = {
6157+
"input": input_tensor_for_ref,
6158+
"dim": 1,
6159+
"index": torch.tensor([[1]]),
6160+
}
6161+
test_args_for_opt = {
6162+
"input": input_tensor_for_opt,
6163+
"dim": 1,
6164+
"index": torch.tensor([[1]]),
6165+
}
6166+
6167+
opt_fn = torch.compile(fn)
6168+
6169+
ref_fwd = fn(test_args_for_ref)
6170+
res_fwd = opt_fn(test_args_for_opt)
6171+
self.assertEqual(res_fwd, ref_fwd)
6172+
6173+
torch.manual_seed(1)
6174+
bwd_tensor_for_ref = torch.randn(ref_fwd.shape, dtype=torch.bfloat16)
6175+
torch.manual_seed(1)
6176+
bwd_tensor_for_opt = torch.randn(res_fwd.shape, dtype=torch.bfloat16)
6177+
self.assertEqual(bwd_tensor_for_ref, bwd_tensor_for_opt)
6178+
6179+
ref_fwd.backward(bwd_tensor_for_ref)
6180+
res_fwd.backward(bwd_tensor_for_opt)
6181+
6182+
ref_grad = test_args_for_ref["input"].grad
6183+
res_grad = test_args_for_opt["input"].grad
6184+
self.assertEqual(ref_grad, res_grad)
6185+
61446186
@unittest.skipIf(
61456187
not codecache.valid_vec_isa_list(), "Does not support vectorization"
61466188
)

torch/_inductor/codegen/cpp_prefix.h

Lines changed: 12 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -38,6 +38,17 @@ float randn_cpu(uint32_t seed, uint32_t offset) {
3838
template <typename T> struct AsIntegerType { typedef T type; };
3939
template <> struct AsIntegerType<float> { typedef uint32_t type; };
4040
template <> struct AsIntegerType<double> { typedef uint64_t type; };
41+
template <> struct AsIntegerType<bfloat16> { typedef uint16_t type; };
42+
43+
template <typename T>
44+
inline T fetch_value(volatile T *addr) {
45+
return *addr;
46+
}
47+
48+
template <>
49+
inline bfloat16 fetch_value<bfloat16>(volatile bfloat16 *addr) {
50+
return bfloat16(addr->x);
51+
}
4152

4253
template <typename T> void atomic_add(volatile T *addr, T offset) {
4354
typedef typename AsIntegerType<T>::type alt_type;
@@ -51,7 +62,7 @@ template <typename T> void atomic_add(volatile T *addr, T offset) {
5162

5263
std::atomic<alt_type> *atomic_addr = (std::atomic<alt_type> *)addr;
5364
do {
54-
T val = *addr;
65+
T val = fetch_value(addr);
5566
reinterpret_cast<T *>(&expected)[0] = val;
5667
reinterpret_cast<T *>(&desired)[0] = val + offset;
5768
} while (!atomic_addr->compare_exchange_weak(expected, desired,

0 commit comments

Comments
 (0)