Skip to content

Commit a4d3abb

Browse files
committed
Lower Lerp
1 parent 35625e7 commit a4d3abb

File tree

7 files changed

+227
-0
lines changed

7 files changed

+227
-0
lines changed

test/cpp/test_aten_xla_tensor.cpp

+109
Original file line numberDiff line numberDiff line change
@@ -10285,5 +10285,114 @@ TEST_F(AtenXlaTensorTest, TestEarlySyncLiveTensors) {
1028510285
cpp_test::GetIgnoredCounters());
1028610286
}
1028710287

10288+
TEST_F(AtenXlaTensorTest, TestLerp) {
10289+
torch::Tensor start =
10290+
torch::rand({3, 4}, torch::TensorOptions(torch::kFloat));
10291+
torch::Tensor end = torch::rand({3, 4}, torch::TensorOptions(torch::kFloat));
10292+
torch::Tensor weight =
10293+
torch::rand({3, 4}, torch::TensorOptions(torch::kFloat));
10294+
torch::Tensor res = torch::lerp(start, end, weight);
10295+
ForEachDevice([&](const torch::Device& device) {
10296+
torch::Tensor xla_start = CopyToDevice(start, device);
10297+
torch::Tensor xla_end = CopyToDevice(end, device);
10298+
torch::Tensor xla_weight = CopyToDevice(weight, device);
10299+
torch::Tensor xla_res = torch::lerp(xla_start, xla_end, xla_weight);
10300+
AllClose(res, xla_res);
10301+
});
10302+
ExpectCounterNotChanged("aten::.*", cpp_test::GetIgnoredCounters());
10303+
ExpectCounterChanged("xla::lerp", cpp_test::GetIgnoredCounters());
10304+
}
10305+
10306+
TEST_F(AtenXlaTensorTest, TestLerpScalar) {
10307+
torch::Tensor start =
10308+
torch::rand({3, 4}, torch::TensorOptions(torch::kFloat));
10309+
torch::Tensor end = torch::rand({3, 4}, torch::TensorOptions(torch::kFloat));
10310+
torch::Scalar weight = torch::Scalar(3.0);
10311+
torch::Tensor res = torch::lerp(start, end, weight);
10312+
ForEachDevice([&](const torch::Device& device) {
10313+
torch::Tensor xla_start = CopyToDevice(start, device);
10314+
torch::Tensor xla_end = CopyToDevice(end, device);
10315+
torch::Tensor xla_res = torch::lerp(xla_start, xla_end, weight);
10316+
AllClose(res, xla_res);
10317+
});
10318+
ExpectCounterNotChanged("aten::.*", cpp_test::GetIgnoredCounters());
10319+
ExpectCounterChanged("xla::lerp", cpp_test::GetIgnoredCounters());
10320+
}
10321+
10322+
TEST_F(AtenXlaTensorTest, TestLerpInplace) {
10323+
torch::Tensor input =
10324+
torch::rand({3, 4}, torch::TensorOptions(torch::kFloat));
10325+
torch::Tensor end = torch::rand({3, 4}, torch::TensorOptions(torch::kFloat));
10326+
torch::Tensor weight =
10327+
torch::rand({3, 4}, torch::TensorOptions(torch::kFloat));
10328+
torch::Tensor input_copy = input.clone();
10329+
input.lerp_(end, weight);
10330+
ForEachDevice([&](const torch::Device& device) {
10331+
torch::Tensor xla_input = CopyToDevice(input_copy, device);
10332+
torch::Tensor xla_end = CopyToDevice(end, device);
10333+
torch::Tensor xla_weight = CopyToDevice(weight, device);
10334+
xla_input.lerp_(xla_end, xla_weight);
10335+
AllClose(xla_input, input);
10336+
});
10337+
ExpectCounterNotChanged("aten::.*", cpp_test::GetIgnoredCounters());
10338+
ExpectCounterChanged("xla::lerp_", cpp_test::GetIgnoredCounters());
10339+
}
10340+
10341+
TEST_F(AtenXlaTensorTest, TestLerpScalarInplace) {
10342+
torch::Tensor input =
10343+
torch::rand({3, 4}, torch::TensorOptions(torch::kFloat));
10344+
torch::Tensor end = torch::rand({3, 4}, torch::TensorOptions(torch::kFloat));
10345+
torch::Scalar weight = torch::Scalar(3.0);
10346+
torch::Tensor input_copy = input.clone();
10347+
input.lerp_(end, weight);
10348+
ForEachDevice([&](const torch::Device& device) {
10349+
torch::Tensor xla_input = CopyToDevice(input_copy, device);
10350+
torch::Tensor xla_end = CopyToDevice(end, device);
10351+
xla_input.lerp_(xla_end, weight);
10352+
AllClose(xla_input, input);
10353+
});
10354+
ExpectCounterNotChanged("aten::.*", cpp_test::GetIgnoredCounters());
10355+
ExpectCounterChanged("xla::lerp_", cpp_test::GetIgnoredCounters());
10356+
}
10357+
10358+
TEST_F(AtenXlaTensorTest, TestLerpOut) {
10359+
torch::Tensor start =
10360+
torch::rand({3, 4}, torch::TensorOptions(torch::kFloat));
10361+
torch::Tensor end = torch::rand({3, 4}, torch::TensorOptions(torch::kFloat));
10362+
torch::Tensor weight =
10363+
torch::rand({3, 4}, torch::TensorOptions(torch::kFloat));
10364+
torch::Tensor res = torch::empty({3, 4}, torch::TensorOptions(torch::kFloat));
10365+
;
10366+
torch::lerp_out(res, start, end, weight);
10367+
ForEachDevice([&](const torch::Device& device) {
10368+
torch::Tensor xla_start = CopyToDevice(start, device);
10369+
torch::Tensor xla_end = CopyToDevice(end, device);
10370+
torch::Tensor xla_weight = CopyToDevice(weight, device);
10371+
torch::Tensor xla_res = torch::empty({3, 4}, xla_start.options());
10372+
torch::lerp_out(xla_res, xla_start, xla_end, xla_weight);
10373+
AllClose(res, xla_res);
10374+
});
10375+
ExpectCounterNotChanged("aten::.*", cpp_test::GetIgnoredCounters());
10376+
ExpectCounterChanged("xla::lerp_out", cpp_test::GetIgnoredCounters());
10377+
}
10378+
10379+
TEST_F(AtenXlaTensorTest, TestLerpScalarOut) {
10380+
torch::Tensor start =
10381+
torch::rand({3, 4}, torch::TensorOptions(torch::kFloat));
10382+
torch::Tensor end = torch::rand({3, 4}, torch::TensorOptions(torch::kFloat));
10383+
torch::Scalar weight = torch::Scalar(3.0);
10384+
torch::Tensor res = torch::empty({3, 4}, torch::TensorOptions(torch::kFloat));
10385+
torch::lerp_out(res, start, end, weight);
10386+
ForEachDevice([&](const torch::Device& device) {
10387+
torch::Tensor xla_start = CopyToDevice(start, device);
10388+
torch::Tensor xla_end = CopyToDevice(end, device);
10389+
torch::Tensor xla_res = torch::empty({3, 4}, xla_start.options());
10390+
torch::lerp_out(xla_res, xla_start, xla_end, weight);
10391+
AllClose(res, xla_res);
10392+
});
10393+
ExpectCounterNotChanged("aten::.*", cpp_test::GetIgnoredCounters());
10394+
ExpectCounterChanged("xla::lerp_out", cpp_test::GetIgnoredCounters());
10395+
}
10396+
1028810397
} // namespace cpp_test
1028910398
} // namespace torch_xla

torch_xla/csrc/aten_xla_type.cpp

+50
Original file line numberDiff line numberDiff line change
@@ -1519,6 +1519,56 @@ at::Tensor leaky_relu_backward(const at::Tensor& grad_output,
15191519
negative_slope.to<double>()));
15201520
}
15211521

1522+
at::Tensor lerp(const at::Tensor& self, const at::Tensor& end,
1523+
const at::Tensor& weight) {
1524+
XLA_FN_COUNTER("xla::");
1525+
return bridge::AtenFromXlaTensor(
1526+
XLATensor::lerp(bridge::GetXlaTensor(self), bridge::GetXlaTensor(end),
1527+
bridge::GetXlaTensor(weight)));
1528+
}
1529+
1530+
at::Tensor lerp(const at::Tensor& self, const at::Tensor& end,
1531+
const at::Scalar& weight) {
1532+
XLA_FN_COUNTER("xla::");
1533+
return bridge::AtenFromXlaTensor(XLATensor::lerp(
1534+
bridge::GetXlaTensor(self), bridge::GetXlaTensor(end), weight));
1535+
}
1536+
1537+
at::Tensor& lerp_(at::Tensor& self, const at::Tensor& end,
1538+
const at::Tensor& weight) {
1539+
XLA_FN_COUNTER("xla::");
1540+
XLATensor self_tensor = bridge::GetXlaTensor(self);
1541+
XLATensor::lerp_(self_tensor, bridge::GetXlaTensor(end),
1542+
bridge::GetXlaTensor(weight));
1543+
return self;
1544+
}
1545+
1546+
at::Tensor& lerp_(at::Tensor& self, const at::Tensor& end,
1547+
const at::Scalar& weight) {
1548+
XLA_FN_COUNTER("xla::");
1549+
XLATensor self_tensor = bridge::GetXlaTensor(self);
1550+
XLATensor::lerp_(self_tensor, bridge::GetXlaTensor(end), weight);
1551+
return self;
1552+
}
1553+
1554+
at::Tensor& lerp_out(const at::Tensor& self, const at::Tensor& end,
1555+
const at::Tensor& weight, at::Tensor& out) {
1556+
XLA_FN_COUNTER("xla::");
1557+
XLATensor out_tensor = bridge::GetXlaTensor(out);
1558+
XLATensor::lerp_out(out_tensor, bridge::GetXlaTensor(self),
1559+
bridge::GetXlaTensor(end), bridge::GetXlaTensor(weight));
1560+
return out;
1561+
}
1562+
1563+
at::Tensor& lerp_out(const at::Tensor& self, const at::Tensor& end,
1564+
const at::Scalar& weight, at::Tensor& out) {
1565+
XLA_FN_COUNTER("xla::");
1566+
XLATensor out_tensor = bridge::GetXlaTensor(out);
1567+
XLATensor::lerp_out(out_tensor, bridge::GetXlaTensor(self),
1568+
bridge::GetXlaTensor(end), weight);
1569+
return out;
1570+
}
1571+
15221572
at::Tensor log(const at::Tensor& self) {
15231573
XLA_FN_COUNTER("xla::");
15241574
return bridge::AtenFromXlaTensor(XLATensor::log(bridge::GetXlaTensor(self)));

torch_xla/csrc/ops/ops.cpp

+5
Original file line numberDiff line numberDiff line change
@@ -738,6 +738,11 @@ NodePtr BaddBmm(const Value& lhs, const Value& rhs, const Value& bias,
738738
std::move(lower_fn));
739739
}
740740

741+
NodePtr Lerp(const Value& start, const Value& end, const Value& weight) {
742+
ScopePusher ir_scope(at::aten::lerp.toQualString());
743+
return start + weight * (end - start);
744+
}
745+
741746
} // namespace ops
742747
} // namespace ir
743748
} // namespace torch_xla

torch_xla/csrc/ops/ops.h

+2
Original file line numberDiff line numberDiff line change
@@ -212,6 +212,8 @@ NodePtr IsNan(const Value& input);
212212
NodePtr BaddBmm(const Value& lhs, const Value& rhs, const Value& bias,
213213
const Value& product_multiplier, const Value& bias_multiplier);
214214

215+
NodePtr Lerp(const Value& start, const Value& end, const Value& weight);
216+
215217
} // namespace ops
216218
} // namespace ir
217219
} // namespace torch_xla

torch_xla/csrc/tensor.h

+13
Original file line numberDiff line numberDiff line change
@@ -651,6 +651,19 @@ class XLATensor {
651651
const XLATensor& input,
652652
double negative_slope);
653653

654+
static XLATensor lerp(const XLATensor& input, const XLATensor& end,
655+
const XLATensor& weight);
656+
static XLATensor lerp(const XLATensor& input, const XLATensor& end,
657+
const at::Scalar& weight);
658+
static void lerp_(XLATensor& input, const XLATensor& end,
659+
const XLATensor& weight);
660+
static void lerp_(XLATensor& input, const XLATensor& end,
661+
const at::Scalar& weight);
662+
static void lerp_out(XLATensor& out, const XLATensor& input,
663+
const XLATensor& end, const XLATensor& weight);
664+
static void lerp_out(XLATensor& out, const XLATensor& input,
665+
const XLATensor& end, const at::Scalar& weight);
666+
654667
static XLATensor log(const XLATensor& input);
655668

656669
static XLATensor log_base(const XLATensor& input, ir::OpKind op, double base);

torch_xla/csrc/tensor_methods.cpp

+42
Original file line numberDiff line numberDiff line change
@@ -1458,6 +1458,48 @@ XLATensor XLATensor::leaky_relu_backward(const XLATensor& grad_output,
14581458
grad_output.GetIrValue(), input.GetIrValue(), negative_slope));
14591459
}
14601460

1461+
XLATensor XLATensor::lerp(const XLATensor& input, const XLATensor& end,
1462+
const XLATensor& weight) {
1463+
return input.CreateFrom(
1464+
ir::ops::Lerp(input.GetIrValue(), end.GetIrValue(), weight.GetIrValue()));
1465+
}
1466+
1467+
XLATensor XLATensor::lerp(const XLATensor& input, const XLATensor& end,
1468+
const at::Scalar& weight) {
1469+
ir::Value weight_val = GetIrValueForScalar(
1470+
weight, input.shape().get().element_type(), input.GetDevice());
1471+
return input.CreateFrom(
1472+
ir::ops::Lerp(input.GetIrValue(), end.GetIrValue(), weight_val));
1473+
}
1474+
1475+
void XLATensor::lerp_(XLATensor& input, const XLATensor& end,
1476+
const XLATensor& weight) {
1477+
input.SetInPlaceIrValue(
1478+
ir::ops::Lerp(input.GetIrValue(), end.GetIrValue(), weight.GetIrValue()));
1479+
}
1480+
1481+
void XLATensor::lerp_(XLATensor& input, const XLATensor& end,
1482+
const at::Scalar& weight) {
1483+
ir::Value weight_val = GetIrValueForScalar(
1484+
weight, input.shape().get().element_type(), input.GetDevice());
1485+
input.SetInPlaceIrValue(
1486+
ir::ops::Lerp(input.GetIrValue(), end.GetIrValue(), weight_val));
1487+
}
1488+
1489+
void XLATensor::lerp_out(XLATensor& out, const XLATensor& input,
1490+
const XLATensor& end, const XLATensor& weight) {
1491+
out.SetInPlaceIrValue(
1492+
ir::ops::Lerp(input.GetIrValue(), end.GetIrValue(), weight.GetIrValue()));
1493+
}
1494+
1495+
void XLATensor::lerp_out(XLATensor& out, const XLATensor& input,
1496+
const XLATensor& end, const at::Scalar& weight) {
1497+
ir::Value weight_val = GetIrValueForScalar(
1498+
weight, input.shape().get().element_type(), input.GetDevice());
1499+
out.SetInPlaceIrValue(
1500+
ir::ops::Lerp(input.GetIrValue(), end.GetIrValue(), weight_val));
1501+
}
1502+
14611503
XLATensor XLATensor::log(const XLATensor& input) {
14621504
return input.CreateFrom(ir::ops::Log(input.GetIrValue()));
14631505
}

xla_native_functions.yaml

+6
Original file line numberDiff line numberDiff line change
@@ -307,6 +307,12 @@ supported:
307307
- sigmoid_backward
308308
- tanh_backward
309309
- ger
310+
- lerp_.Scalar
311+
- lerp_.Tensor
312+
- lerp.Scalar_out
313+
- lerp.Tensor_out
314+
- lerp.Scalar
315+
- lerp.Tensor
310316
autograd:
311317
- max_pool2d
312318
- max_pool3d

0 commit comments

Comments
 (0)