Skip to content

Commit 7c91dec

Browse files
committed
feat(//core/conversion/converters/impl/plugins): template for interpolate plugin
Signed-off-by: Abhiram Iyer <[email protected]> Signed-off-by: Abhiram Iyer <[email protected]>
1 parent 57143c2 commit 7c91dec

File tree

1 file changed

+183
-0
lines changed

1 file changed

+183
-0
lines changed
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,183 @@
1+
#include <torch/extension.h>
2+
#include <torch/script.h>
3+
#include <string>
4+
#include <iostream>
5+
#include <sstream>
6+
#include <NvInfer.h>
7+
#include <ATen/ATen.h>
8+
#include <ATen/cuda/CUDAEvent.h>
9+
#include <torch/torch.h>
10+
#include <cuda_runtime_api.h>
11+
#include "NvInferVersion.h"
12+
#include <vector>
13+
#include <cudnn.h>
14+
#include <NVInferRuntime.h>
15+
#include <NVInferRuntimeCommon.h>
16+
17+
namespace trtorch {
18+
namespace core {
19+
namespace conversion {
20+
namespace converters {
21+
namespace impl {
22+
namespace plugins {
23+
namespace {
24+
25+
class InterpolatePlugin : public nvinfer1::IPluginV2DynamicExt {
26+
private:
27+
at::TensorOptions tensor_options;
28+
std::vector<int64_t> input_sizes;
29+
std::vector<int64_t> output_sizes;
30+
DataType dtype;
31+
32+
std::vector<int64_t> size;
33+
std::string mode;
34+
bool align_corners;
35+
36+
public:
37+
InterpolatePlugin(const char* name, std::vector<int64_t> in_shape,
38+
std::vector<int64_t> out_shape,
39+
std::string mode,
40+
bool align_corners) : name(name), in_shape(in_shape), out_shape(out_shape), mode(mode), align_corners(align_corners) {}
41+
42+
43+
44+
const char* getPluginType() const override {
45+
return "Interpolate";
46+
}
47+
48+
const char* getPluginVersion() const override {
49+
return "1";
50+
}
51+
52+
const char* getPluginNamespace() const override {
53+
return "trtorch";
54+
}
55+
56+
void setPluginNamespace(const char* pluginNamespace) {}
57+
58+
int getTensorRTVersion() const override {
59+
return NV_TENSORRT_MAJOR;
60+
}
61+
62+
nvinfer1::IPluginV2DynamicExt* clone() const override {
63+
auto* plugin = new InterpolatePlugin(*this);
64+
return plugin;
65+
}
66+
67+
nvinfer::DimsExprs getOutputDimensions(int outputIndex, const nvinfer1::DimsExprs *inputs, int nbInputs, nvinfer1::IExprBuilder &exprBuilder) const override {
68+
69+
}
70+
71+
nvinfer1::DataType getOutputDataType(int index, const nvinfer1::DataType* inputTypes, int nbInputs) const override {
72+
73+
}
74+
75+
int getNbOutputs() const override {
76+
return 1;
77+
}
78+
79+
int initialize() override {
80+
81+
}
82+
83+
void terminate() override {
84+
85+
}
86+
87+
void serialize(void* buffer) const override {
88+
89+
}
90+
91+
void size_t getSerializationSize() const override {
92+
93+
}
94+
95+
void destroy() override {
96+
97+
}
98+
99+
bool supportsFormatCombination(int pos, const nvinfer1::PluginTensorDesc* inOut, int nbInputs, int nbOutputs) override {
100+
101+
}
102+
103+
void configurePlugin(const nvinfer1::DynamicPluginTensorDesc* in, int nbInputs, const nvinfer1::DynamicPluginTensorDesc* out, int nbOutputs) override {
104+
105+
}
106+
107+
size_t getWorkspaceSize(const nvinfer1::PluginTensorDesc* inputs, int nbInputs, const PluginTensorDesc* outputs, int nbOutputs) const override {
108+
109+
}
110+
111+
void attachToContext(nvinfer1::cudnnContext*, nvinfer1::cublasContext*, nvinfer1::IGpuAllocator*) override {}
112+
113+
void detachFromContext() override {}
114+
115+
int enqueue(const nvinfer1::PluginTensorDesc* inputDesc, const nvinfer1::PluginTensorDesc* outputDesc, const void *const *inputs,
116+
void *const *outputs, void *workspace,
117+
cudaStream_t stream) override {
118+
119+
}
120+
121+
122+
123+
124+
private:
125+
std::string name;
126+
std::vector<int64_t> in_shape;
127+
std::vector<int64_t> out_shape;
128+
std::string mode;
129+
bool align_corners;
130+
131+
nvinfer1::DataType dtype;
132+
}
133+
134+
135+
class InterpolatePluginCreator : public nvinfer1::IPluginCreator {
136+
public:
137+
InterpolatePluginCreator() {}
138+
139+
int getTensorRTVersion() const override {
140+
return NV_TENSORRT_MAJOR;
141+
}
142+
143+
const char *getPluginNamespace() const override {
144+
return "trtorch";
145+
}
146+
147+
void setPluginNamespace(const char* libNamespace) override {}
148+
149+
const char *getPluginName() const override {
150+
return "interpolate";
151+
}
152+
153+
const char *getPluginVersion() const override {
154+
return "1";
155+
}
156+
157+
nvinfer1::IPluginV2* createPlugin(const char* name, const nvinfer1::PluginFieldCollection *fc) override {
158+
return nullptr;
159+
}
160+
161+
nvinfer1::IPluginV2* createPlugin(const char* name, std::vector<int64_t> in_shape, std::vector<int64_t> out_shape, std::string mode, bool align_corners) {
162+
return new InterpolatePlugin(name, in_shape, out_shape, mode, align_corners);
163+
}
164+
165+
nvinfer1::IPluginV2* deserializePlugin(const char* name, const void *serialData, size_t serialLength) override {
166+
return nullptr;
167+
}
168+
169+
const nvinfer1::PluginFieldCollection* getFieldNames() override {
170+
return nullptr;
171+
}
172+
}
173+
174+
REGISTER_TENSORRT_PLUGIN(InterpolatePluginCreator);
175+
176+
} // namespace
177+
} // namespace plugins
178+
} // namespace impl
179+
} // namespace converters
180+
} // namespace conversion
181+
} // namespace core
182+
} // namespace trtorch
183+

0 commit comments

Comments
 (0)