Skip to content

Commit 017ccd0

Browse files
committed
Initial changes to manually register einsum for XLA
1 parent 9b61c1a commit 017ccd0

File tree

3 files changed

+15
-21
lines changed

3 files changed

+15
-21
lines changed

torch_xla/csrc/init_python_bindings.cpp

Lines changed: 0 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -1867,17 +1867,6 @@ void InitXlaModuleBindings(py::module m) {
18671867
m.def("_xla_optimization_barrier_",
18681868
[](std::vector<at::Tensor>& inputs) { OptimizationBarrier_(inputs); });
18691869

1870-
// TODO(https://github.com/pytorch/xla/issues/8713): torch.einsum is getting
1871-
// decomposed when inside a custom op. This C++ op is an escape hatch to call
1872-
// XLA einsum without going through torch.einsum. We should remove this
1873-
// operation when the linked bug is fixed.
1874-
m.def("_xla_einsum",
1875-
[](const std::string& equation, const std::vector<at::Tensor>& inputs) {
1876-
std::vector<XLATensorPtr> xla_tensors = bridge::GetXlaTensors(inputs);
1877-
XLATensorPtr output = tensor_methods::einsum(equation, xla_tensors);
1878-
return bridge::AtenFromXlaTensor(output);
1879-
});
1880-
18811870
// Creates a placeholder tensor that does not hold any device buffer.
18821871
// This is primarily useful for staging out the HLO of a user computation.
18831872
// Accessing the value of the tensor will panic.

torch_xla/csrc/xla_manual_registration.cpp

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
#include <ATen/ATen.h>
22
#include <torch/library.h>
33

4+
#include "torch_xla/csrc/XLANativeFunctions.h"
45
#include "torch_xla/csrc/aten_fallback.h"
56
#include "torch_xla/csrc/aten_xla_bridge.h"
67
#include "torch_xla/csrc/debug_util.h"
@@ -49,5 +50,11 @@ TORCH_LIBRARY_IMPL(torchvision, XLA, m) {
4950
m.impl(TORCH_SELECTIVE_NAME("torchvision::nms"), TORCH_FN(nms_kernel));
5051
}
5152

53+
// Register generated XLANativeFunctions::einsum as aten::einsum for XLA key.
54+
// This utilizes the implementation from `xla/torch_xla/csrc/aten_xla_type.cpp`.
55+
TORCH_LIBRARY_IMPL(aten, XLA, m) {
56+
m.impl("aten::einsum", TORCH_FN(XLANativeFunctions::einsum));
57+
}
58+
5259
} // namespace manual
5360
} // namespace torch_xla

torch_xla/distributed/spmd/xla_sharding.py

Lines changed: 8 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -680,7 +680,7 @@ def _einsum_linear_forward(input: Tensor, weight: Tensor,
680680
# decomposed when inside a custom op. This C++ op is an escape hatch to call
681681
# XLA einsum without going through torch.einsum. We should remove this
682682
# _einsum escape hatch when the linked bug is fixed.
683-
product = torch_xla._XLAC._xla_einsum('...n,mn->...m', (input, weight))
683+
product = torch.einsum('...n,mn->...m', (input, weight))
684684
if bias is not None:
685685
return product + bias
686686
return product
@@ -708,19 +708,17 @@ def _einsum_linear_backward(grad_output: Tensor, input: Tensor, weight: Tensor,
708708
grad_input = grad_weight = grad_bias = None
709709

710710
if needs_input_grad_input:
711-
grad_input = torch_xla._XLAC._xla_einsum('...m,mn->...n',
712-
(grad_output, weight))
711+
grad_input = torch.einsum('...m,mn->...n', (grad_output, weight))
713712
else:
714713
grad_input = None
715714

716715
if needs_input_grad_weight:
717-
grad_weight = torch_xla._XLAC._xla_einsum('...m,...n->mn',
718-
(grad_output, input))
716+
grad_weight = torch.einsum('...m,...n->mn', (grad_output, input))
719717
else:
720718
grad_weight = None
721719

722720
if bias is not None and needs_input_grad_bias:
723-
grad_bias = torch_xla._XLAC._xla_einsum('...m->m', (grad_output,))
721+
grad_bias = torch.einsum('...m->m', (grad_output,))
724722
else:
725723
grad_bias = None
726724

@@ -765,8 +763,8 @@ class XLAPatchedLinear(torch.autograd.Function):
765763
autocast context, when autocast is enabled.
766764
torch.get_autocast_dtype() fetches datatype for ops run in autocast [2], with the specified device (here, 'xla').
767765
768-
References:
769-
[1] https://pytorch.org/docs/stable/notes/amp_examples.html#functions-with-multiple-inputs-or-autocastable-ops
766+
References:
767+
[1] https://pytorch.org/docs/stable/notes/amp_examples.html#functions-with-multiple-inputs-or-autocastable-ops
770768
[2] https://github.com/pytorch/pytorch/blob/2cc01cc6d3ad2aff47e8460667ba654b2e4c9f21/torch/amp/autocast_mode.py#L500
771769
772770
TODO (alanwaketan): Let's patch it on the dispatcher level.
@@ -1260,8 +1258,8 @@ class MarkShardingFunction(torch.autograd.Function):
12601258
Usage:
12611259
new_tensor = MarkShardingFunction.apply(tensor, mesh, ('axis_1', 'axis_2'))
12621260
1263-
This is required to guide GSPMD sharding propagation better during the
1264-
backward pass as during complicated workloads the compiler can introduce extra
1261+
This is required to guide GSPMD sharding propagation better during the
1262+
backward pass as during complicated workloads the compiler can introduce extra
12651263
collectives that can hurt performance.
12661264
"""
12671265

0 commit comments

Comments
 (0)