Skip to content

Commit 51e6043

Browse files
committed
Use promoteAdd/Div/Mul
1 parent f4bdee0 commit 51e6043

File tree

4 files changed

+25
-4
lines changed

4 files changed

+25
-4
lines changed

torch_xla/csrc/ops/ops_lower_fn.cpp

+3-2
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77
#include "torch_xla/csrc/matrix.h"
88
#include "torch_xla/csrc/pooling.h"
99
#include "torch_xla/csrc/reduction.h"
10+
#include "torch_xla/csrc/xla_lower_util.h"
1011

1112
namespace torch_xla {
1213
torch_xla::XlaOpVector Abs::Lower(LoweringContext* loctx) const {
@@ -79,7 +80,7 @@ torch_xla::XlaOpVector Addcdiv::Lower(LoweringContext* loctx) const {
7980
xla::XlaOp xla_t1 = loctx->GetOutputOp(operand(1));
8081
xla::XlaOp xla_t2 = loctx->GetOutputOp(operand(2));
8182
xla::XlaOp xla_val = loctx->GetOutputOp(operand(3));
82-
return ReturnOp(xla_input + (xla_t1 / xla_t2) * xla_val, loctx);
83+
return ReturnOp(BuildAddcdiv(xla_input, xla_t1, xla_t2, xla_val), loctx);
8384
}
8485

8586
torch_xla::XlaOpVector Addcmul::Lower(LoweringContext* loctx) const {
@@ -91,7 +92,7 @@ torch_xla::XlaOpVector Addcmul::Lower(LoweringContext* loctx) const {
9192
xla::XlaOp xla_t1 = loctx->GetOutputOp(operand(1));
9293
xla::XlaOp xla_t2 = loctx->GetOutputOp(operand(2));
9394
xla::XlaOp xla_val = loctx->GetOutputOp(operand(3));
94-
return ReturnOp(xla_input + (xla_t1 * xla_t2) * xla_val, loctx);
95+
return ReturnOp(BuildAddcmul(xla_input, xla_t1, xla_t2, xla_val), loctx);
9596
}
9697

9798
torch_xla::XlaOpVector Asin::Lower(LoweringContext* loctx) const {

torch_xla/csrc/ops/ops_xla_shape_fn.cpp

+4-2
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77
#include "torch_xla/csrc/helpers.h"
88
#include "torch_xla/csrc/pooling.h"
99
#include "torch_xla/csrc/reduction.h"
10+
#include "torch_xla/csrc/xla_lower_util.h"
1011

1112
namespace torch_xla {
1213
namespace {
@@ -114,7 +115,7 @@ xla::Shape AddcdivOutputShape(const torch::lazy::Value& input,
114115
const torch::lazy::Value& t2,
115116
const torch::lazy::Value& value) {
116117
auto shape_fn = [](absl::Span<const xla::XlaOp> operands) -> xla::XlaOp {
117-
return operands[0] + (operands[1] / operands[2]) * operands[3];
118+
return BuildAddcdiv(operands[0], operands[1], operands[2], operands[3]);
118119
};
119120
return InferOutputShape({GetXlaShape(input), GetXlaShape(t1), GetXlaShape(t2),
120121
GetXlaShape(value)},
@@ -126,8 +127,9 @@ xla::Shape AddcmulOutputShape(const torch::lazy::Value& input,
126127
const torch::lazy::Value& t2,
127128
const torch::lazy::Value& value) {
128129
auto shape_fn = [](absl::Span<const xla::XlaOp> operands) -> xla::XlaOp {
129-
return operands[0] + (operands[1] * operands[2]) * operands[3];
130+
return BuildAddcmul(operands[0], operands[1], operands[2], operands[3]);
130131
};
132+
131133
return InferOutputShape({GetXlaShape(input), GetXlaShape(t1), GetXlaShape(t2),
132134
GetXlaShape(value)},
133135
shape_fn);

torch_xla/csrc/xla_lower_util.cpp

+12
Original file line numberDiff line numberDiff line change
@@ -1012,4 +1012,16 @@ xla::XlaOp BuildRoll(xla::XlaOp input, absl::Span<const int64_t> shifts,
10121012
return need_flatten ? xla::Reshape(input, input_shape.dimensions()) : input;
10131013
}
10141014

1015+
xla::XlaOp BuildAddcdiv(xla::XlaOp input, xla::XlaOp t1, xla::XlaOp t2,
1016+
xla::XlaOp val) {
1017+
return XlaHelpers::PromotedAdd(
1018+
input, XlaHelpers::PromotedMul(XlaHelpers::PromotedDiv(t1, t2), val));
1019+
}
1020+
1021+
xla::XlaOp BuildAddcmul(xla::XlaOp input, xla::XlaOp t1, xla::XlaOp t2,
1022+
xla::XlaOp val) {
1023+
return XlaHelpers::PromotedAdd(
1024+
input, XlaHelpers::PromotedMul(XlaHelpers::PromotedMul(t1, t2), val));
1025+
}
1026+
10151027
} // namespace torch_xla

torch_xla/csrc/xla_lower_util.h

+6
Original file line numberDiff line numberDiff line change
@@ -119,4 +119,10 @@ xla::XlaOp BuildXLogY(xla::XlaOp input, xla::XlaOp other);
119119
xla::XlaOp BuildRoll(xla::XlaOp input, absl::Span<const int64_t> shifts,
120120
absl::Span<const int64_t> dims);
121121

122+
xla::XlaOp BuildAddcdiv(xla::XlaOp input, xla::XlaOp t1, xla::XlaOp t2,
123+
xla::XlaOp val);
124+
125+
xla::XlaOp BuildAddcmul(xla::XlaOp input, xla::XlaOp t1, xla::XlaOp t2,
126+
xla::XlaOp val);
127+
122128
} // namespace torch_xla

0 commit comments

Comments
 (0)