@@ -10285,5 +10285,114 @@ TEST_F(AtenXlaTensorTest, TestEarlySyncLiveTensors) {
10285
10285
cpp_test::GetIgnoredCounters());
10286
10286
}
10287
10287
10288
+ TEST_F(AtenXlaTensorTest, TestLerp) {
10289
+ torch::Tensor start =
10290
+ torch::rand({3, 4}, torch::TensorOptions(torch::kFloat));
10291
+ torch::Tensor end = torch::rand({3, 4}, torch::TensorOptions(torch::kFloat));
10292
+ torch::Tensor weight =
10293
+ torch::rand({3, 4}, torch::TensorOptions(torch::kFloat));
10294
+ torch::Tensor res = torch::lerp(start, end, weight);
10295
+ ForEachDevice([&](const torch::Device& device) {
10296
+ torch::Tensor xla_start = CopyToDevice(start, device);
10297
+ torch::Tensor xla_end = CopyToDevice(end, device);
10298
+ torch::Tensor xla_weight = CopyToDevice(weight, device);
10299
+ torch::Tensor xla_res = torch::lerp(xla_start, xla_end, xla_weight);
10300
+ AllClose(res, xla_res);
10301
+ });
10302
+ ExpectCounterNotChanged("aten::.*", cpp_test::GetIgnoredCounters());
10303
+ ExpectCounterChanged("xla::lerp", cpp_test::GetIgnoredCounters());
10304
+ }
10305
+
10306
+ TEST_F(AtenXlaTensorTest, TestLerpScalar) {
10307
+ torch::Tensor start =
10308
+ torch::rand({3, 4}, torch::TensorOptions(torch::kFloat));
10309
+ torch::Tensor end = torch::rand({3, 4}, torch::TensorOptions(torch::kFloat));
10310
+ torch::Scalar weight = torch::Scalar(3.0);
10311
+ torch::Tensor res = torch::lerp(start, end, weight);
10312
+ ForEachDevice([&](const torch::Device& device) {
10313
+ torch::Tensor xla_start = CopyToDevice(start, device);
10314
+ torch::Tensor xla_end = CopyToDevice(end, device);
10315
+ torch::Tensor xla_res = torch::lerp(xla_start, xla_end, weight);
10316
+ AllClose(res, xla_res);
10317
+ });
10318
+ ExpectCounterNotChanged("aten::.*", cpp_test::GetIgnoredCounters());
10319
+ ExpectCounterChanged("xla::lerp", cpp_test::GetIgnoredCounters());
10320
+ }
10321
+
10322
+ TEST_F(AtenXlaTensorTest, TestLerpInplace) {
10323
+ torch::Tensor input =
10324
+ torch::rand({3, 4}, torch::TensorOptions(torch::kFloat));
10325
+ torch::Tensor end = torch::rand({3, 4}, torch::TensorOptions(torch::kFloat));
10326
+ torch::Tensor weight =
10327
+ torch::rand({3, 4}, torch::TensorOptions(torch::kFloat));
10328
+ torch::Tensor input_copy = input.clone();
10329
+ input.lerp_(end, weight);
10330
+ ForEachDevice([&](const torch::Device& device) {
10331
+ torch::Tensor xla_input = CopyToDevice(input_copy, device);
10332
+ torch::Tensor xla_end = CopyToDevice(end, device);
10333
+ torch::Tensor xla_weight = CopyToDevice(weight, device);
10334
+ xla_input.lerp_(xla_end, xla_weight);
10335
+ AllClose(xla_input, input);
10336
+ });
10337
+ ExpectCounterNotChanged("aten::.*", cpp_test::GetIgnoredCounters());
10338
+ ExpectCounterChanged("xla::lerp_", cpp_test::GetIgnoredCounters());
10339
+ }
10340
+
10341
+ TEST_F(AtenXlaTensorTest, TestLerpScalarInplace) {
10342
+ torch::Tensor input =
10343
+ torch::rand({3, 4}, torch::TensorOptions(torch::kFloat));
10344
+ torch::Tensor end = torch::rand({3, 4}, torch::TensorOptions(torch::kFloat));
10345
+ torch::Scalar weight = torch::Scalar(3.0);
10346
+ torch::Tensor input_copy = input.clone();
10347
+ input.lerp_(end, weight);
10348
+ ForEachDevice([&](const torch::Device& device) {
10349
+ torch::Tensor xla_input = CopyToDevice(input_copy, device);
10350
+ torch::Tensor xla_end = CopyToDevice(end, device);
10351
+ xla_input.lerp_(xla_end, weight);
10352
+ AllClose(xla_input, input);
10353
+ });
10354
+ ExpectCounterNotChanged("aten::.*", cpp_test::GetIgnoredCounters());
10355
+ ExpectCounterChanged("xla::lerp_", cpp_test::GetIgnoredCounters());
10356
+ }
10357
+
10358
+ TEST_F(AtenXlaTensorTest, TestLerpOut) {
10359
+ torch::Tensor start =
10360
+ torch::rand({3, 4}, torch::TensorOptions(torch::kFloat));
10361
+ torch::Tensor end = torch::rand({3, 4}, torch::TensorOptions(torch::kFloat));
10362
+ torch::Tensor weight =
10363
+ torch::rand({3, 4}, torch::TensorOptions(torch::kFloat));
10364
+ torch::Tensor res = torch::empty({3, 4}, torch::TensorOptions(torch::kFloat));
10365
+ ;
10366
+ torch::lerp_out(res, start, end, weight);
10367
+ ForEachDevice([&](const torch::Device& device) {
10368
+ torch::Tensor xla_start = CopyToDevice(start, device);
10369
+ torch::Tensor xla_end = CopyToDevice(end, device);
10370
+ torch::Tensor xla_weight = CopyToDevice(weight, device);
10371
+ torch::Tensor xla_res = torch::empty({3, 4}, xla_start.options());
10372
+ torch::lerp_out(xla_res, xla_start, xla_end, xla_weight);
10373
+ AllClose(res, xla_res);
10374
+ });
10375
+ ExpectCounterNotChanged("aten::.*", cpp_test::GetIgnoredCounters());
10376
+ ExpectCounterChanged("xla::lerp_out", cpp_test::GetIgnoredCounters());
10377
+ }
10378
+
10379
+ TEST_F(AtenXlaTensorTest, TestLerpScalarOut) {
10380
+ torch::Tensor start =
10381
+ torch::rand({3, 4}, torch::TensorOptions(torch::kFloat));
10382
+ torch::Tensor end = torch::rand({3, 4}, torch::TensorOptions(torch::kFloat));
10383
+ torch::Scalar weight = torch::Scalar(3.0);
10384
+ torch::Tensor res = torch::empty({3, 4}, torch::TensorOptions(torch::kFloat));
10385
+ torch::lerp_out(res, start, end, weight);
10386
+ ForEachDevice([&](const torch::Device& device) {
10387
+ torch::Tensor xla_start = CopyToDevice(start, device);
10388
+ torch::Tensor xla_end = CopyToDevice(end, device);
10389
+ torch::Tensor xla_res = torch::empty({3, 4}, xla_start.options());
10390
+ torch::lerp_out(xla_res, xla_start, xla_end, weight);
10391
+ AllClose(res, xla_res);
10392
+ });
10393
+ ExpectCounterNotChanged("aten::.*", cpp_test::GetIgnoredCounters());
10394
+ ExpectCounterChanged("xla::lerp_out", cpp_test::GetIgnoredCounters());
10395
+ }
10396
+
10288
10397
} // namespace cpp_test
10289
10398
} // namespace torch_xla
0 commit comments