Skip to content

Commit 05fa2aa

Browse files
authored
Full codegen erf, erfc, erfinv, and exp (#3659)
1 parent 105f077 commit 05fa2aa

7 files changed

+48
-43
lines changed

torch_xla/csrc/aten_xla_type.cpp

-21
Original file line numberDiff line numberDiff line change
@@ -1322,27 +1322,6 @@ at::Tensor XLANativeFunctions::eq(const at::Tensor& self,
13221322
XLATensor::eq(bridge::GetXlaTensor(self), bridge::GetXlaTensor(other)));
13231323
}
13241324

1325-
at::Tensor XLANativeFunctions::erf(const at::Tensor& self) {
1326-
XLA_FN_COUNTER("xla::");
1327-
return bridge::AtenFromXlaTensor(XLATensor::erf(bridge::GetXlaTensor(self)));
1328-
}
1329-
1330-
at::Tensor XLANativeFunctions::erfc(const at::Tensor& self) {
1331-
XLA_FN_COUNTER("xla::");
1332-
return bridge::AtenFromXlaTensor(XLATensor::erfc(bridge::GetXlaTensor(self)));
1333-
}
1334-
1335-
at::Tensor XLANativeFunctions::erfinv(const at::Tensor& self) {
1336-
XLA_FN_COUNTER("xla::");
1337-
return bridge::AtenFromXlaTensor(
1338-
XLATensor::erfinv(bridge::GetXlaTensor(self)));
1339-
}
1340-
1341-
at::Tensor XLANativeFunctions::exp(const at::Tensor& self) {
1342-
XLA_FN_COUNTER("xla::");
1343-
return bridge::AtenFromXlaTensor(XLATensor::exp(bridge::GetXlaTensor(self)));
1344-
}
1345-
13461325
at::Tensor XLANativeFunctions::expand(const at::Tensor& self,
13471326
at::IntArrayRef size, bool implicit) {
13481327
XLA_FN_COUNTER("xla::");

torch_xla/csrc/ops/ops_lower_fn.cpp

+20
Original file line numberDiff line numberDiff line change
@@ -53,6 +53,26 @@ torch_xla::XlaOpVector Cosh::Lower(LoweringContext* loctx) const {
5353
return ReturnOp(xla::Cosh(xla_input), loctx);
5454
}
5555

56+
torch_xla::XlaOpVector Erf::Lower(LoweringContext* loctx) const {
57+
xla::XlaOp xla_input = loctx->GetOutputOp(operand(0));
58+
return ReturnOp(xla::Erf(xla_input), loctx);
59+
}
60+
61+
torch_xla::XlaOpVector Erfc::Lower(LoweringContext* loctx) const {
62+
xla::XlaOp xla_input = loctx->GetOutputOp(operand(0));
63+
return ReturnOp(xla::Erfc(xla_input), loctx);
64+
}
65+
66+
torch_xla::XlaOpVector Erfinv::Lower(LoweringContext* loctx) const {
67+
xla::XlaOp xla_input = loctx->GetOutputOp(operand(0));
68+
return ReturnOp(xla::ErfInv(xla_input), loctx);
69+
}
70+
71+
torch_xla::XlaOpVector Exp::Lower(LoweringContext* loctx) const {
72+
xla::XlaOp xla_input = loctx->GetOutputOp(operand(0));
73+
return ReturnOp(xla::Exp(xla_input), loctx);
74+
}
75+
5676
torch_xla::XlaOpVector Floor::Lower(LoweringContext* loctx) const {
5777
xla::XlaOp xla_input = loctx->GetOutputOp(operand(0));
5878
return ReturnOp(xla::Floor(xla_input), loctx);

torch_xla/csrc/ops/ops_xla_shape_fn.cpp

+16
Original file line numberDiff line numberDiff line change
@@ -41,6 +41,22 @@ xla::Shape CoshOutputShape(const torch::lazy::Value& input) {
4141
return GetXlaShape(input);
4242
}
4343

44+
xla::Shape ErfOutputShape(const torch::lazy::Value& input) {
45+
return GetXlaShape(input);
46+
}
47+
48+
xla::Shape ErfcOutputShape(const torch::lazy::Value& input) {
49+
return GetXlaShape(input);
50+
}
51+
52+
xla::Shape ErfinvOutputShape(const torch::lazy::Value& input) {
53+
return GetXlaShape(input);
54+
}
55+
56+
xla::Shape ExpOutputShape(const torch::lazy::Value& input) {
57+
return GetXlaShape(input);
58+
}
59+
4460
xla::Shape FloorOutputShape(const torch::lazy::Value& input) {
4561
return GetXlaShape(input);
4662
}

torch_xla/csrc/ops/ops_xla_shape_fn.h

+8
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,14 @@ xla::Shape CosOutputShape(const torch::lazy::Value& input);
2121

2222
xla::Shape CoshOutputShape(const torch::lazy::Value& input);
2323

24+
xla::Shape ErfOutputShape(const torch::lazy::Value& input);
25+
26+
xla::Shape ErfcOutputShape(const torch::lazy::Value& input);
27+
28+
xla::Shape ErfinvOutputShape(const torch::lazy::Value& input);
29+
30+
xla::Shape ExpOutputShape(const torch::lazy::Value& input);
31+
2432
xla::Shape FloorOutputShape(const torch::lazy::Value& input);
2533

2634
xla::Shape InverseOutputShape(const torch::lazy::Value& input);

torch_xla/csrc/tensor.h

-6
Original file line numberDiff line numberDiff line change
@@ -555,12 +555,6 @@ class XLATensor : public c10::intrusive_ptr_target {
555555

556556
static XLATensor eq(const XLATensor& input, const XLATensor& other);
557557

558-
static XLATensor erf(const XLATensor& input);
559-
560-
static XLATensor erfc(const XLATensor& input);
561-
562-
static XLATensor erfinv(const XLATensor& input);
563-
564558
static XLATensor exp(const XLATensor& input);
565559

566560
static XLATensor expand(const XLATensor& input, std::vector<int64_t> size);

torch_xla/csrc/tensor_methods.cpp

-12
Original file line numberDiff line numberDiff line change
@@ -1271,18 +1271,6 @@ XLATensor XLATensor::embedding_dense_backward(const XLATensor& grad_output,
12711271
padding_idx, scale_grad_by_freq);
12721272
}
12731273

1274-
XLATensor XLATensor::erf(const XLATensor& input) {
1275-
return input.CreateFrom(Erf(input.GetIrValue()));
1276-
}
1277-
1278-
XLATensor XLATensor::erfc(const XLATensor& input) {
1279-
return input.CreateFrom(Erfc(input.GetIrValue()));
1280-
}
1281-
1282-
XLATensor XLATensor::erfinv(const XLATensor& input) {
1283-
return input.CreateFrom(Erfinv(input.GetIrValue()));
1284-
}
1285-
12861274
XLATensor XLATensor::exp(const XLATensor& input) {
12871275
return input.CreateFrom(Exp(input.GetIrValue()));
12881276
}

xla_native_functions.yaml

+4-4
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,10 @@ full_codegen:
1010
- atanh
1111
- cos
1212
- cosh
13+
- erf
14+
- erfc
15+
- erfinv
16+
- exp
1317
- floor
1418
- inverse
1519
- logdet
@@ -121,10 +125,6 @@ supported:
121125
- empty_strided
122126
- eq.Scalar
123127
- eq.Tensor
124-
- erf
125-
- erfc
126-
- erfinv
127-
- exp
128128
- expand
129129
- expm1
130130
- exponential_

0 commit comments

Comments
 (0)