Skip to content

Commit 48185ab

Browse files
committed
Make layout pining optional for cross core communication
1 parent 92d6ac2 commit 48185ab

14 files changed

+295
-174
lines changed

torch_xla/core/xla_model.py

Lines changed: 40 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -526,7 +526,12 @@ def _host_all_reduce(reduce_type, inputs, cctx, scale=None):
526526
REDUCE_SUM, inputs, token, 1.0, [])
527527

528528

529-
def all_reduce(reduce_type, inputs, scale=1.0, groups=None, cctx=None):
529+
def all_reduce(reduce_type,
530+
inputs,
531+
scale=1.0,
532+
groups=None,
533+
cctx=None,
534+
pin_layout=True):
530535
"""Performs an inplace reduce operation on the input tensor(s).
531536
532537
Args:
@@ -542,6 +547,11 @@ def all_reduce(reduce_type, inputs, scale=1.0, groups=None, cctx=None):
542547
defines two groups, one with the `[0, 1, 2, 3]` replicas and one with
543548
the `[4, 5, 6, 7]` replicas. If `None` there will be only one group with
544549
all the replicas in it.
550+
pin_layout (bool, optional): whether to pin the layout for this communication op.
551+
Layout pining can prevent potential data corruption when each process that
552+
participate in the communication has slightly different program, but it might
553+
cause some xla compiation to fail. Unpin the layout when you see error message
554+
like "HloModule has a mix of layout constrained".
545555
546556
Returns:
547557
If a single `torch.Tensor` is passed, the return value is a `torch.Tensor`
@@ -562,12 +572,13 @@ def all_reduce(reduce_type, inputs, scale=1.0, groups=None, cctx=None):
562572
token, devctx = _get_all_reduce_token()
563573
if isinstance(inputs, torch.Tensor):
564574
result = torch_xla._XLAC._xla_all_reduce(reduce_type, inputs, token,
565-
scale, cctx.intercore_group)
575+
scale, cctx.intercore_group,
576+
pin_layout)
566577
devctx.all_reduce_token = result[1]
567578
results = [result[0]]
568579
else:
569580
devctx.all_reduce_token = torch_xla._XLAC._xla_all_reduce_inplace(
570-
reduce_type, inputs, token, scale, cctx.intercore_group)
581+
reduce_type, inputs, token, scale, cctx.intercore_group, pin_layout)
571582
results = inputs
572583
else:
573584
if isinstance(inputs, torch.Tensor):
@@ -582,7 +593,7 @@ def all_reduce(reduce_type, inputs, scale=1.0, groups=None, cctx=None):
582593
return results[0] if isinstance(inputs, torch.Tensor) else results
583594

584595

585-
def all_gather(value, dim=0, groups=None, output=None):
596+
def all_gather(value, dim=0, groups=None, output=None, pin_layout=True):
586597
"""Performs an all-gather operation along a given dimension.
587598
588599
Args:
@@ -595,6 +606,11 @@ def all_gather(value, dim=0, groups=None, output=None):
595606
the `[4, 5, 6, 7]` replicas. If `None` there will be only one group with
596607
all the replicas in it.
597608
output (torch.Tensor): Optional output tensor.
609+
pin_layout (bool, optional): whether to pin the layout for this communication op.
610+
Layout pining can prevent potential data corruption when each process that
611+
participate in the communication has slightly different program, but it might
612+
cause some xla compiation to fail. Unpin the layout when you see error message
613+
like "HloModule has a mix of layout constrained".
598614
599615
Returns:
600616
A tensor which has, in the ``dim`` dimension, all the values from the
@@ -613,12 +629,13 @@ def all_gather(value, dim=0, groups=None, output=None):
613629
if output != None:
614630
# Call the out of place version of the all_gather
615631
new_token = torch_xla._XLAC._xla_all_gather_out(output, value, token, dim,
616-
shard_count, groups or [])
632+
shard_count, groups or [],
633+
pin_layout)
617634
devctx.all_reduce_token = new_token
618635
return output
619636

620637
result = torch_xla._XLAC._xla_all_gather(value, token, dim, shard_count,
621-
groups or [])
638+
groups or [], pin_layout)
622639
devctx.all_reduce_token = result[1]
623640
return result[0]
624641

@@ -627,7 +644,8 @@ def all_to_all(value,
627644
split_dimension,
628645
concat_dimension,
629646
split_count,
630-
groups=None):
647+
groups=None,
648+
pin_layout=True):
631649
"""Performs an XLA `AllToAll()` operation on the input tensor.
632650
633651
See: https://www.tensorflow.org/xla/operation_semantics#alltoall
@@ -642,14 +660,19 @@ def all_to_all(value,
642660
defines two groups, one with the `[0, 1, 2, 3]` replicas and one with
643661
the `[4, 5, 6, 7]` replicas. If `None` there will be only one group with
644662
all the replicas in it.
663+
pin_layout (bool, optional): whether to pin the layout for this communication op.
664+
Layout pining can prevent potential data corruption when each process that
665+
participate in the communication has slightly different program, but it might
666+
cause some xla compiation to fail. Unpin the layout when you see error message
667+
like "HloModule has a mix of layout constrained".
645668
646669
Returns:
647670
The result `torch.Tensor` of the `all_to_all()` operation.
648671
"""
649672
token, devctx = _get_all_reduce_token()
650673
result = torch_xla._XLAC._xla_all_to_all(value, token, split_dimension,
651674
concat_dimension, split_count,
652-
groups or [])
675+
groups or [], pin_layout)
653676
devctx.all_reduce_token = result[1]
654677
return result[0]
655678

@@ -685,7 +708,8 @@ def reduce_scatter(reduce_type,
685708
scatter_dim,
686709
shard_count,
687710
groups=None,
688-
output=None):
711+
output=None,
712+
pin_layout=True):
689713
"""Performs a XLA `ReduceScatter()` operation on the input tensor.
690714
691715
See: https://www.tensorflow.org/xla/operation_semantics#reducescatter
@@ -704,6 +728,11 @@ def reduce_scatter(reduce_type,
704728
the `[4, 5, 6, 7]` replicas. If `None` there will be only one group with
705729
all the replicas in it.
706730
output: Optional output tensor
731+
pin_layout (bool, optional): whether to pin the layout for this communication op.
732+
Layout pining can prevent potential data corruption when each process that
733+
participate in the communication has slightly different program, but it might
734+
cause some xla compiation to fail. Unpin the layout when you see error message
735+
like "HloModule has a mix of layout constrained".
707736
708737
Returns:
709738
A `torch.Tensor` with all the values reduced accross replicas. Each process
@@ -717,13 +746,13 @@ def reduce_scatter(reduce_type,
717746
input, token, scale,
718747
scatter_dim,
719748
shard_count, groups or
720-
[])
749+
[], pin_layout)
721750
devctx.all_reduce_token = new_token
722751
return output
723752

724753
result = torch_xla._XLAC._xla_reduce_scatter(reduce_type, input, token, scale,
725754
scatter_dim, shard_count,
726-
groups or [])
755+
groups or [], pin_layout)
727756
devctx.all_reduce_token = result[1]
728757
return result[0]
729758

torch_xla/csrc/cross_replica_reduces.cpp

Lines changed: 65 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -87,7 +87,7 @@ std::vector<xla::ReplicaGroup> CreateReduceGroups(
8787
std::vector<xla::XlaOp> BuildAllReduce(
8888
AllReduceType reduce_type, absl::Span<const xla::XlaOp> operands,
8989
xla::XlaOp token, double scale,
90-
const std::vector<std::vector<int64_t>>& groups) {
90+
const std::vector<std::vector<int64_t>>& groups, bool pin_layout) {
9191
std::vector<xla::ReplicaGroup> reduce_groups = CreateReduceGroups(groups);
9292
// TODO: We use pseudo-tokens ATM, which are real values. This need to be
9393
// switched to use the real XLA Token once support has been added to XLA
@@ -101,11 +101,19 @@ std::vector<xla::XlaOp> BuildAllReduce(
101101
type_ctx.second.operand_shapes.push_back(
102102
XlaHelpers::ShapeOfXlaOp(token_op));
103103

104-
xla::XlaOp reduce = xla::AllReduce(
105-
xla::Tuple(operands[0].builder(), type_ctx.second.ops),
106-
GetReduceComutation(reduce_type, type_ctx.first), reduce_groups,
107-
/*channel_id=*/absl::nullopt,
108-
MakeReduceShape(type_ctx.second.operand_shapes));
104+
xla::XlaOp reduce;
105+
if (pin_layout) {
106+
reduce = xla::AllReduce(
107+
xla::Tuple(operands[0].builder(), type_ctx.second.ops),
108+
GetReduceComutation(reduce_type, type_ctx.first), reduce_groups,
109+
/*channel_id=*/absl::nullopt,
110+
/*shape_with_layout=*/
111+
MakeReduceShape(type_ctx.second.operand_shapes));
112+
} else {
113+
reduce = xla::AllReduce(
114+
xla::Tuple(operands[0].builder(), type_ctx.second.ops),
115+
GetReduceComutation(reduce_type, type_ctx.first), reduce_groups);
116+
}
109117
for (size_t i = 0; i < type_ctx.second.indices.size(); ++i) {
110118
size_t op_idx = type_ctx.second.indices[i];
111119
xla::XlaOp gte = xla::GetTupleElement(reduce, i);
@@ -128,28 +136,49 @@ std::vector<xla::XlaOp> BuildAllReduce(
128136
AllToAllResult BuildAllToAll(xla::XlaOp input, xla::XlaOp token,
129137
int64_t split_dimension, int64_t concat_dimension,
130138
int64_t split_count,
131-
const std::vector<std::vector<int64_t>>& groups) {
139+
const std::vector<std::vector<int64_t>>& groups,
140+
bool pin_layout) {
132141
std::vector<xla::ReplicaGroup> reduce_groups = CreateReduceGroups(groups);
133142
const xla::Shape& input_shape = XlaHelpers::ShapeOfXlaOp(input);
134-
xla::Shape reduce_shape = MakeArrayShapeFromDimensions(
135-
input_shape.dimensions(), input_shape.dynamic_dimensions(),
136-
input_shape.element_type(), GetCurrentDevice().device_type.hw_type);
137143
TokenHandler token_handler(token);
138-
xla::XlaOp reduce_result = xla::AllToAll(
139-
token_handler.GetInput(input, &input_shape), split_dimension,
140-
concat_dimension, split_count, reduce_groups, reduce_shape.layout());
144+
xla::XlaOp reduce_result;
145+
if (pin_layout) {
146+
xla::Shape reduce_shape = MakeArrayShapeFromDimensions(
147+
input_shape.dimensions(), input_shape.dynamic_dimensions(),
148+
input_shape.element_type(), GetCurrentDevice().device_type.hw_type);
149+
reduce_result = xla::AllToAll(token_handler.GetInput(input, &input_shape),
150+
split_dimension, concat_dimension,
151+
split_count, reduce_groups,
152+
/*layout=*/reduce_shape.layout());
153+
} else {
154+
reduce_result = xla::AllToAll(token_handler.GetInput(input, &input_shape),
155+
split_dimension, concat_dimension,
156+
split_count, reduce_groups);
157+
}
141158
return {reduce_result, token_handler.GetNewToken(reduce_result)};
142159
}
143160

144-
AllGatherResult BuildAllGather(
145-
xla::XlaOp input, xla::XlaOp token, int64_t dim, int64_t shard_count,
146-
const std::vector<std::vector<int64_t>>& groups) {
161+
AllGatherResult BuildAllGather(xla::XlaOp input, xla::XlaOp token, int64_t dim,
162+
int64_t shard_count,
163+
const std::vector<std::vector<int64_t>>& groups,
164+
bool pin_layout) {
147165
std::vector<xla::ReplicaGroup> reduce_groups = CreateReduceGroups(groups);
148166
const xla::Shape& input_shape = XlaHelpers::ShapeOfXlaOp(input);
149167
TokenHandler token_handler(token);
150-
xla::XlaOp all_gather_result =
151-
xla::AllGather(token_handler.GetInput(input, &input_shape), dim,
152-
shard_count, reduce_groups);
168+
xla::XlaOp all_gather_result;
169+
if (pin_layout) {
170+
xla::Shape reduce_shape = MakeArrayShapeFromDimensions(
171+
input_shape.dimensions(), input_shape.dynamic_dimensions(),
172+
input_shape.element_type(), GetCurrentDevice().device_type.hw_type);
173+
all_gather_result =
174+
xla::AllGather(token_handler.GetInput(input, &input_shape), dim,
175+
shard_count, reduce_groups, /*channel_id=*/absl::nullopt,
176+
/*layout=*/reduce_shape.layout());
177+
} else {
178+
all_gather_result =
179+
xla::AllGather(token_handler.GetInput(input, &input_shape), dim,
180+
shard_count, reduce_groups);
181+
}
153182
return {all_gather_result, token_handler.GetNewToken(all_gather_result)};
154183
}
155184

@@ -169,15 +198,26 @@ CollectivePermuteResult BuildCollectivePermute(
169198
ReduceScatterResult BuildReduceScatter(
170199
AllReduceType reduce_type, xla::XlaOp input, xla::XlaOp token, double scale,
171200
int64_t scatter_dim, int64_t shard_count,
172-
const std::vector<std::vector<int64_t>>& groups) {
201+
const std::vector<std::vector<int64_t>>& groups, bool pin_layout) {
173202
std::vector<xla::ReplicaGroup> reduce_groups = CreateReduceGroups(groups);
174203
TokenHandler token_handler(token);
175204
const xla::Shape& input_shape = XlaHelpers::ShapeOfXlaOp(input);
176-
177-
xla::XlaOp reduce_result = xla::ReduceScatter(
178-
token_handler.GetInput(input, &input_shape),
179-
GetReduceComutation(reduce_type, input_shape.element_type()), scatter_dim,
180-
shard_count, reduce_groups);
205+
xla::XlaOp reduce_result;
206+
if (pin_layout) {
207+
xla::Shape reduce_shape = MakeArrayShapeFromDimensions(
208+
input_shape.dimensions(), input_shape.dynamic_dimensions(),
209+
input_shape.element_type(), GetCurrentDevice().device_type.hw_type);
210+
reduce_result = xla::ReduceScatter(
211+
token_handler.GetInput(input, &input_shape),
212+
GetReduceComutation(reduce_type, input_shape.element_type()),
213+
scatter_dim, shard_count, reduce_groups, /*channel_id=*/absl::nullopt,
214+
/*layout=*/reduce_shape.layout());
215+
} else {
216+
reduce_result = xla::ReduceScatter(
217+
token_handler.GetInput(input, &input_shape),
218+
GetReduceComutation(reduce_type, input_shape.element_type()),
219+
scatter_dim, shard_count, reduce_groups);
220+
}
181221

182222
if (scale != 1.0) {
183223
xla::XlaOp scaling_value = XlaHelpers::ScalarValue<float>(

torch_xla/csrc/cross_replica_reduces.h

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -39,16 +39,18 @@ struct ReduceScatterResult {
3939
std::vector<xla::XlaOp> BuildAllReduce(
4040
AllReduceType reduce_type, absl::Span<const xla::XlaOp> operands,
4141
xla::XlaOp token, double scale,
42-
const std::vector<std::vector<int64_t>>& groups);
42+
const std::vector<std::vector<int64_t>>& groups, bool pin_layout);
4343

4444
AllToAllResult BuildAllToAll(xla::XlaOp input, xla::XlaOp token,
4545
int64_t split_dimension, int64_t concat_dimension,
4646
int64_t split_count,
47-
const std::vector<std::vector<int64_t>>& groups);
47+
const std::vector<std::vector<int64_t>>& groups,
48+
bool pin_layout);
4849

4950
AllGatherResult BuildAllGather(xla::XlaOp input, xla::XlaOp token, int64_t dim,
5051
int64_t shard_count,
51-
const std::vector<std::vector<int64_t>>& groups);
52+
const std::vector<std::vector<int64_t>>& groups,
53+
bool pin_layout);
5254

5355
CollectivePermuteResult BuildCollectivePermute(
5456
xla::XlaOp input, xla::XlaOp token,
@@ -57,6 +59,6 @@ CollectivePermuteResult BuildCollectivePermute(
5759
ReduceScatterResult BuildReduceScatter(
5860
AllReduceType reduce_type, xla::XlaOp input, xla::XlaOp token, double scale,
5961
int64_t scatter_dim, int64_t shard_count,
60-
const std::vector<std::vector<int64_t>>& groups);
62+
const std::vector<std::vector<int64_t>>& groups, bool pin_layout);
6163

6264
} // namespace torch_xla

0 commit comments

Comments
 (0)