Skip to content

Commit f0fefaa

Browse files
committed
fix(): trying to resolve interpolate plugin problems
Signed-off-by: Abhiram Iyer <[email protected]> Signed-off-by: Abhiram Iyer <[email protected]>
1 parent 58dbaef commit f0fefaa

File tree

5 files changed

+319
-206
lines changed

5 files changed

+319
-206
lines changed

Diff for: core/conversion/converters/BUILD

+1-1
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@ config_setting(
1010
cc_library(
1111
name = "converters",
1212
hdrs = [
13-
"converters.h",
13+
"converters.h"
1414
],
1515
srcs = [
1616
"NodeConverterRegistry.cpp",

Diff for: core/conversion/converters/impl/interpolate.cpp

+24-9
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,8 @@
11
#include "torch/torch.h"
22
#include "core/util/prelude.h"
33
#include "core/conversion/converters/converters.h"
4+
#include "NvInfer.h"
5+
#include "plugins/interpolate_plugin.h"
46

57
#include <csignal>
68

@@ -108,7 +110,7 @@ auto interpolate_registrations TRTORCH_UNUSED = RegisterNodeConversionPatterns()
108110
auto in = args[0].ITensor();
109111
auto in_shape = util::toVec(in->getDimensions());
110112

111-
bool align_corners = args[2].IValue()->to<bool>();
113+
bool align_corners = args[2].unwrapToBool();
112114

113115
// Case 1: user uses output size and not scales
114116
if (!args[1].IValue()->isNone() && args[3].IValue()->isNone()) {
@@ -119,16 +121,29 @@ auto interpolate_registrations TRTORCH_UNUSED = RegisterNodeConversionPatterns()
119121
auto out_shape = in_shape;
120122
std::copy(out_size.begin(), out_size.end(), out_shape.begin() + (in_shape.size() - out_size.size()));
121123

122-
auto resize_layer = ctx->net->addResize(*in);
123-
TRTORCH_CHECK(resize_layer, "Unable to create interpolation (resizing) layer from node" << *n);
124+
if (!align_corners) {
125+
//auto creator = getPluginRegistry()->getPluginCreator("interpolate", "1");
126+
//auto* plugin = creator->createPlugin(util::node_info(n).c_str(), in_shape, out_shape, out_size, std::string("linear"), align_corners);
127+
auto creator = new plugins::InterpolatePluginCreator();
124128

125-
resize_layer->setOutputDimensions(util::toDims(out_shape));
126-
resize_layer->setResizeMode(nvinfer1::ResizeMode::kLINEAR);
127-
resize_layer->setAlignCorners(align_corners);
128-
resize_layer->setName(util::node_info(n).c_str());
129+
auto plugin = creator->createPlugin(util::node_info(n).c_str(), in_shape, out_shape, out_size, std::string("linear"), align_corners);
129130

130-
auto layer_output = ctx->AssociateValueAndTensor(n->outputs()[0], resize_layer->getOutput(0));
131-
LOG_DEBUG("Output tensor shape: " << layer_output->getDimensions());
131+
auto resize_layer = ctx->net->addPluginV2(reinterpret_cast<nvinfer1::ITensor* const*>(in), 1, *plugin);
132+
133+
auto layer_output = ctx->AssociateValueAndTensor(n->outputs()[0], resize_layer->getOutput(0));
134+
LOG_DEBUG("Output tensor shape: " << layer_output->getDimensions());
135+
} else {
136+
auto resize_layer = ctx->net->addResize(*in);
137+
TRTORCH_CHECK(resize_layer, "Unable to create interpolation (resizing) layer from node" << *n);
138+
139+
resize_layer->setOutputDimensions(util::toDims(out_shape));
140+
resize_layer->setResizeMode(nvinfer1::ResizeMode::kLINEAR);
141+
resize_layer->setAlignCorners(align_corners);
142+
resize_layer->setName(util::node_info(n).c_str());
143+
144+
auto layer_output = ctx->AssociateValueAndTensor(n->outputs()[0], resize_layer->getOutput(0));
145+
LOG_DEBUG("Output tensor shape: " << layer_output->getDimensions());
146+
}
132147
} else {
133148
TRTORCH_THROW_ERROR("Unable to convert node: " << util::node_info(n) << "\nScale factor parameter for upsample_linear1d not supported yet.");
134149
}

Diff for: core/conversion/converters/impl/plugins/BUILD

+4-2
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,9 @@ config_setting(
99

1010
cc_library(
1111
name = "plugins",
12-
hdrs = [],
12+
hdrs = [
13+
"interpolate_plugin.h"
14+
],
1315
srcs = [
1416
"interpolate_plugin.cpp"
1517
],
@@ -29,5 +31,5 @@ load("@rules_pkg//:pkg.bzl", "pkg_tar")
2931
pkg_tar(
3032
name = "include",
3133
package_dir = "core/conversion/converters/impl/plugins",
32-
srcs = [],
34+
srcs = ["interpolate_plugin.h"],
3335
)

0 commit comments

Comments
 (0)