-
Notifications
You must be signed in to change notification settings - Fork 2.2k
/
Copy pathinstanceNormalizationPlugin.cpp
404 lines (361 loc) · 14.6 KB
/
instanceNormalizationPlugin.cpp
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
/*
* Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "instanceNormalizationPlugin.h"
#include <cuda_fp16.h>
#include <stdexcept>
using namespace nvinfer1;
using nvinfer1::plugin::InstanceNormalizationPlugin;
using nvinfer1::plugin::InstanceNormalizationPluginCreator;
#define CHECK_CUDA(call) \
do \
{ \
cudaError_t status = call; \
if (status != cudaSuccess) \
{ \
return status; \
} \
} while (0)
#define CHECK_CUDNN(call) \
do \
{ \
cudnnStatus_t status = call; \
if (status != CUDNN_STATUS_SUCCESS) \
{ \
return status; \
} \
} while (0)
inline bool is_CHW(nvinfer1::Dims const& dims)
{
return (dims.nbDims == 3 && dims.type[0] == nvinfer1::DimensionType::kCHANNEL
&& dims.type[1] == nvinfer1::DimensionType::kSPATIAL && dims.type[2] == nvinfer1::DimensionType::kSPATIAL);
}
cudnnStatus_t convert_trt2cudnn_dtype(nvinfer1::DataType trt_dtype, cudnnDataType_t* cudnn_dtype)
{
switch (trt_dtype)
{
case nvinfer1::DataType::kFLOAT: *cudnn_dtype = CUDNN_DATA_FLOAT; break;
case nvinfer1::DataType::kHALF: *cudnn_dtype = CUDNN_DATA_HALF; break;
default: return CUDNN_STATUS_BAD_PARAM;
}
return CUDNN_STATUS_SUCCESS;
}
namespace
{
constexpr const char* INSTANCE_PLUGIN_VERSION{"1"};
constexpr const char* INSTANCE_PLUGIN_NAME{"InstanceNormalization_TRT"};
} // namespace
PluginFieldCollection InstanceNormalizationPluginCreator::mFC{};
std::vector<PluginField> InstanceNormalizationPluginCreator::mPluginAttributes;
InstanceNormalizationPlugin::InstanceNormalizationPlugin(
float epsilon, const std::vector<float>& scale, const std::vector<float>& bias)
: _epsilon(epsilon)
, _nchan(scale.size())
, _h_scale(scale)
, _h_bias(bias)
, _d_scale(nullptr)
, _d_bias(nullptr)
, _d_bytes(0)
{
ASSERT(scale.size() == bias.size());
}
InstanceNormalizationPlugin::InstanceNormalizationPlugin(
float epsilon, nvinfer1::Weights const& scale, nvinfer1::Weights const& bias)
: _epsilon(epsilon)
, _nchan(scale.count)
, _d_scale(nullptr)
, _d_bias(nullptr)
, _d_bytes(0)
{
ASSERT(scale.count == bias.count);
if (scale.type == nvinfer1::DataType::kFLOAT)
{
_h_scale.assign((float*) scale.values, (float*) scale.values + scale.count);
}
else if (scale.type == nvinfer1::DataType::kHALF)
{
_h_scale.reserve(_nchan);
for (int c = 0; c < _nchan; ++c)
{
unsigned short value = ((unsigned short*) scale.values)[c];
_h_scale.push_back(__internal_half2float(value));
}
}
else
{
throw std::runtime_error("Unsupported scale dtype");
}
if (bias.type == nvinfer1::DataType::kFLOAT)
{
_h_bias.assign((float*) bias.values, (float*) bias.values + bias.count);
}
else if (bias.type == nvinfer1::DataType::kHALF)
{
_h_bias.reserve(_nchan);
for (int c = 0; c < _nchan; ++c)
{
unsigned short value = ((unsigned short*) bias.values)[c];
_h_bias.push_back(__internal_half2float(value));
}
}
else
{
throw std::runtime_error("Unsupported bias dtype");
}
}
InstanceNormalizationPlugin::InstanceNormalizationPlugin(void const* serialData, size_t serialLength)
{
deserialize_value(&serialData, &serialLength, &_epsilon);
deserialize_value(&serialData, &serialLength, &_nchan);
deserialize_value(&serialData, &serialLength, &_h_scale);
deserialize_value(&serialData, &serialLength, &_h_bias);
}
InstanceNormalizationPlugin::~InstanceNormalizationPlugin()
{
terminate();
}
// InstanceNormalizationPlugin returns one output.
int InstanceNormalizationPlugin::getNbOutputs() const
{
return 1;
}
DimsExprs InstanceNormalizationPlugin::getOutputDimensions(
int outputIndex, const nvinfer1::DimsExprs* inputs, int nbInputs, nvinfer1::IExprBuilder& exprBuilder)
{
nvinfer1::DimsExprs output(inputs[0]);
return output;
}
int InstanceNormalizationPlugin::initialize()
{
return 0;
}
void InstanceNormalizationPlugin::terminate()
{
cudaFree(_d_bias);
cudaFree(_d_scale);
}
size_t InstanceNormalizationPlugin::getWorkspaceSize(const nvinfer1::PluginTensorDesc* inputs, int nbInputs,
const nvinfer1::PluginTensorDesc* outputs, int nbOutputs) const
{
return 0;
}
int InstanceNormalizationPlugin::enqueue(const nvinfer1::PluginTensorDesc* inputDesc,
const nvinfer1::PluginTensorDesc* outputDesc, const void* const* inputs, void* const* outputs, void* workspace,
cudaStream_t stream)
{
nvinfer1::Dims input_dims = inputDesc[0].dims;
int n = input_dims.d[0];
int c = input_dims.d[1];
int h = input_dims.d[2];
int w = input_dims.d[3] > 0 ? input_dims.d[3] : 1;
size_t nchan_bytes = c * sizeof(float);
// Note: We repeat the data for each batch entry so that we can do the full
// computation in a single CUDNN call in enqueue().
if (_d_bytes < n * nchan_bytes)
{
cudaFree(_d_bias);
cudaFree(_d_scale);
_d_bytes = n * nchan_bytes;
CHECK_CUDA(cudaMalloc((void**) &_d_scale, _d_bytes));
CHECK_CUDA(cudaMalloc((void**) &_d_bias, _d_bytes));
}
for (int i = 0; i < n; ++i)
{
CHECK_CUDA(cudaMemcpy(_d_scale + i * c, _h_scale.data(), nchan_bytes, cudaMemcpyHostToDevice));
CHECK_CUDA(cudaMemcpy(_d_bias + i * c, _h_bias.data(), nchan_bytes, cudaMemcpyHostToDevice));
}
CHECK_CUDNN(cudnnSetTensor4dDescriptor(_b_desc, CUDNN_TENSOR_NCHW, CUDNN_DATA_FLOAT, 1, n * c, 1, 1));
cudnnDataType_t cudnn_dtype{};
CHECK_CUDNN(convert_trt2cudnn_dtype(inputDesc[0].type, &cudnn_dtype));
CHECK_CUDNN(cudnnSetTensor4dDescriptor(_x_desc, CUDNN_TENSOR_NCHW, cudnn_dtype, 1, n * c, h, w));
CHECK_CUDNN(cudnnSetTensor4dDescriptor(_y_desc, CUDNN_TENSOR_NCHW, cudnn_dtype, 1, n * c, h, w));
float alpha = 1;
float beta = 0;
void const* x_ptr = inputs[0];
void* y_ptr = outputs[0];
CHECK_CUDNN(cudnnSetStream(_cudnn_handle, stream));
// Note: Use of CUDNN_BATCHNORM_SPATIAL_PERSISTENT can cause numerical
// overflows (NaNs) for fp32 data in some circumstances. The lower-
// performance CUDNN_BATCHNORM_SPATIAL should be used if this is not
// acceptable.
CHECK_CUDNN(cudnnBatchNormalizationForwardTraining(_cudnn_handle, CUDNN_BATCHNORM_SPATIAL_PERSISTENT, &alpha, &beta,
_x_desc, x_ptr, _y_desc, y_ptr, _b_desc, _d_scale, _d_bias, 1., nullptr, nullptr, _epsilon, nullptr, nullptr));
return 0;
}
size_t InstanceNormalizationPlugin::getSerializationSize() const
{
return (serialized_size(_epsilon) + serialized_size(_nchan) + serialized_size(_h_scale) + serialized_size(_h_bias));
}
void InstanceNormalizationPlugin::serialize(void* buffer) const
{
serialize_value(&buffer, _epsilon);
serialize_value(&buffer, _nchan);
serialize_value(&buffer, _h_scale);
serialize_value(&buffer, _h_bias);
}
bool InstanceNormalizationPlugin::supportsFormatCombination(
int pos, const nvinfer1::PluginTensorDesc* inOut, int nbInputs, int nbOutputs)
{
ASSERT(inOut && pos < (nbInputs + nbOutputs));
return ((inOut[pos].type == nvinfer1::DataType::kFLOAT || inOut[pos].type == nvinfer1::DataType::kHALF)
&& inOut[pos].format == nvinfer1::PluginFormat::kNCHW && inOut[pos].type == inOut[0].type);
}
const char* InstanceNormalizationPlugin::getPluginType() const
{
return INSTANCE_PLUGIN_NAME;
}
const char* InstanceNormalizationPlugin::getPluginVersion() const
{
return INSTANCE_PLUGIN_VERSION;
}
void InstanceNormalizationPlugin::destroy()
{
delete this;
}
IPluginV2DynamicExt* InstanceNormalizationPlugin::clone() const
{
auto* plugin = new InstanceNormalizationPlugin{_epsilon, _h_scale, _h_bias};
plugin->setPluginNamespace(mPluginNamespace.c_str());
return plugin;
}
// Set plugin namespace
void InstanceNormalizationPlugin::setPluginNamespace(const char* pluginNamespace)
{
mPluginNamespace = pluginNamespace;
}
const char* InstanceNormalizationPlugin::getPluginNamespace() const
{
return mPluginNamespace.c_str();
}
nvinfer1::DataType InstanceNormalizationPlugin::getOutputDataType(
int index, const nvinfer1::DataType* inputTypes, int nbInputs) const
{
ASSERT(inputTypes && nbInputs > 0 && index == 0);
return inputTypes[0];
}
// Attach the plugin object to an execution context and grant the plugin the access to some context resource.
void InstanceNormalizationPlugin::attachToContext(
cudnnContext* cudnnContext, cublasContext* cublasContext, IGpuAllocator* gpuAllocator)
{
_cudnn_handle = cudnnContext;
cudnnCreateTensorDescriptor(&_b_desc);
cudnnCreateTensorDescriptor(&_x_desc);
cudnnCreateTensorDescriptor(&_y_desc);
}
// Detach the plugin object from its execution context.
void InstanceNormalizationPlugin::detachFromContext()
{
cudnnDestroyTensorDescriptor(_y_desc);
cudnnDestroyTensorDescriptor(_x_desc);
cudnnDestroyTensorDescriptor(_b_desc);
}
void InstanceNormalizationPlugin::configurePlugin(const nvinfer1::DynamicPluginTensorDesc* in, int nbInputs,
const nvinfer1::DynamicPluginTensorDesc* out, int nbOutputs)
{
auto input_dims = in[0].desc.dims;
for (int i = 0; i < nbInputs; i++)
{
for (int j = 0; j < input_dims.nbDims; j++)
{
// Do not support dynamic dimensions
ASSERT(input_dims.d[j] != -1);
}
}
int n = input_dims.d[0];
int c = input_dims.d[1];
size_t nchan_bytes = c * sizeof(float);
if (_d_bytes < n * nchan_bytes)
{
cudaFree(_d_bias);
cudaFree(_d_scale);
_d_bytes = n * nchan_bytes;
cudaMalloc((void**) &_d_scale, _d_bytes);
cudaMalloc((void**) &_d_bias, _d_bytes);
}
}
// InstanceNormalizationPluginCreator methods
InstanceNormalizationPluginCreator::InstanceNormalizationPluginCreator()
{
mPluginAttributes.emplace_back(PluginField("epsilon", nullptr, PluginFieldType::kFLOAT32, 1));
mPluginAttributes.emplace_back(PluginField("scales", nullptr, PluginFieldType::kFLOAT32, 1));
mPluginAttributes.emplace_back(PluginField("bias", nullptr, PluginFieldType::kFLOAT32, 1));
mFC.nbFields = mPluginAttributes.size();
mFC.fields = mPluginAttributes.data();
}
const char* InstanceNormalizationPluginCreator::getPluginName() const
{
return INSTANCE_PLUGIN_NAME;
}
const char* InstanceNormalizationPluginCreator::getPluginVersion() const
{
return INSTANCE_PLUGIN_VERSION;
}
const PluginFieldCollection* InstanceNormalizationPluginCreator::getFieldNames()
{
return &mFC;
}
IPluginV2DynamicExt* InstanceNormalizationPluginCreator::createPlugin(
const char* name, const nvinfer1::PluginFieldCollection* fc)
{
std::vector<float> scaleValues;
std::vector<float> biasValues;
float epsilon{};
const PluginField* fields = fc->fields;
for (int i = 0; i < fc->nbFields; ++i)
{
const char* attrName = fields[i].name;
if (!strcmp(attrName, "epsilon"))
{
ASSERT(fields[i].type == PluginFieldType::kFLOAT32);
epsilon = *(static_cast<const float*>(fields[i].data));
}
else if (!strcmp(attrName, "scales"))
{
ASSERT(fields[i].type == PluginFieldType::kFLOAT32);
int size = fields[i].length;
scaleValues.reserve(size);
const auto* w = static_cast<const float*>(fields[i].data);
for (int j = 0; j < size; j++)
{
scaleValues.push_back(*w);
w++;
}
}
else if (!strcmp(attrName, "bias"))
{
ASSERT(fields[i].type == PluginFieldType::kFLOAT32);
int size = fields[i].length;
biasValues.reserve(size);
const auto* w = static_cast<const float*>(fields[i].data);
for (int j = 0; j < size; j++)
{
biasValues.push_back(*w);
w++;
}
}
}
Weights scaleWeights{DataType::kFLOAT, scaleValues.data(), (int64_t) scaleValues.size()};
Weights biasWeights{DataType::kFLOAT, biasValues.data(), (int64_t) biasValues.size()};
InstanceNormalizationPlugin* obj = new InstanceNormalizationPlugin(epsilon, scaleWeights, biasWeights);
obj->setPluginNamespace(mNamespace.c_str());
return obj;
}
IPluginV2DynamicExt* InstanceNormalizationPluginCreator::deserializePlugin(
const char* name, const void* serialData, size_t serialLength)
{
InstanceNormalizationPlugin* obj = new InstanceNormalizationPlugin{serialData, serialLength};
obj->setPluginNamespace(mNamespace.c_str());
return obj;
}