1
1
import collections
2
2
from collections .abc import Generator , MutableMapping
3
3
import math
4
- import os
5
4
from collections import OrderedDict , defaultdict
6
5
from dataclasses import dataclass , field
7
6
import torch
7
+ from torch import Tensor
8
+ from torch .library import custom_op
8
9
import torch_xla
9
10
import torch_xla .core .xla_model as xm
10
11
import torch_xla ._internal .utils as _utils
11
12
from torch_xla .distributed .spmd import XLAShardedTensor , XLAShard
12
13
import torch_xla .runtime as xr
14
+ import torch_xla .debug .profiler as xp
13
15
14
16
import numpy as np
15
17
import functools
@@ -663,17 +665,106 @@ def apply(self, t: torch.Tensor):
663
665
mark_sharding (t , self .mesh , self .partition_spec )
664
666
665
667
668
+ ### Linear layer implementation backed by einsum.
669
+
670
+
671
+ # A custom forward op that uses einsum internally
672
+ @custom_op (
673
+ "xla::einsum_linear_forward" ,
674
+ schema = "(Tensor input, Tensor weight, Tensor? bias) -> Tensor" ,
675
+ mutates_args = ())
676
+ def _einsum_linear_forward (input : Tensor , weight : Tensor ,
677
+ bias : Optional [Tensor ]):
678
+ with xp .Trace ('einsum_linear_forward' ):
679
+ # TODO(https://github.com/pytorch/xla/issues/8713): torch.einsum is getting
680
+ # decomposed when inside a custom op. This C++ op is an escape hatch to call
681
+ # XLA einsum without going through torch.einsum. We should remove this
682
+ # _einsum escape hatch when the linked bug is fixed.
683
+ product = torch_xla ._XLAC ._xla_einsum ('...n,mn->...m' , (input , weight ))
684
+ if bias is not None :
685
+ return product + bias
686
+ return product
687
+
688
+
689
+ @_einsum_linear_forward .register_fake
690
+ def _einsum_linear_forward_fake (input : Tensor , weight : Tensor ,
691
+ bias : Optional [Tensor ]):
692
+ product = torch .einsum ('...n,mn->...m' , input , weight )
693
+ if bias is not None :
694
+ return product + bias
695
+ return product
696
+
697
+
698
+ @custom_op (
699
+ "xla::einsum_linear_backward" ,
700
+ schema = "(Tensor grad_output, Tensor input, Tensor weight, Tensor? bias, bool needs_input_grad_input, bool needs_input_grad_weight, bool needs_input_grad_bias) -> (Tensor, Tensor, Tensor)" ,
701
+ mutates_args = ())
702
+ def _einsum_linear_backward (grad_output : Tensor , input : Tensor , weight : Tensor ,
703
+ bias : Optional [Tensor ],
704
+ needs_input_grad_input : bool ,
705
+ needs_input_grad_weight : bool ,
706
+ needs_input_grad_bias : bool ):
707
+ with xp .Trace ('einsum_linear_backward' ):
708
+ grad_input = grad_weight = grad_bias = None
709
+
710
+ if needs_input_grad_input :
711
+ grad_input = torch_xla ._XLAC ._xla_einsum ('...m,mn->...n' ,
712
+ (grad_output , weight ))
713
+ else :
714
+ grad_input = None
715
+
716
+ if needs_input_grad_weight :
717
+ grad_weight = torch_xla ._XLAC ._xla_einsum ('...m,...n->mn' ,
718
+ (grad_output , input ))
719
+ else :
720
+ grad_weight = None
721
+
722
+ if bias is not None and needs_input_grad_bias :
723
+ grad_bias = torch_xla ._XLAC ._xla_einsum ('...m->m' , (grad_output ,))
724
+ else :
725
+ grad_bias = None
726
+
727
+ return grad_input , grad_weight , grad_bias
728
+
729
+
730
+ @_einsum_linear_backward .register_fake
731
+ def _einsum_linear_backward_fake (grad_output : Tensor , input : Tensor ,
732
+ weight : Tensor , bias : Optional [Tensor ],
733
+ needs_input_grad_input : bool ,
734
+ needs_input_grad_weight : bool ,
735
+ needs_input_grad_bias : bool ):
736
+ grad_input = grad_weight = grad_bias = None
737
+
738
+ if needs_input_grad_input :
739
+ grad_input = torch .einsum ('...m,mn->...n' , grad_output , weight )
740
+ else :
741
+ grad_input = None
742
+
743
+ if needs_input_grad_weight :
744
+ grad_weight = torch .einsum ('...m,...n->mn' , grad_output , input )
745
+ else :
746
+ grad_weight = None
747
+
748
+ if bias is not None and needs_input_grad_bias :
749
+ grad_bias = torch .einsum ('...m->m' , grad_output )
750
+ else :
751
+ grad_bias = None
752
+
753
+ return grad_input , grad_weight , grad_bias
754
+
755
+
756
+ # Now define the XLAPatchedLinear function that uses the custom ops
666
757
class XLAPatchedLinear (torch .autograd .Function ):
667
758
"""
668
759
A patched version of `torch.nn.functional.linear` that uses einsum instead
669
760
of torch.matmul which will flatten the tensors to 2D and collide the sharded
670
761
dimensions. The torch.matmul default behavior makes it very hard for XLA compiler
671
762
to propagate the sharding annotation.
672
763
673
- Autocast decorators @custom_fwd and @custom_bwd used as per autocast docs [1] to bring this class/layer within
764
+ Autocast decorators @custom_fwd and @custom_bwd used as per autocast docs [1] to bring this class/layer within
674
765
autocast context, when autocast is enabled.
675
766
torch.get_autocast_dtype() fetches datatype for ops run in autocast [2], with the specified device (here, 'xla').
676
-
767
+
677
768
References:
678
769
[1] https://pytorch.org/docs/stable/notes/amp_examples.html#functions-with-multiple-inputs-or-autocastable-ops
679
770
[2] https://github.com/pytorch/pytorch/blob/2cc01cc6d3ad2aff47e8460667ba654b2e4c9f21/torch/amp/autocast_mode.py#L500
@@ -683,35 +774,71 @@ class XLAPatchedLinear(torch.autograd.Function):
683
774
684
775
@staticmethod
685
776
@custom_fwd (device_type = 'xla' , cast_inputs = torch .get_autocast_dtype ('xla' ))
686
- def forward (ctx , input , weight , bias = None ):
687
- # bias is an optional argument
777
+ def forward (ctx ,
778
+ input : Tensor ,
779
+ weight : Tensor ,
780
+ bias : Optional [Tensor ] = None ):
688
781
ctx .save_for_backward (input , weight , bias )
689
- with torch .no_grad ():
690
- product = torch .einsum ('...n,mn->...m' , input , weight )
691
- if bias is None :
692
- return product
693
- return product + bias
782
+ # Call our custom forward op. By wrapping the einsum in custom ops,
783
+ # AOTAutograd won't decompose the einsum.
784
+ return torch .ops .xla .einsum_linear_forward (input , weight , bias )
694
785
695
786
@staticmethod
696
787
@custom_bwd (device_type = 'xla' )
697
- def backward (ctx , grad_output ):
788
+ def backward (ctx , grad_output : Tensor ):
698
789
input , weight , bias = ctx .saved_tensors
699
- grad_input = grad_weight = grad_bias = None
790
+ needs_input_grad_input = ctx .needs_input_grad [0 ]
791
+ needs_input_grad_weight = ctx .needs_input_grad [1 ]
792
+ needs_input_grad_bias = False
793
+ if bias is not None :
794
+ needs_input_grad_bias = ctx .needs_input_grad [2 ]
700
795
701
- if ctx .needs_input_grad [0 ]:
702
- grad_input = torch .einsum ('...m,mn->...n' , grad_output , weight )
703
- if ctx .needs_input_grad [1 ]:
704
- grad_weight = torch .einsum ('...m,...n->mn' , grad_output , input )
705
- if bias is not None and ctx .needs_input_grad [2 ]:
706
- grad_bias = torch .einsum ('...m->m' , grad_output )
707
-
708
- return grad_input , grad_weight , grad_bias
796
+ # Call our custom backward op with the boolean flags
797
+ grad_input , grad_weight , grad_bias = torch .ops .xla .einsum_linear_backward (
798
+ grad_output , input , weight , bias , needs_input_grad_input ,
799
+ needs_input_grad_weight , needs_input_grad_bias )
800
+ return grad_input , grad_weight , grad_bias , None
709
801
710
802
711
803
def xla_patched_nn_linear_forward (m , input ):
712
804
return XLAPatchedLinear .apply (input , m .weight , m .bias )
713
805
714
806
807
+ class EinsumLinear (torch .nn .Linear ):
808
+ """
809
+ A `torch.nn.Linear` subclass implemented with `einsum`.
810
+ """
811
+
812
+ def __init__ (self , * args , ** kwargs ):
813
+ super ().__init__ (* args , ** kwargs )
814
+
815
+ def forward (self , input ):
816
+ t = xla_patched_nn_linear_forward (self , input )
817
+ assert isinstance (t , torch .Tensor )
818
+ return t
819
+
820
+
821
+ def apply_xla_patch_to_nn_linear (module : torch .nn .Module ):
822
+ """
823
+ Recursively replace `nn.Linear` layers with `EinsumLinear` in the module.
824
+
825
+ Without this patch, an `nn.Linear` module in PyTorch/XLA will lower to reshapes
826
+ and transposes instead of einsum, thus compromising sharding propagation.
827
+ """
828
+ for name , child in module .named_children ():
829
+ if isinstance (child ,
830
+ torch .nn .Linear ) and not isinstance (child , EinsumLinear ):
831
+ einsum_linear = EinsumLinear (
832
+ child .in_features , child .out_features , bias = child .bias is not None )
833
+ einsum_linear .load_state_dict (
834
+ child .state_dict (), strict = True , assign = True )
835
+ setattr (module , name , einsum_linear )
836
+ else :
837
+ apply_xla_patch_to_nn_linear (child )
838
+
839
+ return module
840
+
841
+
715
842
def apply_backward_optimization_barrier (m : torch .nn .Module ):
716
843
"""
717
844
Register a full backward hook that apply an optimization barrier to the given module.
0 commit comments