Skip to content

Manually register einsum on xla #8801

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 15 commits into from
Mar 7, 2025
3 changes: 2 additions & 1 deletion test/run_tests.sh
Original file line number Diff line number Diff line change
Expand Up @@ -249,11 +249,12 @@ function run_xla_op_tests3 {
run_test "$CDIR/test_persistent_cache.py"
run_test "$CDIR/test_devices.py"
run_device_detection_test "$CDIR/test_gpu_device_detection.py"
run_test "$CDIR/test_manual_xla_registration.py"
# NOTE: this line below is testing export and don't care about GPU
PJRT_DEVICE=CPU CPU_NUM_DEVICES=1 run_coverage "$CDIR/test_core_aten_ops.py"
run_test "$CDIR/test_pallas.py"
run_xla_ir_hlo_debug run_test "$CDIR/test_user_computation_debug_cache.py"

# Test examples
run_test "$CDIR/../examples/scan/scan_examples.py"

Expand Down
52 changes: 52 additions & 0 deletions test/test_manual_xla_registration.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,52 @@
import sys
import unittest
import torch
from torch import Tensor
from torch.library import custom_op
import torch_xla


@custom_op(
"xla::custom_einsum",
schema="(str function, Tensor input, Tensor weight) -> Tensor",
mutates_args=())
def custom_einsum(function: str, input: Tensor, weight: Tensor):
return torch.einsum(function, input, weight)


def is_einsum_lowered(func):
X = torch.zeros(3, 5, requires_grad=False, device='xla')
Y = torch.zeros(5, 7, requires_grad=False, device='xla')

out = func(X, Y)
ir = torch_xla._XLAC._get_xla_tensors_text([out])
return ir


class OperationLowered(unittest.TestCase):

def test_einsum_lowered(self):
for f in [torch.einsum, custom_einsum]:
with self.subTest(f=f):
ir = is_einsum_lowered(lambda a, b: f('...n,mn->...m', a, b))

self.assertIn("einsum", ir,
"Expected einsum to be in ir from it being lowered")
self.assertNotIn(
"permute", ir,
"Expected no permute to be in ir from it being lowered")

def test_einsum_not_lowered(self):
# 'ab,bc->ab' won't be lowered becaused it cannot be backpropagated
ir = is_einsum_lowered(lambda a, b: torch.einsum('ab,bc->ab', a, b))

self.assertNotIn(
"einsum", ir,
"Expected no einsum to be in ir from it not being lowered")
self.assertIn("permute", ir,
"Expected permute to be in ir from it not being lowered")


if __name__ == '__main__':
test = unittest.main()
sys.exit(0 if test.result.wasSuccessful() else 1)
11 changes: 0 additions & 11 deletions torch_xla/csrc/init_python_bindings.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1867,17 +1867,6 @@ void InitXlaModuleBindings(py::module m) {
m.def("_xla_optimization_barrier_",
[](std::vector<at::Tensor>& inputs) { OptimizationBarrier_(inputs); });

// TODO(https://github.com/pytorch/xla/issues/8713): torch.einsum is getting
// decomposed when inside a custom op. This C++ op is an escape hatch to call
// XLA einsum without going through torch.einsum. We should remove this
// operation when the linked bug is fixed.
m.def("_xla_einsum",
[](const std::string& equation, const std::vector<at::Tensor>& inputs) {
std::vector<XLATensorPtr> xla_tensors = bridge::GetXlaTensors(inputs);
XLATensorPtr output = tensor_methods::einsum(equation, xla_tensors);
return bridge::AtenFromXlaTensor(output);
});

// Creates a placeholder tensor that does not hold any device buffer.
// This is primarily useful for staging out the HLO of a user computation.
// Accessing the value of the tensor will panic.
Expand Down
7 changes: 7 additions & 0 deletions torch_xla/csrc/xla_manual_registration.cpp
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
#include <ATen/ATen.h>
#include <torch/library.h>

#include "torch_xla/csrc/XLANativeFunctions.h"
#include "torch_xla/csrc/aten_fallback.h"
#include "torch_xla/csrc/aten_xla_bridge.h"
#include "torch_xla/csrc/debug_util.h"
Expand Down Expand Up @@ -49,5 +50,11 @@ TORCH_LIBRARY_IMPL(torchvision, XLA, m) {
m.impl(TORCH_SELECTIVE_NAME("torchvision::nms"), TORCH_FN(nms_kernel));
}

// Register generated XLANativeFunctions::einsum as aten::einsum for XLA key.
// This utilizes the implementation from `xla/torch_xla/csrc/aten_xla_type.cpp`.
TORCH_LIBRARY_IMPL(aten, XLA, m) {
m.impl("aten::einsum", TORCH_FN(XLANativeFunctions::einsum));
}

} // namespace manual
} // namespace torch_xla
84 changes: 38 additions & 46 deletions torch_xla/distributed/spmd/xla_sharding.py
Original file line number Diff line number Diff line change
Expand Up @@ -676,11 +676,7 @@ def apply(self, t: torch.Tensor):
def _einsum_linear_forward(input: Tensor, weight: Tensor,
bias: Optional[Tensor]):
with xp.Trace('einsum_linear_forward'):
# TODO(https://github.com/pytorch/xla/issues/8713): torch.einsum is getting
# decomposed when inside a custom op. This C++ op is an escape hatch to call
# XLA einsum without going through torch.einsum. We should remove this
# _einsum escape hatch when the linked bug is fixed.
product = torch_xla._XLAC._xla_einsum('...n,mn->...m', (input, weight))
product = torch.einsum('...n,mn->...m', input, weight)
if bias is not None:
return product + bias
return product
Expand All @@ -695,6 +691,31 @@ def _einsum_linear_forward_fake(input: Tensor, weight: Tensor,
return product


def _einsum_linear_backward_operation(grad_output: Tensor, input: Tensor,
weight: Tensor, bias: Optional[Tensor],
needs_input_grad_input: bool,
needs_input_grad_weight: bool,
needs_input_grad_bias: bool):
grad_input = grad_weight = grad_bias = None

if needs_input_grad_input:
grad_input = torch.einsum('...m,mn->...n', grad_output, weight).clone()
else:
grad_input = None

if needs_input_grad_weight:
grad_weight = torch.einsum('...m,...n->mn', grad_output, input).clone()
else:
grad_weight = None

if bias is not None and needs_input_grad_bias:
grad_bias = torch.einsum('...m->m', grad_output).clone()
else:
grad_bias = None

return grad_input, grad_weight, grad_bias


@custom_op(
"xla::einsum_linear_backward",
schema="(Tensor grad_output, Tensor input, Tensor weight, Tensor? bias, bool needs_input_grad_input, bool needs_input_grad_weight, bool needs_input_grad_bias) -> (Tensor, Tensor, Tensor)",
Expand All @@ -705,26 +726,10 @@ def _einsum_linear_backward(grad_output: Tensor, input: Tensor, weight: Tensor,
needs_input_grad_weight: bool,
needs_input_grad_bias: bool):
with xp.Trace('einsum_linear_backward'):
grad_input = grad_weight = grad_bias = None

if needs_input_grad_input:
grad_input = torch_xla._XLAC._xla_einsum('...m,mn->...n',
(grad_output, weight))
else:
grad_input = None

if needs_input_grad_weight:
grad_weight = torch_xla._XLAC._xla_einsum('...m,...n->mn',
(grad_output, input))
else:
grad_weight = None

if bias is not None and needs_input_grad_bias:
grad_bias = torch_xla._XLAC._xla_einsum('...m->m', (grad_output,))
else:
grad_bias = None

return grad_input, grad_weight, grad_bias
return _einsum_linear_backward_operation(grad_output, input, weight, bias,
needs_input_grad_input,
needs_input_grad_weight,
needs_input_grad_bias)


@_einsum_linear_backward.register_fake
Expand All @@ -733,24 +738,11 @@ def _einsum_linear_backward_fake(grad_output: Tensor, input: Tensor,
needs_input_grad_input: bool,
needs_input_grad_weight: bool,
needs_input_grad_bias: bool):
grad_input = grad_weight = grad_bias = None

if needs_input_grad_input:
grad_input = torch.einsum('...m,mn->...n', grad_output, weight)
else:
grad_input = None

if needs_input_grad_weight:
grad_weight = torch.einsum('...m,...n->mn', grad_output, input)
else:
grad_weight = None

if bias is not None and needs_input_grad_bias:
grad_bias = torch.einsum('...m->m', grad_output)
else:
grad_bias = None

return grad_input, grad_weight, grad_bias
return _einsum_linear_backward_operation(grad_output, input, weight, bias,
needs_input_grad_input,
needs_input_grad_weight,
needs_input_grad_bias)


# Now define the XLAPatchedLinear function that uses the custom ops
Expand All @@ -765,8 +757,8 @@ class XLAPatchedLinear(torch.autograd.Function):
autocast context, when autocast is enabled.
torch.get_autocast_dtype() fetches datatype for ops run in autocast [2], with the specified device (here, 'xla').

References:
[1] https://pytorch.org/docs/stable/notes/amp_examples.html#functions-with-multiple-inputs-or-autocastable-ops
References:
[1] https://pytorch.org/docs/stable/notes/amp_examples.html#functions-with-multiple-inputs-or-autocastable-ops
[2] https://github.com/pytorch/pytorch/blob/2cc01cc6d3ad2aff47e8460667ba654b2e4c9f21/torch/amp/autocast_mode.py#L500

TODO (alanwaketan): Let's patch it on the dispatcher level.
Expand Down Expand Up @@ -1260,8 +1252,8 @@ class MarkShardingFunction(torch.autograd.Function):
Usage:
new_tensor = MarkShardingFunction.apply(tensor, mesh, ('axis_1', 'axis_2'))

This is required to guide GSPMD sharding propagation better during the
backward pass as during complicated workloads the compiler can introduce extra
This is required to guide GSPMD sharding propagation better during the
backward pass as during complicated workloads the compiler can introduce extra
collectives that can hurt performance.
"""

Expand Down
Loading