Skip to content

Commit 0914fec

Browse files
authored
Add configurable input size for TLT MaskRCNN Plugin (#986)
Signed-off-by: Tyler Zhu <[email protected]> Co-authored-by: Tyler Zhu <[email protected]>
1 parent 1565fe7 commit 0914fec

File tree

6 files changed

+79
-48
lines changed

6 files changed

+79
-48
lines changed

plugin/generateDetectionPlugin/generateDetectionPlugin.cpp

+26-14
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@
1616
#include "generateDetectionPlugin.h"
1717
#include "plugin.h"
1818
#include <cuda_runtime_api.h>
19+
#include <algorithm>
1920

2021
using namespace nvinfer1;
2122
using namespace plugin;
@@ -40,6 +41,7 @@ GenerateDetectionPluginCreator::GenerateDetectionPluginCreator()
4041
mPluginAttributes.emplace_back(PluginField("keep_topk", nullptr, PluginFieldType::kINT32, 1));
4142
mPluginAttributes.emplace_back(PluginField("score_threshold", nullptr, PluginFieldType::kFLOAT32, 1));
4243
mPluginAttributes.emplace_back(PluginField("iou_threshold", nullptr, PluginFieldType::kFLOAT32, 1));
44+
mPluginAttributes.emplace_back(PluginField("image_size", nullptr, PluginFieldType::kINT32, 3));
4345

4446
mFC.nbFields = mPluginAttributes.size();
4547
mFC.fields = mPluginAttributes.data();
@@ -62,6 +64,7 @@ const PluginFieldCollection* GenerateDetectionPluginCreator::getFieldNames()
6264

6365
IPluginV2Ext* GenerateDetectionPluginCreator::createPlugin(const char* name, const PluginFieldCollection* fc)
6466
{
67+
auto image_size = TLTMaskRCNNConfig::IMAGE_SHAPE;
6568
const PluginField* fields = fc->fields;
6669
for (int i = 0; i < fc->nbFields; ++i)
6770
{
@@ -86,20 +89,27 @@ IPluginV2Ext* GenerateDetectionPluginCreator::createPlugin(const char* name, con
8689
assert(fields[i].type == PluginFieldType::kFLOAT32);
8790
mIOUThreshold = *(static_cast<const float*>(fields[i].data));
8891
}
92+
if (!strcmp(attrName, "image_size"))
93+
{
94+
assert(fields[i].type == PluginFieldType::kINT32);
95+
const auto dims = static_cast<const int32_t*>(fields[i].data);
96+
std::copy_n(dims, 3, image_size.d);
97+
}
8998
}
90-
return new GenerateDetection(mNbClasses, mKeepTopK, mScoreThreshold, mIOUThreshold);
99+
return new GenerateDetection(mNbClasses, mKeepTopK, mScoreThreshold, mIOUThreshold, image_size);
91100
};
92101

93102
IPluginV2Ext* GenerateDetectionPluginCreator::deserializePlugin(const char* name, const void* data, size_t length)
94103
{
95104
return new GenerateDetection(data, length);
96105
};
97106

98-
GenerateDetection::GenerateDetection(int num_classes, int keep_topk, float score_threshold, float iou_threshold)
107+
GenerateDetection::GenerateDetection(int num_classes, int keep_topk, float score_threshold, float iou_threshold, const nvinfer1::Dims& image_size)
99108
: mNbClasses(num_classes)
100109
, mKeepTopK(keep_topk)
101110
, mScoreThreshold(score_threshold)
102111
, mIOUThreshold(iou_threshold)
112+
, mImageSize(image_size)
103113
{
104114
mBackgroundLabel = 0;
105115
assert(mNbClasses > 0);
@@ -178,7 +188,7 @@ const char* GenerateDetection::getPluginNamespace() const
178188

179189
size_t GenerateDetection::getSerializationSize() const
180190
{
181-
return sizeof(int) * 2 + sizeof(float) * 2 + sizeof(int) * 2;
191+
return sizeof(int) * 2 + sizeof(float) * 2 + sizeof(int) * 2 + sizeof(nvinfer1::Dims);
182192
};
183193

184194
void GenerateDetection::serialize(void* buffer) const
@@ -190,6 +200,7 @@ void GenerateDetection::serialize(void* buffer) const
190200
write(d, mIOUThreshold);
191201
write(d, mMaxBatchSize);
192202
write(d, mAnchorsCnt);
203+
write(d, mImageSize);
193204
ASSERT(d == a + getSerializationSize());
194205
};
195206

@@ -202,6 +213,7 @@ GenerateDetection::GenerateDetection(const void* data, size_t length)
202213
float iou_threshold = read<float>(d);
203214
mMaxBatchSize = read<int>(d);
204215
mAnchorsCnt = read<int>(d);
216+
mImageSize = read<nvinfer1::Dims3>(d);
205217
ASSERT(d == a + length);
206218

207219
mNbClasses = num_classes;
@@ -264,17 +276,17 @@ int GenerateDetection::enqueue(
264276

265277
// refine detection
266278
RefineDetectionWorkSpace refDetcWorkspace(batch_size, mAnchorsCnt, mParam, mType);
267-
cudaError_t status = DetectionPostProcess(stream, batch_size, mAnchorsCnt,
268-
static_cast<float*>(mRegWeightDevice->mPtr),
269-
static_cast<float>(TLTMaskRCNNConfig::IMAGE_SHAPE.d[1]), // Image Height
270-
static_cast<float>(TLTMaskRCNNConfig::IMAGE_SHAPE.d[2]), // Image Width
271-
DataType::kFLOAT, // mType,
272-
mParam, refDetcWorkspace, workspace,
273-
inputs[1], // inputs[InScore]
274-
inputs[0], // inputs[InDelta],
275-
mValidCnt->mPtr, // inputs[InCountValid],
276-
inputs[2], // inputs[ROI]
277-
detections);
279+
cudaError_t status
280+
= DetectionPostProcess(stream, batch_size, mAnchorsCnt, static_cast<float*>(mRegWeightDevice->mPtr),
281+
static_cast<float>(mImageSize.d[1]), // Image Height
282+
static_cast<float>(mImageSize.d[2]), // Image Width
283+
DataType::kFLOAT, // mType,
284+
mParam, refDetcWorkspace, workspace,
285+
inputs[1], // inputs[InScore]
286+
inputs[0], // inputs[InDelta],
287+
mValidCnt->mPtr, // inputs[InCountValid],
288+
inputs[2], // inputs[ROI]
289+
detections);
278290

279291
assert(status == cudaSuccess);
280292
return status;

plugin/generateDetectionPlugin/generateDetectionPlugin.h

+3-1
Original file line numberDiff line numberDiff line change
@@ -35,7 +35,7 @@ namespace plugin
3535
class GenerateDetection : public IPluginV2Ext
3636
{
3737
public:
38-
GenerateDetection(int num_classes, int keep_topk, float score_threshold, float iou_threshold);
38+
GenerateDetection(int num_classes, int keep_topk, float score_threshold, float iou_threshold, const nvinfer1::Dims& image_size);
3939

4040
GenerateDetection(const void* data, size_t length);
4141

@@ -103,6 +103,8 @@ class GenerateDetection : public IPluginV2Ext
103103
RefineNMSParameters mParam;
104104
std::shared_ptr<CudaBind<float>> mRegWeightDevice;
105105

106+
nvinfer1::Dims mImageSize;
107+
106108
std::string mNameSpace;
107109
};
108110

plugin/multilevelCropAndResizePlugin/multilevelCropAndResizePlugin.cpp

+15-6
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@
1616
#include "multilevelCropAndResizePlugin.h"
1717
#include "plugin.h"
1818
#include <cuda_runtime_api.h>
19+
#include <algorithm>
1920

2021
#include <fstream>
2122

@@ -36,6 +37,7 @@ std::vector<PluginField> MultilevelCropAndResizePluginCreator::mPluginAttributes
3637
MultilevelCropAndResizePluginCreator::MultilevelCropAndResizePluginCreator()
3738
{
3839
mPluginAttributes.emplace_back(PluginField("pooled_size", nullptr, PluginFieldType::kINT32, 1));
40+
mPluginAttributes.emplace_back(PluginField("image_size", nullptr, PluginFieldType::kINT32, 3));
3941

4042
mFC.nbFields = mPluginAttributes.size();
4143
mFC.fields = mPluginAttributes.data();
@@ -58,6 +60,7 @@ const PluginFieldCollection* MultilevelCropAndResizePluginCreator::getFieldNames
5860

5961
IPluginV2Ext* MultilevelCropAndResizePluginCreator::createPlugin(const char* name, const PluginFieldCollection* fc)
6062
{
63+
auto image_size = TLTMaskRCNNConfig::IMAGE_SHAPE;
6164
const PluginField* fields = fc->fields;
6265
for (int i = 0; i < fc->nbFields; ++i)
6366
{
@@ -67,25 +70,31 @@ IPluginV2Ext* MultilevelCropAndResizePluginCreator::createPlugin(const char* nam
6770
assert(fields[i].type == PluginFieldType::kINT32);
6871
mPooledSize = *(static_cast<const int*>(fields[i].data));
6972
}
73+
if (!strcmp(attrName, "image_size"))
74+
{
75+
assert(fields[i].type == PluginFieldType::kINT32);
76+
const auto dims = static_cast<const int32_t*>(fields[i].data);
77+
std::copy_n(dims, 3, image_size.d);
78+
}
7079
}
71-
return new MultilevelCropAndResize(mPooledSize);
80+
return new MultilevelCropAndResize(mPooledSize, image_size);
7281
};
7382

7483
IPluginV2Ext* MultilevelCropAndResizePluginCreator::deserializePlugin(const char* name, const void* data, size_t length)
7584
{
7685
return new MultilevelCropAndResize(data, length);
7786
};
7887

79-
MultilevelCropAndResize::MultilevelCropAndResize(int pooled_size)
88+
MultilevelCropAndResize::MultilevelCropAndResize(int pooled_size, const nvinfer1::Dims& image_size)
8089
: mPooledSize({pooled_size, pooled_size})
8190
{
8291

8392
assert(pooled_size > 0);
8493
// shape
85-
mInputHeight = TLTMaskRCNNConfig::IMAGE_SHAPE.d[1];
86-
mInputWidth = TLTMaskRCNNConfig::IMAGE_SHAPE.d[2];
87-
//Threshold to P3: Smaller -> P2
88-
mThresh = (224*224) / (4.0f);
94+
mInputHeight = image_size.d[1];
95+
mInputWidth = image_size.d[2];
96+
// Threshold to P3: Smaller -> P2
97+
mThresh = (224 * 224) / (4.0f);
8998
};
9099

91100
int MultilevelCropAndResize::getNbOutputs() const

plugin/multilevelCropAndResizePlugin/multilevelCropAndResizePlugin.h

+1-1
Original file line numberDiff line numberDiff line change
@@ -35,7 +35,7 @@ namespace plugin
3535
class MultilevelCropAndResize : public IPluginV2Ext
3636
{
3737
public:
38-
MultilevelCropAndResize(int pooled_size);
38+
MultilevelCropAndResize(int pooled_size, const nvinfer1::Dims& image_size);
3939

4040
MultilevelCropAndResize(const void* data, size_t length);
4141

plugin/multilevelProposeROI/multilevelProposeROIPlugin.cpp

+31-24
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@
1818
#include "plugin.h"
1919
#include <cuda_runtime_api.h>
2020
#include <iostream>
21+
#include <algorithm>
2122
#include <math.h>
2223

2324
#include <fstream>
@@ -43,6 +44,7 @@ MultilevelProposeROIPluginCreator::MultilevelProposeROIPluginCreator()
4344
mPluginAttributes.emplace_back(PluginField("keep_topk", nullptr, PluginFieldType::kINT32, 1));
4445
mPluginAttributes.emplace_back(PluginField("fg_threshold", nullptr, PluginFieldType::kFLOAT32, 1));
4546
mPluginAttributes.emplace_back(PluginField("iou_threshold", nullptr, PluginFieldType::kFLOAT32, 1));
47+
mPluginAttributes.emplace_back(PluginField("image_size", nullptr, PluginFieldType::kINT32, 3));
4648

4749
mFC.nbFields = mPluginAttributes.size();
4850
mFC.fields = mPluginAttributes.data();
@@ -65,6 +67,7 @@ const PluginFieldCollection* MultilevelProposeROIPluginCreator::getFieldNames()
6567

6668
IPluginV2Ext* MultilevelProposeROIPluginCreator::createPlugin(const char* name, const PluginFieldCollection* fc)
6769
{
70+
auto image_size = TLTMaskRCNNConfig::IMAGE_SHAPE;
6871
const PluginField* fields = fc->fields;
6972
for (int i = 0; i < fc->nbFields; ++i)
7073
{
@@ -89,20 +92,27 @@ IPluginV2Ext* MultilevelProposeROIPluginCreator::createPlugin(const char* name,
8992
assert(fields[i].type == PluginFieldType::kFLOAT32);
9093
mIOUThreshold = *(static_cast<const float*>(fields[i].data));
9194
}
95+
if (!strcmp(attrName, "image_size"))
96+
{
97+
assert(fields[i].type == PluginFieldType::kINT32);
98+
const auto dims = static_cast<const int32_t*>(fields[i].data);
99+
std::copy_n(dims, 3, image_size.d);
100+
}
92101
}
93-
return new MultilevelProposeROI(mPreNMSTopK, mKeepTopK, mFGThreshold, mIOUThreshold);
102+
return new MultilevelProposeROI(mPreNMSTopK, mKeepTopK, mFGThreshold, mIOUThreshold, image_size);
94103
};
95104

96105
IPluginV2Ext* MultilevelProposeROIPluginCreator::deserializePlugin(const char* name, const void* data, size_t length)
97106
{
98107
return new MultilevelProposeROI(data, length);
99108
};
100109

101-
MultilevelProposeROI::MultilevelProposeROI(int prenms_topk, int keep_topk, float fg_threshold, float iou_threshold)
110+
MultilevelProposeROI::MultilevelProposeROI(int prenms_topk, int keep_topk, float fg_threshold, float iou_threshold, const nvinfer1::Dims image_size)
102111
: mPreNMSTopK(prenms_topk)
103112
, mKeepTopK(keep_topk)
104113
, mFGThreshold(fg_threshold)
105114
, mIOUThreshold(iou_threshold)
115+
, mImageSize(image_size)
106116
{
107117
mBackgroundLabel = -1;
108118
assert(mPreNMSTopK > 0);
@@ -121,7 +131,7 @@ MultilevelProposeROI::MultilevelProposeROI(int prenms_topk, int keep_topk, float
121131

122132
mFeatureCnt = TLTMaskRCNNConfig::MAX_LEVEL - TLTMaskRCNNConfig::MIN_LEVEL + 1;
123133

124-
generate_pyramid_anchors();
134+
generate_pyramid_anchors(mImageSize);
125135
};
126136

127137
int MultilevelProposeROI::getNbOutputs() const
@@ -224,7 +234,7 @@ const char* MultilevelProposeROI::getPluginNamespace() const
224234

225235
size_t MultilevelProposeROI::getSerializationSize() const
226236
{
227-
return sizeof(int) * 2 + sizeof(float) * 2 + sizeof(int) * (mFeatureCnt + 1);
237+
return sizeof(int) * 2 + sizeof(float) * 2 + sizeof(int) * (mFeatureCnt + 1) + sizeof(nvinfer1::Dims);
228238
};
229239

230240
void MultilevelProposeROI::serialize(void* buffer) const
@@ -239,6 +249,7 @@ void MultilevelProposeROI::serialize(void* buffer) const
239249
{
240250
write(d, mAnchorsCnt[i]);
241251
}
252+
write(d, mImageSize);
242253
ASSERT(d == a + getSerializationSize());
243254
};
244255

@@ -257,6 +268,7 @@ MultilevelProposeROI::MultilevelProposeROI(const void* data, size_t length)
257268
{
258269
mAnchorsCnt.push_back(read<int>(d));
259270
}
271+
mImageSize = read<nvinfer1::Dims3>(d);
260272
ASSERT(d == a + length);
261273

262274
mBackgroundLabel = -1;
@@ -273,7 +285,7 @@ MultilevelProposeROI::MultilevelProposeROI(const void* data, size_t length)
273285

274286
mType = DataType::kFLOAT;
275287

276-
generate_pyramid_anchors();
288+
generate_pyramid_anchors(mImageSize);
277289
};
278290

279291
void MultilevelProposeROI::check_valid_inputs(const nvinfer1::Dims* inputs, int nbInputDims)
@@ -329,9 +341,9 @@ Dims MultilevelProposeROI::getOutputDimensions(int index, const Dims* inputs, in
329341
return proposals;
330342
}
331343

332-
void MultilevelProposeROI::generate_pyramid_anchors()
344+
void MultilevelProposeROI::generate_pyramid_anchors(const nvinfer1::Dims& image_size)
333345
{
334-
const auto image_dims = TLTMaskRCNNConfig::IMAGE_SHAPE;
346+
const auto image_dims = image_size;
335347

336348
const auto& anchor_scale = TLTMaskRCNNConfig::RPN_ANCHOR_SCALE;
337349
const auto& min_level = TLTMaskRCNNConfig::MIN_LEVEL;
@@ -388,23 +400,18 @@ int MultilevelProposeROI::enqueue(
388400
{
389401

390402
MultilevelProposeROIWorkSpace proposal_ws(batch_size, mAnchorsCnt[i], mPreNMSTopK, mParam, mType);
391-
status = MultilevelPropose(stream,
392-
batch_size,
393-
mAnchorsCnt[i],
394-
mPreNMSTopK,
395-
static_cast<float*>(mRegWeightDevice->mPtr),
396-
static_cast<float>(TLTMaskRCNNConfig::IMAGE_SHAPE.d[1]), //Input Height
397-
static_cast<float>(TLTMaskRCNNConfig::IMAGE_SHAPE.d[2]),
398-
DataType::kFLOAT, // mType,
399-
mParam,
400-
proposal_ws,
401-
workspace + kernel_workspace_offset,
402-
inputs[2*i + 1], // inputs[object_score],
403-
inputs[2*i], // inputs[bbox_delta]
404-
mValidCnt->mPtr,
405-
mAnchorBoxesDevice[i]->mPtr, // inputs[anchors]
406-
mTempScores[i]->mPtr, //temp scores [batch_size, topk, 1]
407-
mTempBboxes[i]->mPtr); //temp
403+
status = MultilevelPropose(stream, batch_size, mAnchorsCnt[i], mPreNMSTopK,
404+
static_cast<float*>(mRegWeightDevice->mPtr),
405+
static_cast<float>(mImageSize.d[1]), // Input Height
406+
static_cast<float>(mImageSize.d[2]),
407+
DataType::kFLOAT, // mType,
408+
mParam, proposal_ws, workspace + kernel_workspace_offset,
409+
inputs[2 * i + 1], // inputs[object_score],
410+
inputs[2 * i], // inputs[bbox_delta]
411+
mValidCnt->mPtr,
412+
mAnchorBoxesDevice[i]->mPtr, // inputs[anchors]
413+
mTempScores[i]->mPtr, // temp scores [batch_size, topk, 1]
414+
mTempBboxes[i]->mPtr); // temp
408415
assert(status == cudaSuccess);
409416
kernel_workspace_offset += proposal_ws.totalSize;
410417
}

plugin/multilevelProposeROI/multilevelProposeROIPlugin.h

+3-2
Original file line numberDiff line numberDiff line change
@@ -34,7 +34,7 @@ namespace plugin
3434
class MultilevelProposeROI : public IPluginV2Ext
3535
{
3636
public:
37-
MultilevelProposeROI(int prenms_topk, int keep_topk, float fg_threshold, float iou_threshold);
37+
MultilevelProposeROI(int prenms_topk, int keep_topk, float fg_threshold, float iou_threshold, const nvinfer1::Dims image_size);
3838

3939
MultilevelProposeROI(const void* data, size_t length);
4040

@@ -88,7 +88,7 @@ class MultilevelProposeROI : public IPluginV2Ext
8888

8989
private:
9090
void check_valid_inputs(const nvinfer1::Dims* inputs, int nbInputDims);
91-
void generate_pyramid_anchors();
91+
void generate_pyramid_anchors(const nvinfer1::Dims& image_size);
9292

9393
int mBackgroundLabel;
9494
int mPreNMSTopK;
@@ -111,6 +111,7 @@ class MultilevelProposeROI : public IPluginV2Ext
111111
float** mDeviceBboxes;
112112
std::shared_ptr<CudaBind<float>> mRegWeightDevice;
113113

114+
nvinfer1::Dims mImageSize;
114115
nvinfer1::DataType mType;
115116
RefineNMSParameters mParam;
116117

0 commit comments

Comments
 (0)