@@ -87,7 +87,7 @@ std::vector<xla::ReplicaGroup> CreateReduceGroups(
87
87
std::vector<xla::XlaOp> BuildAllReduce (
88
88
AllReduceType reduce_type, absl::Span<const xla::XlaOp> operands,
89
89
xla::XlaOp token, double scale,
90
- const std::vector<std::vector<int64_t >>& groups) {
90
+ const std::vector<std::vector<int64_t >>& groups, bool pin_layout ) {
91
91
std::vector<xla::ReplicaGroup> reduce_groups = CreateReduceGroups (groups);
92
92
// TODO: We use pseudo-tokens ATM, which are real values. This need to be
93
93
// switched to use the real XLA Token once support has been added to XLA
@@ -101,11 +101,19 @@ std::vector<xla::XlaOp> BuildAllReduce(
101
101
type_ctx.second .operand_shapes .push_back (
102
102
XlaHelpers::ShapeOfXlaOp (token_op));
103
103
104
- xla::XlaOp reduce = xla::AllReduce (
105
- xla::Tuple (operands[0 ].builder (), type_ctx.second .ops ),
106
- GetReduceComutation (reduce_type, type_ctx.first ), reduce_groups,
107
- /* channel_id=*/ absl::nullopt,
108
- MakeReduceShape (type_ctx.second .operand_shapes ));
104
+ xla::XlaOp reduce;
105
+ if (pin_layout) {
106
+ reduce = xla::AllReduce (
107
+ xla::Tuple (operands[0 ].builder (), type_ctx.second .ops ),
108
+ GetReduceComutation (reduce_type, type_ctx.first ), reduce_groups,
109
+ /* channel_id=*/ absl::nullopt,
110
+ /* shape_with_layout=*/
111
+ MakeReduceShape (type_ctx.second .operand_shapes ));
112
+ } else {
113
+ reduce = xla::AllReduce (
114
+ xla::Tuple (operands[0 ].builder (), type_ctx.second .ops ),
115
+ GetReduceComutation (reduce_type, type_ctx.first ), reduce_groups);
116
+ }
109
117
for (size_t i = 0 ; i < type_ctx.second .indices .size (); ++i) {
110
118
size_t op_idx = type_ctx.second .indices [i];
111
119
xla::XlaOp gte = xla::GetTupleElement (reduce, i);
@@ -128,28 +136,49 @@ std::vector<xla::XlaOp> BuildAllReduce(
128
136
AllToAllResult BuildAllToAll (xla::XlaOp input, xla::XlaOp token,
129
137
int64_t split_dimension, int64_t concat_dimension,
130
138
int64_t split_count,
131
- const std::vector<std::vector<int64_t >>& groups) {
139
+ const std::vector<std::vector<int64_t >>& groups,
140
+ bool pin_layout) {
132
141
std::vector<xla::ReplicaGroup> reduce_groups = CreateReduceGroups (groups);
133
142
const xla::Shape& input_shape = XlaHelpers::ShapeOfXlaOp (input);
134
- xla::Shape reduce_shape = MakeArrayShapeFromDimensions (
135
- input_shape.dimensions (), input_shape.dynamic_dimensions (),
136
- input_shape.element_type (), GetCurrentDevice ().device_type .hw_type );
137
143
TokenHandler token_handler (token);
138
- xla::XlaOp reduce_result = xla::AllToAll (
139
- token_handler.GetInput (input, &input_shape), split_dimension,
140
- concat_dimension, split_count, reduce_groups, reduce_shape.layout ());
144
+ xla::XlaOp reduce_result;
145
+ if (pin_layout) {
146
+ xla::Shape reduce_shape = MakeArrayShapeFromDimensions (
147
+ input_shape.dimensions (), input_shape.dynamic_dimensions (),
148
+ input_shape.element_type (), GetCurrentDevice ().device_type .hw_type );
149
+ reduce_result = xla::AllToAll (token_handler.GetInput (input, &input_shape),
150
+ split_dimension, concat_dimension,
151
+ split_count, reduce_groups,
152
+ /* layout=*/ reduce_shape.layout ());
153
+ } else {
154
+ reduce_result = xla::AllToAll (token_handler.GetInput (input, &input_shape),
155
+ split_dimension, concat_dimension,
156
+ split_count, reduce_groups);
157
+ }
141
158
return {reduce_result, token_handler.GetNewToken (reduce_result)};
142
159
}
143
160
144
- AllGatherResult BuildAllGather (
145
- xla::XlaOp input, xla::XlaOp token, int64_t dim, int64_t shard_count,
146
- const std::vector<std::vector<int64_t >>& groups) {
161
+ AllGatherResult BuildAllGather (xla::XlaOp input, xla::XlaOp token, int64_t dim,
162
+ int64_t shard_count,
163
+ const std::vector<std::vector<int64_t >>& groups,
164
+ bool pin_layout) {
147
165
std::vector<xla::ReplicaGroup> reduce_groups = CreateReduceGroups (groups);
148
166
const xla::Shape& input_shape = XlaHelpers::ShapeOfXlaOp (input);
149
167
TokenHandler token_handler (token);
150
- xla::XlaOp all_gather_result =
151
- xla::AllGather (token_handler.GetInput (input, &input_shape), dim,
152
- shard_count, reduce_groups);
168
+ xla::XlaOp all_gather_result;
169
+ if (pin_layout) {
170
+ xla::Shape reduce_shape = MakeArrayShapeFromDimensions (
171
+ input_shape.dimensions (), input_shape.dynamic_dimensions (),
172
+ input_shape.element_type (), GetCurrentDevice ().device_type .hw_type );
173
+ all_gather_result =
174
+ xla::AllGather (token_handler.GetInput (input, &input_shape), dim,
175
+ shard_count, reduce_groups, /* channel_id=*/ absl::nullopt,
176
+ /* layout=*/ reduce_shape.layout ());
177
+ } else {
178
+ all_gather_result =
179
+ xla::AllGather (token_handler.GetInput (input, &input_shape), dim,
180
+ shard_count, reduce_groups);
181
+ }
153
182
return {all_gather_result, token_handler.GetNewToken (all_gather_result)};
154
183
}
155
184
@@ -169,15 +198,26 @@ CollectivePermuteResult BuildCollectivePermute(
169
198
ReduceScatterResult BuildReduceScatter (
170
199
AllReduceType reduce_type, xla::XlaOp input, xla::XlaOp token, double scale,
171
200
int64_t scatter_dim, int64_t shard_count,
172
- const std::vector<std::vector<int64_t >>& groups) {
201
+ const std::vector<std::vector<int64_t >>& groups, bool pin_layout ) {
173
202
std::vector<xla::ReplicaGroup> reduce_groups = CreateReduceGroups (groups);
174
203
TokenHandler token_handler (token);
175
204
const xla::Shape& input_shape = XlaHelpers::ShapeOfXlaOp (input);
176
-
177
- xla::XlaOp reduce_result = xla::ReduceScatter (
178
- token_handler.GetInput (input, &input_shape),
179
- GetReduceComutation (reduce_type, input_shape.element_type ()), scatter_dim,
180
- shard_count, reduce_groups);
205
+ xla::XlaOp reduce_result;
206
+ if (pin_layout) {
207
+ xla::Shape reduce_shape = MakeArrayShapeFromDimensions (
208
+ input_shape.dimensions (), input_shape.dynamic_dimensions (),
209
+ input_shape.element_type (), GetCurrentDevice ().device_type .hw_type );
210
+ reduce_result = xla::ReduceScatter (
211
+ token_handler.GetInput (input, &input_shape),
212
+ GetReduceComutation (reduce_type, input_shape.element_type ()),
213
+ scatter_dim, shard_count, reduce_groups, /* channel_id=*/ absl::nullopt,
214
+ /* layout=*/ reduce_shape.layout ());
215
+ } else {
216
+ reduce_result = xla::ReduceScatter (
217
+ token_handler.GetInput (input, &input_shape),
218
+ GetReduceComutation (reduce_type, input_shape.element_type ()),
219
+ scatter_dim, shard_count, reduce_groups);
220
+ }
181
221
182
222
if (scale != 1.0 ) {
183
223
xla::XlaOp scaling_value = XlaHelpers::ScalarValue<float >(
0 commit comments