Skip to content

Commit f8f5dc4

Browse files
alanwaketanManfeiBai
authored andcommitted
Code-gen addcdiv again (#4447)
Summary: This pull request redoes the addcdiv and addcmul code-gen, and adds a test case to verify that if we reuse the DataCache for scalars. This needs pytorch/pytorch#92066 to function. Test Plan: PJRT_DEVICE=CPU python test/test_operations.py -v -k test_cached_addcdiv Fixes #4213.
1 parent e99c295 commit f8f5dc4

7 files changed

+70
-58
lines changed

test/test_operations.py

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1893,6 +1893,23 @@ def test_sigmoid_bounds(self):
18931893
assert torch.all(lower_bound >= 0.0)
18941894
assert torch.all(upper_bound <= 1.0)
18951895

1896+
def test_cached_addcdiv(self):
1897+
xla_device = xm.xla_device()
1898+
met.clear_all()
1899+
1900+
t1 = torch.randn(1, 3).to(xla_device)
1901+
t2 = torch.randn(1, 3).to(xla_device)
1902+
t3 = torch.randn(1, 3).to(xla_device)
1903+
t1.addcdiv_(t2, t3, value=0.1)
1904+
xm.mark_step()
1905+
self.assertEqual(met.metric_data("TransferToServerTime")[0], 4)
1906+
1907+
# The following two scalars shouldn't trigger TransferToServerTime.
1908+
t1.addcdiv_(t2, t3, value=0.1)
1909+
t1.addcdiv_(t2, t3, value=0.1)
1910+
xm.mark_step()
1911+
self.assertEqual(met.metric_data("TransferToServerTime")[0], 4)
1912+
18961913

18971914
class MNISTComparator(nn.Module):
18981915

torch_xla/csrc/aten_xla_type.cpp

Lines changed: 0 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -596,37 +596,6 @@ at::Tensor XLANativeFunctions::add(const at::Tensor& self,
596596
});
597597
}
598598

599-
at::Tensor XLANativeFunctions::addcdiv(const at::Tensor& self,
600-
const at::Tensor& tensor1,
601-
const at::Tensor& tensor2,
602-
const at::Scalar& value) {
603-
TORCH_LAZY_FN_COUNTER("xla::");
604-
return bridge::AtenFromXlaTensor(tensor_methods::addcdiv(
605-
bridge::GetXlaTensor(self), value, bridge::GetXlaTensor(tensor1),
606-
bridge::GetXlaTensor(tensor2)));
607-
}
608-
609-
at::Tensor& XLANativeFunctions::addcdiv_(at::Tensor& self,
610-
const at::Tensor& tensor1,
611-
const at::Tensor& tensor2,
612-
const at::Scalar& value) {
613-
TORCH_LAZY_FN_COUNTER("xla::");
614-
XLATensorPtr self_tensor = bridge::GetXlaTensor(self);
615-
tensor_methods::addcdiv_(self_tensor, value, bridge::GetXlaTensor(tensor1),
616-
bridge::GetXlaTensor(tensor2));
617-
return self;
618-
}
619-
620-
at::Tensor XLANativeFunctions::addcmul(const at::Tensor& self,
621-
const at::Tensor& tensor1,
622-
const at::Tensor& tensor2,
623-
const at::Scalar& value) {
624-
TORCH_LAZY_FN_COUNTER("xla::");
625-
return bridge::AtenFromXlaTensor(tensor_methods::addcmul(
626-
bridge::GetXlaTensor(self), value, bridge::GetXlaTensor(tensor1),
627-
bridge::GetXlaTensor(tensor2)));
628-
}
629-
630599
at::Tensor XLANativeFunctions::addmm(const at::Tensor& self,
631600
const at::Tensor& mat1,
632601
const at::Tensor& mat2,

torch_xla/csrc/ops/ops_lower_fn.cpp

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -56,6 +56,22 @@ torch_xla::XlaOpVector AdaptiveAvgPool3dBackward::Lower(
5656
return ReturnOp(xla_output, loctx);
5757
}
5858

59+
torch_xla::XlaOpVector Addcdiv::Lower(LoweringContext* loctx) const {
60+
xla::XlaOp xla_input = loctx->GetOutputOp(operand(0));
61+
xla::XlaOp xla_t1 = loctx->GetOutputOp(operand(1));
62+
xla::XlaOp xla_t2 = loctx->GetOutputOp(operand(2));
63+
xla::XlaOp xla_val = loctx->GetOutputOp(operand(3));
64+
return ReturnOp(BuildAddcdiv(xla_input, xla_t1, xla_t2, xla_val), loctx);
65+
}
66+
67+
torch_xla::XlaOpVector Addcmul::Lower(LoweringContext* loctx) const {
68+
xla::XlaOp xla_input = loctx->GetOutputOp(operand(0));
69+
xla::XlaOp xla_t1 = loctx->GetOutputOp(operand(1));
70+
xla::XlaOp xla_t2 = loctx->GetOutputOp(operand(2));
71+
xla::XlaOp xla_val = loctx->GetOutputOp(operand(3));
72+
return ReturnOp(BuildAddcmul(xla_input, xla_t1, xla_t2, xla_val), loctx);
73+
}
74+
5975
torch_xla::XlaOpVector All::Lower(LoweringContext* loctx) const {
6076
xla::XlaOp input = loctx->GetOutputOp(operand(0));
6177
std::vector<int64_t> dimensions =

torch_xla/csrc/ops/ops_xla_shape_fn.cpp

Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -97,6 +97,31 @@ xla::Shape AdaptiveAvgPool3dBackwardOutputShape(
9797
lower_for_shape_fn);
9898
}
9999

100+
xla::Shape AddcdivOutputShape(const torch::lazy::Value& input,
101+
const torch::lazy::Value& t1,
102+
const torch::lazy::Value& t2,
103+
const torch::lazy::Value& value) {
104+
auto shape_fn = [](absl::Span<const xla::XlaOp> operands) -> xla::XlaOp {
105+
return BuildAddcdiv(operands[0], operands[1], operands[2], operands[3]);
106+
};
107+
return InferOutputShape({GetXlaShape(input), GetXlaShape(t1), GetXlaShape(t2),
108+
GetXlaShape(value)},
109+
shape_fn);
110+
}
111+
112+
xla::Shape AddcmulOutputShape(const torch::lazy::Value& input,
113+
const torch::lazy::Value& t1,
114+
const torch::lazy::Value& t2,
115+
const torch::lazy::Value& value) {
116+
auto shape_fn = [](absl::Span<const xla::XlaOp> operands) -> xla::XlaOp {
117+
return BuildAddcmul(operands[0], operands[1], operands[2], operands[3]);
118+
};
119+
120+
return InferOutputShape({GetXlaShape(input), GetXlaShape(t1), GetXlaShape(t2),
121+
GetXlaShape(value)},
122+
shape_fn);
123+
}
124+
100125
xla::Shape AllOutputShape(const torch::lazy::Value& input) {
101126
std::vector<int64_t> dimensions =
102127
torch::lazy::Iota<int64_t>(GetXlaShape(input).rank());

torch_xla/csrc/ops/ops_xla_shape_fn.h

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,16 @@ xla::Shape AdaptiveAvgPool3dOutputShape(const torch::lazy::Value& input,
2121
xla::Shape AdaptiveAvgPool3dBackwardOutputShape(
2222
const torch::lazy::Value& grad_output, const torch::lazy::Value& input);
2323

24+
xla::Shape AddcdivOutputShape(const torch::lazy::Value& input,
25+
const torch::lazy::Value& t1,
26+
const torch::lazy::Value& t2,
27+
const torch::lazy::Value& value);
28+
29+
xla::Shape AddcmulOutputShape(const torch::lazy::Value& input,
30+
const torch::lazy::Value& t1,
31+
const torch::lazy::Value& t2,
32+
const torch::lazy::Value& value);
33+
2434
xla::Shape AllOutputShape(const torch::lazy::Value& input);
2535

2636
xla::Shape AllDimOutputShape(const torch::lazy::Value& input, const int64_t dim,

torch_xla/csrc/tensor_methods.cpp

Lines changed: 0 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -671,30 +671,6 @@ XLATensorPtr add(const XLATensorPtr& input, const at::Scalar& other,
671671
logical_element_type);
672672
}
673673

674-
XLATensorPtr addcdiv(const XLATensorPtr& input, const at::Scalar& value,
675-
const XLATensorPtr& tensor1, const XLATensorPtr& tensor2) {
676-
torch::lazy::Value constant = XLAGraphExecutor::Get()->GetIrValueForScalar(
677-
value, tensor1->shape().get().element_type(), input->GetDevice());
678-
torch::lazy::Value div = tensor1->GetIrValue() / tensor2->GetIrValue();
679-
return input->CreateFrom(input->GetIrValue() + div * constant);
680-
}
681-
682-
void addcdiv_(XLATensorPtr& input, const at::Scalar& value,
683-
const XLATensorPtr& tensor1, const XLATensorPtr& tensor2) {
684-
torch::lazy::Value constant = XLAGraphExecutor::Get()->GetIrValueForScalar(
685-
value, tensor1->shape().get().element_type(), input->GetDevice());
686-
torch::lazy::Value div = tensor1->GetIrValue() / tensor2->GetIrValue();
687-
input->SetInPlaceIrValue(input->GetIrValue() + div * constant);
688-
}
689-
690-
XLATensorPtr addcmul(const XLATensorPtr& input, const at::Scalar& value,
691-
const XLATensorPtr& tensor1, const XLATensorPtr& tensor2) {
692-
torch::lazy::Value constant = XLAGraphExecutor::Get()->GetIrValueForScalar(
693-
value, tensor1->shape().get().element_type(), input->GetDevice());
694-
torch::lazy::Value mul = tensor1->GetIrValue() * tensor2->GetIrValue();
695-
return input->CreateFrom(input->GetIrValue() + mul * constant);
696-
}
697-
698674
XLATensorPtr addmm(const XLATensorPtr& input, const XLATensorPtr& weight,
699675
const XLATensorPtr& bias) {
700676
return input->CreateFrom(AddMatMulOp(

xla_native_functions.yaml

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,8 @@ full_codegen:
66
- abs
77
- all
88
- all.dim
9+
- addcdiv
10+
- addcmul
911
- amax
1012
- amin
1113
- any
@@ -123,9 +125,6 @@ supported:
123125
- adaptive_max_pool2d_backward
124126
- add.Scalar
125127
- add.Tensor
126-
- addcdiv
127-
- addcdiv_
128-
- addcmul
129128
- addmm
130129
- alias
131130
- arange.start_out

0 commit comments

Comments
 (0)