Skip to content

Commit cf19c0c

Browse files
committed
Make optimization_barrier take vector and add test for correctness
1 parent 15790ef commit cf19c0c

File tree

3 files changed

+28
-8
lines changed

3 files changed

+28
-8
lines changed

test/test_operations.py

+13
Original file line numberDiff line numberDiff line change
@@ -748,6 +748,19 @@ def test_masked_select_shape(self):
748748
self.assertEqual(x_dim0_shape.item(), 3)
749749

750750

751+
class TestOptimizationBarrier(XlaTestCase):
752+
753+
def test_optimization_barrier_correctness(self):
754+
device = xm.xla_device()
755+
# only test optimization_barrier on TPU
756+
if xm.xla_device_hw(device) != 'TPU':
757+
return
758+
x = torch.randn(5, 5, device=device)
759+
y = torch.randn(5, 5, device=device)
760+
(x1, y1) = xm.optimization_barrier([x, y])
761+
self.assertEqual(x + y, x1 + y1)
762+
763+
751764
class TestDataType(XlaTestCase):
752765

753766
def test_mixed_dtype_tuple(self):

torch_xla/core/xla_model.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -1032,6 +1032,6 @@ def optimization_barrier(tensors):
10321032
the gradient checkpointing.
10331033
10341034
Args:
1035-
tensors (torch.Tensor): `torch.Tensor`s to add barrier to.
1035+
tensors (List[torch.Tensor]): List of `torch.Tensor` to add barrier to.
10361036
"""
10371037
return torch_xla._XLAC._xla_optimization_barrier(tensors)

torch_xla/csrc/init_python_bindings.cpp

+14-7
Original file line numberDiff line numberDiff line change
@@ -276,11 +276,17 @@ std::pair<at::Tensor, std::shared_ptr<ir::Value>> CollectivePermute(
276276
std::make_shared<ir::Value>(new_token));
277277
}
278278

279-
at::Tensor OptimizationBarrier(const at::Tensor& input) {
280-
at::Tensor result = bridge::AtenFromXlaTensor(
281-
XLATensor::optimization_barrier(bridge::GetXlaTensor(input)));
282-
return torch::autograd::make_variable(
283-
result, /*requires_grad=*/input.requires_grad());
279+
std::vector<at::Tensor> OptimizationBarrier(
280+
const std::vector<at::Tensor>& tensors) {
281+
std::vector<at::Tensor> result;
282+
result.reserve(tensors.size());
283+
for (auto& tensor : tensors) {
284+
result.push_back(torch::autograd::make_variable(
285+
bridge::AtenFromXlaTensor(
286+
XLATensor::optimization_barrier(bridge::GetXlaTensor(tensor))),
287+
/*requires_grad=*/tensor.requires_grad()));
288+
}
289+
return result;
284290
}
285291

286292
void SyncTensors(const std::vector<at::Tensor>& tensors,
@@ -1035,8 +1041,9 @@ void InitXlaModuleBindings(py::module m) {
10351041
}
10361042
return new_token;
10371043
});
1038-
m.def("_xla_optimization_barrier",
1039-
[](const at::Tensor& input) { return OptimizationBarrier(input); });
1044+
m.def("_xla_optimization_barrier", [](const std::vector<at::Tensor>& inputs) {
1045+
return OptimizationBarrier(inputs);
1046+
});
10401047
m.def("_xla_set_default_device", [](const std::string& device) {
10411048
return SetCurrentThreadDevice(device);
10421049
});

0 commit comments

Comments
 (0)