Skip to content

Commit 7d3dd38

Browse files
committed
Remove references to torch_xla._XLAC._xla_einsum from xla_sharding.py
1 parent e8939da commit 7d3dd38

File tree

1 file changed

+3
-3
lines changed

1 file changed

+3
-3
lines changed

torch_xla/distributed/spmd/xla_sharding.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -704,19 +704,19 @@ def _einsum_linear_backward(grad_output: Tensor, input: Tensor, weight: Tensor,
704704
grad_input = grad_weight = grad_bias = None
705705

706706
if needs_input_grad_input:
707-
grad_input = torch_xla._XLAC._xla_einsum('...m,mn->...n',
707+
grad_input = torch.einsum('...m,mn->...n',
708708
(grad_output, weight))
709709
else:
710710
grad_input = None
711711

712712
if needs_input_grad_weight:
713-
grad_weight = torch_xla._XLAC._xla_einsum('...m,...n->mn',
713+
grad_weight = torch.einsum('...m,...n->mn',
714714
(grad_output, input))
715715
else:
716716
grad_weight = None
717717

718718
if bias is not None and needs_input_grad_bias:
719-
grad_bias = torch_xla._XLAC._xla_einsum('...m->m', (grad_output,))
719+
grad_bias = torch.einsum('...m->m', (grad_output,))
720720
else:
721721
grad_bias = None
722722

0 commit comments

Comments
 (0)