Skip to content

Introduce apply_xla_patch_to_nn_linear and test that in a scan #8739

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Feb 25, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
83 changes: 83 additions & 0 deletions test/scan/test_scan_spmd.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,14 @@
from copy import deepcopy
import sys
import re
import unittest

import torch
import torch_xla
import torch.nn as nn
from torch_xla.distributed.spmd.xla_sharding import apply_xla_patch_to_nn_linear
from torch_xla.experimental.scan import scan
from torch_xla.experimental.scan_layers import scan_layers
from torch_xla.distributed.spmd import mark_sharding, set_global_mesh, get_1d_mesh
import torch_xla.runtime as xr

Expand Down Expand Up @@ -54,6 +59,84 @@ def fn(carry, x):
f'devices=[1,{N}]0,',
torch_xla._XLAC._get_xla_tensor_debug_info(tensor))

@unittest.skipUnless(xr.global_runtime_device_count() >= 4,
"Multiple devices required")
def test_scan_xla_patched_linear(self):
"""
When we use scan to trace `XLAPatchedLinear` layers, the lowered HLO should
consist of einsum instead of reshapes and transposes. This is important for
sharding constraint propagation.
"""

# Create a model with a few linear layers.
class MyModel(nn.Module):

def __init__(self):
super().__init__()
self.layers = nn.Sequential(*[nn.Linear(128, 128) for _ in range(4)])
self.use_scan = True

def forward(self, x: torch.Tensor):
if self.use_scan:
return scan_layers(self.layers, x)
else:
return self.layers(x)

model = MyModel().to('xla')
# High dimensional input whose last dim is the contraction dim.
torch_xla.manual_seed(42)
x = torch.randn((3, 4, 5, 128), device='xla')
torch_xla.sync()

# If we trace the `nn.Linear` without applying the einsum patch, the lowered
# HLO will contain a `dot` operation where the input is flattened to 2D:
# the `3, 4, 5, 128` shape is flattened to `60, 128`. This destroys any sharding
# constraint applied to the first 3 dims.
self.check_dots_in_model(
model, x, expect_pattern=r'%dot\.\d+ = f32\[60,128\]')

# Once we patch the `nn.Linear` modules to use `einsum` and ensure that einsum is
# lowered without getting unnecessarily decomposed, the HLO should contain a
# `dot` operation that preserves the high dimensional structure. In turn, the
# compiler will be able to preserve the sharding constraints on those dimensions.
model = apply_xla_patch_to_nn_linear(model)
self.check_dots_in_model(
model, x, expect_pattern=r'%dot\.\d+ = f32\[3,4,5,128\]')

# Finally, test the numerics against an eager CPU impl.
x = x.bfloat16()
model = model.bfloat16()
model_cpu = MyModel().bfloat16()
model_cpu.load_state_dict(model.state_dict())
model_cpu.to('cpu')
model_cpu.use_scan = False
torch_xla.sync()
y_cpu = model_cpu(x.cpu())
y_xla = model(x)

torch_xla.sync()
torch.testing.assert_close(y_cpu, y_xla.cpu())

def check_dots_in_model(self, model, x, expect_pattern):
# Trace the model to get the HLO.
y = model(x)
hlo_text: str = torch_xla._XLAC._get_xla_tensors_hlo([y])

count = self.count_regex(hlo_text, expect_pattern)
assert count == 0 or count == 1, f"count = {count}"

if count == 1:
# This is the expected case.
pass
else:
raise RuntimeError(
f"""Expected `nn.Linear` lowering to contain `{expect_pattern}`. Full HLO:
{hlo_text}
""")

def count_regex(self, hlo_text, regex_str):
return len(re.findall(regex_str, hlo_text))


if __name__ == '__main__':
test = unittest.main()
Expand Down
54 changes: 54 additions & 0 deletions test/spmd/test_xla_sharding.py
Original file line number Diff line number Diff line change
Expand Up @@ -1421,6 +1421,60 @@ def test_fallback(self):
self.assertIn("Data Shape: c64[2048,8]\n OpSharding: {replicated}",
torch_xla._XLAC._get_xla_tensor_debug_info(freqs_cis))

def test_xla_patched_linear(self):
"""
Test the numerical accuracy of XLAPatchedLinear.
"""

from torch_xla.distributed.spmd.xla_sharding import XLAPatchedLinear
import torch_xla.runtime
import torch.nn.functional as F

with torch_xla.runtime.xla_device():
torch_xla.manual_seed(42)
x0 = torch.randn(2, 3, requires_grad=True)
w0 = torch.randn(4, 3, requires_grad=True)
b0 = torch.randn(4, requires_grad=True)
torch_xla.sync()

# Run `XLAPatchedLinear`.

x = x0.clone().detach().requires_grad_()
w = w0.clone().detach().requires_grad_()
b = b0.clone().detach().requires_grad_()

y = XLAPatchedLinear.apply(x, w, b)
assert y is not None
loss = y.sum()
loss.backward()
torch_xla.sync()

assert x.grad is not None
assert w.grad is not None
assert b.grad is not None
y1, xg1, wg1, bg1 = y.clone().detach(), x.grad.clone().detach(
), w.grad.clone().detach(), b.grad.clone().detach()

# Compare with `F.linear`.

x = x0.clone().detach().requires_grad_()
w = w0.clone().detach().requires_grad_()
b = b0.clone().detach().requires_grad_()

y = F.linear(x, w, b)
loss = y.sum()
loss.backward()

assert x.grad is not None
assert w.grad is not None
assert b.grad is not None
y2, xg2, wg2, bg2 = y.clone().detach(), x.grad.clone().detach(
), w.grad.clone().detach(), b.grad.clone().detach()
torch.testing.assert_close(y1, y2)
torch.testing.assert_close(xg1, xg2)
torch.testing.assert_close(wg1, wg2)
torch.testing.assert_close(bg1, bg2)


if __name__ == '__main__':
test = unittest.main()
Expand Down
10 changes: 10 additions & 0 deletions torch_xla/csrc/init_python_bindings.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1848,6 +1848,16 @@ void InitXlaModuleBindings(py::module m) {
});
m.def("_xla_optimization_barrier_",
[](std::vector<at::Tensor>& inputs) { OptimizationBarrier_(inputs); });
// TODO(https://github.com/pytorch/xla/issues/8713): torch.einsum is getting
// decomposed when inside a custom op. This C++ op is an escape hatch to call
// XLA einsum without going through torch.einsum. We should remove this
// operation when the linked bug is fixed.
m.def("_xla_einsum",
[](const std::string& equation, const std::vector<at::Tensor>& inputs) {
std::vector<XLATensorPtr> xla_tensors = bridge::GetXlaTensors(inputs);
XLATensorPtr output = tensor_methods::einsum(equation, xla_tensors);
return bridge::AtenFromXlaTensor(output);
});
m.def("_xla_set_default_device", [](const std::string& device) {
return SetCurrentThreadDevice(device);
});
Expand Down
167 changes: 147 additions & 20 deletions torch_xla/distributed/spmd/xla_sharding.py
Original file line number Diff line number Diff line change
@@ -1,15 +1,17 @@
import collections
from collections.abc import Generator, MutableMapping
import math
import os
from collections import OrderedDict, defaultdict
from dataclasses import dataclass, field
import torch
from torch import Tensor
from torch.library import custom_op
import torch_xla
import torch_xla.core.xla_model as xm
import torch_xla._internal.utils as _utils
from torch_xla.distributed.spmd import XLAShardedTensor, XLAShard
import torch_xla.runtime as xr
import torch_xla.debug.profiler as xp

import numpy as np
import functools
Expand Down Expand Up @@ -663,17 +665,106 @@ def apply(self, t: torch.Tensor):
mark_sharding(t, self.mesh, self.partition_spec)


### Linear layer implementation backed by einsum.


# A custom forward op that uses einsum internally
@custom_op(
"xla::einsum_linear_forward",
schema="(Tensor input, Tensor weight, Tensor? bias) -> Tensor",
mutates_args=())
def _einsum_linear_forward(input: Tensor, weight: Tensor,
bias: Optional[Tensor]):
with xp.Trace('einsum_linear_forward'):
# TODO(https://github.com/pytorch/xla/issues/8713): torch.einsum is getting
# decomposed when inside a custom op. This C++ op is an escape hatch to call
# XLA einsum without going through torch.einsum. We should remove this
# _einsum escape hatch when the linked bug is fixed.
product = torch_xla._XLAC._xla_einsum('...n,mn->...m', (input, weight))
if bias is not None:
return product + bias
return product


@_einsum_linear_forward.register_fake
def _einsum_linear_forward_fake(input: Tensor, weight: Tensor,
bias: Optional[Tensor]):
product = torch.einsum('...n,mn->...m', input, weight)
if bias is not None:
return product + bias
return product


@custom_op(
"xla::einsum_linear_backward",
schema="(Tensor grad_output, Tensor input, Tensor weight, Tensor? bias, bool needs_input_grad_input, bool needs_input_grad_weight, bool needs_input_grad_bias) -> (Tensor, Tensor, Tensor)",
mutates_args=())
def _einsum_linear_backward(grad_output: Tensor, input: Tensor, weight: Tensor,
bias: Optional[Tensor],
needs_input_grad_input: bool,
needs_input_grad_weight: bool,
needs_input_grad_bias: bool):
with xp.Trace('einsum_linear_backward'):
grad_input = grad_weight = grad_bias = None

if needs_input_grad_input:
grad_input = torch_xla._XLAC._xla_einsum('...m,mn->...n',
(grad_output, weight))
else:
grad_input = None

if needs_input_grad_weight:
grad_weight = torch_xla._XLAC._xla_einsum('...m,...n->mn',
(grad_output, input))
else:
grad_weight = None

if bias is not None and needs_input_grad_bias:
grad_bias = torch_xla._XLAC._xla_einsum('...m->m', (grad_output,))
else:
grad_bias = None

return grad_input, grad_weight, grad_bias


@_einsum_linear_backward.register_fake
def _einsum_linear_backward_fake(grad_output: Tensor, input: Tensor,
weight: Tensor, bias: Optional[Tensor],
needs_input_grad_input: bool,
needs_input_grad_weight: bool,
needs_input_grad_bias: bool):
grad_input = grad_weight = grad_bias = None

if needs_input_grad_input:
grad_input = torch.einsum('...m,mn->...n', grad_output, weight)
else:
grad_input = None

if needs_input_grad_weight:
grad_weight = torch.einsum('...m,...n->mn', grad_output, input)
else:
grad_weight = None

if bias is not None and needs_input_grad_bias:
grad_bias = torch.einsum('...m->m', grad_output)
else:
grad_bias = None

return grad_input, grad_weight, grad_bias


# Now define the XLAPatchedLinear function that uses the custom ops
class XLAPatchedLinear(torch.autograd.Function):
"""
A patched version of `torch.nn.functional.linear` that uses einsum instead
of torch.matmul which will flatten the tensors to 2D and collide the sharded
dimensions. The torch.matmul default behavior makes it very hard for XLA compiler
to propagate the sharding annotation.

Autocast decorators @custom_fwd and @custom_bwd used as per autocast docs [1] to bring this class/layer within
Autocast decorators @custom_fwd and @custom_bwd used as per autocast docs [1] to bring this class/layer within
autocast context, when autocast is enabled.
torch.get_autocast_dtype() fetches datatype for ops run in autocast [2], with the specified device (here, 'xla').

References:
[1] https://pytorch.org/docs/stable/notes/amp_examples.html#functions-with-multiple-inputs-or-autocastable-ops
[2] https://github.com/pytorch/pytorch/blob/2cc01cc6d3ad2aff47e8460667ba654b2e4c9f21/torch/amp/autocast_mode.py#L500
Expand All @@ -683,35 +774,71 @@ class XLAPatchedLinear(torch.autograd.Function):

@staticmethod
@custom_fwd(device_type='xla', cast_inputs=torch.get_autocast_dtype('xla'))
def forward(ctx, input, weight, bias=None):
# bias is an optional argument
def forward(ctx,
input: Tensor,
weight: Tensor,
bias: Optional[Tensor] = None):
ctx.save_for_backward(input, weight, bias)
with torch.no_grad():
product = torch.einsum('...n,mn->...m', input, weight)
if bias is None:
return product
return product + bias
# Call our custom forward op. By wrapping the einsum in custom ops,
# AOTAutograd won't decompose the einsum.
return torch.ops.xla.einsum_linear_forward(input, weight, bias)

@staticmethod
@custom_bwd(device_type='xla')
def backward(ctx, grad_output):
def backward(ctx, grad_output: Tensor):
input, weight, bias = ctx.saved_tensors
grad_input = grad_weight = grad_bias = None
needs_input_grad_input = ctx.needs_input_grad[0]
needs_input_grad_weight = ctx.needs_input_grad[1]
needs_input_grad_bias = False
if bias is not None:
needs_input_grad_bias = ctx.needs_input_grad[2]

if ctx.needs_input_grad[0]:
grad_input = torch.einsum('...m,mn->...n', grad_output, weight)
if ctx.needs_input_grad[1]:
grad_weight = torch.einsum('...m,...n->mn', grad_output, input)
if bias is not None and ctx.needs_input_grad[2]:
grad_bias = torch.einsum('...m->m', grad_output)

return grad_input, grad_weight, grad_bias
# Call our custom backward op with the boolean flags
grad_input, grad_weight, grad_bias = torch.ops.xla.einsum_linear_backward(
grad_output, input, weight, bias, needs_input_grad_input,
needs_input_grad_weight, needs_input_grad_bias)
return grad_input, grad_weight, grad_bias, None


def xla_patched_nn_linear_forward(m, input):
return XLAPatchedLinear.apply(input, m.weight, m.bias)


class EinsumLinear(torch.nn.Linear):
"""
A `torch.nn.Linear` subclass implemented with `einsum`.
"""

def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)

def forward(self, input):
t = xla_patched_nn_linear_forward(self, input)
assert isinstance(t, torch.Tensor)
return t


def apply_xla_patch_to_nn_linear(module: torch.nn.Module):
"""
Recursively replace `nn.Linear` layers with `EinsumLinear` in the module.

Without this patch, an `nn.Linear` module in PyTorch/XLA will lower to reshapes
and transposes instead of einsum, thus compromising sharding propagation.
"""
for name, child in module.named_children():
if isinstance(child,
torch.nn.Linear) and not isinstance(child, EinsumLinear):
einsum_linear = EinsumLinear(
child.in_features, child.out_features, bias=child.bias is not None)
einsum_linear.load_state_dict(
child.state_dict(), strict=True, assign=True)
setattr(module, name, einsum_linear)
else:
apply_xla_patch_to_nn_linear(child)

return module


def apply_backward_optimization_barrier(m: torch.nn.Module):
"""
Register a full backward hook that apply an optimization barrier to the given module.
Expand Down
Loading