21
21
#include " torch_xla/csrc/aten_xla_bridge.h"
22
22
#include " torch_xla/csrc/debug_util.h"
23
23
#include " torch_xla/csrc/device.h"
24
+ #include " torch_xla/csrc/generated/LazyIr.h"
24
25
#include " torch_xla/csrc/generated/XLANativeFunctions.h"
25
26
#include " torch_xla/csrc/helpers.h"
26
27
#include " torch_xla/csrc/ops/as_strided.h"
@@ -330,8 +331,17 @@ at::Tensor XLANativeFunctions::_adaptive_avg_pool3d(
330
331
&xla_cpu_fallback, ATEN_OP (_adaptive_avg_pool3d)>::call (self,
331
332
output_size);
332
333
}
333
- return bridge::AtenFromXlaTensor (XLATensor::adaptive_avg_pool3d (
334
- bridge::GetXlaTensor (self), output_size_list));
334
+ auto common_device = torch_xla::bridge::GetXlaDevice (self);
335
+ XLA_CHECK (common_device);
336
+ auto shapes =
337
+ torch::lazy::compute_shape__adaptive_avg_pool3d (self, output_size);
338
+ XLA_CHECK (shapes.size () == 1 );
339
+ torch::lazy::NodePtr node = torch::lazy::MakeNode<AdaptiveAvgPool3d>(
340
+ bridge::GetXlaTensor (self)->GetIrValue (),
341
+ std::vector<int64_t >(output_size.begin (), output_size.end ()),
342
+ std::move (shapes));
343
+ return torch_xla::bridge::AtenFromXlaTensor (
344
+ torch_xla::XLATensor::Create (std::move (node), *common_device));
335
345
}
336
346
337
347
at::Tensor XLANativeFunctions::_adaptive_avg_pool3d_backward (
@@ -347,8 +357,17 @@ at::Tensor XLANativeFunctions::_adaptive_avg_pool3d_backward(
347
357
&xla_cpu_fallback,
348
358
ATEN_OP (_adaptive_avg_pool3d_backward)>::call (grad_output, self);
349
359
}
350
- return bridge::AtenFromXlaTensor (XLATensor::adaptive_avg_pool3d_backward (
351
- bridge::GetXlaTensor (grad_output), bridge::GetXlaTensor (self)));
360
+ auto common_device = torch_xla::bridge::GetXlaDevice (grad_output, self);
361
+ XLA_CHECK (common_device);
362
+ auto shapes = torch::lazy::compute_shape__adaptive_avg_pool3d_backward (
363
+ grad_output, self);
364
+ XLA_CHECK (shapes.size () == 1 );
365
+ torch::lazy::NodePtr node = torch::lazy::MakeNode<AdaptiveAvgPool3dBackward>(
366
+ bridge::GetXlaTensor (grad_output)->GetIrValue (),
367
+ bridge::GetXlaTensor (self)->GetIrValue (), std::move (shapes));
368
+
369
+ return torch_xla::bridge::AtenFromXlaTensor (
370
+ torch_xla::XLATensor::Create (std::move (node), *common_device));
352
371
}
353
372
354
373
at::Tensor XLANativeFunctions::_adaptive_avg_pool2d (
0 commit comments