Skip to content

partially codegen adaptive_avgpool3d and backward #3790

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 2 commits into from
Aug 8, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
27 changes: 23 additions & 4 deletions torch_xla/csrc/aten_xla_type.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
#include "torch_xla/csrc/aten_xla_bridge.h"
#include "torch_xla/csrc/debug_util.h"
#include "torch_xla/csrc/device.h"
#include "torch_xla/csrc/generated/LazyIr.h"
#include "torch_xla/csrc/generated/XLANativeFunctions.h"
#include "torch_xla/csrc/helpers.h"
#include "torch_xla/csrc/ops/as_strided.h"
Expand Down Expand Up @@ -334,8 +335,17 @@ at::Tensor XLANativeFunctions::_adaptive_avg_pool3d(
&xla_cpu_fallback, ATEN_OP(_adaptive_avg_pool3d)>::call(self,
output_size);
}
return bridge::AtenFromXlaTensor(XLATensor::adaptive_avg_pool3d(
bridge::GetXlaTensor(self), output_size_list));
auto common_device = torch_xla::bridge::GetXlaDevice(self);
XLA_CHECK(common_device);
auto shapes =
torch::lazy::compute_shape__adaptive_avg_pool3d(self, output_size);
XLA_CHECK(shapes.size() == 1);
torch::lazy::NodePtr node = torch::lazy::MakeNode<AdaptiveAvgPool3d>(
bridge::GetXlaTensor(self)->GetIrValue(),
std::vector<int64_t>(output_size.begin(), output_size.end()),
std::move(shapes));
return torch_xla::bridge::AtenFromXlaTensor(
torch_xla::XLATensor::Create(std::move(node), *common_device));
}

at::Tensor XLANativeFunctions::_adaptive_avg_pool3d_backward(
Expand All @@ -351,8 +361,17 @@ at::Tensor XLANativeFunctions::_adaptive_avg_pool3d_backward(
&xla_cpu_fallback,
ATEN_OP(_adaptive_avg_pool3d_backward)>::call(grad_output, self);
}
return bridge::AtenFromXlaTensor(XLATensor::adaptive_avg_pool3d_backward(
bridge::GetXlaTensor(grad_output), bridge::GetXlaTensor(self)));
auto common_device = torch_xla::bridge::GetXlaDevice(grad_output, self);
XLA_CHECK(common_device);
auto shapes = torch::lazy::compute_shape__adaptive_avg_pool3d_backward(
grad_output, self);
XLA_CHECK(shapes.size() == 1);
torch::lazy::NodePtr node = torch::lazy::MakeNode<AdaptiveAvgPool3dBackward>(
bridge::GetXlaTensor(grad_output)->GetIrValue(),
bridge::GetXlaTensor(self)->GetIrValue(), std::move(shapes));

return torch_xla::bridge::AtenFromXlaTensor(
torch_xla::XLATensor::Create(std::move(node), *common_device));
}

at::Tensor XLANativeFunctions::_adaptive_avg_pool2d(
Expand Down
48 changes: 0 additions & 48 deletions torch_xla/csrc/ops/adaptive_avg_pool3d.cpp

This file was deleted.

28 changes: 0 additions & 28 deletions torch_xla/csrc/ops/adaptive_avg_pool3d.h

This file was deleted.

26 changes: 0 additions & 26 deletions torch_xla/csrc/ops/ops.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -382,32 +382,6 @@ torch::lazy::NodePtr AdaptiveMaxPool2dBackward(
std::move(lower_fn));
}

torch::lazy::NodePtr AdaptiveAvgPool3dBackward(
const torch::lazy::Value& grad_output, const torch::lazy::Value& input) {
auto lower_fn = [](const XlaNode& node,
LoweringContext* loctx) -> XlaOpVector {
xla::XlaOp grad_output = loctx->GetOutputOp(node.operand(0));
xla::XlaOp input = loctx->GetOutputOp(node.operand(1));
xla::XlaOp xla_output = BuildAdaptiveAvgPool3dBackward(
/*out_backprop=*/grad_output, /*input=*/input);
return node.ReturnOp(xla_output, loctx);
};
auto lower_for_shape_fn =
[](absl::Span<const xla::XlaOp> operands) -> xla::XlaOp {
XLA_CHECK_EQ(operands.size(), 2);
return BuildAdaptiveAvgPool3dBackward(/*out_backprop=*/operands[0],
/*input=*/operands[1]);
};
return GenericOp(torch::lazy::OpKind(at::aten::adaptive_avg_pool3d_backward),
{grad_output, input},
[&]() {
return InferOutputShape(
{GetXlaShape(grad_output), GetXlaShape(input)},
lower_for_shape_fn);
},
std::move(lower_fn));
}

torch::lazy::NodePtr ComparisonOp(c10::Symbol kind,
const torch::lazy::Value& input,
const torch::lazy::Value& other) {
Expand Down
3 changes: 0 additions & 3 deletions torch_xla/csrc/ops/ops.h
Original file line number Diff line number Diff line change
Expand Up @@ -160,9 +160,6 @@ torch::lazy::NodePtr MatMul(const torch::lazy::Value& lhs,
torch::lazy::NodePtr AdaptiveMaxPool2dBackward(
const torch::lazy::Value& grad_output, const torch::lazy::Value& input);

torch::lazy::NodePtr AdaptiveAvgPool3dBackward(
const torch::lazy::Value& grad_output, const torch::lazy::Value& input);

torch::lazy::NodePtr ComparisonOp(c10::Symbol kind,
const torch::lazy::Value& input,
const torch::lazy::Value& other);
Expand Down
14 changes: 14 additions & 0 deletions torch_xla/csrc/ops/ops_lower_fn.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,20 @@ torch_xla::XlaOpVector AdaptiveAvgPool2dBackward::Lower(
loctx);
}

torch_xla::XlaOpVector AdaptiveAvgPool3d::Lower(LoweringContext* loctx) const {
xla::XlaOp input = loctx->GetOutputOp(operand(0));
return ReturnOp(BuildAdaptiveAvgPool3d(input, output_size), loctx);
}

torch_xla::XlaOpVector AdaptiveAvgPool3dBackward::Lower(
LoweringContext* loctx) const {
xla::XlaOp grad_output = loctx->GetOutputOp(operand(0));
xla::XlaOp input = loctx->GetOutputOp(operand(1));
xla::XlaOp xla_output = BuildAdaptiveAvgPool3dBackward(
/*out_backprop=*/grad_output, /*input=*/input);
return ReturnOp(xla_output, loctx);
}

torch_xla::XlaOpVector Asin::Lower(LoweringContext* loctx) const {
xla::XlaOp xla_input = loctx->GetOutputOp(operand(0));
return ReturnOp(xla::Asin(xla_input), loctx);
Expand Down
22 changes: 22 additions & 0 deletions torch_xla/csrc/ops/ops_xla_shape_fn.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,28 @@ xla::Shape AdaptiveAvgPool2dBackwardOutputShape(
lower_for_shape_fn);
}

xla::Shape AdaptiveAvgPool3dOutputShape(const torch::lazy::Value& input,
absl::Span<const int64_t> output_size) {
auto lower_for_shape_fn =
[output_size](absl::Span<const xla::XlaOp> operands) -> xla::XlaOp {
XLA_CHECK_EQ(operands.size(), 1);
return BuildAdaptiveAvgPool3d(operands[0], output_size);
};
return InferOutputShape({GetXlaShape(input)}, lower_for_shape_fn);
}

xla::Shape AdaptiveAvgPool3dBackwardOutputShape(
const torch::lazy::Value& grad_output, const torch::lazy::Value& input) {
auto lower_for_shape_fn =
[](absl::Span<const xla::XlaOp> operands) -> xla::XlaOp {
XLA_CHECK_EQ(operands.size(), 2);
return BuildAdaptiveAvgPool3dBackward(/*out_backprop=*/operands[0],
/*input=*/operands[1]);
};
return InferOutputShape({GetXlaShape(grad_output), GetXlaShape(input)},
lower_for_shape_fn);
}

xla::Shape AsinOutputShape(const torch::lazy::Value& input) {
return GetXlaShape(input);
}
Expand Down
6 changes: 6 additions & 0 deletions torch_xla/csrc/ops/ops_xla_shape_fn.h
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,12 @@ xla::Shape AdaptiveAvgPool2dOutputShape(const torch::lazy::Value& input,
xla::Shape AdaptiveAvgPool2dBackwardOutputShape(
const torch::lazy::Value& grad_output, const torch::lazy::Value& input);

xla::Shape AdaptiveAvgPool3dOutputShape(const torch::lazy::Value& input,
absl::Span<const int64_t> output_size);

xla::Shape AdaptiveAvgPool3dBackwardOutputShape(
const torch::lazy::Value& grad_output, const torch::lazy::Value& input);

xla::Shape AsinOutputShape(const torch::lazy::Value& input);

xla::Shape AsinhOutputShape(const torch::lazy::Value& input);
Expand Down
6 changes: 0 additions & 6 deletions torch_xla/csrc/tensor.h
Original file line number Diff line number Diff line change
Expand Up @@ -328,12 +328,6 @@ class XLATensor : public c10::intrusive_ptr_target {
static XLATensorPtr adaptive_max_pool2d_backward(
const XLATensorPtr& grad_output, const XLATensorPtr& input);

static XLATensorPtr adaptive_avg_pool3d(const XLATensorPtr& input,
std::vector<int64_t> output_size);

static XLATensorPtr adaptive_avg_pool3d_backward(
const XLATensorPtr& grad_output, const XLATensorPtr& input);

static XLATensorPtr _adaptive_avg_pool2d(
const XLATensorPtr& input, std::vector<int64_t> output_size,
std::vector<torch::lazy::Shape>&& shapes);
Expand Down
13 changes: 0 additions & 13 deletions torch_xla/csrc/tensor_methods.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,6 @@
#include "torch_xla/csrc/layout_manager.h"
#include "torch_xla/csrc/lowering_context.h"
#include "torch_xla/csrc/ops/adam_optimizer_step.h"
#include "torch_xla/csrc/ops/adaptive_avg_pool3d.h"
#include "torch_xla/csrc/ops/adaptive_max_pool2d.h"
#include "torch_xla/csrc/ops/all.h"
#include "torch_xla/csrc/ops/all_gather.h"
Expand Down Expand Up @@ -605,18 +604,6 @@ XLATensorPtr XLATensor::adaptive_max_pool2d_backward(
input->GetIrValue()));
}

XLATensorPtr XLATensor::adaptive_avg_pool3d(const XLATensorPtr& input,
std::vector<int64_t> output_size) {
return input->CreateFrom(torch::lazy::MakeNode<AdaptiveAvgPool3d>(
input->GetIrValue(), std::move(output_size)));
}

XLATensorPtr XLATensor::adaptive_avg_pool3d_backward(
const XLATensorPtr& grad_output, const XLATensorPtr& input) {
return input->CreateFrom(AdaptiveAvgPool3dBackward(grad_output->GetIrValue(),
input->GetIrValue()));
}

XLATensorPtr XLATensor::_adaptive_avg_pool2d(
const XLATensorPtr& input, std::vector<int64_t> output_size,
std::vector<torch::lazy::Shape>&& shapes) {
Expand Down
2 changes: 2 additions & 0 deletions xla_native_functions.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,8 @@ full_codegen:
ir_gen:
- _adaptive_avg_pool2d
- _adaptive_avg_pool2d_backward
- _adaptive_avg_pool3d
- _adaptive_avg_pool3d_backward
supported:
- __ilshift__.Scalar
- __ilshift__.Tensor
Expand Down