21
21
#include " torch_xla/csrc/aten_xla_bridge.h"
22
22
#include " torch_xla/csrc/debug_util.h"
23
23
#include " torch_xla/csrc/device.h"
24
+ #include " torch_xla/csrc/generated/LazyIr.h"
24
25
#include " torch_xla/csrc/generated/XLANativeFunctions.h"
25
26
#include " torch_xla/csrc/helpers.h"
26
27
#include " torch_xla/csrc/ops/as_strided.h"
@@ -334,8 +335,17 @@ at::Tensor XLANativeFunctions::_adaptive_avg_pool3d(
334
335
&xla_cpu_fallback, ATEN_OP (_adaptive_avg_pool3d)>::call (self,
335
336
output_size);
336
337
}
337
- return bridge::AtenFromXlaTensor (XLATensor::adaptive_avg_pool3d (
338
- bridge::GetXlaTensor (self), output_size_list));
338
+ auto common_device = torch_xla::bridge::GetXlaDevice (self);
339
+ XLA_CHECK (common_device);
340
+ auto shapes =
341
+ torch::lazy::compute_shape__adaptive_avg_pool3d (self, output_size);
342
+ XLA_CHECK (shapes.size () == 1 );
343
+ torch::lazy::NodePtr node = torch::lazy::MakeNode<AdaptiveAvgPool3d>(
344
+ bridge::GetXlaTensor (self)->GetIrValue (),
345
+ std::vector<int64_t >(output_size.begin (), output_size.end ()),
346
+ std::move (shapes));
347
+ return torch_xla::bridge::AtenFromXlaTensor (
348
+ torch_xla::XLATensor::Create (std::move (node), *common_device));
339
349
}
340
350
341
351
at::Tensor XLANativeFunctions::_adaptive_avg_pool3d_backward (
@@ -351,8 +361,17 @@ at::Tensor XLANativeFunctions::_adaptive_avg_pool3d_backward(
351
361
&xla_cpu_fallback,
352
362
ATEN_OP (_adaptive_avg_pool3d_backward)>::call (grad_output, self);
353
363
}
354
- return bridge::AtenFromXlaTensor (XLATensor::adaptive_avg_pool3d_backward (
355
- bridge::GetXlaTensor (grad_output), bridge::GetXlaTensor (self)));
364
+ auto common_device = torch_xla::bridge::GetXlaDevice (grad_output, self);
365
+ XLA_CHECK (common_device);
366
+ auto shapes = torch::lazy::compute_shape__adaptive_avg_pool3d_backward (
367
+ grad_output, self);
368
+ XLA_CHECK (shapes.size () == 1 );
369
+ torch::lazy::NodePtr node = torch::lazy::MakeNode<AdaptiveAvgPool3dBackward>(
370
+ bridge::GetXlaTensor (grad_output)->GetIrValue (),
371
+ bridge::GetXlaTensor (self)->GetIrValue (), std::move (shapes));
372
+
373
+ return torch_xla::bridge::AtenFromXlaTensor (
374
+ torch_xla::XLATensor::Create (std::move (node), *common_device));
356
375
}
357
376
358
377
at::Tensor XLANativeFunctions::_adaptive_avg_pool2d (
0 commit comments