@@ -680,7 +680,7 @@ def _einsum_linear_forward(input: Tensor, weight: Tensor,
680
680
# decomposed when inside a custom op. This C++ op is an escape hatch to call
681
681
# XLA einsum without going through torch.einsum. We should remove this
682
682
# _einsum escape hatch when the linked bug is fixed.
683
- product = torch_xla . _XLAC . _xla_einsum ('...n,mn->...m' , (input , weight ))
683
+ product = torch . einsum ('...n,mn->...m' , (input , weight ))
684
684
if bias is not None :
685
685
return product + bias
686
686
return product
@@ -708,19 +708,17 @@ def _einsum_linear_backward(grad_output: Tensor, input: Tensor, weight: Tensor,
708
708
grad_input = grad_weight = grad_bias = None
709
709
710
710
if needs_input_grad_input :
711
- grad_input = torch_xla ._XLAC ._xla_einsum ('...m,mn->...n' ,
712
- (grad_output , weight ))
711
+ grad_input = torch .einsum ('...m,mn->...n' , (grad_output , weight ))
713
712
else :
714
713
grad_input = None
715
714
716
715
if needs_input_grad_weight :
717
- grad_weight = torch_xla ._XLAC ._xla_einsum ('...m,...n->mn' ,
718
- (grad_output , input ))
716
+ grad_weight = torch .einsum ('...m,...n->mn' , (grad_output , input ))
719
717
else :
720
718
grad_weight = None
721
719
722
720
if bias is not None and needs_input_grad_bias :
723
- grad_bias = torch_xla . _XLAC . _xla_einsum ('...m->m' , (grad_output ,))
721
+ grad_bias = torch . einsum ('...m->m' , (grad_output ,))
724
722
else :
725
723
grad_bias = None
726
724
@@ -765,8 +763,8 @@ class XLAPatchedLinear(torch.autograd.Function):
765
763
autocast context, when autocast is enabled.
766
764
torch.get_autocast_dtype() fetches datatype for ops run in autocast [2], with the specified device (here, 'xla').
767
765
768
- References:
769
- [1] https://pytorch.org/docs/stable/notes/amp_examples.html#functions-with-multiple-inputs-or-autocastable-ops
766
+ References:
767
+ [1] https://pytorch.org/docs/stable/notes/amp_examples.html#functions-with-multiple-inputs-or-autocastable-ops
770
768
[2] https://github.com/pytorch/pytorch/blob/2cc01cc6d3ad2aff47e8460667ba654b2e4c9f21/torch/amp/autocast_mode.py#L500
771
769
772
770
TODO (alanwaketan): Let's patch it on the dispatcher level.
@@ -1260,8 +1258,8 @@ class MarkShardingFunction(torch.autograd.Function):
1260
1258
Usage:
1261
1259
new_tensor = MarkShardingFunction.apply(tensor, mesh, ('axis_1', 'axis_2'))
1262
1260
1263
- This is required to guide GSPMD sharding propagation better during the
1264
- backward pass as during complicated workloads the compiler can introduce extra
1261
+ This is required to guide GSPMD sharding propagation better during the
1262
+ backward pass as during complicated workloads the compiler can introduce extra
1265
1263
collectives that can hurt performance.
1266
1264
"""
1267
1265
0 commit comments