22
22
#include < vector>
23
23
24
24
using namespace nvinfer1 ;
25
- using nvinfer1::plugin::GridAnchorGenerator;
26
- using nvinfer1::plugin::GridAnchorPluginCreator;
27
25
28
26
namespace
29
27
{
30
- const char * GRID_ANCHOR_PLUGIN_VERSION{ " 1 " };
31
- const char * GRID_ANCHOR_PLUGIN_NAME{ " GridAnchor_TRT " } ;
28
+ std::string GRID_ANCHOR_PLUGIN_NAMES[] = { " GridAnchor_TRT " , " GridAnchorRect_TRT " };
29
+ const char * GRID_ANCHOR_PLUGIN_VERSION = " 1 " ;
32
30
} // namespace
33
- PluginFieldCollection GridAnchorPluginCreator::mFC {};
34
- std::vector<PluginField> GridAnchorPluginCreator::mPluginAttributes ;
35
31
36
- GridAnchorGenerator::GridAnchorGenerator (const GridAnchorParameters* paramIn, int mNumLayers )
37
- : mNumLayers(mNumLayers )
32
+ PluginFieldCollection GridAnchorBasePluginCreator::mFC {};
33
+ std::vector<PluginField> GridAnchorBasePluginCreator::mPluginAttributes ;
34
+
35
+ GridAnchorGenerator::GridAnchorGenerator (const GridAnchorParameters* paramIn, int numLayers, const char *name)
36
+ : mNumLayers(numLayers), mPluginName(name)
38
37
{
39
38
CUASSERT (cudaMallocHost ((void **) &mNumPriors , mNumLayers * sizeof (int )));
40
39
CUASSERT (cudaMallocHost ((void **) &mDeviceWidths , mNumLayers * sizeof (Weights)));
@@ -121,7 +120,8 @@ GridAnchorGenerator::GridAnchorGenerator(const GridAnchorParameters* paramIn, in
121
120
}
122
121
}
123
122
124
- GridAnchorGenerator::GridAnchorGenerator (const void * data, size_t length)
123
+ GridAnchorGenerator::GridAnchorGenerator (const void * data, size_t length, const char *name) :
124
+ mPluginName(name)
125
125
{
126
126
const char *d = reinterpret_cast <const char *>(data), *a = d;
127
127
mNumLayers = read <int >(d);
@@ -276,14 +276,15 @@ bool GridAnchorGenerator::supportsFormat(DataType type, PluginFormat format) con
276
276
277
277
const char * GridAnchorGenerator::getPluginType () const
278
278
{
279
- return GRID_ANCHOR_PLUGIN_NAME ;
279
+ return mPluginName . c_str () ;
280
280
}
281
281
282
282
const char * GridAnchorGenerator::getPluginVersion () const
283
283
{
284
284
return GRID_ANCHOR_PLUGIN_VERSION;
285
285
}
286
286
287
+
287
288
// Set plugin namespace
288
289
void GridAnchorGenerator::setPluginNamespace (const char * pluginNamespace)
289
290
{
@@ -341,12 +342,12 @@ void GridAnchorGenerator::destroy()
341
342
342
343
IPluginV2Ext* GridAnchorGenerator::clone () const
343
344
{
344
- IPluginV2Ext* plugin = new GridAnchorGenerator (mParam .data (), mNumLayers );
345
+ IPluginV2Ext* plugin = new GridAnchorGenerator (mParam .data (), mNumLayers , mPluginName . c_str () );
345
346
plugin->setPluginNamespace (mPluginNamespace .c_str ());
346
347
return plugin;
347
348
}
348
349
349
- GridAnchorPluginCreator::GridAnchorPluginCreator ()
350
+ GridAnchorBasePluginCreator::GridAnchorBasePluginCreator ()
350
351
{
351
352
mPluginAttributes .emplace_back (PluginField (" minSize" , nullptr , PluginFieldType::kFLOAT32 , 1 ));
352
353
mPluginAttributes .emplace_back (PluginField (" maxSize" , nullptr , PluginFieldType::kFLOAT32 , 1 ));
@@ -359,29 +360,31 @@ GridAnchorPluginCreator::GridAnchorPluginCreator()
359
360
mFC .fields = mPluginAttributes .data ();
360
361
}
361
362
362
- const char * GridAnchorPluginCreator ::getPluginName () const
363
+ const char * GridAnchorBasePluginCreator ::getPluginName () const
363
364
{
364
- return GRID_ANCHOR_PLUGIN_NAME ;
365
+ return mPluginName . c_str () ;
365
366
}
366
367
367
- const char * GridAnchorPluginCreator ::getPluginVersion () const
368
+ const char * GridAnchorBasePluginCreator ::getPluginVersion () const
368
369
{
369
370
return GRID_ANCHOR_PLUGIN_VERSION;
370
371
}
371
372
372
- const PluginFieldCollection* GridAnchorPluginCreator ::getFieldNames ()
373
+ const PluginFieldCollection* GridAnchorBasePluginCreator ::getFieldNames ()
373
374
{
374
375
return &mFC ;
375
376
}
376
377
377
- IPluginV2Ext* GridAnchorPluginCreator ::createPlugin (const char * name, const PluginFieldCollection* fc)
378
+ IPluginV2Ext* GridAnchorBasePluginCreator ::createPlugin (const char * name, const PluginFieldCollection* fc)
378
379
{
379
380
float minScale = 0 .2F , maxScale = 0 .95F ;
380
381
int numLayers = 6 ;
381
382
std::vector<float > aspectRatios;
382
383
std::vector<int > fMapShapes ;
383
384
std::vector<float > layerVariances;
384
385
const PluginField* fields = fc->fields ;
386
+
387
+ const bool isFMapRect = (GRID_ANCHOR_PLUGIN_NAMES[1 ] == mPluginName );
385
388
for (int i = 0 ; i < fc->nbFields ; ++i)
386
389
{
387
390
const char * attrName = fields[i].name ;
@@ -428,6 +431,7 @@ IPluginV2Ext* GridAnchorPluginCreator::createPlugin(const char* name, const Plug
428
431
{
429
432
ASSERT (fields[i].type == PluginFieldType::kINT32 );
430
433
int size = fields[i].length ;
434
+ ASSERT (!isFMapRect || (size % 2 == 0 ));
431
435
fMapShapes .reserve (size);
432
436
const int * fMap = static_cast <const int *>(fields[i].data );
433
437
for (int j = 0 ; j < size; j++)
@@ -442,7 +446,8 @@ IPluginV2Ext* GridAnchorPluginCreator::createPlugin(const char* name, const Plug
442
446
std::vector<float > firstLayerAspectRatios;
443
447
444
448
ASSERT (numLayers > 0 );
445
- ASSERT ((int ) fMapShapes .size () == numLayers);
449
+ const int numExpectedLayers = static_cast <int >(fMapShapes .size ()) >> (isFMapRect ? 1 : 0 );
450
+ ASSERT (numExpectedLayers == numLayers);
446
451
447
452
int numFirstLayerARs = 3 ;
448
453
// First layer only has the first 3 aspect ratios from aspectRatios
@@ -457,30 +462,42 @@ IPluginV2Ext* GridAnchorPluginCreator::createPlugin(const char* name, const Plug
457
462
// One set of box parameters for one layer
458
463
for (int i = 0 ; i < numLayers; i++)
459
464
{
465
+ int hOffset = (isFMapRect ? i * 2 : i);
466
+ int wOffset = (isFMapRect ? i * 2 + 1 : i);
460
467
// Only the first layer is different
461
468
if (i == 0 )
462
469
{
463
470
boxParams[i] = {minScale, maxScale, firstLayerAspectRatios.data (), (int ) firstLayerAspectRatios.size (),
464
- fMapShapes [i ], fMapShapes [i ],
471
+ fMapShapes [hOffset ], fMapShapes [wOffset ],
465
472
{layerVariances[0 ], layerVariances[1 ], layerVariances[2 ], layerVariances[3 ]}};
466
473
}
467
474
else
468
475
{
469
- boxParams[i] = {minScale, maxScale, aspectRatios.data (), (int ) aspectRatios.size (), fMapShapes [i ],
470
- fMapShapes [i ], {layerVariances[0 ], layerVariances[1 ], layerVariances[2 ], layerVariances[3 ]}};
476
+ boxParams[i] = {minScale, maxScale, aspectRatios.data (), (int ) aspectRatios.size (), fMapShapes [hOffset ],
477
+ fMapShapes [wOffset ], {layerVariances[0 ], layerVariances[1 ], layerVariances[2 ], layerVariances[3 ]}};
471
478
}
472
479
}
473
480
474
- GridAnchorGenerator* obj = new GridAnchorGenerator (boxParams.data (), numLayers);
481
+ GridAnchorGenerator* obj = new GridAnchorGenerator (boxParams.data (), numLayers, mPluginName . c_str () );
475
482
obj->setPluginNamespace (mNamespace .c_str ());
476
483
return obj;
477
484
}
478
485
479
- IPluginV2Ext* GridAnchorPluginCreator ::deserializePlugin (const char * name, const void * serialData, size_t serialLength)
486
+ IPluginV2Ext* GridAnchorBasePluginCreator ::deserializePlugin (const char * name, const void * serialData, size_t serialLength)
480
487
{
481
488
// This object will be deleted when the network is destroyed, which will
482
489
// call GridAnchor::destroy()
483
- GridAnchorGenerator* obj = new GridAnchorGenerator (serialData, serialLength);
490
+ GridAnchorGenerator* obj = new GridAnchorGenerator (serialData, serialLength, mPluginName . c_str () );
484
491
obj->setPluginNamespace (mNamespace .c_str ());
485
492
return obj;
486
493
}
494
+
495
+ GridAnchorPluginCreator::GridAnchorPluginCreator ()
496
+ {
497
+ mPluginName = GRID_ANCHOR_PLUGIN_NAMES[0 ];
498
+ }
499
+
500
+ GridAnchorRectPluginCreator::GridAnchorRectPluginCreator ()
501
+ {
502
+ mPluginName = GRID_ANCHOR_PLUGIN_NAMES[1 ];
503
+ }
0 commit comments