Skip to content

Commit 4d24922

Browse files
committed
Fix special scalar handling for addcdiv and addcmul
1 parent 205ae57 commit 4d24922

File tree

3 files changed

+21
-0
lines changed

3 files changed

+21
-0
lines changed

test/test_operations_hlo.py

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -35,6 +35,21 @@ def test_expand(self):
3535
hlo_text = torch_xla._XLAC._get_xla_tensors_text([b])
3636
assert 'aten::expand' in hlo_text
3737

38+
def test_special_scalars_addcdiv_addcmul(self):
39+
a = torch.rand(5, 5).to(xm.xla_device())
40+
b = torch.rand(5, 5).to(xm.xla_device())
41+
c = torch.rand(5, 5).to(xm.xla_device())
42+
for op in [torch.addcdiv, torch.addcmul]:
43+
out = op(a, b, c, value=1.0)
44+
hlo_text = torch_xla._XLAC._get_xla_tensors_text([out])
45+
instructions = hlo_text.split('\n')
46+
const_hlo = instructions[1]
47+
root_hlo = instructions[5]
48+
assert 'prim::Constant()' in const_hlo
49+
assert 'xla::device_data()' not in const_hlo
50+
assert 'f32' in root_hlo
51+
assert 'f64' not in root_hlo
52+
3853

3954
if __name__ == '__main__':
4055
torch.set_default_tensor_type('torch.FloatTensor')

torch_xla/csrc/init_python_bindings.cpp

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -741,6 +741,10 @@ absl::flat_hash_map<std::string, absl::variant<int>> ConvertDictToMap(
741741
return map;
742742
}
743743

744+
// Override some upstream torch::lazy env vars for better performance.
745+
// Upstream lazy env vars defined in torch/csrc/lazy/core/config.h.
746+
void SetDefaultLazyEnvVars() { FLAGS_torch_lazy_handle_special_scalars = true; }
747+
744748
// Maps PT/XLA env vars to upstream torch::lazy env vars.
745749
// Upstream lazy env vars defined in torch/csrc/lazy/core/config.h.
746750
void MapXlaEnvVarsToLazy() {
@@ -1469,6 +1473,7 @@ void InitXlaModuleBindings(py::module m) {
14691473
});
14701474

14711475
m.def("_init_xla_lazy_backend", []() {
1476+
SetDefaultLazyEnvVars();
14721477
MapXlaEnvVarsToLazy();
14731478
InitXlaBackend();
14741479
});

torch_xla/csrc/xla_lower_util.cpp

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1014,6 +1014,7 @@ xla::XlaOp BuildRoll(xla::XlaOp input, absl::Span<const int64_t> shifts,
10141014

10151015
xla::XlaOp BuildAddcdiv(xla::XlaOp input, xla::XlaOp t1, xla::XlaOp t2,
10161016
xla::XlaOp val) {
1017+
val = MaybeConvertTo(val, XlaHelpers::ShapeOfXlaOp(t1).element_type());
10171018
return XlaHelpers::PromotedAdd(
10181019
input, XlaHelpers::PromotedMul(XlaHelpers::PromotedDiv(t1, t2), val));
10191020
}

0 commit comments

Comments
 (0)