Skip to content

Lower IsNan #2969

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 2 commits into from
Jun 2, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion OP_LOWERING_GUIDE.md
Original file line number Diff line number Diff line change
Expand Up @@ -24,4 +24,4 @@ All file mentioned below lives under the `xla/torch_xla/csrc` folder, with the e
Our CircleCI runs PyTorch native python tests for every change and every day. Those tests will use XLA implementation if we provide a lowering. We usually don’t need to add additional python tests for PyTorch/XLA unless we want to verify some xla behaviors(like dynamic shape) or we skipped the pytorch native test for some reason. The python test should be added to `xla/test/test_operations.py` if it is required. We also need to add CPP tests in `xla/test/cpp/test_aten_xla_tensor.cpp`. This test should call PyTorch c++ API and verify our implementation yields the same result as PyTorch native implementation. We also need to verify if the xla implementation is called when the tensor is a XLA tensor by checking the `aten::op` and `xla::op` counters.

## Tips
The process of lowering is breaking down the PyTorch operations into a sequence of XlaOp. To provide a good lowering of the PyTorch operation, one needs to have a good grasp of what XLA is capable of. Reading the XlaOp document and looking into how similar ops is lowered is the best way to achieve that. You can find a minimal Op lowering example in [this pr](https://github.com/pytorch/xla/pull/1592). You can also find a slightly more complicated example with backward lowering in [this pr](https://github.com/pytorch/xla/pull/1940).
The process of lowering is breaking down the PyTorch operations into a sequence of XlaOp. To provide a good lowering of the PyTorch operation, one needs to have a good grasp of what XLA is capable of. Reading the XlaOp document and looking into how similar ops is lowered is the best way to achieve that. You can find a minimal Op lowering example in [this pr](https://github.com/pytorch/xla/pull/2969). You can also find a slightly more complicated example with backward lowering in [this pr](https://github.com/pytorch/xla/pull/1940).
13 changes: 13 additions & 0 deletions test/cpp/test_aten_xla_tensor.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -4350,6 +4350,19 @@ TEST_F(AtenXlaTensorTest, TestInverse) {
ExpectCounterChanged("xla::inverse", cpp_test::GetIgnoredCounters());
}

TEST_F(AtenXlaTensorTest, TestIsnan) {
torch::Tensor a = torch::tensor({1.0, 2.0, std::nan("1"), 4.0},
torch::TensorOptions(torch::kFloat));
torch::Tensor b = torch::isnan(a);
ForEachDevice([&](const torch::Device& device) {
torch::Tensor xla_a = CopyToDevice(a, device);
torch::Tensor xla_b = torch::isnan(xla_a);
AllEqual(b, xla_b);
});
ExpectCounterNotChanged("aten::.*", cpp_test::GetIgnoredCounters());
ExpectCounterChanged("xla::isnan", cpp_test::GetIgnoredCounters());
}

TEST_F(AtenXlaTensorTest, TestExpand) {
torch::Tensor a = torch::rand({3, 4}, torch::TensorOptions(torch::kFloat));
torch::Tensor b = a.expand({2, 3, 4}, /*implicit=*/false);
Expand Down
6 changes: 6 additions & 0 deletions torch_xla/csrc/aten_xla_type.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1443,6 +1443,12 @@ at::Tensor inverse(const at::Tensor& self) {
XLATensor::inverse(bridge::GetXlaTensor(self)));
}

at::Tensor isnan(const at::Tensor& self) {
XLA_FN_COUNTER("xla::");
return bridge::AtenFromXlaTensor(
XLATensor::isnan(bridge::GetXlaTensor(self)));
}

at::Tensor kl_div(const at::Tensor& self, const at::Tensor& target,
int64_t reduction, bool log_target) {
XLA_FN_COUNTER("xla::");
Expand Down
1 change: 1 addition & 0 deletions torch_xla/csrc/ops/ops.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -90,6 +90,7 @@ PTXLA_UNARY_OP(Ceil, at::aten::ceil, xla::Ceil);
PTXLA_UNARY_OP(Floor, at::aten::floor, xla::Floor);
PTXLA_UNARY_OP(Round, at::aten::round, xla::RoundToEven);
PTXLA_UNARY_OP(Not, at::aten::bitwise_not, xla::Not);
PTXLA_UNARY_OP(IsNan, at::aten::isnan, IsNan);

PTXLA_BINARY_OP(Min, at::aten::min, xla::Min);
PTXLA_BINARY_OP(Max, at::aten::max, xla::Max);
Expand Down
2 changes: 2 additions & 0 deletions torch_xla/csrc/ops/ops.h
Original file line number Diff line number Diff line change
Expand Up @@ -207,6 +207,8 @@ NodePtr LogDet(const Value& input);

NodePtr Inverse(const Value& input);

NodePtr IsNan(const Value& input);

NodePtr BaddBmm(const Value& lhs, const Value& rhs, const Value& bias,
const Value& product_multiplier, const Value& bias_multiplier);

Expand Down
2 changes: 2 additions & 0 deletions torch_xla/csrc/tensor.h
Original file line number Diff line number Diff line change
Expand Up @@ -608,6 +608,8 @@ class XLATensor {

static XLATensor inverse(const XLATensor& input);

static XLATensor isnan(const XLATensor& input);

static XLATensor kl_div_backward(const XLATensor& grad_output,
const XLATensor& input,
const XLATensor& target,
Expand Down
5 changes: 5 additions & 0 deletions torch_xla/csrc/tensor_methods.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1365,6 +1365,11 @@ XLATensor XLATensor::inverse(const XLATensor& input) {
return input.CreateFrom(ir::ops::Inverse(input.GetIrValue()));
}

XLATensor XLATensor::isnan(const XLATensor& input) {
return input.CreateFrom(ir::ops::IsNan(input.GetIrValue()),
at::ScalarType::Bool);
}

XLATensor XLATensor::kl_div_backward(const XLATensor& grad_output,
const XLATensor& input,
const XLATensor& target,
Expand Down
1 change: 1 addition & 0 deletions xla_native_functions.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -70,6 +70,7 @@ supported:
- index_put_
- _index_put_impl_
- inverse
- isnan
- kl_div
- kl_div_backward
- kthvalue
Expand Down