1
- #include < torch/extension.h>
2
- #include < torch/script.h>
3
1
#include < string>
4
2
#include < iostream>
5
3
#include < sstream>
6
- #include < NvInfer.h>
7
4
#include < ATen/ATen.h>
8
5
#include < ATen/cuda/CUDAEvent.h>
9
- #include < torch/torch.h>
10
6
#include < cuda_runtime_api.h>
11
- #include " NvInferVersion.h"
12
7
#include < vector>
13
8
#include < cudnn.h>
14
- #include < NVInferRuntime.h>
15
- #include < NVInferRuntimeCommon.h>
9
+
10
+ #include " core/util/prelude.h"
11
+ #include " torch/torch.h"
12
+ #include " NvInfer.h"
13
+
14
+ using namespace nvinfer1 ;
16
15
17
16
namespace trtorch {
18
17
namespace core {
@@ -25,24 +24,58 @@ namespace {
25
24
class InterpolatePlugin : public nvinfer1 ::IPluginV2DynamicExt {
26
25
private:
27
26
at::TensorOptions tensor_options;
28
- std::vector<int64_t > input_sizes;
29
- std::vector<int64_t > output_sizes;
30
27
DataType dtype;
31
28
29
+ std::vector<int64_t > in_shape;
30
+ std::vector<int64_t > out_shape;
32
31
std::vector<int64_t > size;
33
32
std::string mode;
34
33
bool align_corners;
35
34
36
35
public:
37
- InterpolatePlugin (const char * name, std::vector<int64_t > in_shape,
38
- std::vector<int64_t > out_shape,
39
- std::string mode,
40
- bool align_corners) : name(name), in_shape(in_shape), out_shape(out_shape), mode(mode), align_corners(align_corners) {}
41
-
36
+ InterpolatePlugin (std::vector<int64_t > in_shape, std::vector<int64_t > out_shape, std::vector<int64_t > size, std::string mode, bool align_corners) :
37
+ in_shape (in_shape), out_shape(out_shape), size(size), mode(mode), align_corners(align_corners)
38
+ {}
39
+
40
+ InterpolatePlugin (const char *data, size_t length) {
41
+ std::istringstream data_stream (std::string (data, length));
42
+
43
+ torch::serialize::InputArchive input_archive;
44
+ input_archive.load_from (data_stream);
45
+
46
+ {
47
+ torch::IValue value;
48
+ input_archive.read (" in_shape" , value);
49
+ in_shape = value.toIntVector ();
50
+ }
51
+ {
52
+ torch::IValue value;
53
+ input_archive.read (" out_shape" , value);
54
+ out_shape = value.toIntVector ();
55
+ }
56
+ {
57
+ torch::IValue value;
58
+ input_archive.read (" size" , value);
59
+ size = value.toIntVector ();
60
+ }
61
+ {
62
+ torch::IValue value;
63
+ input_archive.read (" mode" , value);
64
+ mode = value.toStringRef ();
65
+ }
66
+ {
67
+ torch::IValue value;
68
+ input_archive.read (" align_corners" , value);
69
+ align_corners = value.toBool ();
70
+ }
71
+ }
42
72
73
+ int getNbOutputs () const override {
74
+ return 1 ;
75
+ }
43
76
44
77
const char * getPluginType () const override {
45
- return " Interpolate " ;
78
+ return " Interpolate_TRTorch " ;
46
79
}
47
80
48
81
const char * getPluginVersion () const override {
@@ -60,79 +93,125 @@ class InterpolatePlugin : public nvinfer1::IPluginV2DynamicExt {
60
93
}
61
94
62
95
nvinfer1::IPluginV2DynamicExt* clone () const override {
63
- auto * plugin = new InterpolatePlugin (*this );
64
- return plugin;
96
+ return new InterpolatePlugin (in_shape, out_shape, size, mode, align_corners);
65
97
}
66
98
67
- nvinfer::DimsExprs getOutputDimensions (int outputIndex, const nvinfer1::DimsExprs *inputs, int nbInputs, nvinfer1::IExprBuilder &exprBuilder) const override {
68
-
69
- }
99
+ nvinfer1::DimsExprs getOutputDimensions (int outputIndex, const nvinfer1::DimsExprs *inputs, int nbInputs, nvinfer1::IExprBuilder &exprBuilder) override {
100
+ // nvinfer1::DimsExprs output(inputs[0]);
70
101
71
- nvinfer1::DataType getOutputDataType (int index, const nvinfer1::DataType* inputTypes, int nbInputs) const override {
102
+ // output.nbDims = out_shape.size();
103
+
104
+ // for (int i = 0; i < out_shape.size(); i++) {
105
+ // output.d[i] = exprBuilder.getConstantValue(out_shape[i]);
106
+ // }
72
107
108
+ // return output;
109
+ nvinfer1::DimsExprs empty;
110
+ return empty;
73
111
}
74
112
75
- int getNbOutputs ( ) const override {
76
- return 1 ;
113
+ nvinfer1::DataType getOutputDataType ( int index, const nvinfer1::DataType* inputTypes, int nbInputs ) const override {
114
+ return DataType:: kFLOAT ;
77
115
}
78
116
79
117
int initialize () override {
118
+ tensor_options = tensor_options.device (c10::kCUDA );
119
+ tensor_options = tensor_options.dtype (c10::kFloat );
80
120
121
+ return 0 ;
81
122
}
82
123
83
- void terminate () override {
84
-
85
- }
124
+ void terminate () override {}
86
125
87
126
void serialize (void * buffer) const override {
127
+ std::string data = serializeToString ();
128
+ size_t size = getSerializationSize ();
88
129
130
+ data.copy ((char *) buffer, size);
89
131
}
90
132
91
- void size_t getSerializationSize () const override {
133
+ std::string serializeToString () const {
134
+ torch::serialize::OutputArchive output_archive;
92
135
93
- }
136
+ output_archive.write (" in_shape" , torch::IValue (in_shape));
137
+ output_archive.write (" out_shape" , torch::IValue (out_shape));
138
+ output_archive.write (" size" , torch::IValue (size));
139
+ output_archive.write (" mode" , torch::IValue (mode));
140
+ output_archive.write (" align_corners" , torch::IValue (align_corners));
94
141
95
- void destroy () override {
142
+ std::ostringstream data_str;
143
+ output_archive.save_to (data_str);
96
144
145
+ return data_str.str ();
97
146
}
98
147
99
- bool supportsFormatCombination ( int pos, const nvinfer1::PluginTensorDesc* inOut, int nbInputs, int nbOutputs) override {
100
-
148
+ size_t getSerializationSize () const override {
149
+ return serializeToString (). size ();
101
150
}
102
151
103
- void configurePlugin ( const nvinfer1::DynamicPluginTensorDesc* in, int nbInputs, const nvinfer1::DynamicPluginTensorDesc* out, int nbOutputs ) override {
152
+ void destroy ( ) override {}
104
153
105
- }
154
+ bool supportsFormatCombination (int pos, const nvinfer1::PluginTensorDesc* inOut, int nbInputs, int nbOutputs) override {
155
+ if (inOut->format != nvinfer1::TensorFormat::kLINEAR ) {
156
+ return false ;
157
+ }
106
158
107
- size_t getWorkspaceSize (const nvinfer1::PluginTensorDesc* inputs, int nbInputs, const PluginTensorDesc* outputs, int nbOutputs) const override {
159
+ if (inOut->type == DataType::kINT32 || inOut->type == DataType::kINT8 ) {
160
+ return false ;
161
+ }
108
162
163
+ return true ;
109
164
}
110
165
111
- void attachToContext (nvinfer1::cudnnContext*, nvinfer1::cublasContext*, nvinfer1::IGpuAllocator*) override {}
166
+ void configurePlugin (const nvinfer1::DynamicPluginTensorDesc* in, int nbInputs, const nvinfer1::DynamicPluginTensorDesc* out, int nbOutputs) override {
167
+ dtype = DataType::kFLOAT ;
168
+ }
112
169
113
- void detachFromContext () override {}
170
+ size_t getWorkspaceSize (const nvinfer1::PluginTensorDesc* inputs, int nbInputs, const nvinfer1::PluginTensorDesc* outputs, int nbOutputs) const override {
171
+ return 0 ;
172
+ }
114
173
115
174
int enqueue (const nvinfer1::PluginTensorDesc* inputDesc, const nvinfer1::PluginTensorDesc* outputDesc, const void *const *inputs,
116
175
void *const *outputs, void *workspace,
117
176
cudaStream_t stream) override {
118
-
119
- }
177
+ at::Tensor input = at::from_blob (( void *) inputs[ 0 ], in_shape, []( void *){}, tensor_options);
178
+ at::Tensor output = at::from_blob (outputs[ 0 ], out_shape, []( void *){}, tensor_options);
120
179
180
+ at::cuda::CUDAStream torch_stream = at::cuda::getStreamFromPool ();
181
+ at::cuda::CUDAStreamGuard torch_guard (torch_stream);
121
182
183
+ cudaEvent_t event;
184
+ cudaEventCreate (&event);
185
+ cudaEventRecord (event, stream);
122
186
187
+ cudaStreamWaitEvent (torch_stream.stream (), event, 0 );
123
188
124
- private:
125
- std::string name;
126
- std::vector<int64_t > in_shape;
127
- std::vector<int64_t > out_shape;
128
- std::string mode;
129
- bool align_corners;
189
+ if (mode == " linear" ) {
190
+ at::upsample_linear1d_out (output, input, {size[0 ]}, align_corners);
191
+ } else if (mode == " bilinear" ) {
192
+ at::upsample_bilinear2d_out (output, input, {size[0 ], size[1 ]}, align_corners);
193
+ } else if (mode == " trilinear" ) {
194
+ at::upsample_trilinear3d_out (output, input, {size[0 ], size[1 ], size[2 ]}, align_corners);
195
+ }
196
+
197
+ cudaEvent_t torch_event;
198
+ cudaEventCreate (&torch_event);
199
+ cudaEventRecord (torch_event, torch_stream.stream ());
200
+
201
+ cudaStreamWaitEvent (stream, torch_event, 0 );
202
+
203
+ cudaEventDestroy (event);
204
+ cudaEventDestroy (torch_event);
130
205
131
- nvinfer1::DataType dtype;
132
- }
206
+ return 0 ;
207
+ }
208
+ };
133
209
134
210
135
211
class InterpolatePluginCreator : public nvinfer1 ::IPluginCreator {
212
+ private:
213
+ std::string name;
214
+
136
215
public:
137
216
InterpolatePluginCreator () {}
138
217
@@ -158,18 +237,20 @@ class InterpolatePluginCreator : public nvinfer1::IPluginCreator {
158
237
return nullptr ;
159
238
}
160
239
161
- nvinfer1::IPluginV2* createPlugin (const char * name, std::vector<int64_t > in_shape, std::vector<int64_t > out_shape, std::string mode, bool align_corners) {
162
- return new InterpolatePlugin (name, in_shape, out_shape, mode, align_corners);
240
+ nvinfer1::IPluginV2* createPlugin (const char * name, std::vector<int64_t > in_shape, std::vector<int64_t > out_shape, std::vector<int64_t > size, std::string mode, bool align_corners) {
241
+ name = name;
242
+ return new InterpolatePlugin (in_shape, out_shape, size, mode, align_corners);
163
243
}
164
244
165
245
nvinfer1::IPluginV2* deserializePlugin (const char * name, const void *serialData, size_t serialLength) override {
166
- return nullptr ;
246
+ name = name;
247
+ return new InterpolatePlugin ((const char *) serialData, serialLength);
167
248
}
168
249
169
250
const nvinfer1::PluginFieldCollection* getFieldNames () override {
170
251
return nullptr ;
171
252
}
172
- }
253
+ };
173
254
174
255
REGISTER_TENSORRT_PLUGIN (InterpolatePluginCreator);
175
256
0 commit comments