Skip to content

Commit c175cbd

Browse files
committed
Use promoteAdd/Div/Mul
1 parent 71ba7c3 commit c175cbd

File tree

4 files changed

+25
-4
lines changed

4 files changed

+25
-4
lines changed

torch_xla/csrc/ops/ops_lower_fn.cpp

+3-2
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77
#include "torch_xla/csrc/matrix.h"
88
#include "torch_xla/csrc/pooling.h"
99
#include "torch_xla/csrc/reduction.h"
10+
#include "torch_xla/csrc/xla_lower_util.h"
1011

1112
namespace torch_xla {
1213

@@ -63,7 +64,7 @@ torch_xla::XlaOpVector Addcdiv::Lower(LoweringContext* loctx) const {
6364
xla::XlaOp xla_t1 = loctx->GetOutputOp(operand(1));
6465
xla::XlaOp xla_t2 = loctx->GetOutputOp(operand(2));
6566
xla::XlaOp xla_val = loctx->GetOutputOp(operand(3));
66-
return ReturnOp(xla_input + (xla_t1 / xla_t2) * xla_val, loctx);
67+
return ReturnOp(BuildAddcdiv(xla_input, xla_t1, xla_t2, xla_val), loctx);
6768
}
6869

6970
torch_xla::XlaOpVector Addcmul::Lower(LoweringContext* loctx) const {
@@ -75,7 +76,7 @@ torch_xla::XlaOpVector Addcmul::Lower(LoweringContext* loctx) const {
7576
xla::XlaOp xla_t1 = loctx->GetOutputOp(operand(1));
7677
xla::XlaOp xla_t2 = loctx->GetOutputOp(operand(2));
7778
xla::XlaOp xla_val = loctx->GetOutputOp(operand(3));
78-
return ReturnOp(xla_input + (xla_t1 * xla_t2) * xla_val, loctx);
79+
return ReturnOp(BuildAddcmul(xla_input, xla_t1, xla_t2, xla_val), loctx);
7980
}
8081

8182
torch_xla::XlaOpVector Asin::Lower(LoweringContext* loctx) const {

torch_xla/csrc/ops/ops_xla_shape_fn.cpp

+4-2
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77
#include "torch_xla/csrc/helpers.h"
88
#include "torch_xla/csrc/pooling.h"
99
#include "torch_xla/csrc/reduction.h"
10+
#include "torch_xla/csrc/xla_lower_util.h"
1011

1112
namespace torch_xla {
1213
namespace {
@@ -85,7 +86,7 @@ xla::Shape AddcdivOutputShape(const torch::lazy::Value& input,
8586
const torch::lazy::Value& t2,
8687
const torch::lazy::Value& value) {
8788
auto shape_fn = [](absl::Span<const xla::XlaOp> operands) -> xla::XlaOp {
88-
return operands[0] + (operands[1] / operands[2]) * operands[3];
89+
return BuildAddcdiv(operands[0], operands[1], operands[2], operands[3]);
8990
};
9091
return InferOutputShape({GetXlaShape(input), GetXlaShape(t1), GetXlaShape(t2),
9192
GetXlaShape(value)},
@@ -97,8 +98,9 @@ xla::Shape AddcmulOutputShape(const torch::lazy::Value& input,
9798
const torch::lazy::Value& t2,
9899
const torch::lazy::Value& value) {
99100
auto shape_fn = [](absl::Span<const xla::XlaOp> operands) -> xla::XlaOp {
100-
return operands[0] + (operands[1] * operands[2]) * operands[3];
101+
return BuildAddcmul(operands[0], operands[1], operands[2], operands[3]);
101102
};
103+
102104
return InferOutputShape({GetXlaShape(input), GetXlaShape(t1), GetXlaShape(t2),
103105
GetXlaShape(value)},
104106
shape_fn);

torch_xla/csrc/xla_lower_util.cpp

+12
Original file line numberDiff line numberDiff line change
@@ -1018,4 +1018,16 @@ xla::XlaOp BuildRoll(xla::XlaOp input, absl::Span<const int64_t> shifts,
10181018
return need_flatten ? xla::Reshape(input, input_shape.dimensions()) : input;
10191019
}
10201020

1021+
xla::XlaOp BuildAddcdiv(xla::XlaOp input, xla::XlaOp t1, xla::XlaOp t2,
1022+
xla::XlaOp val) {
1023+
return XlaHelpers::PromotedAdd(
1024+
input, XlaHelpers::PromotedMul(XlaHelpers::PromotedDiv(t1, t2), val));
1025+
}
1026+
1027+
xla::XlaOp BuildAddcmul(xla::XlaOp input, xla::XlaOp t1, xla::XlaOp t2,
1028+
xla::XlaOp val) {
1029+
return XlaHelpers::PromotedAdd(
1030+
input, XlaHelpers::PromotedMul(XlaHelpers::PromotedMul(t1, t2), val));
1031+
}
1032+
10211033
} // namespace torch_xla

torch_xla/csrc/xla_lower_util.h

+6
Original file line numberDiff line numberDiff line change
@@ -121,4 +121,10 @@ xla::XlaOp BuildXLogY(xla::XlaOp input, xla::XlaOp other);
121121
xla::XlaOp BuildRoll(xla::XlaOp input, absl::Span<const int64_t> shifts,
122122
absl::Span<const int64_t> dims);
123123

124+
xla::XlaOp BuildAddcdiv(xla::XlaOp input, xla::XlaOp t1, xla::XlaOp t2,
125+
xla::XlaOp val);
126+
127+
xla::XlaOp BuildAddcmul(xla::XlaOp input, xla::XlaOp t1, xla::XlaOp t2,
128+
xla::XlaOp val);
129+
124130
} // namespace torch_xla

0 commit comments

Comments
 (0)