Skip to content

Commit e2aace1

Browse files
committed
correct linter
1 parent 4180ef9 commit e2aace1

File tree

1 file changed

+2
-4
lines changed

1 file changed

+2
-4
lines changed

torch_xla/distributed/spmd/xla_sharding.py

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -704,14 +704,12 @@ def _einsum_linear_backward(grad_output: Tensor, input: Tensor, weight: Tensor,
704704
grad_input = grad_weight = grad_bias = None
705705

706706
if needs_input_grad_input:
707-
grad_input = torch.einsum('...m,mn->...n',
708-
(grad_output, weight))
707+
grad_input = torch.einsum('...m,mn->...n', (grad_output, weight))
709708
else:
710709
grad_input = None
711710

712711
if needs_input_grad_weight:
713-
grad_weight = torch.einsum('...m,...n->mn',
714-
(grad_output, input))
712+
grad_weight = torch.einsum('...m,...n->mn', (grad_output, input))
715713
else:
716714
grad_weight = None
717715

0 commit comments

Comments
 (0)