Skip to content

Revert PR#3450 and use sparse_gather in gather #3566

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 4 commits into from
May 16, 2022
Merged
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 4 additions & 3 deletions torch_xla/csrc/aten_xla_type.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1495,10 +1495,11 @@ at::Tensor XLANativeFunctions::frac(const at::Tensor& self) {

at::Tensor XLANativeFunctions::gather(const at::Tensor& self, int64_t dim,
const at::Tensor& index,
bool /* sparse_grad */) {
bool sparse_grad) {
XLA_FN_COUNTER("xla::");
return bridge::AtenFromXlaTensor(XLATensor::gather(
Copy link
Contributor Author

@yeounoh yeounoh May 13, 2022

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ditto --> re-ran clang-format-7, nothing changed. Will leave it for now.

bridge::GetXlaTensor(self), dim, bridge::GetXlaTensor(index)));
return bridge::AtenFromXlaTensor(
XLATensor::gather(bridge::GetXlaTensor(self), dim,
bridge::GetXlaTensor(index), sparse_grad));
}

at::Tensor XLANativeFunctions::ge(const at::Tensor& self,
Expand Down
14 changes: 9 additions & 5 deletions torch_xla/csrc/ops/gather.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -8,26 +8,30 @@

namespace torch_xla {

Gather::Gather(const XlaValue& input, int64_t dim, const XlaValue& index)
Gather::Gather(const XlaValue& input, int64_t dim, const XlaValue& index,
bool sparse_grad)
: XlaNode(torch::lazy::OpKind(at::aten::gather), {input, index},
xla::ShapeUtil::MakeShape(input.xla_shape().element_type(),
index.xla_shape().dimensions()),
/*num_outputs=*/1, torch::lazy::MHash(dim)),
dim_(dim) {}
dim_(dim),
sparse_grad_(sparse_grad) {}

torch::lazy::NodePtr Gather::Clone(OpList operands) const {
return torch::lazy::MakeNode<Gather>(operands.at(0), dim_, operands.at(1));
return torch::lazy::MakeNode<Gather>(operands.at(0), dim_, operands.at(1),
sparse_grad_);
}

XlaOpVector Gather::Lower(LoweringContext* loctx) const {
xla::XlaOp input = loctx->GetOutputOp(operand(0));
xla::XlaOp index = loctx->GetOutputOp(operand(1));
return ReturnOp(xla::TorchGather(input, index, dim_, /*sparse=*/true), loctx);
return ReturnOp(xla::TorchGather(input, index, dim_, sparse_grad_), loctx);
}

std::string Gather::ToString() const {
std::stringstream ss;
ss << XlaNode::ToString() << ", dim=" << dim_;
ss << XlaNode::ToString() << ", dim=" << dim_
<< ", sparse_grad=" << sparse_grad_;
return ss.str();
}

Expand Down
4 changes: 3 additions & 1 deletion torch_xla/csrc/ops/gather.h
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,8 @@ namespace torch_xla {

class Gather : public XlaNode {
public:
Gather(const XlaValue& input, int64_t dim, const XlaValue& index);
Gather(const XlaValue& input, int64_t dim, const XlaValue& index,
bool sparse_grad);

std::string ToString() const override;

Expand All @@ -18,6 +19,7 @@ class Gather : public XlaNode {

private:
int64_t dim_;
bool sparse_grad_;
};

} // namespace torch_xla
2 changes: 1 addition & 1 deletion torch_xla/csrc/tensor.h
Original file line number Diff line number Diff line change
Expand Up @@ -616,7 +616,7 @@ class XLATensor {
c10::optional<at::ScalarType> scalar_type);

static XLATensor gather(const XLATensor& input, int64_t dim,
const XLATensor& index);
const XLATensor& index, bool sparse_grad);

static XLATensor ge(const XLATensor& input, const at::Scalar& other);

Expand Down
2 changes: 1 addition & 1 deletion torch_xla/csrc/tensor_methods.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1403,7 +1403,7 @@ XLATensor XLATensor::full_like(const XLATensor& input,
}

XLATensor XLATensor::gather(const XLATensor& input, int64_t dim,
const XLATensor& index) {
const XLATensor& index, bool sparse_grad) {
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Don't you need to pass sparse_grad to line 1418?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

return input.CreateFrom(torch::lazy::MakeNode<Gather>(
      input.GetIrValue(), canonical_dim, index.GetIrValue(), sparse_grad));

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It's being passed already?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actually, it's not picked in this diff?! My local branch is up to date and shows the correct one -- but the github file view doesn't.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ok, VS CODE didn't write out the change to the file, resyncing and recommiting. Thanks @miladm

xla::Shape input_shape = input.shape();
xla::Shape index_shape = index.shape();
XLA_CHECK_EQ(input_shape.rank(), index_shape.rank());
Expand Down