Skip to content

Commit 15790ef

Browse files
committed
Add optimization barrier
1 parent 7dbb584 commit 15790ef

File tree

8 files changed

+42
-0
lines changed

8 files changed

+42
-0
lines changed

torch_xla/core/xla_model.py

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1024,3 +1024,14 @@ def get_memory_info(device):
10241024
memory in KB) keys.
10251025
"""
10261026
return torch_xla._XLAC._xla_memory_info(str(device))
1027+
1028+
1029+
def optimization_barrier(tensors):
1030+
"""Blocks xla compiler from moving computations across this barrier. The common
1031+
use case would be blocking xla common-subexpression elimination pass from undoing
1032+
the gradient checkpointing.
1033+
1034+
Args:
1035+
tensors (torch.Tensor): `torch.Tensor`s to add barrier to.
1036+
"""
1037+
return torch_xla._XLAC._xla_optimization_barrier(tensors)

torch_xla/csrc/init_python_bindings.cpp

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -276,6 +276,13 @@ std::pair<at::Tensor, std::shared_ptr<ir::Value>> CollectivePermute(
276276
std::make_shared<ir::Value>(new_token));
277277
}
278278

279+
at::Tensor OptimizationBarrier(const at::Tensor& input) {
280+
at::Tensor result = bridge::AtenFromXlaTensor(
281+
XLATensor::optimization_barrier(bridge::GetXlaTensor(input)));
282+
return torch::autograd::make_variable(
283+
result, /*requires_grad=*/input.requires_grad());
284+
}
285+
279286
void SyncTensors(const std::vector<at::Tensor>& tensors,
280287
const std::vector<std::string>& devices, bool wait,
281288
bool sync_xla_data) {
@@ -1028,6 +1035,8 @@ void InitXlaModuleBindings(py::module m) {
10281035
}
10291036
return new_token;
10301037
});
1038+
m.def("_xla_optimization_barrier",
1039+
[](const at::Tensor& input) { return OptimizationBarrier(input); });
10311040
m.def("_xla_set_default_device", [](const std::string& device) {
10321041
return SetCurrentThreadDevice(device);
10331042
});

torch_xla/csrc/ops/ops.cpp

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,7 @@
2525
#include "torch_xla/csrc/ops/permute.h"
2626
#include "torch_xla/csrc/ops/softmax_backward.h"
2727
#include "torch_xla/csrc/ops/sum.h"
28+
#include "torch_xla/csrc/ops/xla_ops.h"
2829
#include "torch_xla/csrc/pooling.h"
2930
#include "torch_xla/csrc/tensor_util.h"
3031
#include "torch_xla/csrc/torch_util.h"
@@ -1031,6 +1032,17 @@ NodePtr Softplus(const Value& input, const Value& beta,
10311032
std::move(lower_fn));
10321033
}
10331034

1035+
NodePtr OptimizationBarrier(const Value& input) {
1036+
auto lower_fn = [](const Node& node, LoweringContext* loctx) -> XlaOpVector {
1037+
xla::XlaOp xla_input = loctx->GetOutputOp(node.operand(0));
1038+
xla::XlaOp xla_output = xla::OptimizationBarrier(xla_input);
1039+
return node.ReturnOp(xla_output, loctx);
1040+
};
1041+
1042+
return GenericOp(xla_optimization_barrier, {input}, input.xla_shape(),
1043+
std::move(lower_fn));
1044+
}
1045+
10341046
} // namespace ops
10351047
} // namespace ir
10361048
} // namespace torch_xla

torch_xla/csrc/ops/ops.h

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -247,6 +247,8 @@ NodePtr SLogDet(const Value& input);
247247

248248
NodePtr Softplus(const Value& input, const Value& beta, const Value& threshold);
249249

250+
NodePtr OptimizationBarrier(const Value& input);
251+
250252
} // namespace ops
251253
} // namespace ir
252254
} // namespace torch_xla

torch_xla/csrc/ops/xla_ops.cpp

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@ const OpKindWrapper xla_get_dimensions_size("xla::xla_get_dimensions_size");
1818
const OpKindWrapper xla_moving_average("xla::moving_average");
1919
const OpKindWrapper xla_nms("xla::nms");
2020
const OpKindWrapper xla_not_supported("xla::not_supported");
21+
const OpKindWrapper xla_optimization_barrier("xla::optimization_barrier");
2122
const OpKindWrapper xla_reduce_scatter("xla::reduce_scatter");
2223
const OpKindWrapper xla_replication_pad("xla::replication_pad");
2324
const OpKindWrapper xla_replication_pad_backward(

torch_xla/csrc/ops/xla_ops.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -43,6 +43,7 @@ extern const OpKindWrapper xla_get_dimensions_size;
4343
extern const OpKindWrapper xla_moving_average;
4444
extern const OpKindWrapper xla_nms;
4545
extern const OpKindWrapper xla_not_supported;
46+
extern const OpKindWrapper xla_optimization_barrier;
4647
extern const OpKindWrapper xla_reduce_scatter;
4748
extern const OpKindWrapper xla_replication_pad;
4849
extern const OpKindWrapper xla_replication_pad_backward;

torch_xla/csrc/tensor.h

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -918,6 +918,8 @@ class XLATensor {
918918
static XLATensor not_supported(std::string description, xla::Shape shape,
919919
const Device& device);
920920

921+
static XLATensor optimization_barrier(const XLATensor& input);
922+
921923
// Permute the dimensions of this tensor according to the given permutation.
922924
static XLATensor permute(const XLATensor& input,
923925
absl::Span<const int64_t> dims);

torch_xla/csrc/tensor_methods.cpp

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2182,6 +2182,10 @@ XLATensor XLATensor::not_supported(std::string description, xla::Shape shape,
21822182
device);
21832183
}
21842184

2185+
XLATensor XLATensor::optimization_barrier(const XLATensor& input) {
2186+
return input.CreateFrom(ir::ops::OptimizationBarrier(input.GetIrValue()));
2187+
}
2188+
21852189
XLATensor XLATensor::permute(const XLATensor& input,
21862190
absl::Span<const int64_t> dims) {
21872191
auto input_shape = input.shape();

0 commit comments

Comments
 (0)