@@ -676,11 +676,7 @@ def apply(self, t: torch.Tensor):
676
676
def _einsum_linear_forward (input : Tensor , weight : Tensor ,
677
677
bias : Optional [Tensor ]):
678
678
with xp .Trace ('einsum_linear_forward' ):
679
- # TODO(https://github.com/pytorch/xla/issues/8713): torch.einsum is getting
680
- # decomposed when inside a custom op. This C++ op is an escape hatch to call
681
- # XLA einsum without going through torch.einsum. We should remove this
682
- # _einsum escape hatch when the linked bug is fixed.
683
- product = torch_xla ._XLAC ._xla_einsum ('...n,mn->...m' , (input , weight ))
679
+ product = torch .einsum ('...n,mn->...m' , input , weight )
684
680
if bias is not None :
685
681
return product + bias
686
682
return product
@@ -695,6 +691,31 @@ def _einsum_linear_forward_fake(input: Tensor, weight: Tensor,
695
691
return product
696
692
697
693
694
+ def _einsum_linear_backward_operation (grad_output : Tensor , input : Tensor ,
695
+ weight : Tensor , bias : Optional [Tensor ],
696
+ needs_input_grad_input : bool ,
697
+ needs_input_grad_weight : bool ,
698
+ needs_input_grad_bias : bool ):
699
+ grad_input = grad_weight = grad_bias = None
700
+
701
+ if needs_input_grad_input :
702
+ grad_input = torch .einsum ('...m,mn->...n' , grad_output , weight ).clone ()
703
+ else :
704
+ grad_input = None
705
+
706
+ if needs_input_grad_weight :
707
+ grad_weight = torch .einsum ('...m,...n->mn' , grad_output , input ).clone ()
708
+ else :
709
+ grad_weight = None
710
+
711
+ if bias is not None and needs_input_grad_bias :
712
+ grad_bias = torch .einsum ('...m->m' , grad_output ).clone ()
713
+ else :
714
+ grad_bias = None
715
+
716
+ return grad_input , grad_weight , grad_bias
717
+
718
+
698
719
@custom_op (
699
720
"xla::einsum_linear_backward" ,
700
721
schema = "(Tensor grad_output, Tensor input, Tensor weight, Tensor? bias, bool needs_input_grad_input, bool needs_input_grad_weight, bool needs_input_grad_bias) -> (Tensor, Tensor, Tensor)" ,
@@ -705,26 +726,10 @@ def _einsum_linear_backward(grad_output: Tensor, input: Tensor, weight: Tensor,
705
726
needs_input_grad_weight : bool ,
706
727
needs_input_grad_bias : bool ):
707
728
with xp .Trace ('einsum_linear_backward' ):
708
- grad_input = grad_weight = grad_bias = None
709
-
710
- if needs_input_grad_input :
711
- grad_input = torch_xla ._XLAC ._xla_einsum ('...m,mn->...n' ,
712
- (grad_output , weight ))
713
- else :
714
- grad_input = None
715
-
716
- if needs_input_grad_weight :
717
- grad_weight = torch_xla ._XLAC ._xla_einsum ('...m,...n->mn' ,
718
- (grad_output , input ))
719
- else :
720
- grad_weight = None
721
-
722
- if bias is not None and needs_input_grad_bias :
723
- grad_bias = torch_xla ._XLAC ._xla_einsum ('...m->m' , (grad_output ,))
724
- else :
725
- grad_bias = None
726
-
727
- return grad_input , grad_weight , grad_bias
729
+ return _einsum_linear_backward_operation (grad_output , input , weight , bias ,
730
+ needs_input_grad_input ,
731
+ needs_input_grad_weight ,
732
+ needs_input_grad_bias )
728
733
729
734
730
735
@_einsum_linear_backward .register_fake
@@ -733,24 +738,11 @@ def _einsum_linear_backward_fake(grad_output: Tensor, input: Tensor,
733
738
needs_input_grad_input : bool ,
734
739
needs_input_grad_weight : bool ,
735
740
needs_input_grad_bias : bool ):
736
- grad_input = grad_weight = grad_bias = None
737
-
738
- if needs_input_grad_input :
739
- grad_input = torch .einsum ('...m,mn->...n' , grad_output , weight )
740
- else :
741
- grad_input = None
742
741
743
- if needs_input_grad_weight :
744
- grad_weight = torch .einsum ('...m,...n->mn' , grad_output , input )
745
- else :
746
- grad_weight = None
747
-
748
- if bias is not None and needs_input_grad_bias :
749
- grad_bias = torch .einsum ('...m->m' , grad_output )
750
- else :
751
- grad_bias = None
752
-
753
- return grad_input , grad_weight , grad_bias
742
+ return _einsum_linear_backward_operation (grad_output , input , weight , bias ,
743
+ needs_input_grad_input ,
744
+ needs_input_grad_weight ,
745
+ needs_input_grad_bias )
754
746
755
747
756
748
# Now define the XLAPatchedLinear function that uses the custom ops
@@ -765,8 +757,8 @@ class XLAPatchedLinear(torch.autograd.Function):
765
757
autocast context, when autocast is enabled.
766
758
torch.get_autocast_dtype() fetches datatype for ops run in autocast [2], with the specified device (here, 'xla').
767
759
768
- References:
769
- [1] https://pytorch.org/docs/stable/notes/amp_examples.html#functions-with-multiple-inputs-or-autocastable-ops
760
+ References:
761
+ [1] https://pytorch.org/docs/stable/notes/amp_examples.html#functions-with-multiple-inputs-or-autocastable-ops
770
762
[2] https://github.com/pytorch/pytorch/blob/2cc01cc6d3ad2aff47e8460667ba654b2e4c9f21/torch/amp/autocast_mode.py#L500
771
763
772
764
TODO (alanwaketan): Let's patch it on the dispatcher level.
@@ -1260,8 +1252,8 @@ class MarkShardingFunction(torch.autograd.Function):
1260
1252
Usage:
1261
1253
new_tensor = MarkShardingFunction.apply(tensor, mesh, ('axis_1', 'axis_2'))
1262
1254
1263
- This is required to guide GSPMD sharding propagation better during the
1264
- backward pass as during complicated workloads the compiler can introduce extra
1255
+ This is required to guide GSPMD sharding propagation better during the
1256
+ backward pass as during complicated workloads the compiler can introduce extra
1265
1257
collectives that can hurt performance.
1266
1258
"""
1267
1259
0 commit comments