Skip to content

Commit 1db6c86

Browse files
authored
Manually register einsum on xla (#8801)
1 parent 7aea058 commit 1db6c86

File tree

5 files changed

+99
-58
lines changed

5 files changed

+99
-58
lines changed

test/run_tests.sh

+2-1
Original file line numberDiff line numberDiff line change
@@ -249,11 +249,12 @@ function run_xla_op_tests3 {
249249
run_test "$CDIR/test_persistent_cache.py"
250250
run_test "$CDIR/test_devices.py"
251251
run_device_detection_test "$CDIR/test_gpu_device_detection.py"
252+
run_test "$CDIR/test_manual_xla_registration.py"
252253
# NOTE: this line below is testing export and don't care about GPU
253254
PJRT_DEVICE=CPU CPU_NUM_DEVICES=1 run_coverage "$CDIR/test_core_aten_ops.py"
254255
run_test "$CDIR/test_pallas.py"
255256
run_xla_ir_hlo_debug run_test "$CDIR/test_user_computation_debug_cache.py"
256-
257+
257258
# Test examples
258259
run_test "$CDIR/../examples/scan/scan_examples.py"
259260

test/test_manual_xla_registration.py

+52
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,52 @@
1+
import sys
2+
import unittest
3+
import torch
4+
from torch import Tensor
5+
from torch.library import custom_op
6+
import torch_xla
7+
8+
9+
@custom_op(
10+
"xla::custom_einsum",
11+
schema="(str function, Tensor input, Tensor weight) -> Tensor",
12+
mutates_args=())
13+
def custom_einsum(function: str, input: Tensor, weight: Tensor):
14+
return torch.einsum(function, input, weight)
15+
16+
17+
def is_einsum_lowered(func):
18+
X = torch.zeros(3, 5, requires_grad=False, device='xla')
19+
Y = torch.zeros(5, 7, requires_grad=False, device='xla')
20+
21+
out = func(X, Y)
22+
ir = torch_xla._XLAC._get_xla_tensors_text([out])
23+
return ir
24+
25+
26+
class OperationLowered(unittest.TestCase):
27+
28+
def test_einsum_lowered(self):
29+
for f in [torch.einsum, custom_einsum]:
30+
with self.subTest(f=f):
31+
ir = is_einsum_lowered(lambda a, b: f('...n,mn->...m', a, b))
32+
33+
self.assertIn("einsum", ir,
34+
"Expected einsum to be in ir from it being lowered")
35+
self.assertNotIn(
36+
"permute", ir,
37+
"Expected no permute to be in ir from it being lowered")
38+
39+
def test_einsum_not_lowered(self):
40+
# 'ab,bc->ab' won't be lowered becaused it cannot be backpropagated
41+
ir = is_einsum_lowered(lambda a, b: torch.einsum('ab,bc->ab', a, b))
42+
43+
self.assertNotIn(
44+
"einsum", ir,
45+
"Expected no einsum to be in ir from it not being lowered")
46+
self.assertIn("permute", ir,
47+
"Expected permute to be in ir from it not being lowered")
48+
49+
50+
if __name__ == '__main__':
51+
test = unittest.main()
52+
sys.exit(0 if test.result.wasSuccessful() else 1)

torch_xla/csrc/init_python_bindings.cpp

-11
Original file line numberDiff line numberDiff line change
@@ -1867,17 +1867,6 @@ void InitXlaModuleBindings(py::module m) {
18671867
m.def("_xla_optimization_barrier_",
18681868
[](std::vector<at::Tensor>& inputs) { OptimizationBarrier_(inputs); });
18691869

1870-
// TODO(https://github.com/pytorch/xla/issues/8713): torch.einsum is getting
1871-
// decomposed when inside a custom op. This C++ op is an escape hatch to call
1872-
// XLA einsum without going through torch.einsum. We should remove this
1873-
// operation when the linked bug is fixed.
1874-
m.def("_xla_einsum",
1875-
[](const std::string& equation, const std::vector<at::Tensor>& inputs) {
1876-
std::vector<XLATensorPtr> xla_tensors = bridge::GetXlaTensors(inputs);
1877-
XLATensorPtr output = tensor_methods::einsum(equation, xla_tensors);
1878-
return bridge::AtenFromXlaTensor(output);
1879-
});
1880-
18811870
// Creates a placeholder tensor that does not hold any device buffer.
18821871
// This is primarily useful for staging out the HLO of a user computation.
18831872
// Accessing the value of the tensor will panic.

torch_xla/csrc/xla_manual_registration.cpp

+7
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
#include <ATen/ATen.h>
22
#include <torch/library.h>
33

4+
#include "torch_xla/csrc/XLANativeFunctions.h"
45
#include "torch_xla/csrc/aten_fallback.h"
56
#include "torch_xla/csrc/aten_xla_bridge.h"
67
#include "torch_xla/csrc/debug_util.h"
@@ -49,5 +50,11 @@ TORCH_LIBRARY_IMPL(torchvision, XLA, m) {
4950
m.impl(TORCH_SELECTIVE_NAME("torchvision::nms"), TORCH_FN(nms_kernel));
5051
}
5152

53+
// Register generated XLANativeFunctions::einsum as aten::einsum for XLA key.
54+
// This utilizes the implementation from `xla/torch_xla/csrc/aten_xla_type.cpp`.
55+
TORCH_LIBRARY_IMPL(aten, XLA, m) {
56+
m.impl("aten::einsum", TORCH_FN(XLANativeFunctions::einsum));
57+
}
58+
5259
} // namespace manual
5360
} // namespace torch_xla

torch_xla/distributed/spmd/xla_sharding.py

+38-46
Original file line numberDiff line numberDiff line change
@@ -676,11 +676,7 @@ def apply(self, t: torch.Tensor):
676676
def _einsum_linear_forward(input: Tensor, weight: Tensor,
677677
bias: Optional[Tensor]):
678678
with xp.Trace('einsum_linear_forward'):
679-
# TODO(https://github.com/pytorch/xla/issues/8713): torch.einsum is getting
680-
# decomposed when inside a custom op. This C++ op is an escape hatch to call
681-
# XLA einsum without going through torch.einsum. We should remove this
682-
# _einsum escape hatch when the linked bug is fixed.
683-
product = torch_xla._XLAC._xla_einsum('...n,mn->...m', (input, weight))
679+
product = torch.einsum('...n,mn->...m', input, weight)
684680
if bias is not None:
685681
return product + bias
686682
return product
@@ -695,6 +691,31 @@ def _einsum_linear_forward_fake(input: Tensor, weight: Tensor,
695691
return product
696692

697693

694+
def _einsum_linear_backward_operation(grad_output: Tensor, input: Tensor,
695+
weight: Tensor, bias: Optional[Tensor],
696+
needs_input_grad_input: bool,
697+
needs_input_grad_weight: bool,
698+
needs_input_grad_bias: bool):
699+
grad_input = grad_weight = grad_bias = None
700+
701+
if needs_input_grad_input:
702+
grad_input = torch.einsum('...m,mn->...n', grad_output, weight).clone()
703+
else:
704+
grad_input = None
705+
706+
if needs_input_grad_weight:
707+
grad_weight = torch.einsum('...m,...n->mn', grad_output, input).clone()
708+
else:
709+
grad_weight = None
710+
711+
if bias is not None and needs_input_grad_bias:
712+
grad_bias = torch.einsum('...m->m', grad_output).clone()
713+
else:
714+
grad_bias = None
715+
716+
return grad_input, grad_weight, grad_bias
717+
718+
698719
@custom_op(
699720
"xla::einsum_linear_backward",
700721
schema="(Tensor grad_output, Tensor input, Tensor weight, Tensor? bias, bool needs_input_grad_input, bool needs_input_grad_weight, bool needs_input_grad_bias) -> (Tensor, Tensor, Tensor)",
@@ -705,26 +726,10 @@ def _einsum_linear_backward(grad_output: Tensor, input: Tensor, weight: Tensor,
705726
needs_input_grad_weight: bool,
706727
needs_input_grad_bias: bool):
707728
with xp.Trace('einsum_linear_backward'):
708-
grad_input = grad_weight = grad_bias = None
709-
710-
if needs_input_grad_input:
711-
grad_input = torch_xla._XLAC._xla_einsum('...m,mn->...n',
712-
(grad_output, weight))
713-
else:
714-
grad_input = None
715-
716-
if needs_input_grad_weight:
717-
grad_weight = torch_xla._XLAC._xla_einsum('...m,...n->mn',
718-
(grad_output, input))
719-
else:
720-
grad_weight = None
721-
722-
if bias is not None and needs_input_grad_bias:
723-
grad_bias = torch_xla._XLAC._xla_einsum('...m->m', (grad_output,))
724-
else:
725-
grad_bias = None
726-
727-
return grad_input, grad_weight, grad_bias
729+
return _einsum_linear_backward_operation(grad_output, input, weight, bias,
730+
needs_input_grad_input,
731+
needs_input_grad_weight,
732+
needs_input_grad_bias)
728733

729734

730735
@_einsum_linear_backward.register_fake
@@ -733,24 +738,11 @@ def _einsum_linear_backward_fake(grad_output: Tensor, input: Tensor,
733738
needs_input_grad_input: bool,
734739
needs_input_grad_weight: bool,
735740
needs_input_grad_bias: bool):
736-
grad_input = grad_weight = grad_bias = None
737-
738-
if needs_input_grad_input:
739-
grad_input = torch.einsum('...m,mn->...n', grad_output, weight)
740-
else:
741-
grad_input = None
742741

743-
if needs_input_grad_weight:
744-
grad_weight = torch.einsum('...m,...n->mn', grad_output, input)
745-
else:
746-
grad_weight = None
747-
748-
if bias is not None and needs_input_grad_bias:
749-
grad_bias = torch.einsum('...m->m', grad_output)
750-
else:
751-
grad_bias = None
752-
753-
return grad_input, grad_weight, grad_bias
742+
return _einsum_linear_backward_operation(grad_output, input, weight, bias,
743+
needs_input_grad_input,
744+
needs_input_grad_weight,
745+
needs_input_grad_bias)
754746

755747

756748
# Now define the XLAPatchedLinear function that uses the custom ops
@@ -765,8 +757,8 @@ class XLAPatchedLinear(torch.autograd.Function):
765757
autocast context, when autocast is enabled.
766758
torch.get_autocast_dtype() fetches datatype for ops run in autocast [2], with the specified device (here, 'xla').
767759
768-
References:
769-
[1] https://pytorch.org/docs/stable/notes/amp_examples.html#functions-with-multiple-inputs-or-autocastable-ops
760+
References:
761+
[1] https://pytorch.org/docs/stable/notes/amp_examples.html#functions-with-multiple-inputs-or-autocastable-ops
770762
[2] https://github.com/pytorch/pytorch/blob/2cc01cc6d3ad2aff47e8460667ba654b2e4c9f21/torch/amp/autocast_mode.py#L500
771763
772764
TODO (alanwaketan): Let's patch it on the dispatcher level.
@@ -1260,8 +1252,8 @@ class MarkShardingFunction(torch.autograd.Function):
12601252
Usage:
12611253
new_tensor = MarkShardingFunction.apply(tensor, mesh, ('axis_1', 'axis_2'))
12621254
1263-
This is required to guide GSPMD sharding propagation better during the
1264-
backward pass as during complicated workloads the compiler can introduce extra
1255+
This is required to guide GSPMD sharding propagation better during the
1256+
backward pass as during complicated workloads the compiler can introduce extra
12651257
collectives that can hurt performance.
12661258
"""
12671259

0 commit comments

Comments
 (0)