From 48185abcac50503fd2a0aa81dbc666e20de1db03 Mon Sep 17 00:00:00 2001 From: JackCaoG Date: Mon, 18 Apr 2022 23:57:22 +0000 Subject: [PATCH 1/2] Make layout pining optional for cross core communication --- torch_xla/core/xla_model.py | 51 ++++++-- torch_xla/csrc/cross_replica_reduces.cpp | 90 +++++++++++---- torch_xla/csrc/cross_replica_reduces.h | 10 +- torch_xla/csrc/init_python_bindings.cpp | 141 ++++++++++++----------- torch_xla/csrc/ops/all_gather.cpp | 24 ++-- torch_xla/csrc/ops/all_gather.h | 6 +- torch_xla/csrc/ops/all_reduce.cpp | 17 +-- torch_xla/csrc/ops/all_reduce.h | 5 +- torch_xla/csrc/ops/all_to_all.cpp | 25 ++-- torch_xla/csrc/ops/all_to_all.h | 5 +- torch_xla/csrc/ops/reduce_scatter.cpp | 29 +++-- torch_xla/csrc/ops/reduce_scatter.h | 6 +- torch_xla/csrc/tensor.h | 21 ++-- torch_xla/csrc/tensor_methods.cpp | 39 ++++--- 14 files changed, 295 insertions(+), 174 deletions(-) diff --git a/torch_xla/core/xla_model.py b/torch_xla/core/xla_model.py index 04f025c42bb9..59f6b1b84270 100644 --- a/torch_xla/core/xla_model.py +++ b/torch_xla/core/xla_model.py @@ -526,7 +526,12 @@ def _host_all_reduce(reduce_type, inputs, cctx, scale=None): REDUCE_SUM, inputs, token, 1.0, []) -def all_reduce(reduce_type, inputs, scale=1.0, groups=None, cctx=None): +def all_reduce(reduce_type, + inputs, + scale=1.0, + groups=None, + cctx=None, + pin_layout=True): """Performs an inplace reduce operation on the input tensor(s). Args: @@ -542,6 +547,11 @@ def all_reduce(reduce_type, inputs, scale=1.0, groups=None, cctx=None): defines two groups, one with the `[0, 1, 2, 3]` replicas and one with the `[4, 5, 6, 7]` replicas. If `None` there will be only one group with all the replicas in it. + pin_layout (bool, optional): whether to pin the layout for this communication op. + Layout pining can prevent potential data corruption when each process that + participate in the communication has slightly different program, but it might + cause some xla compiation to fail. Unpin the layout when you see error message + like "HloModule has a mix of layout constrained". Returns: If a single `torch.Tensor` is passed, the return value is a `torch.Tensor` @@ -562,12 +572,13 @@ def all_reduce(reduce_type, inputs, scale=1.0, groups=None, cctx=None): token, devctx = _get_all_reduce_token() if isinstance(inputs, torch.Tensor): result = torch_xla._XLAC._xla_all_reduce(reduce_type, inputs, token, - scale, cctx.intercore_group) + scale, cctx.intercore_group, + pin_layout) devctx.all_reduce_token = result[1] results = [result[0]] else: devctx.all_reduce_token = torch_xla._XLAC._xla_all_reduce_inplace( - reduce_type, inputs, token, scale, cctx.intercore_group) + reduce_type, inputs, token, scale, cctx.intercore_group, pin_layout) results = inputs else: if isinstance(inputs, torch.Tensor): @@ -582,7 +593,7 @@ def all_reduce(reduce_type, inputs, scale=1.0, groups=None, cctx=None): return results[0] if isinstance(inputs, torch.Tensor) else results -def all_gather(value, dim=0, groups=None, output=None): +def all_gather(value, dim=0, groups=None, output=None, pin_layout=True): """Performs an all-gather operation along a given dimension. Args: @@ -595,6 +606,11 @@ def all_gather(value, dim=0, groups=None, output=None): the `[4, 5, 6, 7]` replicas. If `None` there will be only one group with all the replicas in it. output (torch.Tensor): Optional output tensor. + pin_layout (bool, optional): whether to pin the layout for this communication op. + Layout pining can prevent potential data corruption when each process that + participate in the communication has slightly different program, but it might + cause some xla compiation to fail. Unpin the layout when you see error message + like "HloModule has a mix of layout constrained". Returns: A tensor which has, in the ``dim`` dimension, all the values from the @@ -613,12 +629,13 @@ def all_gather(value, dim=0, groups=None, output=None): if output != None: # Call the out of place version of the all_gather new_token = torch_xla._XLAC._xla_all_gather_out(output, value, token, dim, - shard_count, groups or []) + shard_count, groups or [], + pin_layout) devctx.all_reduce_token = new_token return output result = torch_xla._XLAC._xla_all_gather(value, token, dim, shard_count, - groups or []) + groups or [], pin_layout) devctx.all_reduce_token = result[1] return result[0] @@ -627,7 +644,8 @@ def all_to_all(value, split_dimension, concat_dimension, split_count, - groups=None): + groups=None, + pin_layout=True): """Performs an XLA `AllToAll()` operation on the input tensor. See: https://www.tensorflow.org/xla/operation_semantics#alltoall @@ -642,6 +660,11 @@ def all_to_all(value, defines two groups, one with the `[0, 1, 2, 3]` replicas and one with the `[4, 5, 6, 7]` replicas. If `None` there will be only one group with all the replicas in it. + pin_layout (bool, optional): whether to pin the layout for this communication op. + Layout pining can prevent potential data corruption when each process that + participate in the communication has slightly different program, but it might + cause some xla compiation to fail. Unpin the layout when you see error message + like "HloModule has a mix of layout constrained". Returns: The result `torch.Tensor` of the `all_to_all()` operation. @@ -649,7 +672,7 @@ def all_to_all(value, token, devctx = _get_all_reduce_token() result = torch_xla._XLAC._xla_all_to_all(value, token, split_dimension, concat_dimension, split_count, - groups or []) + groups or [], pin_layout) devctx.all_reduce_token = result[1] return result[0] @@ -685,7 +708,8 @@ def reduce_scatter(reduce_type, scatter_dim, shard_count, groups=None, - output=None): + output=None, + pin_layout=True): """Performs a XLA `ReduceScatter()` operation on the input tensor. See: https://www.tensorflow.org/xla/operation_semantics#reducescatter @@ -704,6 +728,11 @@ def reduce_scatter(reduce_type, the `[4, 5, 6, 7]` replicas. If `None` there will be only one group with all the replicas in it. output: Optional output tensor + pin_layout (bool, optional): whether to pin the layout for this communication op. + Layout pining can prevent potential data corruption when each process that + participate in the communication has slightly different program, but it might + cause some xla compiation to fail. Unpin the layout when you see error message + like "HloModule has a mix of layout constrained". Returns: A `torch.Tensor` with all the values reduced accross replicas. Each process @@ -717,13 +746,13 @@ def reduce_scatter(reduce_type, input, token, scale, scatter_dim, shard_count, groups or - []) + [], pin_layout) devctx.all_reduce_token = new_token return output result = torch_xla._XLAC._xla_reduce_scatter(reduce_type, input, token, scale, scatter_dim, shard_count, - groups or []) + groups or [], pin_layout) devctx.all_reduce_token = result[1] return result[0] diff --git a/torch_xla/csrc/cross_replica_reduces.cpp b/torch_xla/csrc/cross_replica_reduces.cpp index 68f939f1b761..61a9a0c494be 100644 --- a/torch_xla/csrc/cross_replica_reduces.cpp +++ b/torch_xla/csrc/cross_replica_reduces.cpp @@ -87,7 +87,7 @@ std::vector CreateReduceGroups( std::vector BuildAllReduce( AllReduceType reduce_type, absl::Span operands, xla::XlaOp token, double scale, - const std::vector>& groups) { + const std::vector>& groups, bool pin_layout) { std::vector reduce_groups = CreateReduceGroups(groups); // TODO: We use pseudo-tokens ATM, which are real values. This need to be // switched to use the real XLA Token once support has been added to XLA @@ -101,11 +101,19 @@ std::vector BuildAllReduce( type_ctx.second.operand_shapes.push_back( XlaHelpers::ShapeOfXlaOp(token_op)); - xla::XlaOp reduce = xla::AllReduce( - xla::Tuple(operands[0].builder(), type_ctx.second.ops), - GetReduceComutation(reduce_type, type_ctx.first), reduce_groups, - /*channel_id=*/absl::nullopt, - MakeReduceShape(type_ctx.second.operand_shapes)); + xla::XlaOp reduce; + if (pin_layout) { + reduce = xla::AllReduce( + xla::Tuple(operands[0].builder(), type_ctx.second.ops), + GetReduceComutation(reduce_type, type_ctx.first), reduce_groups, + /*channel_id=*/absl::nullopt, + /*shape_with_layout=*/ + MakeReduceShape(type_ctx.second.operand_shapes)); + } else { + reduce = xla::AllReduce( + xla::Tuple(operands[0].builder(), type_ctx.second.ops), + GetReduceComutation(reduce_type, type_ctx.first), reduce_groups); + } for (size_t i = 0; i < type_ctx.second.indices.size(); ++i) { size_t op_idx = type_ctx.second.indices[i]; xla::XlaOp gte = xla::GetTupleElement(reduce, i); @@ -128,28 +136,49 @@ std::vector BuildAllReduce( AllToAllResult BuildAllToAll(xla::XlaOp input, xla::XlaOp token, int64_t split_dimension, int64_t concat_dimension, int64_t split_count, - const std::vector>& groups) { + const std::vector>& groups, + bool pin_layout) { std::vector reduce_groups = CreateReduceGroups(groups); const xla::Shape& input_shape = XlaHelpers::ShapeOfXlaOp(input); - xla::Shape reduce_shape = MakeArrayShapeFromDimensions( - input_shape.dimensions(), input_shape.dynamic_dimensions(), - input_shape.element_type(), GetCurrentDevice().device_type.hw_type); TokenHandler token_handler(token); - xla::XlaOp reduce_result = xla::AllToAll( - token_handler.GetInput(input, &input_shape), split_dimension, - concat_dimension, split_count, reduce_groups, reduce_shape.layout()); + xla::XlaOp reduce_result; + if (pin_layout) { + xla::Shape reduce_shape = MakeArrayShapeFromDimensions( + input_shape.dimensions(), input_shape.dynamic_dimensions(), + input_shape.element_type(), GetCurrentDevice().device_type.hw_type); + reduce_result = xla::AllToAll(token_handler.GetInput(input, &input_shape), + split_dimension, concat_dimension, + split_count, reduce_groups, + /*layout=*/reduce_shape.layout()); + } else { + reduce_result = xla::AllToAll(token_handler.GetInput(input, &input_shape), + split_dimension, concat_dimension, + split_count, reduce_groups); + } return {reduce_result, token_handler.GetNewToken(reduce_result)}; } -AllGatherResult BuildAllGather( - xla::XlaOp input, xla::XlaOp token, int64_t dim, int64_t shard_count, - const std::vector>& groups) { +AllGatherResult BuildAllGather(xla::XlaOp input, xla::XlaOp token, int64_t dim, + int64_t shard_count, + const std::vector>& groups, + bool pin_layout) { std::vector reduce_groups = CreateReduceGroups(groups); const xla::Shape& input_shape = XlaHelpers::ShapeOfXlaOp(input); TokenHandler token_handler(token); - xla::XlaOp all_gather_result = - xla::AllGather(token_handler.GetInput(input, &input_shape), dim, - shard_count, reduce_groups); + xla::XlaOp all_gather_result; + if (pin_layout) { + xla::Shape reduce_shape = MakeArrayShapeFromDimensions( + input_shape.dimensions(), input_shape.dynamic_dimensions(), + input_shape.element_type(), GetCurrentDevice().device_type.hw_type); + all_gather_result = + xla::AllGather(token_handler.GetInput(input, &input_shape), dim, + shard_count, reduce_groups, /*channel_id=*/absl::nullopt, + /*layout=*/reduce_shape.layout()); + } else { + all_gather_result = + xla::AllGather(token_handler.GetInput(input, &input_shape), dim, + shard_count, reduce_groups); + } return {all_gather_result, token_handler.GetNewToken(all_gather_result)}; } @@ -169,15 +198,26 @@ CollectivePermuteResult BuildCollectivePermute( ReduceScatterResult BuildReduceScatter( AllReduceType reduce_type, xla::XlaOp input, xla::XlaOp token, double scale, int64_t scatter_dim, int64_t shard_count, - const std::vector>& groups) { + const std::vector>& groups, bool pin_layout) { std::vector reduce_groups = CreateReduceGroups(groups); TokenHandler token_handler(token); const xla::Shape& input_shape = XlaHelpers::ShapeOfXlaOp(input); - - xla::XlaOp reduce_result = xla::ReduceScatter( - token_handler.GetInput(input, &input_shape), - GetReduceComutation(reduce_type, input_shape.element_type()), scatter_dim, - shard_count, reduce_groups); + xla::XlaOp reduce_result; + if (pin_layout) { + xla::Shape reduce_shape = MakeArrayShapeFromDimensions( + input_shape.dimensions(), input_shape.dynamic_dimensions(), + input_shape.element_type(), GetCurrentDevice().device_type.hw_type); + reduce_result = xla::ReduceScatter( + token_handler.GetInput(input, &input_shape), + GetReduceComutation(reduce_type, input_shape.element_type()), + scatter_dim, shard_count, reduce_groups, /*channel_id=*/absl::nullopt, + /*layout=*/reduce_shape.layout()); + } else { + reduce_result = xla::ReduceScatter( + token_handler.GetInput(input, &input_shape), + GetReduceComutation(reduce_type, input_shape.element_type()), + scatter_dim, shard_count, reduce_groups); + } if (scale != 1.0) { xla::XlaOp scaling_value = XlaHelpers::ScalarValue( diff --git a/torch_xla/csrc/cross_replica_reduces.h b/torch_xla/csrc/cross_replica_reduces.h index 2e47051d5033..aa132a745598 100644 --- a/torch_xla/csrc/cross_replica_reduces.h +++ b/torch_xla/csrc/cross_replica_reduces.h @@ -39,16 +39,18 @@ struct ReduceScatterResult { std::vector BuildAllReduce( AllReduceType reduce_type, absl::Span operands, xla::XlaOp token, double scale, - const std::vector>& groups); + const std::vector>& groups, bool pin_layout); AllToAllResult BuildAllToAll(xla::XlaOp input, xla::XlaOp token, int64_t split_dimension, int64_t concat_dimension, int64_t split_count, - const std::vector>& groups); + const std::vector>& groups, + bool pin_layout); AllGatherResult BuildAllGather(xla::XlaOp input, xla::XlaOp token, int64_t dim, int64_t shard_count, - const std::vector>& groups); + const std::vector>& groups, + bool pin_layout); CollectivePermuteResult BuildCollectivePermute( xla::XlaOp input, xla::XlaOp token, @@ -57,6 +59,6 @@ CollectivePermuteResult BuildCollectivePermute( ReduceScatterResult BuildReduceScatter( AllReduceType reduce_type, xla::XlaOp input, xla::XlaOp token, double scale, int64_t scatter_dim, int64_t shard_count, - const std::vector>& groups); + const std::vector>& groups, bool pin_layout); } // namespace torch_xla diff --git a/torch_xla/csrc/init_python_bindings.cpp b/torch_xla/csrc/init_python_bindings.cpp index ecec757f17e9..85ede1b42435 100644 --- a/torch_xla/csrc/init_python_bindings.cpp +++ b/torch_xla/csrc/init_python_bindings.cpp @@ -178,21 +178,22 @@ std::vector> CreateSourceTargetPairs( std::shared_ptr AllReduceInPlace( const std::string& reduce_type, const std::vector& tensors, const std::shared_ptr& token, double scale, - const std::vector>& replica_groups) { + const std::vector>& replica_groups, bool pin_layout) { std::vector xtensors = GetXlaTensors(tensors, /*want_all=*/true); - return std::make_shared(XLATensor::all_reduce( - &xtensors, *token, GetReduceType(reduce_type), scale, replica_groups)); + return std::make_shared( + XLATensor::all_reduce(&xtensors, *token, GetReduceType(reduce_type), + scale, replica_groups, pin_layout)); } std::pair> AllReduce( const std::string& reduce_type, const at::Tensor& input, const std::shared_ptr& token, double scale, - const std::vector>& replica_groups) { + const std::vector>& replica_groups, bool pin_layout) { XLATensor result; ir::Value new_token; - std::tie(result, new_token) = - XLATensor::all_reduce(bridge::GetXlaTensor(input), *token, - GetReduceType(reduce_type), scale, replica_groups); + std::tie(result, new_token) = XLATensor::all_reduce( + bridge::GetXlaTensor(input), *token, GetReduceType(reduce_type), scale, + replica_groups, pin_layout); return std::pair>( bridge::AtenFromXlaTensor(std::move(result)), std::make_shared(new_token)); @@ -202,12 +203,12 @@ std::pair> ReduceScatter( const std::string& reduce_type, const at::Tensor& input, const std::shared_ptr& token, double scale, int64_t scatter_dim, int64_t shard_count, - const std::vector>& replica_groups) { + const std::vector>& replica_groups, bool pin_layout) { XLATensor result; ir::Value new_token; std::tie(result, new_token) = XLATensor::reduce_scatter( bridge::GetXlaTensor(input), *token, GetReduceType(reduce_type), scale, - scatter_dim, shard_count, replica_groups); + scatter_dim, shard_count, replica_groups, pin_layout); return std::pair>( bridge::AtenFromXlaTensor(std::move(result)), std::make_shared(new_token)); @@ -217,23 +218,24 @@ std::shared_ptr ReduceScatterOut( const std::string& reduce_type, at::Tensor& output, const at::Tensor& input, const std::shared_ptr& token, double scale, int64_t scatter_dim, int64_t shard_count, - const std::vector>& replica_groups) { + const std::vector>& replica_groups, bool pin_layout) { XLATensor out = bridge::GetXlaTensor(output); ir::Value new_token; new_token = XLATensor::reduce_scatter_out( out, bridge::GetXlaTensor(input), *token, GetReduceType(reduce_type), - scale, scatter_dim, shard_count, replica_groups); + scale, scatter_dim, shard_count, replica_groups, pin_layout); return std::make_shared(new_token); } std::pair> AllGather( const at::Tensor& input, const std::shared_ptr& token, int64_t dim, int64_t shard_count, - const std::vector>& replica_groups) { + const std::vector>& replica_groups, bool pin_layout) { XLATensor result; ir::Value new_token; - std::tie(result, new_token) = XLATensor::all_gather( - bridge::GetXlaTensor(input), *token, dim, shard_count, replica_groups); + std::tie(result, new_token) = + XLATensor::all_gather(bridge::GetXlaTensor(input), *token, dim, + shard_count, replica_groups, pin_layout); return {bridge::AtenFromXlaTensor(std::move(result)), std::make_shared(new_token)}; } @@ -241,24 +243,24 @@ std::pair> AllGather( std::shared_ptr AllGatherOut( at::Tensor& output, const at::Tensor& input, const std::shared_ptr& token, int64_t dim, int64_t shard_count, - const std::vector>& replica_groups) { + const std::vector>& replica_groups, bool pin_layout) { XLATensor out = bridge::GetXlaTensor(output); ir::Value new_token; new_token = XLATensor::all_gather_out(out, bridge::GetXlaTensor(input), *token, dim, - shard_count, replica_groups); + shard_count, replica_groups, pin_layout); return std::make_shared(new_token); } std::pair> AllToAll( const at::Tensor& input, const std::shared_ptr& token, int64_t split_dimension, int64_t concat_dimension, int64_t split_count, - const std::vector>& replica_groups) { + const std::vector>& replica_groups, bool pin_layout) { XLATensor result; ir::Value new_token; std::tie(result, new_token) = XLATensor::all_to_all( bridge::GetXlaTensor(input), *token, split_dimension, concat_dimension, - split_count, replica_groups); + split_count, replica_groups, pin_layout); return std::pair>( bridge::AtenFromXlaTensor(std::move(result)), std::make_shared(new_token)); @@ -892,32 +894,33 @@ void InitXlaModuleBindings(py::module m) { py::class_>(m, "IrValue"); m.def("_xla_create_token", [](const std::string& device) { return CreateToken(device); }); - m.def("_xla_all_reduce_inplace", [](const std::string& reduce_type, - const std::vector& tensors, - const std::shared_ptr& token, - double scale, const py::list& groups) { - std::vector> replica_groups = - CreateReduceGroups(groups); - std::shared_ptr new_token; - { - NoGilSection nogil; - new_token = - AllReduceInPlace(reduce_type, tensors, token, scale, replica_groups); - } - return new_token; - }); + m.def( + "_xla_all_reduce_inplace", + [](const std::string& reduce_type, const std::vector& tensors, + const std::shared_ptr& token, double scale, + const py::list& groups, bool pin_layout) { + std::vector> replica_groups = + CreateReduceGroups(groups); + std::shared_ptr new_token; + { + NoGilSection nogil; + new_token = AllReduceInPlace(reduce_type, tensors, token, scale, + replica_groups, pin_layout); + } + return new_token; + }); m.def("_xla_all_reduce", [](const std::string& reduce_type, const at::Tensor& input, const std::shared_ptr& token, double scale, - const py::list& groups) { + const py::list& groups, bool pin_layout) { std::vector> replica_groups = CreateReduceGroups(groups); at::Tensor result; std::shared_ptr new_token; { NoGilSection nogil; - std::tie(result, new_token) = - AllReduce(reduce_type, input, token, scale, replica_groups); + std::tie(result, new_token) = AllReduce( + reduce_type, input, token, scale, replica_groups, pin_layout); } auto result_tuple = py::tuple(2); result_tuple[0] = torch::autograd::make_variable( @@ -928,7 +931,7 @@ void InitXlaModuleBindings(py::module m) { m.def("_xla_all_to_all", [](const at::Tensor& input, const std::shared_ptr& token, int64_t split_dimension, int64_t concat_dimension, - int64_t split_count, const py::list& groups) { + int64_t split_count, const py::list& groups, bool pin_layout) { std::vector> replica_groups = CreateReduceGroups(groups); at::Tensor result; @@ -937,7 +940,7 @@ void InitXlaModuleBindings(py::module m) { NoGilSection nogil; std::tie(result, new_token) = AllToAll(input, token, split_dimension, concat_dimension, - split_count, replica_groups); + split_count, replica_groups, pin_layout); } auto result_tuple = py::tuple(2); result_tuple[0] = torch::autograd::make_variable( @@ -945,39 +948,40 @@ void InitXlaModuleBindings(py::module m) { result_tuple[1] = new_token; return result_tuple; }); - m.def("_xla_all_gather", - [](const at::Tensor& input, const std::shared_ptr& token, - int64_t dim, int64_t shard_count, const py::list& groups) { - std::vector> replica_groups = - CreateReduceGroups(groups); - at::Tensor result; - std::shared_ptr new_token; - { - NoGilSection nogil; - std::tie(result, new_token) = - AllGather(input, token, dim, shard_count, replica_groups); - } - auto result_tuple = py::tuple(2); - result_tuple[0] = torch::autograd::make_variable( - result, /*requires_grad=*/input.requires_grad()); - result_tuple[1] = new_token; - return result_tuple; - }); - m.def("_xla_all_gather_out", [](at::Tensor& output, const at::Tensor& input, - const std::shared_ptr& token, - int64_t dim, int64_t shard_count, - const py::list& groups) { + m.def("_xla_all_gather", [](const at::Tensor& input, + const std::shared_ptr& token, + int64_t dim, int64_t shard_count, + const py::list& groups, bool pin_layout) { std::vector> replica_groups = CreateReduceGroups(groups); at::Tensor result; std::shared_ptr new_token; { NoGilSection nogil; - new_token = - AllGatherOut(output, input, token, dim, shard_count, replica_groups); + std::tie(result, new_token) = + AllGather(input, token, dim, shard_count, replica_groups, pin_layout); } - return new_token; + auto result_tuple = py::tuple(2); + result_tuple[0] = torch::autograd::make_variable( + result, /*requires_grad=*/input.requires_grad()); + result_tuple[1] = new_token; + return result_tuple; }); + m.def("_xla_all_gather_out", + [](at::Tensor& output, const at::Tensor& input, + const std::shared_ptr& token, int64_t dim, + int64_t shard_count, const py::list& groups, bool pin_layout) { + std::vector> replica_groups = + CreateReduceGroups(groups); + at::Tensor result; + std::shared_ptr new_token; + { + NoGilSection nogil; + new_token = AllGatherOut(output, input, token, dim, shard_count, + replica_groups, pin_layout); + } + return new_token; + }); m.def("_xla_collective_permute", [](const at::Tensor& input, const std::shared_ptr& token, const py::list& pairs) { @@ -999,7 +1003,8 @@ void InitXlaModuleBindings(py::module m) { m.def("_xla_reduce_scatter", [](const std::string& reduce_type, const at::Tensor& input, const std::shared_ptr& token, double scale, - int64_t scatter_dim, int64_t shard_count, const py::list& groups) { + int64_t scatter_dim, int64_t shard_count, const py::list& groups, + bool pin_layout) { std::vector> replica_groups = CreateReduceGroups(groups); at::Tensor result; @@ -1008,7 +1013,7 @@ void InitXlaModuleBindings(py::module m) { NoGilSection nogil; std::tie(result, new_token) = ReduceScatter(reduce_type, input, token, scale, scatter_dim, - shard_count, replica_groups); + shard_count, replica_groups, pin_layout); } auto result_tuple = py::tuple(2); result_tuple[0] = torch::autograd::make_variable( @@ -1020,16 +1025,16 @@ void InitXlaModuleBindings(py::module m) { [](const std::string& reduce_type, at::Tensor& output, const at::Tensor& input, const std::shared_ptr& token, double scale, int64_t scatter_dim, int64_t shard_count, - const py::list& groups) { + const py::list& groups, bool pin_layout) { std::vector> replica_groups = CreateReduceGroups(groups); at::Tensor result; std::shared_ptr new_token; { NoGilSection nogil; - new_token = - ReduceScatterOut(reduce_type, output, input, token, scale, - scatter_dim, shard_count, replica_groups); + new_token = ReduceScatterOut(reduce_type, output, input, token, + scale, scatter_dim, shard_count, + replica_groups, pin_layout); } return new_token; }); diff --git a/torch_xla/csrc/ops/all_gather.cpp b/torch_xla/csrc/ops/all_gather.cpp index 3049bab97cd9..f062226a2c2d 100644 --- a/torch_xla/csrc/ops/all_gather.cpp +++ b/torch_xla/csrc/ops/all_gather.cpp @@ -14,10 +14,11 @@ namespace { xla::Shape NodeOutputShape(const Value& input, const Value& token, int64_t dim, int64_t shard_count, - const std::vector>& groups) { + const std::vector>& groups, + bool pin_layout) { auto shape_fn = [&](absl::Span operands) -> xla::XlaOp { - AllGatherResult result = - BuildAllGather(operands[0], operands[1], dim, shard_count, groups); + AllGatherResult result = BuildAllGather(operands[0], operands[1], dim, + shard_count, groups, pin_layout); return xla::Tuple(operands[0].builder(), {result.result, result.token}); }; return InferOutputShape({input.xla_shape(), token.xla_shape()}, shape_fn); @@ -27,33 +28,36 @@ xla::Shape NodeOutputShape(const Value& input, const Value& token, int64_t dim, AllGather::AllGather(const Value& input, const Value& token, int64_t dim, int64_t shard_count, - std::vector> groups) + std::vector> groups, bool pin_layout) : Node(xla_all_gather, {input, token}, [&]() { - return NodeOutputShape(input, token, dim, shard_count, groups); + return NodeOutputShape(input, token, dim, shard_count, groups, + pin_layout); }, - /*num_outputs=*/2, torch::lazy::MHash(dim, shard_count, groups)), + /*num_outputs=*/2, + torch::lazy::MHash(dim, shard_count, groups, pin_layout)), dim_(dim), shard_count_(shard_count), - groups_(std::move(groups)) {} + groups_(std::move(groups)), + pin_layout_(pin_layout) {} torch::lazy::NodePtr AllGather::Clone(OpList operands) const { return ir::MakeNode(operands.at(0), operands.at(1), dim_, - shard_count_, groups_); + shard_count_, groups_, pin_layout_); } XlaOpVector AllGather::Lower(LoweringContext* loctx) const { xla::XlaOp input = loctx->GetOutputOp(operand(0)); xla::XlaOp token = loctx->GetOutputOp(operand(1)); AllGatherResult result = - BuildAllGather(input, token, dim_, shard_count_, groups_); + BuildAllGather(input, token, dim_, shard_count_, groups_, pin_layout_); return ReturnOps({result.result, result.token}, loctx); } std::string AllGather::ToString() const { std::stringstream ss; ss << Node::ToString() << ", dim=" << dim_ << ", shard_count=" << shard_count_ - << ", groups=("; + << ", pin_layout=" << pin_layout_ << ", groups=("; for (size_t i = 0; i < groups_.size(); ++i) { ss << (i == 0 ? "(" : ",("); ss << absl::StrJoin(groups_[i], ", ") << ")"; diff --git a/torch_xla/csrc/ops/all_gather.h b/torch_xla/csrc/ops/all_gather.h index b21492a02c38..7733204a6003 100644 --- a/torch_xla/csrc/ops/all_gather.h +++ b/torch_xla/csrc/ops/all_gather.h @@ -10,7 +10,8 @@ namespace ops { class AllGather : public Node { public: AllGather(const Value& input, const Value& token, int64_t dim, - int64_t shard_count, std::vector> groups); + int64_t shard_count, std::vector> groups, + bool pin_layout); std::string ToString() const override; @@ -24,10 +25,13 @@ class AllGather : public Node { const std::vector>& groups() const { return groups_; } + bool pin_layout() const { return pin_layout_; } + private: int64_t dim_; int64_t shard_count_; std::vector> groups_; + bool pin_layout_; }; } // namespace ops diff --git a/torch_xla/csrc/ops/all_reduce.cpp b/torch_xla/csrc/ops/all_reduce.cpp index 3da8b933a7d7..7aea57bae5df 100644 --- a/torch_xla/csrc/ops/all_reduce.cpp +++ b/torch_xla/csrc/ops/all_reduce.cpp @@ -34,20 +34,22 @@ std::vector GetOperandList(absl::Span operands, AllReduce::AllReduce(AllReduceType reduce_type, absl::Span operands, const Value& token, - double scale, std::vector> groups) + double scale, std::vector> groups, + bool pin_layout) : Node(xla_cross_replica_sum, GetOperandList(operands, token), [&]() { return NodeOutputShape(operands, token); }, /*num_outputs=*/operands.size() + 1, torch::lazy::MHash(torch::lazy::GetEnumValue(reduce_type), scale, - groups)), + groups, pin_layout)), reduce_type_(reduce_type), scale_(scale), - groups_(std::move(groups)) {} + groups_(std::move(groups)), + pin_layout_(pin_layout) {} torch::lazy::NodePtr AllReduce::Clone(OpList operands) const { std::vector operand_list(operands.begin(), operands.end() - 1); return ir::MakeNode(reduce_type_, operand_list, operands.back(), - scale_, groups_); + scale_, groups_, pin_layout_); } XlaOpVector AllReduce::Lower(LoweringContext* loctx) const { @@ -58,15 +60,16 @@ XlaOpVector AllReduce::Lower(LoweringContext* loctx) const { inputs.push_back(loctx->GetOutputOp(operand_list[i])); } xla::XlaOp token = loctx->GetOutputOp(operand_list.back()); - return ReturnOps(BuildAllReduce(reduce_type_, inputs, token, scale_, groups_), - loctx); + return ReturnOps( + BuildAllReduce(reduce_type_, inputs, token, scale_, groups_, pin_layout_), + loctx); } std::string AllReduce::ToString() const { std::stringstream ss; ss << Node::ToString() << ", reduce_type=" << torch::lazy::GetEnumValue(reduce_type_) - << ", scale=" << scale_ << ", groups=("; + << ", scale=" << scale_ << ", pin_layout=" << pin_layout_ << ", groups=("; for (size_t i = 0; i < groups_.size(); ++i) { ss << (i == 0 ? "(" : ",("); ss << absl::StrJoin(groups_[i], ", ") << ")"; diff --git a/torch_xla/csrc/ops/all_reduce.h b/torch_xla/csrc/ops/all_reduce.h index 331ecfde0dbe..da7187327fe1 100644 --- a/torch_xla/csrc/ops/all_reduce.h +++ b/torch_xla/csrc/ops/all_reduce.h @@ -11,7 +11,7 @@ class AllReduce : public Node { public: AllReduce(AllReduceType reduce_type, absl::Span operands, const Value& token, double scale, - std::vector> groups); + std::vector> groups, bool pin_layout); std::string ToString() const override; @@ -25,10 +25,13 @@ class AllReduce : public Node { const std::vector>& groups() const { return groups_; } + bool pin_layout() const { return pin_layout_; } + private: AllReduceType reduce_type_; double scale_; std::vector> groups_; + bool pin_layout_; }; } // namespace ops diff --git a/torch_xla/csrc/ops/all_to_all.cpp b/torch_xla/csrc/ops/all_to_all.cpp index aa25dc1cb1bc..c367d2fd5cb7 100644 --- a/torch_xla/csrc/ops/all_to_all.cpp +++ b/torch_xla/csrc/ops/all_to_all.cpp @@ -14,11 +14,12 @@ namespace { xla::Shape NodeOutputShape(const Value& input, const Value& token, int64_t split_dimension, int64_t concat_dimension, int64_t split_count, - const std::vector>& groups) { + const std::vector>& groups, + bool pin_layout) { auto shape_fn = [&](absl::Span operands) -> xla::XlaOp { AllToAllResult result = BuildAllToAll(operands[0], operands[1], split_dimension, - concat_dimension, split_count, groups); + concat_dimension, split_count, groups, pin_layout); return xla::Tuple(operands[0].builder(), {result.result, result.token}); }; return InferOutputShape({input.xla_shape(), token.xla_shape()}, shape_fn); @@ -29,31 +30,34 @@ xla::Shape NodeOutputShape(const Value& input, const Value& token, AllToAll::AllToAll(const Value& input, const Value& token, int64_t split_dimension, int64_t concat_dimension, int64_t split_count, - std::vector> groups) + std::vector> groups, bool pin_layout) : Node(xla_all_to_all, {input, token}, [&]() { return NodeOutputShape(input, token, split_dimension, - concat_dimension, split_count, groups); + concat_dimension, split_count, groups, + pin_layout); }, /*num_outputs=*/2, torch::lazy::MHash(split_dimension, concat_dimension, split_count, - groups)), + groups, pin_layout)), split_dimension_(split_dimension), concat_dimension_(concat_dimension), split_count_(split_count), - groups_(std::move(groups)) {} + groups_(std::move(groups)), + pin_layout_(pin_layout) {} torch::lazy::NodePtr AllToAll::Clone(OpList operands) const { return ir::MakeNode(operands.at(0), operands.at(1), split_dimension_, concat_dimension_, - split_count_, groups_); + split_count_, groups_, pin_layout_); } XlaOpVector AllToAll::Lower(LoweringContext* loctx) const { xla::XlaOp input = loctx->GetOutputOp(operand(0)); xla::XlaOp token = loctx->GetOutputOp(operand(1)); - AllToAllResult result = BuildAllToAll( - input, token, split_dimension_, concat_dimension_, split_count_, groups_); + AllToAllResult result = + BuildAllToAll(input, token, split_dimension_, concat_dimension_, + split_count_, groups_, pin_layout_); return ReturnOps({result.result, result.token}, loctx); } @@ -61,7 +65,8 @@ std::string AllToAll::ToString() const { std::stringstream ss; ss << Node::ToString() << ", split_dimension=" << split_dimension_ << ", concat_dimension=" << concat_dimension_ - << ", split_count=" << split_count_ << ", groups=("; + << ", split_count=" << split_count_ << ", pin_layout=" << pin_layout_ + << ", groups=("; for (size_t i = 0; i < groups_.size(); ++i) { ss << (i == 0 ? "(" : ",("); ss << absl::StrJoin(groups_[i], ", ") << ")"; diff --git a/torch_xla/csrc/ops/all_to_all.h b/torch_xla/csrc/ops/all_to_all.h index 4315116c824f..c2742dff6abe 100644 --- a/torch_xla/csrc/ops/all_to_all.h +++ b/torch_xla/csrc/ops/all_to_all.h @@ -11,7 +11,7 @@ class AllToAll : public Node { public: AllToAll(const Value& input, const Value& token, int64_t split_dimension, int64_t concat_dimension, int64_t split_count, - std::vector> groups); + std::vector> groups, bool pin_layout); std::string ToString() const override; @@ -27,11 +27,14 @@ class AllToAll : public Node { const std::vector>& groups() const { return groups_; } + bool pin_layout() const { return pin_layout_; } + private: int64_t split_dimension_; int64_t concat_dimension_; int64_t split_count_; std::vector> groups_; + bool pin_layout_; }; } // namespace ops diff --git a/torch_xla/csrc/ops/reduce_scatter.cpp b/torch_xla/csrc/ops/reduce_scatter.cpp index 0c343938cddf..8af4818ce412 100644 --- a/torch_xla/csrc/ops/reduce_scatter.cpp +++ b/torch_xla/csrc/ops/reduce_scatter.cpp @@ -15,12 +15,14 @@ namespace { xla::Shape NodeOutputShape(AllReduceType reduce_type, const Value input, const Value& token, double scale, int64_t scatter_dim, int64_t shard_count, - const std::vector>& groups) { + const std::vector>& groups, + bool pin_layout) { auto shape_fn = [&](absl::Span operands) -> xla::XlaOp { xla::XlaOp inputOp = operands[0]; xla::XlaOp tokenOp = operands[1]; - ReduceScatterResult result = BuildReduceScatter( - reduce_type, inputOp, tokenOp, scale, scatter_dim, shard_count, groups); + ReduceScatterResult result = + BuildReduceScatter(reduce_type, inputOp, tokenOp, scale, scatter_dim, + shard_count, groups, pin_layout); return xla::Tuple(operands[0].builder(), {result.result, result.token}); }; return InferOutputShape({input.xla_shape(), token.xla_shape()}, shape_fn); @@ -31,32 +33,36 @@ xla::Shape NodeOutputShape(AllReduceType reduce_type, const Value input, ReduceScatter::ReduceScatter(AllReduceType reduce_type, const Value& input, const Value& token, double scale, int64_t scatter_dim, int64_t shard_count, - std::vector> groups) + std::vector> groups, + bool pin_layout) : Node(xla_reduce_scatter, {input, token}, [&]() { return NodeOutputShape(reduce_type, input, token, scale, - scatter_dim, shard_count, groups); + scatter_dim, shard_count, groups, + pin_layout); }, /*num_outputs=*/2, torch::lazy::MHash(torch::lazy::GetEnumValue(reduce_type), scale, - scatter_dim, shard_count, groups)), + scatter_dim, shard_count, groups, pin_layout)), reduce_type_(reduce_type), scale_(scale), scatter_dim_(scatter_dim), shard_count_(shard_count), - groups_(std::move(groups)) {} + groups_(std::move(groups)), + pin_layout_(pin_layout) {} torch::lazy::NodePtr ReduceScatter::Clone(OpList operands) const { return ir::MakeNode(reduce_type_, operands.at(0), operands.at(1), scale_, scatter_dim_, - shard_count_, groups_); + shard_count_, groups_, pin_layout_); } XlaOpVector ReduceScatter::Lower(LoweringContext* loctx) const { xla::XlaOp input = loctx->GetOutputOp(operand(0)); xla::XlaOp token = loctx->GetOutputOp(operand(1)); - ReduceScatterResult result = BuildReduceScatter( - reduce_type_, input, token, scale_, scatter_dim_, shard_count_, groups_); + ReduceScatterResult result = + BuildReduceScatter(reduce_type_, input, token, scale_, scatter_dim_, + shard_count_, groups_, pin_layout_); return ReturnOps({result.result, result.token}, loctx); } @@ -65,7 +71,8 @@ std::string ReduceScatter::ToString() const { ss << Node::ToString() << ", reduce_type=" << torch::lazy::GetEnumValue(reduce_type_) << ", scale=" << scale_ << ", scatter_dim=" << scatter_dim_ - << ", shard_count=" << shard_count_ << ", groups=("; + << ", shard_count=" << shard_count_ << ", pin_layout=" << pin_layout_ + << ", groups=("; for (size_t i = 0; i < groups_.size(); ++i) { ss << (i == 0 ? "(" : ",("); ss << absl::StrJoin(groups_[i], ", ") << ")"; diff --git a/torch_xla/csrc/ops/reduce_scatter.h b/torch_xla/csrc/ops/reduce_scatter.h index b9ece78643fb..c541fc0da5e5 100644 --- a/torch_xla/csrc/ops/reduce_scatter.h +++ b/torch_xla/csrc/ops/reduce_scatter.h @@ -11,7 +11,8 @@ class ReduceScatter : public Node { public: ReduceScatter(AllReduceType reduce_type, const Value& input, const Value& token, double scale, int64_t scatter_dim, - int64_t shard_count, std::vector> groups); + int64_t shard_count, std::vector> groups, + bool pin_layout); std::string ToString() const override; @@ -25,12 +26,15 @@ class ReduceScatter : public Node { const std::vector>& groups() const { return groups_; } + bool pin_layout() const { return pin_layout_; } + private: AllReduceType reduce_type_; double scale_; int64_t scatter_dim_; int64_t shard_count_; std::vector> groups_; + bool pin_layout_; }; } // namespace ops diff --git a/torch_xla/csrc/tensor.h b/torch_xla/csrc/tensor.h index b181348342b9..189a608862e0 100644 --- a/torch_xla/csrc/tensor.h +++ b/torch_xla/csrc/tensor.h @@ -210,41 +210,46 @@ class XLATensor { ////////////////////////////////////////////////////////////////////////////// static std::pair all_reduce( const XLATensor& input, const ir::Value& token, AllReduceType reduce_type, - double scale, std::vector> groups); + double scale, std::vector> groups, bool pin_layout); static ir::Value all_reduce_(XLATensor& input, const ir::Value& token, AllReduceType reduce_type, double scale, - std::vector> groups); + std::vector> groups, + bool pin_layout); static ir::Value all_reduce(std::vector* inputs, const ir::Value& token, AllReduceType reduce_type, double scale, - std::vector> groups); + std::vector> groups, + bool pin_layout); static std::pair reduce_scatter( const XLATensor& input, const ir::Value& token, AllReduceType reduce_type, double scale, int64_t scatter_dim, int64_t shard_count, - std::vector> groups); + std::vector> groups, bool pin_layout); static ir::Value reduce_scatter_out(XLATensor& output, const XLATensor& input, const ir::Value& token, AllReduceType reduce_type, double scale, int64_t scatter_dim, int64_t shard_count, - std::vector> groups); + std::vector> groups, + bool pin_layout); static std::pair all_to_all( const XLATensor& input, const ir::Value& token, int64_t split_dimension, int64_t concat_dimension, int64_t split_count, - std::vector> groups); + std::vector> groups, bool pin_layout); static std::pair all_gather( const XLATensor& input, const ir::Value& token, int64_t dim, - int64_t shard_count, std::vector> groups); + int64_t shard_count, std::vector> groups, + bool pin_layout); static ir::Value all_gather_out(XLATensor& output, const XLATensor& input, const ir::Value& token, int64_t dim, int64_t shard_count, - std::vector> groups); + std::vector> groups, + bool pin_layout); static std::pair collective_permute( const XLATensor& input, const ir::Value& token, diff --git a/torch_xla/csrc/tensor_methods.cpp b/torch_xla/csrc/tensor_methods.cpp index cbbfe0ff6b5a..f8e939d0eaff 100644 --- a/torch_xla/csrc/tensor_methods.cpp +++ b/torch_xla/csrc/tensor_methods.cpp @@ -342,19 +342,20 @@ ViewInfo CreateAsStridedViewInfo(const xla::Shape& input_shape, ////////////////////////////////////////////////////////////////////////////// std::pair XLATensor::all_reduce( const XLATensor& input, const ir::Value& token, AllReduceType reduce_type, - double scale, std::vector> groups) { + double scale, std::vector> groups, bool pin_layout) { std::vector input_values({input.GetIrValue()}); torch::lazy::NodePtr node = ir::MakeNode( - reduce_type, input_values, token, scale, std::move(groups)); + reduce_type, input_values, token, scale, std::move(groups), pin_layout); return {input.CreateFrom(ir::Value(node, 0)), ir::Value(node, 1)}; } ir::Value XLATensor::all_reduce_(XLATensor& input, const ir::Value& token, AllReduceType reduce_type, double scale, - std::vector> groups) { + std::vector> groups, + bool pin_layout) { std::vector input_values({input.GetIrValue()}); torch::lazy::NodePtr node = ir::MakeNode( - reduce_type, input_values, token, scale, std::move(groups)); + reduce_type, input_values, token, scale, std::move(groups), pin_layout); input.SetInPlaceIrValue(ir::Value(node, 0)); return ir::Value(node, 1); } @@ -362,14 +363,15 @@ ir::Value XLATensor::all_reduce_(XLATensor& input, const ir::Value& token, ir::Value XLATensor::all_reduce(std::vector* inputs, const ir::Value& token, AllReduceType reduce_type, double scale, - std::vector> groups) { + std::vector> groups, + bool pin_layout) { std::vector input_values; input_values.reserve(inputs->size()); for (auto& input : *inputs) { input_values.push_back(input.GetIrValue()); } torch::lazy::NodePtr node = ir::MakeNode( - reduce_type, input_values, token, scale, std::move(groups)); + reduce_type, input_values, token, scale, std::move(groups), pin_layout); for (size_t i = 0; i < inputs->size(); ++i) { (*inputs)[i].SetInPlaceIrValue(ir::Value(node, i)); } @@ -379,20 +381,21 @@ ir::Value XLATensor::all_reduce(std::vector* inputs, std::pair XLATensor::reduce_scatter( const XLATensor& input, const ir::Value& token, AllReduceType reduce_type, double scale, int64_t scatter_dim, int64_t shard_count, - std::vector> groups) { + std::vector> groups, bool pin_layout) { torch::lazy::NodePtr node = ir::MakeNode( reduce_type, input.GetIrValue(), token, scale, scatter_dim, shard_count, - std::move(groups)); + std::move(groups), pin_layout); return {input.CreateFrom(ir::Value(node, 0)), ir::Value(node, 1)}; } ir::Value XLATensor::reduce_scatter_out( XLATensor& output, const XLATensor& input, const ir::Value& token, AllReduceType reduce_type, double scale, int64_t scatter_dim, - int64_t shard_count, std::vector> groups) { + int64_t shard_count, std::vector> groups, + bool pin_layout) { torch::lazy::NodePtr node = ir::MakeNode( reduce_type, input.GetIrValue(), token, scale, scatter_dim, shard_count, - std::move(groups)); + std::move(groups), pin_layout); output.SetIrValue(ir::Value(node, 0)); return ir::Value(node, 1); } @@ -400,27 +403,31 @@ ir::Value XLATensor::reduce_scatter_out( std::pair XLATensor::all_to_all( const XLATensor& input, const ir::Value& token, int64_t split_dimension, int64_t concat_dimension, int64_t split_count, - std::vector> groups) { + std::vector> groups, bool pin_layout) { torch::lazy::NodePtr node = ir::MakeNode( input.GetIrValue(), token, split_dimension, concat_dimension, split_count, - std::move(groups)); + std::move(groups), pin_layout); return {input.CreateFrom(ir::Value(node, 0)), ir::Value(node, 1)}; } std::pair XLATensor::all_gather( const XLATensor& input, const ir::Value& token, int64_t dim, - int64_t shard_count, std::vector> groups) { + int64_t shard_count, std::vector> groups, + bool pin_layout) { torch::lazy::NodePtr node = ir::MakeNode( - input.GetIrValue(), token, dim, shard_count, std::move(groups)); + input.GetIrValue(), token, dim, shard_count, std::move(groups), + pin_layout); return {input.CreateFrom(ir::Value(node, 0)), ir::Value(node, 1)}; } ir::Value XLATensor::all_gather_out(XLATensor& output, const XLATensor& input, const ir::Value& token, int64_t dim, int64_t shard_count, - std::vector> groups) { + std::vector> groups, + bool pin_layout) { torch::lazy::NodePtr node = ir::MakeNode( - input.GetIrValue(), token, dim, shard_count, std::move(groups)); + input.GetIrValue(), token, dim, shard_count, std::move(groups), + pin_layout); output.SetIrValue(ir::Value(node, 0)); return ir::Value(node, 1); } From e4e01e4a45787bef0948ff86683921ddd3b7649f Mon Sep 17 00:00:00 2001 From: JackCaoG Date: Tue, 19 Apr 2022 07:41:42 +0000 Subject: [PATCH 2/2] Only pin all_reduce layout --- torch_xla/core/xla_model.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/torch_xla/core/xla_model.py b/torch_xla/core/xla_model.py index 59f6b1b84270..44d6e3f193d8 100644 --- a/torch_xla/core/xla_model.py +++ b/torch_xla/core/xla_model.py @@ -593,7 +593,7 @@ def all_reduce(reduce_type, return results[0] if isinstance(inputs, torch.Tensor) else results -def all_gather(value, dim=0, groups=None, output=None, pin_layout=True): +def all_gather(value, dim=0, groups=None, output=None, pin_layout=False): """Performs an all-gather operation along a given dimension. Args: @@ -645,7 +645,7 @@ def all_to_all(value, concat_dimension, split_count, groups=None, - pin_layout=True): + pin_layout=False): """Performs an XLA `AllToAll()` operation on the input tensor. See: https://www.tensorflow.org/xla/operation_semantics#alltoall @@ -709,7 +709,7 @@ def reduce_scatter(reduce_type, shard_count, groups=None, output=None, - pin_layout=True): + pin_layout=False): """Performs a XLA `ReduceScatter()` operation on the input tensor. See: https://www.tensorflow.org/xla/operation_semantics#reducescatter