Skip to content

Commit 204d9e5

Browse files
committed
Codegen addcdiv and addcmul
1 parent 1f154ce commit 204d9e5

7 files changed

+63
-64
lines changed

scripts/gen_lazy_tensor.py

+2-1
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77
from torchgen.api.types import (
88
BaseCType,
99
OptionalCType,
10+
scalarT,
1011
VectorCType,
1112
kernel_signature,
1213
)
@@ -45,7 +46,7 @@ def node_base_ctor_call(self, schema: LazyIrSchema) -> str:
4546
base_ctor_value_args = ", ".join(base_ctor_value_args_list)
4647

4748
shape_fn_inputs_list = [
48-
f"{a.name}" for a in schema.positional_args
49+
f"{a.name}" for a in (schema.positional_args + schema.keyword_args)
4950
if (a.is_lazy_value or isinstance(a.lazy_type, VectorCType) or
5051
a.name == 'reduction')
5152
]

torch_xla/csrc/aten_xla_type.cpp

-31
Original file line numberDiff line numberDiff line change
@@ -612,37 +612,6 @@ at::Tensor XLANativeFunctions::add(const at::Tensor& self,
612612
});
613613
}
614614

615-
at::Tensor XLANativeFunctions::addcdiv(const at::Tensor& self,
616-
const at::Tensor& tensor1,
617-
const at::Tensor& tensor2,
618-
const at::Scalar& value) {
619-
XLA_FN_COUNTER("xla::");
620-
return bridge::AtenFromXlaTensor(XLATensor::addcdiv(
621-
bridge::GetXlaTensor(self), value, bridge::GetXlaTensor(tensor1),
622-
bridge::GetXlaTensor(tensor2)));
623-
}
624-
625-
at::Tensor& XLANativeFunctions::addcdiv_(at::Tensor& self,
626-
const at::Tensor& tensor1,
627-
const at::Tensor& tensor2,
628-
const at::Scalar& value) {
629-
XLA_FN_COUNTER("xla::");
630-
XLATensorPtr self_tensor = bridge::GetXlaTensor(self);
631-
XLATensor::addcdiv_(self_tensor, value, bridge::GetXlaTensor(tensor1),
632-
bridge::GetXlaTensor(tensor2));
633-
return self;
634-
}
635-
636-
at::Tensor XLANativeFunctions::addcmul(const at::Tensor& self,
637-
const at::Tensor& tensor1,
638-
const at::Tensor& tensor2,
639-
const at::Scalar& value) {
640-
XLA_FN_COUNTER("xla::");
641-
return bridge::AtenFromXlaTensor(XLATensor::addcmul(
642-
bridge::GetXlaTensor(self), value, bridge::GetXlaTensor(tensor1),
643-
bridge::GetXlaTensor(tensor2)));
644-
}
645-
646615
at::Tensor XLANativeFunctions::addmm(const at::Tensor& self,
647616
const at::Tensor& mat1,
648617
const at::Tensor& mat2,

torch_xla/csrc/ops/ops_lower_fn.cpp

+25
Original file line numberDiff line numberDiff line change
@@ -53,6 +53,31 @@ torch_xla::XlaOpVector AdaptiveAvgPool3dBackward::Lower(
5353
return ReturnOp(xla_output, loctx);
5454
}
5555

56+
torch_xla::XlaOpVector Addcdiv::Lower(LoweringContext* loctx) const {
57+
// xla::XlaOp xla_input = loctx->GetOutputOp(operand(0));
58+
// torch::lazy::Value constant = GetIrValueForScalar(
59+
// value, tensor1->shape().get().element_type(), input->GetDevice());
60+
// torch::lazy::Value div = tensor1->GetIrValue() / tensor2->GetIrValue();
61+
// return input->CreateFrom(input->GetIrValue() + div * constant);
62+
xla::XlaOp xla_input = loctx->GetOutputOp(operand(0));
63+
xla::XlaOp xla_t1 = loctx->GetOutputOp(operand(1));
64+
xla::XlaOp xla_t2 = loctx->GetOutputOp(operand(2));
65+
xla::XlaOp xla_val = loctx->GetOutputOp(operand(3));
66+
return ReturnOp(xla_input + (xla_t1 / xla_t2) * xla_val, loctx);
67+
}
68+
69+
torch_xla::XlaOpVector Addcmul::Lower(LoweringContext* loctx) const {
70+
// torch::lazy::Value constant = GetIrValueForScalar(
71+
// value, tensor1->shape().get().element_type(), input->GetDevice());
72+
// torch::lazy::Value mul = tensor1->GetIrValue() * tensor2->GetIrValue();
73+
// return input->CreateFrom(input->GetIrValue() + mul * constant);
74+
xla::XlaOp xla_input = loctx->GetOutputOp(operand(0));
75+
xla::XlaOp xla_t1 = loctx->GetOutputOp(operand(1));
76+
xla::XlaOp xla_t2 = loctx->GetOutputOp(operand(2));
77+
xla::XlaOp xla_val = loctx->GetOutputOp(operand(3));
78+
return ReturnOp(xla_input + (xla_t1 * xla_t2) * xla_val, loctx);
79+
}
80+
5681
torch_xla::XlaOpVector Asin::Lower(LoweringContext* loctx) const {
5782
xla::XlaOp xla_input = loctx->GetOutputOp(operand(0));
5883
return ReturnOp(xla::Asin(xla_input), loctx);

torch_xla/csrc/ops/ops_xla_shape_fn.cpp

+24
Original file line numberDiff line numberDiff line change
@@ -80,6 +80,30 @@ xla::Shape AdaptiveAvgPool3dBackwardOutputShape(
8080
lower_for_shape_fn);
8181
}
8282

83+
xla::Shape AddcdivOutputShape(const torch::lazy::Value& input,
84+
const torch::lazy::Value& t1,
85+
const torch::lazy::Value& t2,
86+
const torch::lazy::Value& value) {
87+
auto shape_fn = [](absl::Span<const xla::XlaOp> operands) -> xla::XlaOp {
88+
return operands[0] + (operands[1] / operands[2]) * operands[3];
89+
};
90+
return InferOutputShape({GetXlaShape(input), GetXlaShape(t1), GetXlaShape(t2),
91+
GetXlaShape(value)},
92+
shape_fn);
93+
}
94+
95+
xla::Shape AddcmulOutputShape(const torch::lazy::Value& input,
96+
const torch::lazy::Value& t1,
97+
const torch::lazy::Value& t2,
98+
const torch::lazy::Value& value) {
99+
auto shape_fn = [](absl::Span<const xla::XlaOp> operands) -> xla::XlaOp {
100+
return operands[0] + (operands[1] * operands[2]) * operands[3];
101+
};
102+
return InferOutputShape({GetXlaShape(input), GetXlaShape(t1), GetXlaShape(t2),
103+
GetXlaShape(value)},
104+
shape_fn);
105+
}
106+
83107
xla::Shape AsinOutputShape(const torch::lazy::Value& input) {
84108
return GetXlaShape(input);
85109
}

torch_xla/csrc/ops/ops_xla_shape_fn.h

+10
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,16 @@ xla::Shape AdaptiveAvgPool3dOutputShape(const torch::lazy::Value& input,
2121
xla::Shape AdaptiveAvgPool3dBackwardOutputShape(
2222
const torch::lazy::Value& grad_output, const torch::lazy::Value& input);
2323

24+
xla::Shape AddcdivOutputShape(const torch::lazy::Value& input,
25+
const torch::lazy::Value& t1,
26+
const torch::lazy::Value& t2,
27+
const torch::lazy::Value& value);
28+
29+
xla::Shape AddcmulOutputShape(const torch::lazy::Value& input,
30+
const torch::lazy::Value& t1,
31+
const torch::lazy::Value& t2,
32+
const torch::lazy::Value& value);
33+
2434
xla::Shape AsinOutputShape(const torch::lazy::Value& input);
2535

2636
xla::Shape AsinhOutputShape(const torch::lazy::Value& input);

torch_xla/csrc/tensor_methods.cpp

-29
Original file line numberDiff line numberDiff line change
@@ -662,35 +662,6 @@ XLATensorPtr XLATensor::add(
662662
logical_element_type);
663663
}
664664

665-
XLATensorPtr XLATensor::addcdiv(const XLATensorPtr& input,
666-
const at::Scalar& value,
667-
const XLATensorPtr& tensor1,
668-
const XLATensorPtr& tensor2) {
669-
torch::lazy::Value constant = GetIrValueForScalar(
670-
value, tensor1->shape().get().element_type(), input->GetDevice());
671-
torch::lazy::Value div = tensor1->GetIrValue() / tensor2->GetIrValue();
672-
return input->CreateFrom(input->GetIrValue() + div * constant);
673-
}
674-
675-
void XLATensor::addcdiv_(XLATensorPtr& input, const at::Scalar& value,
676-
const XLATensorPtr& tensor1,
677-
const XLATensorPtr& tensor2) {
678-
torch::lazy::Value constant = GetIrValueForScalar(
679-
value, tensor1->shape().get().element_type(), input->GetDevice());
680-
torch::lazy::Value div = tensor1->GetIrValue() / tensor2->GetIrValue();
681-
input->SetInPlaceIrValue(input->GetIrValue() + div * constant);
682-
}
683-
684-
XLATensorPtr XLATensor::addcmul(const XLATensorPtr& input,
685-
const at::Scalar& value,
686-
const XLATensorPtr& tensor1,
687-
const XLATensorPtr& tensor2) {
688-
torch::lazy::Value constant = GetIrValueForScalar(
689-
value, tensor1->shape().get().element_type(), input->GetDevice());
690-
torch::lazy::Value mul = tensor1->GetIrValue() * tensor2->GetIrValue();
691-
return input->CreateFrom(input->GetIrValue() + mul * constant);
692-
}
693-
694665
XLATensorPtr XLATensor::addmm(const XLATensorPtr& input,
695666
const XLATensorPtr& weight,
696667
const XLATensorPtr& bias) {

xla_native_functions.yaml

+2-3
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,8 @@ full_codegen:
44
- acos
55
- acosh
66
- abs
7+
- addcdiv
8+
- addcmul
79
- asin
810
- asinh
911
- atan
@@ -86,9 +88,6 @@ supported:
8688
- adaptive_max_pool2d_backward
8789
- add.Scalar
8890
- add.Tensor
89-
- addcdiv
90-
- addcdiv_
91-
- addcmul
9291
- addmm
9392
- alias
9493
- all

0 commit comments

Comments
 (0)