Skip to content

Commit c53acea

Browse files
committed
Support einsum layers in a scan
In order to propagate sharding annotations in 2D sharding, linear layers should be implemented with einsum instead of tranposes/reshapes. Additionally, they need to continue to function inside scan/scan_layers. For this to work we need three pieces: - I added a `apply_xla_patch_to_nn_linear` function to replace the implementation of `nn.Linear` with einsum (calling XLAPatchedLinear). - The XLAPatchedLinear implementation should be wrapped in torch custom ops. That's because AOTAutograd used by scan will decompose all einsums into transposes/reshapes, unless we use `@custom_op` to mark a function as opaque to AOTAutograd. - Even after wrapping them with `@custom_op`, the einsum is still decomposed into transposes/reshapes due to #8713. That's a bug/PyTorch limitation. To workaround this, I added a `_xla_einsum` C++ function that directly builds an einsum given XLA tensors, skipping over any PyTorch dispatcher complexity. Added a test that demonstrates how `nn.Linear` layers by default flattens any non-contracting dims, and how we could avoid that with `apply_xla_patch_to_nn_linear`.
1 parent 1acc987 commit c53acea

File tree

3 files changed

+221
-21
lines changed

3 files changed

+221
-21
lines changed

test/scan/test_scan_spmd.py

Lines changed: 63 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,13 @@
11
import sys
2+
import re
23
import unittest
34

45
import torch
56
import torch_xla
7+
import torch.nn as nn
8+
from torch_xla.distributed.spmd.xla_sharding import xla_patched_nn_linear_forward, apply_xla_patch_to_nn_linear
69
from torch_xla.experimental.scan import scan
10+
from torch_xla.experimental.scan_layers import scan_layers
711
from torch_xla.distributed.spmd import mark_sharding, set_global_mesh, get_1d_mesh
812
import torch_xla.runtime as xr
913

@@ -54,6 +58,65 @@ def fn(carry, x):
5458
f'devices=[1,{N}]0,',
5559
torch_xla._XLAC._get_xla_tensor_debug_info(tensor))
5660

61+
@unittest.skipUnless(xr.global_runtime_device_count() >= 4,
62+
"Multiple devices required")
63+
def test_scan_xla_patched_linear(self):
64+
"""
65+
When we use scan to trace `XLAPatchedLinear` layers, the lowered HLO should
66+
consist of einsum instead of reshapes and transposes. This is important for
67+
sharding constraint propagation.
68+
"""
69+
70+
# Create a model with a few linear layers.
71+
class MyModel(nn.Module):
72+
73+
def __init__(self):
74+
super().__init__()
75+
self.layers = nn.ModuleList([nn.Linear(128, 128) for _ in range(4)])
76+
77+
def forward(self, x: torch.Tensor):
78+
return scan_layers(self.layers, x)
79+
80+
model = MyModel().to('xla')
81+
# High dimensional input whose last dim is the contraction dim.
82+
x = torch.ones((3, 4, 5, 128), device='xla')
83+
torch_xla.sync()
84+
85+
# If we trace the `nn.Linear` without applying the einsum patch, the lowered
86+
# HLO will contain a `dot` operation where the input is flattened to 2D:
87+
# the `3, 4, 5, 128` shape is flattened to `60, 128`. This destroys any sharding
88+
# constraint applied to the first 3 dims.
89+
self.check_dots_in_model(
90+
model, x, expect_pattern=r'%dot\.\d+ = f32\[60,128\]')
91+
92+
# Once we patch the `nn.Linear` modules to use `einsum` and ensure that einsum is
93+
# lowered without getting unnecessarily decomposed, the HLO should contain a
94+
# `dot` operation that preserves the high dimensional structure. In turn, the
95+
# compiler will be able to preserve the sharding constraints on those dimensions.
96+
model = apply_xla_patch_to_nn_linear(model)
97+
self.check_dots_in_model(
98+
model, x, expect_pattern=r'%dot\.\d+ = f32\[3,4,5,128\]')
99+
100+
def check_dots_in_model(self, model, x, expect_pattern):
101+
# Trace the model to get the HLO.
102+
y = model(x)
103+
hlo_text: str = torch_xla._XLAC._get_xla_tensors_hlo([y])
104+
105+
count = self.count_regex(hlo_text, expect_pattern)
106+
assert count == 0 or count == 1, f"count = {count}"
107+
108+
if count == 1:
109+
# This is the expected case.
110+
pass
111+
else:
112+
raise RuntimeError(
113+
f"""Expected `nn.Linear` lowering to contain `{expect_pattern}`. Full HLO:
114+
{hlo_text}
115+
""")
116+
117+
def count_regex(self, hlo_text, regex_str):
118+
return len(re.findall(regex_str, hlo_text))
119+
57120

58121
if __name__ == '__main__':
59122
test = unittest.main()

torch_xla/csrc/init_python_bindings.cpp

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1848,6 +1848,16 @@ void InitXlaModuleBindings(py::module m) {
18481848
});
18491849
m.def("_xla_optimization_barrier_",
18501850
[](std::vector<at::Tensor>& inputs) { OptimizationBarrier_(inputs); });
1851+
// TODO(https://github.com/pytorch/xla/issues/8713): torch.einsum is getting
1852+
// decomposed when inside a custom op. This C++ op is an escape hatch to call
1853+
// XLA einsum without going through torch.einsum. We should remove this
1854+
// operation when the linked bug is fixed.
1855+
m.def("_xla_einsum",
1856+
[](const std::string& equation, const std::vector<at::Tensor>& inputs) {
1857+
std::vector<XLATensorPtr> xla_tensors = bridge::GetXlaTensors(inputs);
1858+
XLATensorPtr output = tensor_methods::einsum(equation, xla_tensors);
1859+
return bridge::AtenFromXlaTensor(output);
1860+
});
18511861
m.def("_xla_set_default_device", [](const std::string& device) {
18521862
return SetCurrentThreadDevice(device);
18531863
});

torch_xla/distributed/spmd/xla_sharding.py

Lines changed: 148 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -1,15 +1,17 @@
11
import collections
22
from collections.abc import Generator, MutableMapping
33
import math
4-
import os
54
from collections import OrderedDict, defaultdict
65
from dataclasses import dataclass, field
76
import torch
7+
from torch import Tensor
8+
from torch.library import custom_op
89
import torch_xla
910
import torch_xla.core.xla_model as xm
1011
import torch_xla._internal.utils as _utils
1112
from torch_xla.distributed.spmd import XLAShardedTensor, XLAShard
1213
import torch_xla.runtime as xr
14+
import torch_xla.debug.profiler as xp
1315

1416
import numpy as np
1517
import functools
@@ -663,17 +665,106 @@ def apply(self, t: torch.Tensor):
663665
mark_sharding(t, self.mesh, self.partition_spec)
664666

665667

668+
### Linear layer implementation backed by einsum.
669+
670+
671+
# A custom forward op that uses einsum internally
672+
@custom_op(
673+
"xla::einsum_linear_forward",
674+
schema="(Tensor input, Tensor weight, Tensor? bias) -> Tensor",
675+
mutates_args=())
676+
def _einsum_linear_forward(input: Tensor, weight: Tensor,
677+
bias: Optional[Tensor]):
678+
with xp.Trace('einsum_linear_forward'):
679+
# TODO(https://github.com/pytorch/xla/issues/8713): torch.einsum is getting
680+
# decomposed when inside a custom op. This C++ op is an escape hatch to call
681+
# XLA einsum without going through torch.einsum. We should remove this
682+
# _einsum escape hatch when the linked bug is fixed.
683+
product = torch_xla._XLAC._xla_einsum('...n,mn->...m', (input, weight))
684+
if bias is not None:
685+
return product + bias
686+
return product
687+
688+
689+
@_einsum_linear_forward.register_fake
690+
def _einsum_linear_forward_fake(input: Tensor, weight: Tensor,
691+
bias: Optional[Tensor]):
692+
product = torch.einsum('...n,mn->...m', input, weight)
693+
if bias is not None:
694+
return product + bias
695+
return product
696+
697+
698+
@custom_op(
699+
"xla::einsum_linear_backward",
700+
schema="(Tensor grad_output, Tensor input, Tensor weight, Tensor? bias, bool needs_input_grad_input, bool needs_input_grad_weight, bool needs_input_grad_bias) -> (Tensor, Tensor, Tensor)",
701+
mutates_args=())
702+
def _einsum_linear_backward(grad_output: Tensor, input: Tensor, weight: Tensor,
703+
bias: Optional[Tensor],
704+
needs_input_grad_input: bool,
705+
needs_input_grad_weight: bool,
706+
needs_input_grad_bias: bool):
707+
with xp.Trace('einsum_linear_backward'):
708+
grad_input = grad_weight = grad_bias = None
709+
710+
if needs_input_grad_input:
711+
grad_input = torch_xla._XLAC._xla_einsum('...m,mn->...n',
712+
(grad_output, weight))
713+
else:
714+
grad_input = None
715+
716+
if needs_input_grad_weight:
717+
grad_weight = torch_xla._XLAC._xla_einsum('...m,...n->mn',
718+
(grad_output, input))
719+
else:
720+
grad_weight = None
721+
722+
if bias is not None and needs_input_grad_bias:
723+
grad_bias = torch_xla._XLAC._xla_einsum('...m->m', (grad_output,))
724+
else:
725+
grad_bias = None
726+
727+
return grad_input, grad_weight, grad_bias
728+
729+
730+
@_einsum_linear_backward.register_fake
731+
def _einsum_linear_backward_fake(grad_output: Tensor, input: Tensor,
732+
weight: Tensor, bias: Optional[Tensor],
733+
needs_input_grad_input: bool,
734+
needs_input_grad_weight: bool,
735+
needs_input_grad_bias: bool):
736+
grad_input = grad_weight = grad_bias = None
737+
738+
if needs_input_grad_input:
739+
grad_input = torch.einsum('...m,mn->...n', grad_output, weight)
740+
else:
741+
grad_input = None
742+
743+
if needs_input_grad_weight:
744+
grad_weight = torch.einsum('...m,...n->mn', grad_output, input)
745+
else:
746+
grad_weight = None
747+
748+
if bias is not None and needs_input_grad_bias:
749+
grad_bias = torch.einsum('...m->m', grad_output)
750+
else:
751+
grad_bias = None
752+
753+
return grad_input, grad_weight, grad_bias
754+
755+
756+
# Now define the XLAPatchedLinear function that uses the custom ops
666757
class XLAPatchedLinear(torch.autograd.Function):
667758
"""
668759
A patched version of `torch.nn.functional.linear` that uses einsum instead
669760
of torch.matmul which will flatten the tensors to 2D and collide the sharded
670761
dimensions. The torch.matmul default behavior makes it very hard for XLA compiler
671762
to propagate the sharding annotation.
672763
673-
Autocast decorators @custom_fwd and @custom_bwd used as per autocast docs [1] to bring this class/layer within
764+
Autocast decorators @custom_fwd and @custom_bwd used as per autocast docs [1] to bring this class/layer within
674765
autocast context, when autocast is enabled.
675766
torch.get_autocast_dtype() fetches datatype for ops run in autocast [2], with the specified device (here, 'xla').
676-
767+
677768
References:
678769
[1] https://pytorch.org/docs/stable/notes/amp_examples.html#functions-with-multiple-inputs-or-autocastable-ops
679770
[2] https://github.com/pytorch/pytorch/blob/2cc01cc6d3ad2aff47e8460667ba654b2e4c9f21/torch/amp/autocast_mode.py#L500
@@ -683,35 +774,71 @@ class XLAPatchedLinear(torch.autograd.Function):
683774

684775
@staticmethod
685776
@custom_fwd(device_type='xla', cast_inputs=torch.get_autocast_dtype('xla'))
686-
def forward(ctx, input, weight, bias=None):
687-
# bias is an optional argument
777+
def forward(ctx,
778+
input: Tensor,
779+
weight: Tensor,
780+
bias: Optional[Tensor] = None):
688781
ctx.save_for_backward(input, weight, bias)
689-
with torch.no_grad():
690-
product = torch.einsum('...n,mn->...m', input, weight)
691-
if bias is None:
692-
return product
693-
return product + bias
782+
# Call our custom forward op. By wrapping the einsum in custom ops,
783+
# AOTAutograd won't decompose the einsum.
784+
return torch.ops.xla.einsum_linear_forward(input, weight, bias)
694785

695786
@staticmethod
696-
@custom_bwd(device_type='xla')
697-
def backward(ctx, grad_output):
787+
@custom_fwd(device_type='xla', cast_inputs=torch.get_autocast_dtype('xla'))
788+
def backward(ctx, grad_output: Tensor):
698789
input, weight, bias = ctx.saved_tensors
699-
grad_input = grad_weight = grad_bias = None
790+
needs_input_grad_input = ctx.needs_input_grad[0]
791+
needs_input_grad_weight = ctx.needs_input_grad[1]
792+
needs_input_grad_bias = False
793+
if bias is not None:
794+
needs_input_grad_bias = ctx.needs_input_grad[2]
700795

701-
if ctx.needs_input_grad[0]:
702-
grad_input = torch.einsum('...m,mn->...n', grad_output, weight)
703-
if ctx.needs_input_grad[1]:
704-
grad_weight = torch.einsum('...m,...n->mn', grad_output, input)
705-
if bias is not None and ctx.needs_input_grad[2]:
706-
grad_bias = torch.einsum('...m->m', grad_output)
707-
708-
return grad_input, grad_weight, grad_bias
796+
# Call our custom backward op with the boolean flags
797+
grad_input, grad_weight, grad_bias = torch.ops.xla.einsum_linear_backward(
798+
grad_output, input, weight, bias, needs_input_grad_input,
799+
needs_input_grad_weight, needs_input_grad_bias)
800+
return grad_input, grad_weight, grad_bias, None
709801

710802

711803
def xla_patched_nn_linear_forward(m, input):
712804
return XLAPatchedLinear.apply(input, m.weight, m.bias)
713805

714806

807+
class EinsumLinear(torch.nn.Linear):
808+
"""
809+
A `torch.nn.Linear` subclass implemented with `einsum`.
810+
"""
811+
812+
def __init__(self, *args, **kwargs):
813+
super().__init__(*args, **kwargs)
814+
815+
def forward(self, input):
816+
t = xla_patched_nn_linear_forward(self, input)
817+
assert isinstance(t, torch.Tensor)
818+
return t
819+
820+
821+
def apply_xla_patch_to_nn_linear(module: torch.nn.Module):
822+
"""
823+
Recursively replace `nn.Linear` layers with `EinsumLinear` in the module.
824+
825+
Without this patch, an `nn.Linear` module in PyTorch/XLA will lower to reshapes
826+
and transposes instead of einsum, thus compromising sharding propagation.
827+
"""
828+
for name, child in module.named_children():
829+
if isinstance(child,
830+
torch.nn.Linear) and not isinstance(child, EinsumLinear):
831+
einsum_linear = EinsumLinear(
832+
child.in_features, child.out_features, bias=child.bias is not None)
833+
einsum_linear.load_state_dict(
834+
child.state_dict(), strict=True, assign=True)
835+
setattr(module, name, einsum_linear)
836+
else:
837+
apply_xla_patch_to_nn_linear(child)
838+
839+
return module
840+
841+
715842
def apply_backward_optimization_barrier(m: torch.nn.Module):
716843
"""
717844
Register a full backward hook that apply an optimization barrier to the given module.

0 commit comments

Comments
 (0)