Skip to content

Commit 58dbaef

Browse files
committed
feat(//core/conversion/converters/impl/plugins): interpolate plugin compiles now. time to test it.
Signed-off-by: Abhiram Iyer <[email protected]> Signed-off-by: Abhiram Iyer <[email protected]>
1 parent 7c91dec commit 58dbaef

File tree

3 files changed

+165
-50
lines changed

3 files changed

+165
-50
lines changed

Diff for: core/conversion/converters/BUILD

+1
Original file line numberDiff line numberDiff line change
@@ -35,6 +35,7 @@ cc_library(
3535
"@tensorrt//:nvinfer",
3636
"//core/util:prelude",
3737
"//core/conversion/conversionctx",
38+
"//core/conversion/converters/impl/plugins"
3839
] + select({
3940
":use_pre_cxx11_abi": ["@libtorch_pre_cxx11_abi//:libtorch"],
4041
"//conditions:default": ["@libtorch//:libtorch"],

Diff for: core/conversion/converters/impl/plugins/BUILD

+33
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,33 @@
1+
package(default_visibility = ["//visibility:public"])
2+
3+
config_setting(
4+
name = "use_pre_cxx11_abi",
5+
values = {
6+
"define": "abi=pre_cxx11_abi",
7+
}
8+
)
9+
10+
cc_library(
11+
name = "plugins",
12+
hdrs = [],
13+
srcs = [
14+
"interpolate_plugin.cpp"
15+
],
16+
deps = [
17+
"@tensorrt//:nvinfer",
18+
"//core/util:prelude",
19+
"//core/conversion/conversionctx",
20+
] + select({
21+
":use_pre_cxx11_abi": ["@libtorch_pre_cxx11_abi//:libtorch"],
22+
"//conditions:default": ["@libtorch//:libtorch"],
23+
}),
24+
alwayslink = True,
25+
)
26+
27+
load("@rules_pkg//:pkg.bzl", "pkg_tar")
28+
29+
pkg_tar(
30+
name = "include",
31+
package_dir = "core/conversion/converters/impl/plugins",
32+
srcs = [],
33+
)

Diff for: core/conversion/converters/impl/plugins/interpolate_plugin.cpp

+131-50
Original file line numberDiff line numberDiff line change
@@ -1,18 +1,17 @@
1-
#include <torch/extension.h>
2-
#include <torch/script.h>
31
#include <string>
42
#include <iostream>
53
#include <sstream>
6-
#include <NvInfer.h>
74
#include <ATen/ATen.h>
85
#include <ATen/cuda/CUDAEvent.h>
9-
#include <torch/torch.h>
106
#include <cuda_runtime_api.h>
11-
#include "NvInferVersion.h"
127
#include <vector>
138
#include <cudnn.h>
14-
#include <NVInferRuntime.h>
15-
#include <NVInferRuntimeCommon.h>
9+
10+
#include "core/util/prelude.h"
11+
#include "torch/torch.h"
12+
#include "NvInfer.h"
13+
14+
using namespace nvinfer1;
1615

1716
namespace trtorch {
1817
namespace core {
@@ -25,24 +24,58 @@ namespace {
2524
class InterpolatePlugin : public nvinfer1::IPluginV2DynamicExt {
2625
private:
2726
at::TensorOptions tensor_options;
28-
std::vector<int64_t> input_sizes;
29-
std::vector<int64_t> output_sizes;
3027
DataType dtype;
3128

29+
std::vector<int64_t> in_shape;
30+
std::vector<int64_t> out_shape;
3231
std::vector<int64_t> size;
3332
std::string mode;
3433
bool align_corners;
3534

3635
public:
37-
InterpolatePlugin(const char* name, std::vector<int64_t> in_shape,
38-
std::vector<int64_t> out_shape,
39-
std::string mode,
40-
bool align_corners) : name(name), in_shape(in_shape), out_shape(out_shape), mode(mode), align_corners(align_corners) {}
41-
36+
InterpolatePlugin(std::vector<int64_t> in_shape, std::vector<int64_t> out_shape, std::vector<int64_t> size, std::string mode, bool align_corners) :
37+
in_shape(in_shape), out_shape(out_shape), size(size), mode(mode), align_corners(align_corners)
38+
{}
39+
40+
InterpolatePlugin(const char *data, size_t length) {
41+
std::istringstream data_stream(std::string(data, length));
42+
43+
torch::serialize::InputArchive input_archive;
44+
input_archive.load_from(data_stream);
45+
46+
{
47+
torch::IValue value;
48+
input_archive.read("in_shape", value);
49+
in_shape = value.toIntVector();
50+
}
51+
{
52+
torch::IValue value;
53+
input_archive.read("out_shape", value);
54+
out_shape = value.toIntVector();
55+
}
56+
{
57+
torch::IValue value;
58+
input_archive.read("size", value);
59+
size = value.toIntVector();
60+
}
61+
{
62+
torch::IValue value;
63+
input_archive.read("mode", value);
64+
mode = value.toStringRef();
65+
}
66+
{
67+
torch::IValue value;
68+
input_archive.read("align_corners", value);
69+
align_corners = value.toBool();
70+
}
71+
}
4272

73+
int getNbOutputs() const override {
74+
return 1;
75+
}
4376

4477
const char* getPluginType() const override {
45-
return "Interpolate";
78+
return "Interpolate_TRTorch";
4679
}
4780

4881
const char* getPluginVersion() const override {
@@ -60,79 +93,125 @@ class InterpolatePlugin : public nvinfer1::IPluginV2DynamicExt {
6093
}
6194

6295
nvinfer1::IPluginV2DynamicExt* clone() const override {
63-
auto* plugin = new InterpolatePlugin(*this);
64-
return plugin;
96+
return new InterpolatePlugin(in_shape, out_shape, size, mode, align_corners);
6597
}
6698

67-
nvinfer::DimsExprs getOutputDimensions(int outputIndex, const nvinfer1::DimsExprs *inputs, int nbInputs, nvinfer1::IExprBuilder &exprBuilder) const override {
68-
69-
}
99+
nvinfer1::DimsExprs getOutputDimensions(int outputIndex, const nvinfer1::DimsExprs *inputs, int nbInputs, nvinfer1::IExprBuilder &exprBuilder) override {
100+
//nvinfer1::DimsExprs output(inputs[0]);
70101

71-
nvinfer1::DataType getOutputDataType(int index, const nvinfer1::DataType* inputTypes, int nbInputs) const override {
102+
// output.nbDims = out_shape.size();
103+
104+
// for (int i = 0; i < out_shape.size(); i++) {
105+
// output.d[i] = exprBuilder.getConstantValue(out_shape[i]);
106+
// }
72107

108+
// return output;
109+
nvinfer1::DimsExprs empty;
110+
return empty;
73111
}
74112

75-
int getNbOutputs() const override {
76-
return 1;
113+
nvinfer1::DataType getOutputDataType(int index, const nvinfer1::DataType* inputTypes, int nbInputs) const override {
114+
return DataType::kFLOAT;
77115
}
78116

79117
int initialize() override {
118+
tensor_options = tensor_options.device(c10::kCUDA);
119+
tensor_options = tensor_options.dtype(c10::kFloat);
80120

121+
return 0;
81122
}
82123

83-
void terminate() override {
84-
85-
}
124+
void terminate() override {}
86125

87126
void serialize(void* buffer) const override {
127+
std::string data = serializeToString();
128+
size_t size = getSerializationSize();
88129

130+
data.copy((char *) buffer, size);
89131
}
90132

91-
void size_t getSerializationSize() const override {
133+
std::string serializeToString() const {
134+
torch::serialize::OutputArchive output_archive;
92135

93-
}
136+
output_archive.write("in_shape", torch::IValue(in_shape));
137+
output_archive.write("out_shape", torch::IValue(out_shape));
138+
output_archive.write("size", torch::IValue(size));
139+
output_archive.write("mode", torch::IValue(mode));
140+
output_archive.write("align_corners", torch::IValue(align_corners));
94141

95-
void destroy() override {
142+
std::ostringstream data_str;
143+
output_archive.save_to(data_str);
96144

145+
return data_str.str();
97146
}
98147

99-
bool supportsFormatCombination(int pos, const nvinfer1::PluginTensorDesc* inOut, int nbInputs, int nbOutputs) override {
100-
148+
size_t getSerializationSize() const override {
149+
return serializeToString().size();
101150
}
102151

103-
void configurePlugin(const nvinfer1::DynamicPluginTensorDesc* in, int nbInputs, const nvinfer1::DynamicPluginTensorDesc* out, int nbOutputs) override {
152+
void destroy() override {}
104153

105-
}
154+
bool supportsFormatCombination(int pos, const nvinfer1::PluginTensorDesc* inOut, int nbInputs, int nbOutputs) override {
155+
if (inOut->format != nvinfer1::TensorFormat::kLINEAR) {
156+
return false;
157+
}
106158

107-
size_t getWorkspaceSize(const nvinfer1::PluginTensorDesc* inputs, int nbInputs, const PluginTensorDesc* outputs, int nbOutputs) const override {
159+
if (inOut->type == DataType::kINT32 || inOut->type == DataType::kINT8) {
160+
return false;
161+
}
108162

163+
return true;
109164
}
110165

111-
void attachToContext(nvinfer1::cudnnContext*, nvinfer1::cublasContext*, nvinfer1::IGpuAllocator*) override {}
166+
void configurePlugin(const nvinfer1::DynamicPluginTensorDesc* in, int nbInputs, const nvinfer1::DynamicPluginTensorDesc* out, int nbOutputs) override {
167+
dtype = DataType::kFLOAT;
168+
}
112169

113-
void detachFromContext() override {}
170+
size_t getWorkspaceSize(const nvinfer1::PluginTensorDesc* inputs, int nbInputs, const nvinfer1::PluginTensorDesc* outputs, int nbOutputs) const override {
171+
return 0;
172+
}
114173

115174
int enqueue(const nvinfer1::PluginTensorDesc* inputDesc, const nvinfer1::PluginTensorDesc* outputDesc, const void *const *inputs,
116175
void *const *outputs, void *workspace,
117176
cudaStream_t stream) override {
118-
119-
}
177+
at::Tensor input = at::from_blob((void*) inputs[0], in_shape, [](void*){}, tensor_options);
178+
at::Tensor output = at::from_blob(outputs[0], out_shape, [](void*){}, tensor_options);
120179

180+
at::cuda::CUDAStream torch_stream = at::cuda::getStreamFromPool();
181+
at::cuda::CUDAStreamGuard torch_guard(torch_stream);
121182

183+
cudaEvent_t event;
184+
cudaEventCreate(&event);
185+
cudaEventRecord(event, stream);
122186

187+
cudaStreamWaitEvent(torch_stream.stream(), event, 0);
123188

124-
private:
125-
std::string name;
126-
std::vector<int64_t> in_shape;
127-
std::vector<int64_t> out_shape;
128-
std::string mode;
129-
bool align_corners;
189+
if (mode == "linear") {
190+
at::upsample_linear1d_out(output, input, {size[0]}, align_corners);
191+
} else if (mode == "bilinear") {
192+
at::upsample_bilinear2d_out(output, input, {size[0], size[1]}, align_corners);
193+
} else if (mode == "trilinear") {
194+
at::upsample_trilinear3d_out(output, input, {size[0], size[1], size[2]}, align_corners);
195+
}
196+
197+
cudaEvent_t torch_event;
198+
cudaEventCreate(&torch_event);
199+
cudaEventRecord(torch_event, torch_stream.stream());
200+
201+
cudaStreamWaitEvent(stream, torch_event, 0);
202+
203+
cudaEventDestroy(event);
204+
cudaEventDestroy(torch_event);
130205

131-
nvinfer1::DataType dtype;
132-
}
206+
return 0;
207+
}
208+
};
133209

134210

135211
class InterpolatePluginCreator : public nvinfer1::IPluginCreator {
212+
private:
213+
std::string name;
214+
136215
public:
137216
InterpolatePluginCreator() {}
138217

@@ -158,18 +237,20 @@ class InterpolatePluginCreator : public nvinfer1::IPluginCreator {
158237
return nullptr;
159238
}
160239

161-
nvinfer1::IPluginV2* createPlugin(const char* name, std::vector<int64_t> in_shape, std::vector<int64_t> out_shape, std::string mode, bool align_corners) {
162-
return new InterpolatePlugin(name, in_shape, out_shape, mode, align_corners);
240+
nvinfer1::IPluginV2* createPlugin(const char* name, std::vector<int64_t> in_shape, std::vector<int64_t> out_shape, std::vector<int64_t> size, std::string mode, bool align_corners) {
241+
name = name;
242+
return new InterpolatePlugin(in_shape, out_shape, size, mode, align_corners);
163243
}
164244

165245
nvinfer1::IPluginV2* deserializePlugin(const char* name, const void *serialData, size_t serialLength) override {
166-
return nullptr;
246+
name = name;
247+
return new InterpolatePlugin((const char*) serialData, serialLength);
167248
}
168249

169250
const nvinfer1::PluginFieldCollection* getFieldNames() override {
170251
return nullptr;
171252
}
172-
}
253+
};
173254

174255
REGISTER_TENSORRT_PLUGIN(InterpolatePluginCreator);
175256

0 commit comments

Comments
 (0)