Skip to content

Commit 1f154ce

Browse files
authored
partially codegen adaptive_avgpool3d and backward (#3790)
* partially codegen adaptive_av_pool3d and backward * Delete .torch_pin
1 parent 477ca24 commit 1f154ce

11 files changed

+67
-128
lines changed

torch_xla/csrc/aten_xla_type.cpp

+23-4
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@
2121
#include "torch_xla/csrc/aten_xla_bridge.h"
2222
#include "torch_xla/csrc/debug_util.h"
2323
#include "torch_xla/csrc/device.h"
24+
#include "torch_xla/csrc/generated/LazyIr.h"
2425
#include "torch_xla/csrc/generated/XLANativeFunctions.h"
2526
#include "torch_xla/csrc/helpers.h"
2627
#include "torch_xla/csrc/ops/as_strided.h"
@@ -330,8 +331,17 @@ at::Tensor XLANativeFunctions::_adaptive_avg_pool3d(
330331
&xla_cpu_fallback, ATEN_OP(_adaptive_avg_pool3d)>::call(self,
331332
output_size);
332333
}
333-
return bridge::AtenFromXlaTensor(XLATensor::adaptive_avg_pool3d(
334-
bridge::GetXlaTensor(self), output_size_list));
334+
auto common_device = torch_xla::bridge::GetXlaDevice(self);
335+
XLA_CHECK(common_device);
336+
auto shapes =
337+
torch::lazy::compute_shape__adaptive_avg_pool3d(self, output_size);
338+
XLA_CHECK(shapes.size() == 1);
339+
torch::lazy::NodePtr node = torch::lazy::MakeNode<AdaptiveAvgPool3d>(
340+
bridge::GetXlaTensor(self)->GetIrValue(),
341+
std::vector<int64_t>(output_size.begin(), output_size.end()),
342+
std::move(shapes));
343+
return torch_xla::bridge::AtenFromXlaTensor(
344+
torch_xla::XLATensor::Create(std::move(node), *common_device));
335345
}
336346

337347
at::Tensor XLANativeFunctions::_adaptive_avg_pool3d_backward(
@@ -347,8 +357,17 @@ at::Tensor XLANativeFunctions::_adaptive_avg_pool3d_backward(
347357
&xla_cpu_fallback,
348358
ATEN_OP(_adaptive_avg_pool3d_backward)>::call(grad_output, self);
349359
}
350-
return bridge::AtenFromXlaTensor(XLATensor::adaptive_avg_pool3d_backward(
351-
bridge::GetXlaTensor(grad_output), bridge::GetXlaTensor(self)));
360+
auto common_device = torch_xla::bridge::GetXlaDevice(grad_output, self);
361+
XLA_CHECK(common_device);
362+
auto shapes = torch::lazy::compute_shape__adaptive_avg_pool3d_backward(
363+
grad_output, self);
364+
XLA_CHECK(shapes.size() == 1);
365+
torch::lazy::NodePtr node = torch::lazy::MakeNode<AdaptiveAvgPool3dBackward>(
366+
bridge::GetXlaTensor(grad_output)->GetIrValue(),
367+
bridge::GetXlaTensor(self)->GetIrValue(), std::move(shapes));
368+
369+
return torch_xla::bridge::AtenFromXlaTensor(
370+
torch_xla::XLATensor::Create(std::move(node), *common_device));
352371
}
353372

354373
at::Tensor XLANativeFunctions::_adaptive_avg_pool2d(

torch_xla/csrc/ops/adaptive_avg_pool3d.cpp

-48
This file was deleted.

torch_xla/csrc/ops/adaptive_avg_pool3d.h

-28
This file was deleted.

torch_xla/csrc/ops/ops.cpp

-26
Original file line numberDiff line numberDiff line change
@@ -355,32 +355,6 @@ torch::lazy::NodePtr AdaptiveMaxPool2dBackward(
355355
std::move(lower_fn));
356356
}
357357

358-
torch::lazy::NodePtr AdaptiveAvgPool3dBackward(
359-
const torch::lazy::Value& grad_output, const torch::lazy::Value& input) {
360-
auto lower_fn = [](const XlaNode& node,
361-
LoweringContext* loctx) -> XlaOpVector {
362-
xla::XlaOp grad_output = loctx->GetOutputOp(node.operand(0));
363-
xla::XlaOp input = loctx->GetOutputOp(node.operand(1));
364-
xla::XlaOp xla_output = BuildAdaptiveAvgPool3dBackward(
365-
/*out_backprop=*/grad_output, /*input=*/input);
366-
return node.ReturnOp(xla_output, loctx);
367-
};
368-
auto lower_for_shape_fn =
369-
[](absl::Span<const xla::XlaOp> operands) -> xla::XlaOp {
370-
XLA_CHECK_EQ(operands.size(), 2);
371-
return BuildAdaptiveAvgPool3dBackward(/*out_backprop=*/operands[0],
372-
/*input=*/operands[1]);
373-
};
374-
return GenericOp(torch::lazy::OpKind(at::aten::adaptive_avg_pool3d_backward),
375-
{grad_output, input},
376-
[&]() {
377-
return InferOutputShape(
378-
{GetXlaShape(grad_output), GetXlaShape(input)},
379-
lower_for_shape_fn);
380-
},
381-
std::move(lower_fn));
382-
}
383-
384358
torch::lazy::NodePtr ComparisonOp(c10::Symbol kind,
385359
const torch::lazy::Value& input,
386360
const torch::lazy::Value& other) {

torch_xla/csrc/ops/ops.h

-3
Original file line numberDiff line numberDiff line change
@@ -156,9 +156,6 @@ torch::lazy::NodePtr MatMul(const torch::lazy::Value& lhs,
156156
torch::lazy::NodePtr AdaptiveMaxPool2dBackward(
157157
const torch::lazy::Value& grad_output, const torch::lazy::Value& input);
158158

159-
torch::lazy::NodePtr AdaptiveAvgPool3dBackward(
160-
const torch::lazy::Value& grad_output, const torch::lazy::Value& input);
161-
162159
torch::lazy::NodePtr ComparisonOp(c10::Symbol kind,
163160
const torch::lazy::Value& input,
164161
const torch::lazy::Value& other);

torch_xla/csrc/ops/ops_lower_fn.cpp

+14
Original file line numberDiff line numberDiff line change
@@ -39,6 +39,20 @@ torch_xla::XlaOpVector AdaptiveAvgPool2dBackward::Lower(
3939
loctx);
4040
}
4141

42+
torch_xla::XlaOpVector AdaptiveAvgPool3d::Lower(LoweringContext* loctx) const {
43+
xla::XlaOp input = loctx->GetOutputOp(operand(0));
44+
return ReturnOp(BuildAdaptiveAvgPool3d(input, output_size), loctx);
45+
}
46+
47+
torch_xla::XlaOpVector AdaptiveAvgPool3dBackward::Lower(
48+
LoweringContext* loctx) const {
49+
xla::XlaOp grad_output = loctx->GetOutputOp(operand(0));
50+
xla::XlaOp input = loctx->GetOutputOp(operand(1));
51+
xla::XlaOp xla_output = BuildAdaptiveAvgPool3dBackward(
52+
/*out_backprop=*/grad_output, /*input=*/input);
53+
return ReturnOp(xla_output, loctx);
54+
}
55+
4256
torch_xla::XlaOpVector Asin::Lower(LoweringContext* loctx) const {
4357
xla::XlaOp xla_input = loctx->GetOutputOp(operand(0));
4458
return ReturnOp(xla::Asin(xla_input), loctx);

torch_xla/csrc/ops/ops_xla_shape_fn.cpp

+22
Original file line numberDiff line numberDiff line change
@@ -58,6 +58,28 @@ xla::Shape AdaptiveAvgPool2dBackwardOutputShape(
5858
lower_for_shape_fn);
5959
}
6060

61+
xla::Shape AdaptiveAvgPool3dOutputShape(const torch::lazy::Value& input,
62+
absl::Span<const int64_t> output_size) {
63+
auto lower_for_shape_fn =
64+
[output_size](absl::Span<const xla::XlaOp> operands) -> xla::XlaOp {
65+
XLA_CHECK_EQ(operands.size(), 1);
66+
return BuildAdaptiveAvgPool3d(operands[0], output_size);
67+
};
68+
return InferOutputShape({GetXlaShape(input)}, lower_for_shape_fn);
69+
}
70+
71+
xla::Shape AdaptiveAvgPool3dBackwardOutputShape(
72+
const torch::lazy::Value& grad_output, const torch::lazy::Value& input) {
73+
auto lower_for_shape_fn =
74+
[](absl::Span<const xla::XlaOp> operands) -> xla::XlaOp {
75+
XLA_CHECK_EQ(operands.size(), 2);
76+
return BuildAdaptiveAvgPool3dBackward(/*out_backprop=*/operands[0],
77+
/*input=*/operands[1]);
78+
};
79+
return InferOutputShape({GetXlaShape(grad_output), GetXlaShape(input)},
80+
lower_for_shape_fn);
81+
}
82+
6183
xla::Shape AsinOutputShape(const torch::lazy::Value& input) {
6284
return GetXlaShape(input);
6385
}

torch_xla/csrc/ops/ops_xla_shape_fn.h

+6
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,12 @@ xla::Shape AdaptiveAvgPool2dOutputShape(const torch::lazy::Value& input,
1515
xla::Shape AdaptiveAvgPool2dBackwardOutputShape(
1616
const torch::lazy::Value& grad_output, const torch::lazy::Value& input);
1717

18+
xla::Shape AdaptiveAvgPool3dOutputShape(const torch::lazy::Value& input,
19+
absl::Span<const int64_t> output_size);
20+
21+
xla::Shape AdaptiveAvgPool3dBackwardOutputShape(
22+
const torch::lazy::Value& grad_output, const torch::lazy::Value& input);
23+
1824
xla::Shape AsinOutputShape(const torch::lazy::Value& input);
1925

2026
xla::Shape AsinhOutputShape(const torch::lazy::Value& input);

torch_xla/csrc/tensor.h

-6
Original file line numberDiff line numberDiff line change
@@ -328,12 +328,6 @@ class XLATensor : public c10::intrusive_ptr_target {
328328
static XLATensorPtr adaptive_max_pool2d_backward(
329329
const XLATensorPtr& grad_output, const XLATensorPtr& input);
330330

331-
static XLATensorPtr adaptive_avg_pool3d(const XLATensorPtr& input,
332-
std::vector<int64_t> output_size);
333-
334-
static XLATensorPtr adaptive_avg_pool3d_backward(
335-
const XLATensorPtr& grad_output, const XLATensorPtr& input);
336-
337331
static XLATensorPtr _adaptive_avg_pool2d(
338332
const XLATensorPtr& input, std::vector<int64_t> output_size,
339333
std::vector<torch::lazy::Shape>&& shapes);

torch_xla/csrc/tensor_methods.cpp

-13
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,6 @@
2121
#include "torch_xla/csrc/layout_manager.h"
2222
#include "torch_xla/csrc/lowering_context.h"
2323
#include "torch_xla/csrc/ops/adam_optimizer_step.h"
24-
#include "torch_xla/csrc/ops/adaptive_avg_pool3d.h"
2524
#include "torch_xla/csrc/ops/adaptive_max_pool2d.h"
2625
#include "torch_xla/csrc/ops/all.h"
2726
#include "torch_xla/csrc/ops/all_gather.h"
@@ -591,18 +590,6 @@ XLATensorPtr XLATensor::adaptive_max_pool2d_backward(
591590
input->GetIrValue()));
592591
}
593592

594-
XLATensorPtr XLATensor::adaptive_avg_pool3d(const XLATensorPtr& input,
595-
std::vector<int64_t> output_size) {
596-
return input->CreateFrom(torch::lazy::MakeNode<AdaptiveAvgPool3d>(
597-
input->GetIrValue(), std::move(output_size)));
598-
}
599-
600-
XLATensorPtr XLATensor::adaptive_avg_pool3d_backward(
601-
const XLATensorPtr& grad_output, const XLATensorPtr& input) {
602-
return input->CreateFrom(AdaptiveAvgPool3dBackward(grad_output->GetIrValue(),
603-
input->GetIrValue()));
604-
}
605-
606593
XLATensorPtr XLATensor::_adaptive_avg_pool2d(
607594
const XLATensorPtr& input, std::vector<int64_t> output_size,
608595
std::vector<torch::lazy::Shape>&& shapes) {

xla_native_functions.yaml

+2
Original file line numberDiff line numberDiff line change
@@ -53,6 +53,8 @@ full_codegen:
5353
ir_gen:
5454
- _adaptive_avg_pool2d
5555
- _adaptive_avg_pool2d_backward
56+
- _adaptive_avg_pool3d
57+
- _adaptive_avg_pool3d_backward
5658
supported:
5759
- __ilshift__.Scalar
5860
- __ilshift__.Tensor

0 commit comments

Comments
 (0)