Skip to content

Make layout pining optional for cross core communication #3511

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 2 commits into from
Apr 19, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
51 changes: 40 additions & 11 deletions torch_xla/core/xla_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -526,7 +526,12 @@ def _host_all_reduce(reduce_type, inputs, cctx, scale=None):
REDUCE_SUM, inputs, token, 1.0, [])


def all_reduce(reduce_type, inputs, scale=1.0, groups=None, cctx=None):
def all_reduce(reduce_type,
inputs,
scale=1.0,
groups=None,
cctx=None,
pin_layout=True):
"""Performs an inplace reduce operation on the input tensor(s).

Args:
Expand All @@ -542,6 +547,11 @@ def all_reduce(reduce_type, inputs, scale=1.0, groups=None, cctx=None):
defines two groups, one with the `[0, 1, 2, 3]` replicas and one with
the `[4, 5, 6, 7]` replicas. If `None` there will be only one group with
all the replicas in it.
pin_layout (bool, optional): whether to pin the layout for this communication op.
Layout pining can prevent potential data corruption when each process that
participate in the communication has slightly different program, but it might
cause some xla compiation to fail. Unpin the layout when you see error message
like "HloModule has a mix of layout constrained".

Returns:
If a single `torch.Tensor` is passed, the return value is a `torch.Tensor`
Expand All @@ -562,12 +572,13 @@ def all_reduce(reduce_type, inputs, scale=1.0, groups=None, cctx=None):
token, devctx = _get_all_reduce_token()
if isinstance(inputs, torch.Tensor):
result = torch_xla._XLAC._xla_all_reduce(reduce_type, inputs, token,
scale, cctx.intercore_group)
scale, cctx.intercore_group,
pin_layout)
devctx.all_reduce_token = result[1]
results = [result[0]]
else:
devctx.all_reduce_token = torch_xla._XLAC._xla_all_reduce_inplace(
reduce_type, inputs, token, scale, cctx.intercore_group)
reduce_type, inputs, token, scale, cctx.intercore_group, pin_layout)
results = inputs
else:
if isinstance(inputs, torch.Tensor):
Expand All @@ -582,7 +593,7 @@ def all_reduce(reduce_type, inputs, scale=1.0, groups=None, cctx=None):
return results[0] if isinstance(inputs, torch.Tensor) else results


def all_gather(value, dim=0, groups=None, output=None):
def all_gather(value, dim=0, groups=None, output=None, pin_layout=False):
"""Performs an all-gather operation along a given dimension.

Args:
Expand All @@ -595,6 +606,11 @@ def all_gather(value, dim=0, groups=None, output=None):
the `[4, 5, 6, 7]` replicas. If `None` there will be only one group with
all the replicas in it.
output (torch.Tensor): Optional output tensor.
pin_layout (bool, optional): whether to pin the layout for this communication op.
Layout pining can prevent potential data corruption when each process that
participate in the communication has slightly different program, but it might
cause some xla compiation to fail. Unpin the layout when you see error message
like "HloModule has a mix of layout constrained".

Returns:
A tensor which has, in the ``dim`` dimension, all the values from the
Expand All @@ -613,12 +629,13 @@ def all_gather(value, dim=0, groups=None, output=None):
if output != None:
# Call the out of place version of the all_gather
new_token = torch_xla._XLAC._xla_all_gather_out(output, value, token, dim,
shard_count, groups or [])
shard_count, groups or [],
pin_layout)
devctx.all_reduce_token = new_token
return output

result = torch_xla._XLAC._xla_all_gather(value, token, dim, shard_count,
groups or [])
groups or [], pin_layout)
devctx.all_reduce_token = result[1]
return result[0]

Expand All @@ -627,7 +644,8 @@ def all_to_all(value,
split_dimension,
concat_dimension,
split_count,
groups=None):
groups=None,
pin_layout=False):
"""Performs an XLA `AllToAll()` operation on the input tensor.

See: https://www.tensorflow.org/xla/operation_semantics#alltoall
Expand All @@ -642,14 +660,19 @@ def all_to_all(value,
defines two groups, one with the `[0, 1, 2, 3]` replicas and one with
the `[4, 5, 6, 7]` replicas. If `None` there will be only one group with
all the replicas in it.
pin_layout (bool, optional): whether to pin the layout for this communication op.
Layout pining can prevent potential data corruption when each process that
participate in the communication has slightly different program, but it might
cause some xla compiation to fail. Unpin the layout when you see error message
like "HloModule has a mix of layout constrained".

Returns:
The result `torch.Tensor` of the `all_to_all()` operation.
"""
token, devctx = _get_all_reduce_token()
result = torch_xla._XLAC._xla_all_to_all(value, token, split_dimension,
concat_dimension, split_count,
groups or [])
groups or [], pin_layout)
devctx.all_reduce_token = result[1]
return result[0]

Expand Down Expand Up @@ -685,7 +708,8 @@ def reduce_scatter(reduce_type,
scatter_dim,
shard_count,
groups=None,
output=None):
output=None,
pin_layout=False):
"""Performs a XLA `ReduceScatter()` operation on the input tensor.

See: https://www.tensorflow.org/xla/operation_semantics#reducescatter
Expand All @@ -704,6 +728,11 @@ def reduce_scatter(reduce_type,
the `[4, 5, 6, 7]` replicas. If `None` there will be only one group with
all the replicas in it.
output: Optional output tensor
pin_layout (bool, optional): whether to pin the layout for this communication op.
Layout pining can prevent potential data corruption when each process that
participate in the communication has slightly different program, but it might
cause some xla compiation to fail. Unpin the layout when you see error message
like "HloModule has a mix of layout constrained".

Returns:
A `torch.Tensor` with all the values reduced accross replicas. Each process
Expand All @@ -717,13 +746,13 @@ def reduce_scatter(reduce_type,
input, token, scale,
scatter_dim,
shard_count, groups or
[])
[], pin_layout)
devctx.all_reduce_token = new_token
return output

result = torch_xla._XLAC._xla_reduce_scatter(reduce_type, input, token, scale,
scatter_dim, shard_count,
groups or [])
groups or [], pin_layout)
devctx.all_reduce_token = result[1]
return result[0]

Expand Down
90 changes: 65 additions & 25 deletions torch_xla/csrc/cross_replica_reduces.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -87,7 +87,7 @@ std::vector<xla::ReplicaGroup> CreateReduceGroups(
std::vector<xla::XlaOp> BuildAllReduce(
AllReduceType reduce_type, absl::Span<const xla::XlaOp> operands,
xla::XlaOp token, double scale,
const std::vector<std::vector<int64_t>>& groups) {
const std::vector<std::vector<int64_t>>& groups, bool pin_layout) {
std::vector<xla::ReplicaGroup> reduce_groups = CreateReduceGroups(groups);
// TODO: We use pseudo-tokens ATM, which are real values. This need to be
// switched to use the real XLA Token once support has been added to XLA
Expand All @@ -101,11 +101,19 @@ std::vector<xla::XlaOp> BuildAllReduce(
type_ctx.second.operand_shapes.push_back(
XlaHelpers::ShapeOfXlaOp(token_op));

xla::XlaOp reduce = xla::AllReduce(
xla::Tuple(operands[0].builder(), type_ctx.second.ops),
GetReduceComutation(reduce_type, type_ctx.first), reduce_groups,
/*channel_id=*/absl::nullopt,
MakeReduceShape(type_ctx.second.operand_shapes));
xla::XlaOp reduce;
if (pin_layout) {
reduce = xla::AllReduce(
xla::Tuple(operands[0].builder(), type_ctx.second.ops),
GetReduceComutation(reduce_type, type_ctx.first), reduce_groups,
/*channel_id=*/absl::nullopt,
/*shape_with_layout=*/
MakeReduceShape(type_ctx.second.operand_shapes));
} else {
reduce = xla::AllReduce(
xla::Tuple(operands[0].builder(), type_ctx.second.ops),
GetReduceComutation(reduce_type, type_ctx.first), reduce_groups);
}
for (size_t i = 0; i < type_ctx.second.indices.size(); ++i) {
size_t op_idx = type_ctx.second.indices[i];
xla::XlaOp gte = xla::GetTupleElement(reduce, i);
Expand All @@ -128,28 +136,49 @@ std::vector<xla::XlaOp> BuildAllReduce(
AllToAllResult BuildAllToAll(xla::XlaOp input, xla::XlaOp token,
int64_t split_dimension, int64_t concat_dimension,
int64_t split_count,
const std::vector<std::vector<int64_t>>& groups) {
const std::vector<std::vector<int64_t>>& groups,
bool pin_layout) {
std::vector<xla::ReplicaGroup> reduce_groups = CreateReduceGroups(groups);
const xla::Shape& input_shape = XlaHelpers::ShapeOfXlaOp(input);
xla::Shape reduce_shape = MakeArrayShapeFromDimensions(
input_shape.dimensions(), input_shape.dynamic_dimensions(),
input_shape.element_type(), GetCurrentDevice().device_type.hw_type);
TokenHandler token_handler(token);
xla::XlaOp reduce_result = xla::AllToAll(
token_handler.GetInput(input, &input_shape), split_dimension,
concat_dimension, split_count, reduce_groups, reduce_shape.layout());
xla::XlaOp reduce_result;
if (pin_layout) {
xla::Shape reduce_shape = MakeArrayShapeFromDimensions(
input_shape.dimensions(), input_shape.dynamic_dimensions(),
input_shape.element_type(), GetCurrentDevice().device_type.hw_type);
reduce_result = xla::AllToAll(token_handler.GetInput(input, &input_shape),
split_dimension, concat_dimension,
split_count, reduce_groups,
/*layout=*/reduce_shape.layout());
} else {
reduce_result = xla::AllToAll(token_handler.GetInput(input, &input_shape),
split_dimension, concat_dimension,
split_count, reduce_groups);
}
return {reduce_result, token_handler.GetNewToken(reduce_result)};
}

AllGatherResult BuildAllGather(
xla::XlaOp input, xla::XlaOp token, int64_t dim, int64_t shard_count,
const std::vector<std::vector<int64_t>>& groups) {
AllGatherResult BuildAllGather(xla::XlaOp input, xla::XlaOp token, int64_t dim,
int64_t shard_count,
const std::vector<std::vector<int64_t>>& groups,
bool pin_layout) {
std::vector<xla::ReplicaGroup> reduce_groups = CreateReduceGroups(groups);
const xla::Shape& input_shape = XlaHelpers::ShapeOfXlaOp(input);
TokenHandler token_handler(token);
xla::XlaOp all_gather_result =
xla::AllGather(token_handler.GetInput(input, &input_shape), dim,
shard_count, reduce_groups);
xla::XlaOp all_gather_result;
if (pin_layout) {
xla::Shape reduce_shape = MakeArrayShapeFromDimensions(
input_shape.dimensions(), input_shape.dynamic_dimensions(),
input_shape.element_type(), GetCurrentDevice().device_type.hw_type);
all_gather_result =
xla::AllGather(token_handler.GetInput(input, &input_shape), dim,
shard_count, reduce_groups, /*channel_id=*/absl::nullopt,
/*layout=*/reduce_shape.layout());
} else {
all_gather_result =
xla::AllGather(token_handler.GetInput(input, &input_shape), dim,
shard_count, reduce_groups);
}
return {all_gather_result, token_handler.GetNewToken(all_gather_result)};
}

Expand All @@ -169,15 +198,26 @@ CollectivePermuteResult BuildCollectivePermute(
ReduceScatterResult BuildReduceScatter(
AllReduceType reduce_type, xla::XlaOp input, xla::XlaOp token, double scale,
int64_t scatter_dim, int64_t shard_count,
const std::vector<std::vector<int64_t>>& groups) {
const std::vector<std::vector<int64_t>>& groups, bool pin_layout) {
std::vector<xla::ReplicaGroup> reduce_groups = CreateReduceGroups(groups);
TokenHandler token_handler(token);
const xla::Shape& input_shape = XlaHelpers::ShapeOfXlaOp(input);

xla::XlaOp reduce_result = xla::ReduceScatter(
token_handler.GetInput(input, &input_shape),
GetReduceComutation(reduce_type, input_shape.element_type()), scatter_dim,
shard_count, reduce_groups);
xla::XlaOp reduce_result;
if (pin_layout) {
xla::Shape reduce_shape = MakeArrayShapeFromDimensions(
input_shape.dimensions(), input_shape.dynamic_dimensions(),
input_shape.element_type(), GetCurrentDevice().device_type.hw_type);
reduce_result = xla::ReduceScatter(
token_handler.GetInput(input, &input_shape),
GetReduceComutation(reduce_type, input_shape.element_type()),
scatter_dim, shard_count, reduce_groups, /*channel_id=*/absl::nullopt,
/*layout=*/reduce_shape.layout());
} else {
reduce_result = xla::ReduceScatter(
token_handler.GetInput(input, &input_shape),
GetReduceComutation(reduce_type, input_shape.element_type()),
scatter_dim, shard_count, reduce_groups);
}

if (scale != 1.0) {
xla::XlaOp scaling_value = XlaHelpers::ScalarValue<float>(
Expand Down
10 changes: 6 additions & 4 deletions torch_xla/csrc/cross_replica_reduces.h
Original file line number Diff line number Diff line change
Expand Up @@ -39,16 +39,18 @@ struct ReduceScatterResult {
std::vector<xla::XlaOp> BuildAllReduce(
AllReduceType reduce_type, absl::Span<const xla::XlaOp> operands,
xla::XlaOp token, double scale,
const std::vector<std::vector<int64_t>>& groups);
const std::vector<std::vector<int64_t>>& groups, bool pin_layout);

AllToAllResult BuildAllToAll(xla::XlaOp input, xla::XlaOp token,
int64_t split_dimension, int64_t concat_dimension,
int64_t split_count,
const std::vector<std::vector<int64_t>>& groups);
const std::vector<std::vector<int64_t>>& groups,
bool pin_layout);

AllGatherResult BuildAllGather(xla::XlaOp input, xla::XlaOp token, int64_t dim,
int64_t shard_count,
const std::vector<std::vector<int64_t>>& groups);
const std::vector<std::vector<int64_t>>& groups,
bool pin_layout);

CollectivePermuteResult BuildCollectivePermute(
xla::XlaOp input, xla::XlaOp token,
Expand All @@ -57,6 +59,6 @@ CollectivePermuteResult BuildCollectivePermute(
ReduceScatterResult BuildReduceScatter(
AllReduceType reduce_type, xla::XlaOp input, xla::XlaOp token, double scale,
int64_t scatter_dim, int64_t shard_count,
const std::vector<std::vector<int64_t>>& groups);
const std::vector<std::vector<int64_t>>& groups, bool pin_layout);

} // namespace torch_xla
Loading