Skip to content

Commit 5d92dc7

Browse files
committed
partially codegen adaptive_av_pool3d and backward
1 parent 63a85d9 commit 5d92dc7

12 files changed

+68
-128
lines changed

torch_patches/.torch_pin

+1
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
#82297

torch_xla/csrc/aten_xla_type.cpp

+23-4
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@
2121
#include "torch_xla/csrc/aten_xla_bridge.h"
2222
#include "torch_xla/csrc/debug_util.h"
2323
#include "torch_xla/csrc/device.h"
24+
#include "torch_xla/csrc/generated/LazyIr.h"
2425
#include "torch_xla/csrc/generated/XLANativeFunctions.h"
2526
#include "torch_xla/csrc/helpers.h"
2627
#include "torch_xla/csrc/ops/as_strided.h"
@@ -334,8 +335,17 @@ at::Tensor XLANativeFunctions::_adaptive_avg_pool3d(
334335
&xla_cpu_fallback, ATEN_OP(_adaptive_avg_pool3d)>::call(self,
335336
output_size);
336337
}
337-
return bridge::AtenFromXlaTensor(XLATensor::adaptive_avg_pool3d(
338-
bridge::GetXlaTensor(self), output_size_list));
338+
auto common_device = torch_xla::bridge::GetXlaDevice(self);
339+
XLA_CHECK(common_device);
340+
auto shapes =
341+
torch::lazy::compute_shape__adaptive_avg_pool3d(self, output_size);
342+
XLA_CHECK(shapes.size() == 1);
343+
torch::lazy::NodePtr node = torch::lazy::MakeNode<AdaptiveAvgPool3d>(
344+
bridge::GetXlaTensor(self)->GetIrValue(),
345+
std::vector<int64_t>(output_size.begin(), output_size.end()),
346+
std::move(shapes));
347+
return torch_xla::bridge::AtenFromXlaTensor(
348+
torch_xla::XLATensor::Create(std::move(node), *common_device));
339349
}
340350

341351
at::Tensor XLANativeFunctions::_adaptive_avg_pool3d_backward(
@@ -351,8 +361,17 @@ at::Tensor XLANativeFunctions::_adaptive_avg_pool3d_backward(
351361
&xla_cpu_fallback,
352362
ATEN_OP(_adaptive_avg_pool3d_backward)>::call(grad_output, self);
353363
}
354-
return bridge::AtenFromXlaTensor(XLATensor::adaptive_avg_pool3d_backward(
355-
bridge::GetXlaTensor(grad_output), bridge::GetXlaTensor(self)));
364+
auto common_device = torch_xla::bridge::GetXlaDevice(grad_output, self);
365+
XLA_CHECK(common_device);
366+
auto shapes = torch::lazy::compute_shape__adaptive_avg_pool3d_backward(
367+
grad_output, self);
368+
XLA_CHECK(shapes.size() == 1);
369+
torch::lazy::NodePtr node = torch::lazy::MakeNode<AdaptiveAvgPool3dBackward>(
370+
bridge::GetXlaTensor(grad_output)->GetIrValue(),
371+
bridge::GetXlaTensor(self)->GetIrValue(), std::move(shapes));
372+
373+
return torch_xla::bridge::AtenFromXlaTensor(
374+
torch_xla::XLATensor::Create(std::move(node), *common_device));
356375
}
357376

358377
at::Tensor XLANativeFunctions::_adaptive_avg_pool2d(

torch_xla/csrc/ops/adaptive_avg_pool3d.cpp

-48
This file was deleted.

torch_xla/csrc/ops/adaptive_avg_pool3d.h

-28
This file was deleted.

torch_xla/csrc/ops/ops.cpp

-26
Original file line numberDiff line numberDiff line change
@@ -383,32 +383,6 @@ torch::lazy::NodePtr AdaptiveMaxPool2dBackward(
383383
std::move(lower_fn));
384384
}
385385

386-
torch::lazy::NodePtr AdaptiveAvgPool3dBackward(
387-
const torch::lazy::Value& grad_output, const torch::lazy::Value& input) {
388-
auto lower_fn = [](const XlaNode& node,
389-
LoweringContext* loctx) -> XlaOpVector {
390-
xla::XlaOp grad_output = loctx->GetOutputOp(node.operand(0));
391-
xla::XlaOp input = loctx->GetOutputOp(node.operand(1));
392-
xla::XlaOp xla_output = BuildAdaptiveAvgPool3dBackward(
393-
/*out_backprop=*/grad_output, /*input=*/input);
394-
return node.ReturnOp(xla_output, loctx);
395-
};
396-
auto lower_for_shape_fn =
397-
[](absl::Span<const xla::XlaOp> operands) -> xla::XlaOp {
398-
XLA_CHECK_EQ(operands.size(), 2);
399-
return BuildAdaptiveAvgPool3dBackward(/*out_backprop=*/operands[0],
400-
/*input=*/operands[1]);
401-
};
402-
return GenericOp(torch::lazy::OpKind(at::aten::adaptive_avg_pool3d_backward),
403-
{grad_output, input},
404-
[&]() {
405-
return InferOutputShape(
406-
{GetXlaShape(grad_output), GetXlaShape(input)},
407-
lower_for_shape_fn);
408-
},
409-
std::move(lower_fn));
410-
}
411-
412386
torch::lazy::NodePtr ComparisonOp(c10::Symbol kind,
413387
const torch::lazy::Value& input,
414388
const torch::lazy::Value& other) {

torch_xla/csrc/ops/ops.h

-3
Original file line numberDiff line numberDiff line change
@@ -160,9 +160,6 @@ torch::lazy::NodePtr MatMul(const torch::lazy::Value& lhs,
160160
torch::lazy::NodePtr AdaptiveMaxPool2dBackward(
161161
const torch::lazy::Value& grad_output, const torch::lazy::Value& input);
162162

163-
torch::lazy::NodePtr AdaptiveAvgPool3dBackward(
164-
const torch::lazy::Value& grad_output, const torch::lazy::Value& input);
165-
166163
torch::lazy::NodePtr ComparisonOp(c10::Symbol kind,
167164
const torch::lazy::Value& input,
168165
const torch::lazy::Value& other);

torch_xla/csrc/ops/ops_lower_fn.cpp

+14
Original file line numberDiff line numberDiff line change
@@ -38,6 +38,20 @@ torch_xla::XlaOpVector AdaptiveAvgPool2dBackward::Lower(
3838
loctx);
3939
}
4040

41+
torch_xla::XlaOpVector AdaptiveAvgPool3d::Lower(LoweringContext* loctx) const {
42+
xla::XlaOp input = loctx->GetOutputOp(operand(0));
43+
return ReturnOp(BuildAdaptiveAvgPool3d(input, output_size), loctx);
44+
}
45+
46+
torch_xla::XlaOpVector AdaptiveAvgPool3dBackward::Lower(
47+
LoweringContext* loctx) const {
48+
xla::XlaOp grad_output = loctx->GetOutputOp(operand(0));
49+
xla::XlaOp input = loctx->GetOutputOp(operand(1));
50+
xla::XlaOp xla_output = BuildAdaptiveAvgPool3dBackward(
51+
/*out_backprop=*/grad_output, /*input=*/input);
52+
return ReturnOp(xla_output, loctx);
53+
}
54+
4155
torch_xla::XlaOpVector Asin::Lower(LoweringContext* loctx) const {
4256
xla::XlaOp xla_input = loctx->GetOutputOp(operand(0));
4357
return ReturnOp(xla::Asin(xla_input), loctx);

torch_xla/csrc/ops/ops_xla_shape_fn.cpp

+22
Original file line numberDiff line numberDiff line change
@@ -42,6 +42,28 @@ xla::Shape AdaptiveAvgPool2dBackwardOutputShape(
4242
lower_for_shape_fn);
4343
}
4444

45+
xla::Shape AdaptiveAvgPool3dOutputShape(const torch::lazy::Value& input,
46+
absl::Span<const int64_t> output_size) {
47+
auto lower_for_shape_fn =
48+
[output_size](absl::Span<const xla::XlaOp> operands) -> xla::XlaOp {
49+
XLA_CHECK_EQ(operands.size(), 1);
50+
return BuildAdaptiveAvgPool3d(operands[0], output_size);
51+
};
52+
return InferOutputShape({GetXlaShape(input)}, lower_for_shape_fn);
53+
}
54+
55+
xla::Shape AdaptiveAvgPool3dBackwardOutputShape(
56+
const torch::lazy::Value& grad_output, const torch::lazy::Value& input) {
57+
auto lower_for_shape_fn =
58+
[](absl::Span<const xla::XlaOp> operands) -> xla::XlaOp {
59+
XLA_CHECK_EQ(operands.size(), 2);
60+
return BuildAdaptiveAvgPool3dBackward(/*out_backprop=*/operands[0],
61+
/*input=*/operands[1]);
62+
};
63+
return InferOutputShape({GetXlaShape(grad_output), GetXlaShape(input)},
64+
lower_for_shape_fn);
65+
}
66+
4567
xla::Shape AsinOutputShape(const torch::lazy::Value& input) {
4668
return GetXlaShape(input);
4769
}

torch_xla/csrc/ops/ops_xla_shape_fn.h

+6
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,12 @@ xla::Shape AdaptiveAvgPool2dOutputShape(const torch::lazy::Value& input,
1515
xla::Shape AdaptiveAvgPool2dBackwardOutputShape(
1616
const torch::lazy::Value& grad_output, const torch::lazy::Value& input);
1717

18+
xla::Shape AdaptiveAvgPool3dOutputShape(const torch::lazy::Value& input,
19+
absl::Span<const int64_t> output_size);
20+
21+
xla::Shape AdaptiveAvgPool3dBackwardOutputShape(
22+
const torch::lazy::Value& grad_output, const torch::lazy::Value& input);
23+
1824
xla::Shape AsinOutputShape(const torch::lazy::Value& input);
1925

2026
xla::Shape AsinhOutputShape(const torch::lazy::Value& input);

torch_xla/csrc/tensor.h

-6
Original file line numberDiff line numberDiff line change
@@ -328,12 +328,6 @@ class XLATensor : public c10::intrusive_ptr_target {
328328
static XLATensorPtr adaptive_max_pool2d_backward(
329329
const XLATensorPtr& grad_output, const XLATensorPtr& input);
330330

331-
static XLATensorPtr adaptive_avg_pool3d(const XLATensorPtr& input,
332-
std::vector<int64_t> output_size);
333-
334-
static XLATensorPtr adaptive_avg_pool3d_backward(
335-
const XLATensorPtr& grad_output, const XLATensorPtr& input);
336-
337331
static XLATensorPtr _adaptive_avg_pool2d(
338332
const XLATensorPtr& input, std::vector<int64_t> output_size,
339333
std::vector<torch::lazy::Shape>&& shapes);

torch_xla/csrc/tensor_methods.cpp

-13
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,6 @@
2121
#include "torch_xla/csrc/layout_manager.h"
2222
#include "torch_xla/csrc/lowering_context.h"
2323
#include "torch_xla/csrc/ops/adam_optimizer_step.h"
24-
#include "torch_xla/csrc/ops/adaptive_avg_pool3d.h"
2524
#include "torch_xla/csrc/ops/adaptive_max_pool2d.h"
2625
#include "torch_xla/csrc/ops/all.h"
2726
#include "torch_xla/csrc/ops/all_gather.h"
@@ -605,18 +604,6 @@ XLATensorPtr XLATensor::adaptive_max_pool2d_backward(
605604
input->GetIrValue()));
606605
}
607606

608-
XLATensorPtr XLATensor::adaptive_avg_pool3d(const XLATensorPtr& input,
609-
std::vector<int64_t> output_size) {
610-
return input->CreateFrom(torch::lazy::MakeNode<AdaptiveAvgPool3d>(
611-
input->GetIrValue(), std::move(output_size)));
612-
}
613-
614-
XLATensorPtr XLATensor::adaptive_avg_pool3d_backward(
615-
const XLATensorPtr& grad_output, const XLATensorPtr& input) {
616-
return input->CreateFrom(AdaptiveAvgPool3dBackward(grad_output->GetIrValue(),
617-
input->GetIrValue()));
618-
}
619-
620607
XLATensorPtr XLATensor::_adaptive_avg_pool2d(
621608
const XLATensorPtr& input, std::vector<int64_t> output_size,
622609
std::vector<torch::lazy::Shape>&& shapes) {

xla_native_functions.yaml

+2
Original file line numberDiff line numberDiff line change
@@ -48,6 +48,8 @@ full_codegen:
4848
ir_gen:
4949
- _adaptive_avg_pool2d
5050
- _adaptive_avg_pool2d_backward
51+
- _adaptive_avg_pool3d
52+
- _adaptive_avg_pool3d_backward
5153
supported:
5254
- __ilshift__.Scalar
5355
- __ilshift__.Tensor

0 commit comments

Comments
 (0)