Skip to content

Commit 6bfcd24

Browse files
authored
Codegen addcdiv and addcmul (#3768)
* Codegen addcdiv and addcmul * pin * Use promoteAdd/Div/Mul * remove comment * Convert scalar to the right type * Delete .torch_pin
1 parent 6639bcc commit 6bfcd24

9 files changed

+75
-64
lines changed

scripts/gen_lazy_tensor.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -50,7 +50,7 @@ def node_base_ctor_call(self, schema: LazyIrSchema) -> str:
5050
base_ctor_value_args = ", ".join(base_ctor_value_args_list)
5151

5252
shape_fn_inputs_list = [
53-
f"{a.name}" for a in schema.positional_args
53+
f"{a.name}" for a in (schema.positional_args + schema.keyword_args)
5454
if (a.is_lazy_value or isinstance(a.lazy_type, VectorCType) or
5555
is_boolean_dtype(a.lazy_type) or a.name == 'reduction')
5656
]

torch_xla/csrc/aten_xla_type.cpp

-31
Original file line numberDiff line numberDiff line change
@@ -612,37 +612,6 @@ at::Tensor XLANativeFunctions::add(const at::Tensor& self,
612612
});
613613
}
614614

615-
at::Tensor XLANativeFunctions::addcdiv(const at::Tensor& self,
616-
const at::Tensor& tensor1,
617-
const at::Tensor& tensor2,
618-
const at::Scalar& value) {
619-
XLA_FN_COUNTER("xla::");
620-
return bridge::AtenFromXlaTensor(XLATensor::addcdiv(
621-
bridge::GetXlaTensor(self), value, bridge::GetXlaTensor(tensor1),
622-
bridge::GetXlaTensor(tensor2)));
623-
}
624-
625-
at::Tensor& XLANativeFunctions::addcdiv_(at::Tensor& self,
626-
const at::Tensor& tensor1,
627-
const at::Tensor& tensor2,
628-
const at::Scalar& value) {
629-
XLA_FN_COUNTER("xla::");
630-
XLATensorPtr self_tensor = bridge::GetXlaTensor(self);
631-
XLATensor::addcdiv_(self_tensor, value, bridge::GetXlaTensor(tensor1),
632-
bridge::GetXlaTensor(tensor2));
633-
return self;
634-
}
635-
636-
at::Tensor XLANativeFunctions::addcmul(const at::Tensor& self,
637-
const at::Tensor& tensor1,
638-
const at::Tensor& tensor2,
639-
const at::Scalar& value) {
640-
XLA_FN_COUNTER("xla::");
641-
return bridge::AtenFromXlaTensor(XLATensor::addcmul(
642-
bridge::GetXlaTensor(self), value, bridge::GetXlaTensor(tensor1),
643-
bridge::GetXlaTensor(tensor2)));
644-
}
645-
646615
at::Tensor XLANativeFunctions::addmm(const at::Tensor& self,
647616
const at::Tensor& mat1,
648617
const at::Tensor& mat2,

torch_xla/csrc/ops/ops_lower_fn.cpp

+17
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77
#include "torch_xla/csrc/matrix.h"
88
#include "torch_xla/csrc/pooling.h"
99
#include "torch_xla/csrc/reduction.h"
10+
#include "torch_xla/csrc/xla_lower_util.h"
1011

1112
namespace torch_xla {
1213
torch_xla::XlaOpVector Abs::Lower(LoweringContext* loctx) const {
@@ -69,6 +70,22 @@ torch_xla::XlaOpVector Amin::Lower(LoweringContext* loctx) const {
6970
return ReturnOp(BuildMinInDims(input, dim, keepdim), loctx);
7071
}
7172

73+
torch_xla::XlaOpVector Addcdiv::Lower(LoweringContext* loctx) const {
74+
xla::XlaOp xla_input = loctx->GetOutputOp(operand(0));
75+
xla::XlaOp xla_t1 = loctx->GetOutputOp(operand(1));
76+
xla::XlaOp xla_t2 = loctx->GetOutputOp(operand(2));
77+
xla::XlaOp xla_val = loctx->GetOutputOp(operand(3));
78+
return ReturnOp(BuildAddcdiv(xla_input, xla_t1, xla_t2, xla_val), loctx);
79+
}
80+
81+
torch_xla::XlaOpVector Addcmul::Lower(LoweringContext* loctx) const {
82+
xla::XlaOp xla_input = loctx->GetOutputOp(operand(0));
83+
xla::XlaOp xla_t1 = loctx->GetOutputOp(operand(1));
84+
xla::XlaOp xla_t2 = loctx->GetOutputOp(operand(2));
85+
xla::XlaOp xla_val = loctx->GetOutputOp(operand(3));
86+
return ReturnOp(BuildAddcmul(xla_input, xla_t1, xla_t2, xla_val), loctx);
87+
}
88+
7289
torch_xla::XlaOpVector Asin::Lower(LoweringContext* loctx) const {
7390
xla::XlaOp xla_input = loctx->GetOutputOp(operand(0));
7491
return ReturnOp(xla::Asin(xla_input), loctx);

torch_xla/csrc/ops/ops_xla_shape_fn.cpp

+26
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77
#include "torch_xla/csrc/helpers.h"
88
#include "torch_xla/csrc/pooling.h"
99
#include "torch_xla/csrc/reduction.h"
10+
#include "torch_xla/csrc/xla_lower_util.h"
1011

1112
namespace torch_xla {
1213
namespace {
@@ -109,6 +110,31 @@ xla::Shape AllOutputShape(const torch::lazy::Value& input) {
109110
return InferOutputShape({GetXlaShape(input)}, lower_for_shape_fn);
110111
}
111112

113+
xla::Shape AddcdivOutputShape(const torch::lazy::Value& input,
114+
const torch::lazy::Value& t1,
115+
const torch::lazy::Value& t2,
116+
const torch::lazy::Value& value) {
117+
auto shape_fn = [](absl::Span<const xla::XlaOp> operands) -> xla::XlaOp {
118+
return BuildAddcdiv(operands[0], operands[1], operands[2], operands[3]);
119+
};
120+
return InferOutputShape({GetXlaShape(input), GetXlaShape(t1), GetXlaShape(t2),
121+
GetXlaShape(value)},
122+
shape_fn);
123+
}
124+
125+
xla::Shape AddcmulOutputShape(const torch::lazy::Value& input,
126+
const torch::lazy::Value& t1,
127+
const torch::lazy::Value& t2,
128+
const torch::lazy::Value& value) {
129+
auto shape_fn = [](absl::Span<const xla::XlaOp> operands) -> xla::XlaOp {
130+
return BuildAddcmul(operands[0], operands[1], operands[2], operands[3]);
131+
};
132+
133+
return InferOutputShape({GetXlaShape(input), GetXlaShape(t1), GetXlaShape(t2),
134+
GetXlaShape(value)},
135+
shape_fn);
136+
}
137+
112138
xla::Shape AsinOutputShape(const torch::lazy::Value& input) {
113139
return GetXlaShape(input);
114140
}

torch_xla/csrc/ops/ops_xla_shape_fn.h

+10
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,16 @@ xla::Shape AminOutputShape(const torch::lazy::Value& input,
2929

3030
xla::Shape AllOutputShape(const torch::lazy::Value& input);
3131

32+
xla::Shape AddcdivOutputShape(const torch::lazy::Value& input,
33+
const torch::lazy::Value& t1,
34+
const torch::lazy::Value& t2,
35+
const torch::lazy::Value& value);
36+
37+
xla::Shape AddcmulOutputShape(const torch::lazy::Value& input,
38+
const torch::lazy::Value& t1,
39+
const torch::lazy::Value& t2,
40+
const torch::lazy::Value& value);
41+
3242
xla::Shape AsinOutputShape(const torch::lazy::Value& input);
3343

3444
xla::Shape AsinhOutputShape(const torch::lazy::Value& input);

torch_xla/csrc/tensor_methods.cpp

-29
Original file line numberDiff line numberDiff line change
@@ -660,35 +660,6 @@ XLATensorPtr XLATensor::add(
660660
logical_element_type);
661661
}
662662

663-
XLATensorPtr XLATensor::addcdiv(const XLATensorPtr& input,
664-
const at::Scalar& value,
665-
const XLATensorPtr& tensor1,
666-
const XLATensorPtr& tensor2) {
667-
torch::lazy::Value constant = GetIrValueForScalar(
668-
value, tensor1->shape().get().element_type(), input->GetDevice());
669-
torch::lazy::Value div = tensor1->GetIrValue() / tensor2->GetIrValue();
670-
return input->CreateFrom(input->GetIrValue() + div * constant);
671-
}
672-
673-
void XLATensor::addcdiv_(XLATensorPtr& input, const at::Scalar& value,
674-
const XLATensorPtr& tensor1,
675-
const XLATensorPtr& tensor2) {
676-
torch::lazy::Value constant = GetIrValueForScalar(
677-
value, tensor1->shape().get().element_type(), input->GetDevice());
678-
torch::lazy::Value div = tensor1->GetIrValue() / tensor2->GetIrValue();
679-
input->SetInPlaceIrValue(input->GetIrValue() + div * constant);
680-
}
681-
682-
XLATensorPtr XLATensor::addcmul(const XLATensorPtr& input,
683-
const at::Scalar& value,
684-
const XLATensorPtr& tensor1,
685-
const XLATensorPtr& tensor2) {
686-
torch::lazy::Value constant = GetIrValueForScalar(
687-
value, tensor1->shape().get().element_type(), input->GetDevice());
688-
torch::lazy::Value mul = tensor1->GetIrValue() * tensor2->GetIrValue();
689-
return input->CreateFrom(input->GetIrValue() + mul * constant);
690-
}
691-
692663
XLATensorPtr XLATensor::addmm(const XLATensorPtr& input,
693664
const XLATensorPtr& weight,
694665
const XLATensorPtr& bias) {

torch_xla/csrc/xla_lower_util.cpp

+13
Original file line numberDiff line numberDiff line change
@@ -1012,4 +1012,17 @@ xla::XlaOp BuildRoll(xla::XlaOp input, absl::Span<const int64_t> shifts,
10121012
return need_flatten ? xla::Reshape(input, input_shape.dimensions()) : input;
10131013
}
10141014

1015+
xla::XlaOp BuildAddcdiv(xla::XlaOp input, xla::XlaOp t1, xla::XlaOp t2,
1016+
xla::XlaOp val) {
1017+
return XlaHelpers::PromotedAdd(
1018+
input, XlaHelpers::PromotedMul(XlaHelpers::PromotedDiv(t1, t2), val));
1019+
}
1020+
1021+
xla::XlaOp BuildAddcmul(xla::XlaOp input, xla::XlaOp t1, xla::XlaOp t2,
1022+
xla::XlaOp val) {
1023+
val = MaybeConvertTo(val, XlaHelpers::ShapeOfXlaOp(t1).element_type());
1024+
return XlaHelpers::PromotedAdd(
1025+
input, XlaHelpers::PromotedMul(XlaHelpers::PromotedMul(t1, t2), val));
1026+
}
1027+
10151028
} // namespace torch_xla

torch_xla/csrc/xla_lower_util.h

+6
Original file line numberDiff line numberDiff line change
@@ -119,4 +119,10 @@ xla::XlaOp BuildXLogY(xla::XlaOp input, xla::XlaOp other);
119119
xla::XlaOp BuildRoll(xla::XlaOp input, absl::Span<const int64_t> shifts,
120120
absl::Span<const int64_t> dims);
121121

122+
xla::XlaOp BuildAddcdiv(xla::XlaOp input, xla::XlaOp t1, xla::XlaOp t2,
123+
xla::XlaOp val);
124+
125+
xla::XlaOp BuildAddcmul(xla::XlaOp input, xla::XlaOp t1, xla::XlaOp t2,
126+
xla::XlaOp val);
127+
122128
} // namespace torch_xla

xla_native_functions.yaml

+2-3
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,8 @@ full_codegen:
77
- all
88
- amax
99
- amin
10+
- addcdiv
11+
- addcmul
1012
- asin
1113
- asinh
1214
- atan
@@ -95,9 +97,6 @@ supported:
9597
- adaptive_max_pool2d_backward
9698
- add.Scalar
9799
- add.Tensor
98-
- addcdiv
99-
- addcdiv_
100-
- addcmul
101100
- addmm
102101
- alias
103102
- all.dim

0 commit comments

Comments
 (0)