Skip to content

Commit 36583d4

Browse files
authored
Codegen lt le (#3876)
1 parent 6bfcd24 commit 36583d4

File tree

6 files changed

+86
-32
lines changed

6 files changed

+86
-32
lines changed

test/cpp/test_aten_xla_tensor.cpp

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -612,6 +612,8 @@ TEST_F(AtenXlaTensorTest, TestLe) {
612612
torch::Tensor xla_c = torch::le(xla_a, xla_b);
613613
AllEqual(c, xla_c);
614614
});
615+
ExpectCounterNotChanged("aten::.*", cpp_test::GetIgnoredCounters());
616+
ExpectCounterChanged("xla::le", cpp_test::GetIgnoredCounters());
615617
}
616618

617619
TEST_F(AtenXlaTensorTest, TestLeInplace) {
@@ -627,6 +629,8 @@ TEST_F(AtenXlaTensorTest, TestLeInplace) {
627629
xla_a.le_(xla_b);
628630
AllClose(xla_a, a);
629631
});
632+
ExpectCounterNotChanged("aten::.*", cpp_test::GetIgnoredCounters());
633+
ExpectCounterChanged("xla::le", cpp_test::GetIgnoredCounters());
630634
}
631635

632636
TEST_F(AtenXlaTensorTest, TestGt) {
@@ -670,6 +674,8 @@ TEST_F(AtenXlaTensorTest, TestLt) {
670674
torch::Tensor xla_c = torch::lt(xla_a, xla_b);
671675
AllEqual(c, xla_c);
672676
});
677+
ExpectCounterNotChanged("aten::.*", cpp_test::GetIgnoredCounters());
678+
ExpectCounterChanged("xla::lt", cpp_test::GetIgnoredCounters());
673679
}
674680

675681
TEST_F(AtenXlaTensorTest, TestLtInplace) {
@@ -685,6 +691,8 @@ TEST_F(AtenXlaTensorTest, TestLtInplace) {
685691
xla_a.lt_(xla_b);
686692
AllClose(xla_a, a);
687693
});
694+
ExpectCounterNotChanged("aten::.*", cpp_test::GetIgnoredCounters());
695+
ExpectCounterChanged("xla::lt", cpp_test::GetIgnoredCounters());
688696
}
689697

690698
TEST_F(AtenXlaTensorTest, TestNeScalar) {
@@ -746,6 +754,8 @@ TEST_F(AtenXlaTensorTest, TestLeScalar) {
746754
torch::Tensor xla_result = torch::le(xla_input, other);
747755
AllEqual(result, xla_result);
748756
});
757+
ExpectCounterNotChanged("aten::.*", cpp_test::GetIgnoredCounters());
758+
ExpectCounterChanged("xla::le", cpp_test::GetIgnoredCounters());
749759
}
750760

751761
TEST_F(AtenXlaTensorTest, TestLeScalarInplace) {
@@ -759,6 +769,8 @@ TEST_F(AtenXlaTensorTest, TestLeScalarInplace) {
759769
xla_input.le_(other);
760770
AllClose(xla_input, input);
761771
});
772+
ExpectCounterNotChanged("aten::.*", cpp_test::GetIgnoredCounters());
773+
ExpectCounterChanged("xla::le", cpp_test::GetIgnoredCounters());
762774
}
763775

764776
TEST_F(AtenXlaTensorTest, TestGtScalar) {
@@ -798,6 +810,8 @@ TEST_F(AtenXlaTensorTest, TestLtScalar) {
798810
torch::Tensor xla_result = torch::lt(xla_input, other);
799811
AllEqual(result, xla_result);
800812
});
813+
ExpectCounterNotChanged("aten::.*", cpp_test::GetIgnoredCounters());
814+
ExpectCounterChanged("xla::lt", cpp_test::GetIgnoredCounters());
801815
}
802816

803817
TEST_F(AtenXlaTensorTest, TestLtScalarInplace) {
@@ -811,6 +825,8 @@ TEST_F(AtenXlaTensorTest, TestLtScalarInplace) {
811825
xla_input.lt_(other);
812826
AllClose(xla_input, input);
813827
});
828+
ExpectCounterNotChanged("aten::.*", cpp_test::GetIgnoredCounters());
829+
ExpectCounterChanged("xla::lt", cpp_test::GetIgnoredCounters());
814830
}
815831

816832
TEST_F(AtenXlaTensorTest, TestIntegerAdd) {

torch_xla/csrc/aten_xla_type.cpp

Lines changed: 0 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -1524,20 +1524,6 @@ std::tuple<at::Tensor, at::Tensor> XLANativeFunctions::kthvalue(
15241524
bridge::AtenFromXlaTensor(std::get<1>(results)));
15251525
}
15261526

1527-
at::Tensor XLANativeFunctions::le(const at::Tensor& self,
1528-
const at::Scalar& other) {
1529-
XLA_FN_COUNTER("xla::");
1530-
return bridge::AtenFromXlaTensor(
1531-
XLATensor::le(bridge::GetXlaTensor(self), other));
1532-
}
1533-
1534-
at::Tensor XLANativeFunctions::le(const at::Tensor& self,
1535-
const at::Tensor& other) {
1536-
XLA_FN_COUNTER("xla::");
1537-
return bridge::AtenFromXlaTensor(
1538-
XLATensor::le(bridge::GetXlaTensor(self), bridge::GetXlaTensor(other)));
1539-
}
1540-
15411527
at::Tensor XLANativeFunctions::leaky_relu(const at::Tensor& self,
15421528
const at::Scalar& negative_slope) {
15431529
XLA_FN_COUNTER("xla::");
@@ -1641,20 +1627,6 @@ at::Tensor XLANativeFunctions::xlogy(const at::Tensor& self,
16411627
bridge::GetXlaTensor(self), bridge::GetXlaTensor(other)));
16421628
}
16431629

1644-
at::Tensor XLANativeFunctions::lt(const at::Tensor& self,
1645-
const at::Scalar& other) {
1646-
XLA_FN_COUNTER("xla::");
1647-
return bridge::AtenFromXlaTensor(
1648-
XLATensor::lt(bridge::GetXlaTensor(self), other));
1649-
}
1650-
1651-
at::Tensor XLANativeFunctions::lt(const at::Tensor& self,
1652-
const at::Tensor& other) {
1653-
XLA_FN_COUNTER("xla::");
1654-
return bridge::AtenFromXlaTensor(
1655-
XLATensor::lt(bridge::GetXlaTensor(self), bridge::GetXlaTensor(other)));
1656-
}
1657-
16581630
at::Tensor& XLANativeFunctions::masked_fill_(at::Tensor& self,
16591631
const at::Tensor& mask,
16601632
const at::Scalar& value) {

torch_xla/csrc/ops/ops_lower_fn.cpp

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -252,6 +252,30 @@ torch_xla::XlaOpVector Logdet::Lower(LoweringContext* loctx) const {
252252
return ReturnOp(xla::LogDet(xla_input), loctx);
253253
}
254254

255+
torch_xla::XlaOpVector LeScalar::Lower(LoweringContext* loctx) const {
256+
xla::XlaOp xla_input = loctx->GetOutputOp(operand(0));
257+
xla::XlaOp xla_other = loctx->GetOutputOp(operand(1));
258+
return ReturnOp(BuildComparisonOp(at::aten::le, xla_input, xla_other), loctx);
259+
}
260+
261+
torch_xla::XlaOpVector LeTensor::Lower(LoweringContext* loctx) const {
262+
xla::XlaOp xla_input = loctx->GetOutputOp(operand(0));
263+
xla::XlaOp xla_other = loctx->GetOutputOp(operand(1));
264+
return ReturnOp(BuildComparisonOp(at::aten::le, xla_input, xla_other), loctx);
265+
}
266+
267+
torch_xla::XlaOpVector LtScalar::Lower(LoweringContext* loctx) const {
268+
xla::XlaOp xla_input = loctx->GetOutputOp(operand(0));
269+
xla::XlaOp xla_other = loctx->GetOutputOp(operand(1));
270+
return ReturnOp(BuildComparisonOp(at::aten::lt, xla_input, xla_other), loctx);
271+
}
272+
273+
torch_xla::XlaOpVector LtTensor::Lower(LoweringContext* loctx) const {
274+
xla::XlaOp xla_input = loctx->GetOutputOp(operand(0));
275+
xla::XlaOp xla_other = loctx->GetOutputOp(operand(1));
276+
return ReturnOp(BuildComparisonOp(at::aten::lt, xla_input, xla_other), loctx);
277+
}
278+
255279
torch_xla::XlaOpVector LogicalAnd::Lower(LoweringContext* loctx) const {
256280
xla::XlaOp xla_input = loctx->GetOutputOp(operand(0));
257281
xla::XlaOp xla_other = loctx->GetOutputOp(operand(1));

torch_xla/csrc/ops/ops_xla_shape_fn.cpp

Lines changed: 30 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -307,6 +307,36 @@ xla::Shape IsnanOutputShape(const torch::lazy::Value& input) {
307307
return isnan_shape;
308308
}
309309

310+
xla::Shape LeScalarOutputShape(const torch::lazy::Value& self,
311+
const torch::lazy::Value& other) {
312+
auto lower_for_shape_fn =
313+
[&](absl::Span<const xla::XlaOp> operands) -> xla::XlaOp {
314+
return BuildComparisonOp(at::aten::le, operands[0], operands[1]);
315+
};
316+
return InferOutputShape({GetXlaShape(self), GetXlaShape(other)},
317+
lower_for_shape_fn);
318+
}
319+
320+
xla::Shape LeTensorOutputShape(const torch::lazy::Value& self,
321+
const torch::lazy::Value& other) {
322+
return LeScalarOutputShape(self, other);
323+
}
324+
325+
xla::Shape LtScalarOutputShape(const torch::lazy::Value& self,
326+
const torch::lazy::Value& other) {
327+
auto lower_for_shape_fn =
328+
[&](absl::Span<const xla::XlaOp> operands) -> xla::XlaOp {
329+
return BuildComparisonOp(at::aten::lt, operands[0], operands[1]);
330+
};
331+
return InferOutputShape({GetXlaShape(self), GetXlaShape(other)},
332+
lower_for_shape_fn);
333+
}
334+
335+
xla::Shape LtTensorOutputShape(const torch::lazy::Value& self,
336+
const torch::lazy::Value& other) {
337+
return LtScalarOutputShape(self, other);
338+
}
339+
310340
xla::Shape LogdetOutputShape(const torch::lazy::Value& input) {
311341
const xla::Shape& input_shape = GetXlaShape(input);
312342
XLA_CHECK_GE(input_shape.rank(), 2) << input_shape;

torch_xla/csrc/ops/ops_xla_shape_fn.h

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -106,6 +106,18 @@ xla::Shape InverseOutputShape(const torch::lazy::Value& input);
106106

107107
xla::Shape IsnanOutputShape(const torch::lazy::Value& input);
108108

109+
xla::Shape LeScalarOutputShape(const torch::lazy::Value& self,
110+
const torch::lazy::Value& other);
111+
112+
xla::Shape LeTensorOutputShape(const torch::lazy::Value& self,
113+
const torch::lazy::Value& other);
114+
115+
xla::Shape LtScalarOutputShape(const torch::lazy::Value& self,
116+
const torch::lazy::Value& other);
117+
118+
xla::Shape LtTensorOutputShape(const torch::lazy::Value& self,
119+
const torch::lazy::Value& other);
120+
109121
xla::Shape LogdetOutputShape(const torch::lazy::Value& input);
110122

111123
xla::Shape LogicalAndOutputShape(const torch::lazy::Value& input,

xla_native_functions.yaml

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -36,13 +36,17 @@ full_codegen:
3636
- hardswish_backward
3737
- inverse
3838
- isnan
39+
- le.Scalar
40+
- le.Tensor
3941
- logdet
4042
- logical_and
4143
- logical_not
4244
- logical_or
4345
- logical_xor
4446
- log_sigmoid_backward
4547
- log_sigmoid_forward
48+
- lt.Scalar
49+
- lt.Tensor
4650
- maximum
4751
- minimum
4852
- reciprocal
@@ -182,8 +186,6 @@ supported:
182186
- index_select
183187
- kl_div
184188
- kthvalue
185-
- le.Scalar
186-
- le.Tensor
187189
- leaky_relu
188190
- leaky_relu_backward
189191
- lerp.Scalar
@@ -194,8 +196,6 @@ supported:
194196
- log2
195197
- log10
196198
- logsumexp
197-
- lt.Scalar
198-
- lt.Tensor
199199
- masked_fill_.Scalar
200200
- masked_fill_.Tensor
201201
- masked_scatter_

0 commit comments

Comments
 (0)