9
9
namespace torch_xla {
10
10
namespace ir {
11
11
namespace ops {
12
- namespace {
13
-
14
- xla::Shape NodeOutputShape (const Value& input, const Value& index,
15
- int64_t dim) {
16
- auto lower_for_shape_fn =
17
- [&](absl::Span<const xla::XlaOp> operands) -> xla::XlaOp {
18
- return xla::TorchGather (operands[0 ], operands[1 ], dim,
19
- IsSparseGather (operands[0 ], operands[1 ], dim));
20
- };
21
- return InferOutputShape ({input.xla_shape (), index .xla_shape ()},
22
- lower_for_shape_fn);
23
- }
24
-
25
- } // namespace
26
12
27
13
Gather::Gather (const Value& input, int64_t dim, const Value& index)
28
14
: Node(torch::lazy::OpKind(at::aten::gather), {input, index },
29
- [&]() { return NodeOutputShape (input, index , dim); },
15
+ xla::ShapeUtil::MakeShape (input.xla_shape().element_type(),
16
+ index.xla_shape().dimensions()),
30
17
/* num_outputs=*/ 1 , torch::lazy::MHash(dim)),
31
18
dim_ (dim) {}
32
19
@@ -37,9 +24,7 @@ NodePtr Gather::Clone(OpList operands) const {
37
24
XlaOpVector Gather::Lower (LoweringContext* loctx) const {
38
25
xla::XlaOp input = loctx->GetOutputOp (operand (0 ));
39
26
xla::XlaOp index = loctx->GetOutputOp (operand (1 ));
40
- return ReturnOp (
41
- xla::TorchGather (input, index , dim_, IsSparseGather (input, index , dim_)),
42
- loctx);
27
+ return ReturnOp (xla::TorchGather (input, index , dim_, /* sparse=*/ true ), loctx);
43
28
}
44
29
45
30
std::string Gather::ToString () const {
0 commit comments