Skip to content

Commit cfb6e69

Browse files
authored
Fix gather by using xla sparse gather (#3450)
All tests passing. Merging the PR.
1 parent c496278 commit cfb6e69

File tree

2 files changed

+14
-21
lines changed

2 files changed

+14
-21
lines changed

torch_xla/csrc/ops/gather.cpp

Lines changed: 3 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -9,24 +9,11 @@
99
namespace torch_xla {
1010
namespace ir {
1111
namespace ops {
12-
namespace {
13-
14-
xla::Shape NodeOutputShape(const Value& input, const Value& index,
15-
int64_t dim) {
16-
auto lower_for_shape_fn =
17-
[&](absl::Span<const xla::XlaOp> operands) -> xla::XlaOp {
18-
return xla::TorchGather(operands[0], operands[1], dim,
19-
IsSparseGather(operands[0], operands[1], dim));
20-
};
21-
return InferOutputShape({input.xla_shape(), index.xla_shape()},
22-
lower_for_shape_fn);
23-
}
24-
25-
} // namespace
2612

2713
Gather::Gather(const Value& input, int64_t dim, const Value& index)
2814
: Node(torch::lazy::OpKind(at::aten::gather), {input, index},
29-
[&]() { return NodeOutputShape(input, index, dim); },
15+
xla::ShapeUtil::MakeShape(input.xla_shape().element_type(),
16+
index.xla_shape().dimensions()),
3017
/*num_outputs=*/1, torch::lazy::MHash(dim)),
3118
dim_(dim) {}
3219

@@ -37,9 +24,7 @@ NodePtr Gather::Clone(OpList operands) const {
3724
XlaOpVector Gather::Lower(LoweringContext* loctx) const {
3825
xla::XlaOp input = loctx->GetOutputOp(operand(0));
3926
xla::XlaOp index = loctx->GetOutputOp(operand(1));
40-
return ReturnOp(
41-
xla::TorchGather(input, index, dim_, IsSparseGather(input, index, dim_)),
42-
loctx);
27+
return ReturnOp(xla::TorchGather(input, index, dim_, /*sparse=*/true), loctx);
4328
}
4429

4530
std::string Gather::ToString() const {

torch_xla/csrc/tensor_methods.cpp

Lines changed: 11 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1391,10 +1391,18 @@ XLATensor XLATensor::full_like(const XLATensor& input,
13911391

13921392
XLATensor XLATensor::gather(const XLATensor& input, int64_t dim,
13931393
const XLATensor& index) {
1394+
xla::Shape input_shape = input.shape();
1395+
xla::Shape index_shape = index.shape();
1396+
XLA_CHECK_EQ(input_shape.rank(), index_shape.rank());
1397+
int64_t canonical_dim =
1398+
torch::lazy::GetCanonicalDimensionIndex(dim, input_shape.rank());
1399+
for (size_t dim = 0; dim < input_shape.rank(); dim++) {
1400+
if (dim != canonical_dim) {
1401+
XLA_CHECK_LE(index.size(dim), input.size(dim));
1402+
}
1403+
}
13941404
return input.CreateFrom(ir::MakeNode<ir::ops::Gather>(
1395-
input.GetIrValue(),
1396-
torch::lazy::GetCanonicalDimensionIndex(dim, input.shape().get().rank()),
1397-
index.GetIrValue()));
1405+
input.GetIrValue(), canonical_dim, index.GetIrValue()));
13981406
}
13991407

14001408
XLATensor XLATensor::ge(const XLATensor& input, const at::Scalar& other) {

0 commit comments

Comments
 (0)