diff --git a/test/test_operations_hlo.py b/test/test_operations_hlo.py index dc59a87a324c..b9c728332d03 100644 --- a/test/test_operations_hlo.py +++ b/test/test_operations_hlo.py @@ -35,6 +35,21 @@ def test_expand(self): hlo_text = torch_xla._XLAC._get_xla_tensors_text([b]) assert 'aten::expand' in hlo_text + def test_special_scalars_addcdiv_addcmul(self): + a = torch.rand(5, 5).to(xm.xla_device()) + b = torch.rand(5, 5).to(xm.xla_device()) + c = torch.rand(5, 5).to(xm.xla_device()) + for op in [torch.addcdiv, torch.addcmul]: + out = op(a, b, c, value=1.0) + hlo_text = torch_xla._XLAC._get_xla_tensors_text([out]) + instructions = hlo_text.split('\n') + const_hlo = instructions[1] + root_hlo = instructions[5] + assert 'prim::Constant()' in const_hlo + assert 'xla::device_data()' not in const_hlo + assert 'f32' in root_hlo + assert 'f64' not in root_hlo + if __name__ == '__main__': torch.set_default_tensor_type('torch.FloatTensor') diff --git a/torch_xla/csrc/init_python_bindings.cpp b/torch_xla/csrc/init_python_bindings.cpp index 60365ff324f9..40fed350b00d 100644 --- a/torch_xla/csrc/init_python_bindings.cpp +++ b/torch_xla/csrc/init_python_bindings.cpp @@ -746,6 +746,9 @@ absl::flat_hash_map> ConvertDictToMap( void MapXlaEnvVarsToLazy() { static bool wants_frames = xla::sys_util::GetEnvBool("XLA_IR_DEBUG", false); FLAGS_torch_lazy_ir_debug = wants_frames; + static bool no_scalars = + xla::sys_util::GetEnvBool("XLA_NO_SPECIAL_SCALARS", false); + FLAGS_torch_lazy_handle_special_scalars = !no_scalars; } std::string GetPyTypeString(py::handle obj) { diff --git a/torch_xla/csrc/xla_lower_util.cpp b/torch_xla/csrc/xla_lower_util.cpp index 01299aa08905..33d2c089e808 100644 --- a/torch_xla/csrc/xla_lower_util.cpp +++ b/torch_xla/csrc/xla_lower_util.cpp @@ -1014,6 +1014,7 @@ xla::XlaOp BuildRoll(xla::XlaOp input, absl::Span shifts, xla::XlaOp BuildAddcdiv(xla::XlaOp input, xla::XlaOp t1, xla::XlaOp t2, xla::XlaOp val) { + val = MaybeConvertTo(val, XlaHelpers::ShapeOfXlaOp(t1).element_type()); return XlaHelpers::PromotedAdd( input, XlaHelpers::PromotedMul(XlaHelpers::PromotedDiv(t1, t2), val)); }