Skip to content

Commit a54b0ed

Browse files
tengyifeipgmoka
authored andcommitted
Introduce apply_xla_patch_to_nn_linear and test that in a scan (#8739)
1 parent 53ed842 commit a54b0ed

File tree

4 files changed

+294
-20
lines changed

4 files changed

+294
-20
lines changed

test/scan/test_scan_spmd.py

Lines changed: 83 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,14 @@
1+
from copy import deepcopy
12
import sys
3+
import re
24
import unittest
35

46
import torch
57
import torch_xla
8+
import torch.nn as nn
9+
from torch_xla.distributed.spmd.xla_sharding import apply_xla_patch_to_nn_linear
610
from torch_xla.experimental.scan import scan
11+
from torch_xla.experimental.scan_layers import scan_layers
712
from torch_xla.distributed.spmd import mark_sharding, set_global_mesh, get_1d_mesh
813
import torch_xla.runtime as xr
914

@@ -54,6 +59,84 @@ def fn(carry, x):
5459
f'devices=[1,{N}]0,',
5560
torch_xla._XLAC._get_xla_tensor_debug_info(tensor))
5661

62+
@unittest.skipUnless(xr.global_runtime_device_count() >= 4,
63+
"Multiple devices required")
64+
def test_scan_xla_patched_linear(self):
65+
"""
66+
When we use scan to trace `XLAPatchedLinear` layers, the lowered HLO should
67+
consist of einsum instead of reshapes and transposes. This is important for
68+
sharding constraint propagation.
69+
"""
70+
71+
# Create a model with a few linear layers.
72+
class MyModel(nn.Module):
73+
74+
def __init__(self):
75+
super().__init__()
76+
self.layers = nn.Sequential(*[nn.Linear(128, 128) for _ in range(4)])
77+
self.use_scan = True
78+
79+
def forward(self, x: torch.Tensor):
80+
if self.use_scan:
81+
return scan_layers(self.layers, x)
82+
else:
83+
return self.layers(x)
84+
85+
model = MyModel().to('xla')
86+
# High dimensional input whose last dim is the contraction dim.
87+
torch_xla.manual_seed(42)
88+
x = torch.randn((3, 4, 5, 128), device='xla')
89+
torch_xla.sync()
90+
91+
# If we trace the `nn.Linear` without applying the einsum patch, the lowered
92+
# HLO will contain a `dot` operation where the input is flattened to 2D:
93+
# the `3, 4, 5, 128` shape is flattened to `60, 128`. This destroys any sharding
94+
# constraint applied to the first 3 dims.
95+
self.check_dots_in_model(
96+
model, x, expect_pattern=r'%dot\.\d+ = f32\[60,128\]')
97+
98+
# Once we patch the `nn.Linear` modules to use `einsum` and ensure that einsum is
99+
# lowered without getting unnecessarily decomposed, the HLO should contain a
100+
# `dot` operation that preserves the high dimensional structure. In turn, the
101+
# compiler will be able to preserve the sharding constraints on those dimensions.
102+
model = apply_xla_patch_to_nn_linear(model)
103+
self.check_dots_in_model(
104+
model, x, expect_pattern=r'%dot\.\d+ = f32\[3,4,5,128\]')
105+
106+
# Finally, test the numerics against an eager CPU impl.
107+
x = x.bfloat16()
108+
model = model.bfloat16()
109+
model_cpu = MyModel().bfloat16()
110+
model_cpu.load_state_dict(model.state_dict())
111+
model_cpu.to('cpu')
112+
model_cpu.use_scan = False
113+
torch_xla.sync()
114+
y_cpu = model_cpu(x.cpu())
115+
y_xla = model(x)
116+
117+
torch_xla.sync()
118+
torch.testing.assert_close(y_cpu, y_xla.cpu())
119+
120+
def check_dots_in_model(self, model, x, expect_pattern):
121+
# Trace the model to get the HLO.
122+
y = model(x)
123+
hlo_text: str = torch_xla._XLAC._get_xla_tensors_hlo([y])
124+
125+
count = self.count_regex(hlo_text, expect_pattern)
126+
assert count == 0 or count == 1, f"count = {count}"
127+
128+
if count == 1:
129+
# This is the expected case.
130+
pass
131+
else:
132+
raise RuntimeError(
133+
f"""Expected `nn.Linear` lowering to contain `{expect_pattern}`. Full HLO:
134+
{hlo_text}
135+
""")
136+
137+
def count_regex(self, hlo_text, regex_str):
138+
return len(re.findall(regex_str, hlo_text))
139+
57140

58141
if __name__ == '__main__':
59142
test = unittest.main()

test/spmd/test_xla_sharding.py

Lines changed: 54 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1421,6 +1421,60 @@ def test_fallback(self):
14211421
self.assertIn("Data Shape: c64[2048,8]\n OpSharding: {replicated}",
14221422
torch_xla._XLAC._get_xla_tensor_debug_info(freqs_cis))
14231423

1424+
def test_xla_patched_linear(self):
1425+
"""
1426+
Test the numerical accuracy of XLAPatchedLinear.
1427+
"""
1428+
1429+
from torch_xla.distributed.spmd.xla_sharding import XLAPatchedLinear
1430+
import torch_xla.runtime
1431+
import torch.nn.functional as F
1432+
1433+
with torch_xla.runtime.xla_device():
1434+
torch_xla.manual_seed(42)
1435+
x0 = torch.randn(2, 3, requires_grad=True)
1436+
w0 = torch.randn(4, 3, requires_grad=True)
1437+
b0 = torch.randn(4, requires_grad=True)
1438+
torch_xla.sync()
1439+
1440+
# Run `XLAPatchedLinear`.
1441+
1442+
x = x0.clone().detach().requires_grad_()
1443+
w = w0.clone().detach().requires_grad_()
1444+
b = b0.clone().detach().requires_grad_()
1445+
1446+
y = XLAPatchedLinear.apply(x, w, b)
1447+
assert y is not None
1448+
loss = y.sum()
1449+
loss.backward()
1450+
torch_xla.sync()
1451+
1452+
assert x.grad is not None
1453+
assert w.grad is not None
1454+
assert b.grad is not None
1455+
y1, xg1, wg1, bg1 = y.clone().detach(), x.grad.clone().detach(
1456+
), w.grad.clone().detach(), b.grad.clone().detach()
1457+
1458+
# Compare with `F.linear`.
1459+
1460+
x = x0.clone().detach().requires_grad_()
1461+
w = w0.clone().detach().requires_grad_()
1462+
b = b0.clone().detach().requires_grad_()
1463+
1464+
y = F.linear(x, w, b)
1465+
loss = y.sum()
1466+
loss.backward()
1467+
1468+
assert x.grad is not None
1469+
assert w.grad is not None
1470+
assert b.grad is not None
1471+
y2, xg2, wg2, bg2 = y.clone().detach(), x.grad.clone().detach(
1472+
), w.grad.clone().detach(), b.grad.clone().detach()
1473+
torch.testing.assert_close(y1, y2)
1474+
torch.testing.assert_close(xg1, xg2)
1475+
torch.testing.assert_close(wg1, wg2)
1476+
torch.testing.assert_close(bg1, bg2)
1477+
14241478

14251479
if __name__ == '__main__':
14261480
test = unittest.main()

torch_xla/csrc/init_python_bindings.cpp

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1848,6 +1848,16 @@ void InitXlaModuleBindings(py::module m) {
18481848
});
18491849
m.def("_xla_optimization_barrier_",
18501850
[](std::vector<at::Tensor>& inputs) { OptimizationBarrier_(inputs); });
1851+
// TODO(https://github.com/pytorch/xla/issues/8713): torch.einsum is getting
1852+
// decomposed when inside a custom op. This C++ op is an escape hatch to call
1853+
// XLA einsum without going through torch.einsum. We should remove this
1854+
// operation when the linked bug is fixed.
1855+
m.def("_xla_einsum",
1856+
[](const std::string& equation, const std::vector<at::Tensor>& inputs) {
1857+
std::vector<XLATensorPtr> xla_tensors = bridge::GetXlaTensors(inputs);
1858+
XLATensorPtr output = tensor_methods::einsum(equation, xla_tensors);
1859+
return bridge::AtenFromXlaTensor(output);
1860+
});
18511861
m.def("_xla_set_default_device", [](const std::string& device) {
18521862
return SetCurrentThreadDevice(device);
18531863
});

torch_xla/distributed/spmd/xla_sharding.py

Lines changed: 147 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -1,15 +1,17 @@
11
import collections
22
from collections.abc import Generator, MutableMapping
33
import math
4-
import os
54
from collections import OrderedDict, defaultdict
65
from dataclasses import dataclass, field
76
import torch
7+
from torch import Tensor
8+
from torch.library import custom_op
89
import torch_xla
910
import torch_xla.core.xla_model as xm
1011
import torch_xla._internal.utils as _utils
1112
from torch_xla.distributed.spmd import XLAShardedTensor, XLAShard
1213
import torch_xla.runtime as xr
14+
import torch_xla.debug.profiler as xp
1315

1416
import numpy as np
1517
import functools
@@ -663,17 +665,106 @@ def apply(self, t: torch.Tensor):
663665
mark_sharding(t, self.mesh, self.partition_spec)
664666

665667

668+
### Linear layer implementation backed by einsum.
669+
670+
671+
# A custom forward op that uses einsum internally
672+
@custom_op(
673+
"xla::einsum_linear_forward",
674+
schema="(Tensor input, Tensor weight, Tensor? bias) -> Tensor",
675+
mutates_args=())
676+
def _einsum_linear_forward(input: Tensor, weight: Tensor,
677+
bias: Optional[Tensor]):
678+
with xp.Trace('einsum_linear_forward'):
679+
# TODO(https://github.com/pytorch/xla/issues/8713): torch.einsum is getting
680+
# decomposed when inside a custom op. This C++ op is an escape hatch to call
681+
# XLA einsum without going through torch.einsum. We should remove this
682+
# _einsum escape hatch when the linked bug is fixed.
683+
product = torch_xla._XLAC._xla_einsum('...n,mn->...m', (input, weight))
684+
if bias is not None:
685+
return product + bias
686+
return product
687+
688+
689+
@_einsum_linear_forward.register_fake
690+
def _einsum_linear_forward_fake(input: Tensor, weight: Tensor,
691+
bias: Optional[Tensor]):
692+
product = torch.einsum('...n,mn->...m', input, weight)
693+
if bias is not None:
694+
return product + bias
695+
return product
696+
697+
698+
@custom_op(
699+
"xla::einsum_linear_backward",
700+
schema="(Tensor grad_output, Tensor input, Tensor weight, Tensor? bias, bool needs_input_grad_input, bool needs_input_grad_weight, bool needs_input_grad_bias) -> (Tensor, Tensor, Tensor)",
701+
mutates_args=())
702+
def _einsum_linear_backward(grad_output: Tensor, input: Tensor, weight: Tensor,
703+
bias: Optional[Tensor],
704+
needs_input_grad_input: bool,
705+
needs_input_grad_weight: bool,
706+
needs_input_grad_bias: bool):
707+
with xp.Trace('einsum_linear_backward'):
708+
grad_input = grad_weight = grad_bias = None
709+
710+
if needs_input_grad_input:
711+
grad_input = torch_xla._XLAC._xla_einsum('...m,mn->...n',
712+
(grad_output, weight))
713+
else:
714+
grad_input = None
715+
716+
if needs_input_grad_weight:
717+
grad_weight = torch_xla._XLAC._xla_einsum('...m,...n->mn',
718+
(grad_output, input))
719+
else:
720+
grad_weight = None
721+
722+
if bias is not None and needs_input_grad_bias:
723+
grad_bias = torch_xla._XLAC._xla_einsum('...m->m', (grad_output,))
724+
else:
725+
grad_bias = None
726+
727+
return grad_input, grad_weight, grad_bias
728+
729+
730+
@_einsum_linear_backward.register_fake
731+
def _einsum_linear_backward_fake(grad_output: Tensor, input: Tensor,
732+
weight: Tensor, bias: Optional[Tensor],
733+
needs_input_grad_input: bool,
734+
needs_input_grad_weight: bool,
735+
needs_input_grad_bias: bool):
736+
grad_input = grad_weight = grad_bias = None
737+
738+
if needs_input_grad_input:
739+
grad_input = torch.einsum('...m,mn->...n', grad_output, weight)
740+
else:
741+
grad_input = None
742+
743+
if needs_input_grad_weight:
744+
grad_weight = torch.einsum('...m,...n->mn', grad_output, input)
745+
else:
746+
grad_weight = None
747+
748+
if bias is not None and needs_input_grad_bias:
749+
grad_bias = torch.einsum('...m->m', grad_output)
750+
else:
751+
grad_bias = None
752+
753+
return grad_input, grad_weight, grad_bias
754+
755+
756+
# Now define the XLAPatchedLinear function that uses the custom ops
666757
class XLAPatchedLinear(torch.autograd.Function):
667758
"""
668759
A patched version of `torch.nn.functional.linear` that uses einsum instead
669760
of torch.matmul which will flatten the tensors to 2D and collide the sharded
670761
dimensions. The torch.matmul default behavior makes it very hard for XLA compiler
671762
to propagate the sharding annotation.
672763
673-
Autocast decorators @custom_fwd and @custom_bwd used as per autocast docs [1] to bring this class/layer within
764+
Autocast decorators @custom_fwd and @custom_bwd used as per autocast docs [1] to bring this class/layer within
674765
autocast context, when autocast is enabled.
675766
torch.get_autocast_dtype() fetches datatype for ops run in autocast [2], with the specified device (here, 'xla').
676-
767+
677768
References:
678769
[1] https://pytorch.org/docs/stable/notes/amp_examples.html#functions-with-multiple-inputs-or-autocastable-ops
679770
[2] https://github.com/pytorch/pytorch/blob/2cc01cc6d3ad2aff47e8460667ba654b2e4c9f21/torch/amp/autocast_mode.py#L500
@@ -683,35 +774,71 @@ class XLAPatchedLinear(torch.autograd.Function):
683774

684775
@staticmethod
685776
@custom_fwd(device_type='xla', cast_inputs=torch.get_autocast_dtype('xla'))
686-
def forward(ctx, input, weight, bias=None):
687-
# bias is an optional argument
777+
def forward(ctx,
778+
input: Tensor,
779+
weight: Tensor,
780+
bias: Optional[Tensor] = None):
688781
ctx.save_for_backward(input, weight, bias)
689-
with torch.no_grad():
690-
product = torch.einsum('...n,mn->...m', input, weight)
691-
if bias is None:
692-
return product
693-
return product + bias
782+
# Call our custom forward op. By wrapping the einsum in custom ops,
783+
# AOTAutograd won't decompose the einsum.
784+
return torch.ops.xla.einsum_linear_forward(input, weight, bias)
694785

695786
@staticmethod
696787
@custom_bwd(device_type='xla')
697-
def backward(ctx, grad_output):
788+
def backward(ctx, grad_output: Tensor):
698789
input, weight, bias = ctx.saved_tensors
699-
grad_input = grad_weight = grad_bias = None
790+
needs_input_grad_input = ctx.needs_input_grad[0]
791+
needs_input_grad_weight = ctx.needs_input_grad[1]
792+
needs_input_grad_bias = False
793+
if bias is not None:
794+
needs_input_grad_bias = ctx.needs_input_grad[2]
700795

701-
if ctx.needs_input_grad[0]:
702-
grad_input = torch.einsum('...m,mn->...n', grad_output, weight)
703-
if ctx.needs_input_grad[1]:
704-
grad_weight = torch.einsum('...m,...n->mn', grad_output, input)
705-
if bias is not None and ctx.needs_input_grad[2]:
706-
grad_bias = torch.einsum('...m->m', grad_output)
707-
708-
return grad_input, grad_weight, grad_bias
796+
# Call our custom backward op with the boolean flags
797+
grad_input, grad_weight, grad_bias = torch.ops.xla.einsum_linear_backward(
798+
grad_output, input, weight, bias, needs_input_grad_input,
799+
needs_input_grad_weight, needs_input_grad_bias)
800+
return grad_input, grad_weight, grad_bias, None
709801

710802

711803
def xla_patched_nn_linear_forward(m, input):
712804
return XLAPatchedLinear.apply(input, m.weight, m.bias)
713805

714806

807+
class EinsumLinear(torch.nn.Linear):
808+
"""
809+
A `torch.nn.Linear` subclass implemented with `einsum`.
810+
"""
811+
812+
def __init__(self, *args, **kwargs):
813+
super().__init__(*args, **kwargs)
814+
815+
def forward(self, input):
816+
t = xla_patched_nn_linear_forward(self, input)
817+
assert isinstance(t, torch.Tensor)
818+
return t
819+
820+
821+
def apply_xla_patch_to_nn_linear(module: torch.nn.Module):
822+
"""
823+
Recursively replace `nn.Linear` layers with `EinsumLinear` in the module.
824+
825+
Without this patch, an `nn.Linear` module in PyTorch/XLA will lower to reshapes
826+
and transposes instead of einsum, thus compromising sharding propagation.
827+
"""
828+
for name, child in module.named_children():
829+
if isinstance(child,
830+
torch.nn.Linear) and not isinstance(child, EinsumLinear):
831+
einsum_linear = EinsumLinear(
832+
child.in_features, child.out_features, bias=child.bias is not None)
833+
einsum_linear.load_state_dict(
834+
child.state_dict(), strict=True, assign=True)
835+
setattr(module, name, einsum_linear)
836+
else:
837+
apply_xla_patch_to_nn_linear(child)
838+
839+
return module
840+
841+
715842
def apply_backward_optimization_barrier(m: torch.nn.Module):
716843
"""
717844
Register a full backward hook that apply an optimization barrier to the given module.

0 commit comments

Comments
 (0)