Skip to content

Commit e4e01e4

Browse files
committed
Only pin all_reduce layout
1 parent 48185ab commit e4e01e4

File tree

1 file changed

+3
-3
lines changed

1 file changed

+3
-3
lines changed

torch_xla/core/xla_model.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -593,7 +593,7 @@ def all_reduce(reduce_type,
593593
return results[0] if isinstance(inputs, torch.Tensor) else results
594594

595595

596-
def all_gather(value, dim=0, groups=None, output=None, pin_layout=True):
596+
def all_gather(value, dim=0, groups=None, output=None, pin_layout=False):
597597
"""Performs an all-gather operation along a given dimension.
598598
599599
Args:
@@ -645,7 +645,7 @@ def all_to_all(value,
645645
concat_dimension,
646646
split_count,
647647
groups=None,
648-
pin_layout=True):
648+
pin_layout=False):
649649
"""Performs an XLA `AllToAll()` operation on the input tensor.
650650
651651
See: https://www.tensorflow.org/xla/operation_semantics#alltoall
@@ -709,7 +709,7 @@ def reduce_scatter(reduce_type,
709709
shard_count,
710710
groups=None,
711711
output=None,
712-
pin_layout=True):
712+
pin_layout=False):
713713
"""Performs a XLA `ReduceScatter()` operation on the input tensor.
714714
715715
See: https://www.tensorflow.org/xla/operation_semantics#reducescatter

0 commit comments

Comments
 (0)