Skip to content

Commit 88d934e

Browse files
authored
Fix special scalar handling for addcdiv and addcmul (#3953)
* Fix special scalar handling for addcdiv and addcmul * Address CR comments
1 parent 2d87716 commit 88d934e

File tree

3 files changed

+19
-0
lines changed

3 files changed

+19
-0
lines changed

test/test_operations_hlo.py

+15
Original file line numberDiff line numberDiff line change
@@ -35,6 +35,21 @@ def test_expand(self):
3535
hlo_text = torch_xla._XLAC._get_xla_tensors_text([b])
3636
assert 'aten::expand' in hlo_text
3737

38+
def test_special_scalars_addcdiv_addcmul(self):
39+
a = torch.rand(5, 5).to(xm.xla_device())
40+
b = torch.rand(5, 5).to(xm.xla_device())
41+
c = torch.rand(5, 5).to(xm.xla_device())
42+
for op in [torch.addcdiv, torch.addcmul]:
43+
out = op(a, b, c, value=1.0)
44+
hlo_text = torch_xla._XLAC._get_xla_tensors_text([out])
45+
instructions = hlo_text.split('\n')
46+
const_hlo = instructions[1]
47+
root_hlo = instructions[5]
48+
assert 'prim::Constant()' in const_hlo
49+
assert 'xla::device_data()' not in const_hlo
50+
assert 'f32' in root_hlo
51+
assert 'f64' not in root_hlo
52+
3853

3954
if __name__ == '__main__':
4055
torch.set_default_tensor_type('torch.FloatTensor')

torch_xla/csrc/init_python_bindings.cpp

+3
Original file line numberDiff line numberDiff line change
@@ -746,6 +746,9 @@ absl::flat_hash_map<std::string, absl::variant<int>> ConvertDictToMap(
746746
void MapXlaEnvVarsToLazy() {
747747
static bool wants_frames = xla::sys_util::GetEnvBool("XLA_IR_DEBUG", false);
748748
FLAGS_torch_lazy_ir_debug = wants_frames;
749+
static bool no_scalars =
750+
xla::sys_util::GetEnvBool("XLA_NO_SPECIAL_SCALARS", false);
751+
FLAGS_torch_lazy_handle_special_scalars = !no_scalars;
749752
}
750753

751754
std::string GetPyTypeString(py::handle obj) {

torch_xla/csrc/xla_lower_util.cpp

+1
Original file line numberDiff line numberDiff line change
@@ -1014,6 +1014,7 @@ xla::XlaOp BuildRoll(xla::XlaOp input, absl::Span<const int64_t> shifts,
10141014

10151015
xla::XlaOp BuildAddcdiv(xla::XlaOp input, xla::XlaOp t1, xla::XlaOp t2,
10161016
xla::XlaOp val) {
1017+
val = MaybeConvertTo(val, XlaHelpers::ShapeOfXlaOp(t1).element_type());
10171018
return XlaHelpers::PromotedAdd(
10181019
input, XlaHelpers::PromotedMul(XlaHelpers::PromotedDiv(t1, t2), val));
10191020
}

0 commit comments

Comments
 (0)