Skip to content

Commit 35625e7

Browse files
authored
Lower IsNan (#2969)
* Lower IsNan * Update OP_LOWERING_GUIDE
1 parent f34281f commit 35625e7

8 files changed

+31
-1
lines changed

OP_LOWERING_GUIDE.md

+1-1
Original file line numberDiff line numberDiff line change
@@ -24,4 +24,4 @@ All file mentioned below lives under the `xla/torch_xla/csrc` folder, with the e
2424
Our CircleCI runs PyTorch native python tests for every change and every day. Those tests will use XLA implementation if we provide a lowering. We usually don’t need to add additional python tests for PyTorch/XLA unless we want to verify some xla behaviors(like dynamic shape) or we skipped the pytorch native test for some reason. The python test should be added to `xla/test/test_operations.py` if it is required. We also need to add CPP tests in `xla/test/cpp/test_aten_xla_tensor.cpp`. This test should call PyTorch c++ API and verify our implementation yields the same result as PyTorch native implementation. We also need to verify if the xla implementation is called when the tensor is a XLA tensor by checking the `aten::op` and `xla::op` counters.
2525

2626
## Tips
27-
The process of lowering is breaking down the PyTorch operations into a sequence of XlaOp. To provide a good lowering of the PyTorch operation, one needs to have a good grasp of what XLA is capable of. Reading the XlaOp document and looking into how similar ops is lowered is the best way to achieve that. You can find a minimal Op lowering example in [this pr](https://github.com/pytorch/xla/pull/1592). You can also find a slightly more complicated example with backward lowering in [this pr](https://github.com/pytorch/xla/pull/1940).
27+
The process of lowering is breaking down the PyTorch operations into a sequence of XlaOp. To provide a good lowering of the PyTorch operation, one needs to have a good grasp of what XLA is capable of. Reading the XlaOp document and looking into how similar ops is lowered is the best way to achieve that. You can find a minimal Op lowering example in [this pr](https://github.com/pytorch/xla/pull/2969). You can also find a slightly more complicated example with backward lowering in [this pr](https://github.com/pytorch/xla/pull/1940).

test/cpp/test_aten_xla_tensor.cpp

+13
Original file line numberDiff line numberDiff line change
@@ -4350,6 +4350,19 @@ TEST_F(AtenXlaTensorTest, TestInverse) {
43504350
ExpectCounterChanged("xla::inverse", cpp_test::GetIgnoredCounters());
43514351
}
43524352

4353+
TEST_F(AtenXlaTensorTest, TestIsnan) {
4354+
torch::Tensor a = torch::tensor({1.0, 2.0, std::nan("1"), 4.0},
4355+
torch::TensorOptions(torch::kFloat));
4356+
torch::Tensor b = torch::isnan(a);
4357+
ForEachDevice([&](const torch::Device& device) {
4358+
torch::Tensor xla_a = CopyToDevice(a, device);
4359+
torch::Tensor xla_b = torch::isnan(xla_a);
4360+
AllEqual(b, xla_b);
4361+
});
4362+
ExpectCounterNotChanged("aten::.*", cpp_test::GetIgnoredCounters());
4363+
ExpectCounterChanged("xla::isnan", cpp_test::GetIgnoredCounters());
4364+
}
4365+
43534366
TEST_F(AtenXlaTensorTest, TestExpand) {
43544367
torch::Tensor a = torch::rand({3, 4}, torch::TensorOptions(torch::kFloat));
43554368
torch::Tensor b = a.expand({2, 3, 4}, /*implicit=*/false);

torch_xla/csrc/aten_xla_type.cpp

+6
Original file line numberDiff line numberDiff line change
@@ -1443,6 +1443,12 @@ at::Tensor inverse(const at::Tensor& self) {
14431443
XLATensor::inverse(bridge::GetXlaTensor(self)));
14441444
}
14451445

1446+
at::Tensor isnan(const at::Tensor& self) {
1447+
XLA_FN_COUNTER("xla::");
1448+
return bridge::AtenFromXlaTensor(
1449+
XLATensor::isnan(bridge::GetXlaTensor(self)));
1450+
}
1451+
14461452
at::Tensor kl_div(const at::Tensor& self, const at::Tensor& target,
14471453
int64_t reduction, bool log_target) {
14481454
XLA_FN_COUNTER("xla::");

torch_xla/csrc/ops/ops.cpp

+1
Original file line numberDiff line numberDiff line change
@@ -90,6 +90,7 @@ PTXLA_UNARY_OP(Ceil, at::aten::ceil, xla::Ceil);
9090
PTXLA_UNARY_OP(Floor, at::aten::floor, xla::Floor);
9191
PTXLA_UNARY_OP(Round, at::aten::round, xla::RoundToEven);
9292
PTXLA_UNARY_OP(Not, at::aten::bitwise_not, xla::Not);
93+
PTXLA_UNARY_OP(IsNan, at::aten::isnan, IsNan);
9394

9495
PTXLA_BINARY_OP(Min, at::aten::min, xla::Min);
9596
PTXLA_BINARY_OP(Max, at::aten::max, xla::Max);

torch_xla/csrc/ops/ops.h

+2
Original file line numberDiff line numberDiff line change
@@ -207,6 +207,8 @@ NodePtr LogDet(const Value& input);
207207

208208
NodePtr Inverse(const Value& input);
209209

210+
NodePtr IsNan(const Value& input);
211+
210212
NodePtr BaddBmm(const Value& lhs, const Value& rhs, const Value& bias,
211213
const Value& product_multiplier, const Value& bias_multiplier);
212214

torch_xla/csrc/tensor.h

+2
Original file line numberDiff line numberDiff line change
@@ -608,6 +608,8 @@ class XLATensor {
608608

609609
static XLATensor inverse(const XLATensor& input);
610610

611+
static XLATensor isnan(const XLATensor& input);
612+
611613
static XLATensor kl_div_backward(const XLATensor& grad_output,
612614
const XLATensor& input,
613615
const XLATensor& target,

torch_xla/csrc/tensor_methods.cpp

+5
Original file line numberDiff line numberDiff line change
@@ -1365,6 +1365,11 @@ XLATensor XLATensor::inverse(const XLATensor& input) {
13651365
return input.CreateFrom(ir::ops::Inverse(input.GetIrValue()));
13661366
}
13671367

1368+
XLATensor XLATensor::isnan(const XLATensor& input) {
1369+
return input.CreateFrom(ir::ops::IsNan(input.GetIrValue()),
1370+
at::ScalarType::Bool);
1371+
}
1372+
13681373
XLATensor XLATensor::kl_div_backward(const XLATensor& grad_output,
13691374
const XLATensor& input,
13701375
const XLATensor& target,

xla_native_functions.yaml

+1
Original file line numberDiff line numberDiff line change
@@ -70,6 +70,7 @@ supported:
7070
- index_put_
7171
- _index_put_impl_
7272
- inverse
73+
- isnan
7374
- kl_div
7475
- kl_div_backward
7576
- kthvalue

0 commit comments

Comments
 (0)