@@ -347,13 +347,13 @@ bool convertOnnxWeights(
347
347
{
348
348
continue ;
349
349
}
350
- else
350
+ else
351
351
{
352
352
LOG_ERROR (" Key value of: " << keyName << " was not expected!" );
353
353
return false ;
354
354
}
355
355
}
356
-
356
+
357
357
// Buffer to hold the data read from the file
358
358
std::vector<char > dataBuf;
359
359
// Will update dataBuf and nbytes by reference.
@@ -1315,7 +1315,7 @@ nvinfer1::Dims insertDimension(const nvinfer1::Dims& dims, const int axis, const
1315
1315
bool parseExternalWeights (IImporterContext* ctx, std::string file, std::string path, int offset, int length,
1316
1316
std::vector<char >& weightsBuf, size_t & size)
1317
1317
{
1318
- // The weight paths in the ONNX model are relative paths to the main ONNX file.
1318
+ // The weight paths in the ONNX model are relative paths to the main ONNX file.
1319
1319
#ifdef _MSC_VER
1320
1320
size_t slash = path.rfind (" \\ " );
1321
1321
#else
@@ -1486,71 +1486,47 @@ nvinfer1::ITensor* reshapeTensor(IImporterContext* ctx, nvinfer1::ITensor& tenso
1486
1486
return layer->getOutput (0 );
1487
1487
}
1488
1488
1489
- NodeImportResult scaleHelper (IImporterContext* ctx, const ::ONNX_NAMESPACE::NodeProto& node, nvinfer1::ITensor& tensor_, nvinfer1::ScaleMode mode,
1490
- nvinfer1::Weights shift, nvinfer1::Weights scale, nvinfer1::Weights power, std::string shiftName, std::string scaleName)
1489
+ NodeImportResult scaleHelper (IImporterContext* ctx, const ::ONNX_NAMESPACE::NodeProto& node, nvinfer1::ITensor& tensor_,
1490
+ nvinfer1::ScaleMode mode, const nvinfer1::Weights& shift, const nvinfer1::Weights& scale,
1491
+ const nvinfer1::Weights& power, const char * shiftName, const char * scaleName)
1491
1492
{
1492
- nvinfer1::ITensor* tensor_ptr = &tensor_;
1493
- nvinfer1::Dims dims = tensor_ptr-> getDimensions ( );
1493
+ nvinfer1::ITensor* tensorPtr = &tensor_;
1494
+ const ShapeTensor origShape = shapeOf (*tensorPtr );
1494
1495
1495
1496
// TensorRT scale layers support 4D(NCHW) or 5D(NCDHW) input.
1496
- // For input other than 4D or 5D will be expanded to 4D.
1497
- int expectedNbDims = 4 ;
1498
- bool needToExpandDims = (dims.nbDims != 4 && dims.nbDims != 5 );
1499
- nvinfer1::Dims orig_shape = dims;
1500
- if (needToExpandDims)
1497
+ // For input other than 4D or 5D will be expanded or squeezed to 4D.
1498
+ bool needToReshape = (origShape.size () != 4 && origShape.size () != 5 );
1499
+ if (needToReshape)
1501
1500
{
1502
- // Expand or squash dims to 4D
1503
- nvinfer1::Dims new_shape = dims;
1504
- while (new_shape.nbDims < expectedNbDims)
1501
+ if (origShape.size () < 4 )
1505
1502
{
1506
- new_shape.d [new_shape.nbDims ++] = 1 ;
1503
+ std::vector<int > expandAxes (4 - origShape.size ());
1504
+ std::iota (expandAxes.begin (), expandAxes.end (), origShape.size ());
1505
+ tensorPtr = unsqueezeTensor (ctx, node, *tensorPtr, expandAxes);
1507
1506
}
1508
- while (new_shape. nbDims > expectedNbDims)
1507
+ else
1509
1508
{
1510
- new_shape.d [3 ] *= new_shape.d [--new_shape.nbDims ];
1509
+ // Collapse trailing dimensions if origShape.size() > 5
1510
+ const ShapeTensor collapsedDim = product (ctx, origShape, 3 , origShape.size (), 1 );
1511
+ const ShapeTensor collapsedShape = concat (ctx, gather (ctx, origShape, iotaShapeVector (3 )), collapsedDim);
1512
+ tensorPtr = &reshape (ctx, *tensorPtr, collapsedShape);
1511
1513
}
1512
- tensor_ptr = reshapeTensor (ctx, *tensor_ptr, new_shape);
1513
- ASSERT (tensor_ptr, ErrorCode::kUNSUPPORTED_NODE );
1514
- dims = tensor_ptr->getDimensions ();
1515
1514
}
1516
1515
1517
- ASSERT (dims.nbDims == 4 || dims.nbDims == 5 , ErrorCode::kUNSUPPORTED_NODE );
1518
-
1519
- // Fill in dtype for any unused (dummy) weights
1520
- nvinfer1::DataType* dtype_ptr = nullptr ;
1521
- if (shift.count )
1522
- {
1523
- dtype_ptr = &shift.type ;
1524
- }
1525
- if (scale.count )
1526
- {
1527
- ASSERT (!dtype_ptr || *dtype_ptr == scale.type , ErrorCode::kUNSUPPORTED_NODE );
1528
- dtype_ptr = &scale.type ;
1529
- }
1530
- if (power.count )
1531
- {
1532
- ASSERT (!dtype_ptr || *dtype_ptr == power.type , ErrorCode::kUNSUPPORTED_NODE );
1533
- dtype_ptr = &power.type ;
1534
- }
1535
- ASSERT (dtype_ptr, ErrorCode::kINTERNAL_ERROR );
1536
- shift.type = *dtype_ptr;
1537
- scale.type = *dtype_ptr;
1538
- power.type = *dtype_ptr;
1539
- auto * layer = ctx->network ()->addScaleNd (*tensor_ptr, mode, shift, scale, power, 1 );
1540
- ASSERT (layer, ErrorCode::kUNSUPPORTED_NODE );
1516
+ auto * layer = ctx->network ()->addScaleNd (*tensorPtr, mode, shift, scale, power, 1 );
1517
+ ASSERT (layer && " Failed to add a Scale layer." , ErrorCode::kUNSUPPORTED_NODE );
1541
1518
// Register layer name, and shift and scale weight names for the refit map.
1542
1519
ctx->registerLayer (layer, getNodeName (node));
1543
- ctx->insertRefitMap (shiftName, getNodeName (node), nvinfer1::WeightsRole::kSHIFT );
1544
- ctx->insertRefitMap (scaleName, getNodeName (node), nvinfer1::WeightsRole::kSCALE );
1545
- tensor_ptr = layer->getOutput (0 );
1546
1520
1547
- if (needToExpandDims)
1521
+ tensorPtr = layer->getOutput (0 );
1522
+
1523
+ if (needToReshape)
1548
1524
{
1549
- tensor_ptr = reshapeTensor (ctx, *tensor_ptr, orig_shape );
1550
- ASSERT (tensor_ptr , ErrorCode::kUNSUPPORTED_NODE );
1525
+ tensorPtr = & reshape (ctx, *tensorPtr, origShape );
1526
+ ASSERT (tensorPtr && " Failed to reshape tensor. " , ErrorCode::kUNSUPPORTED_NODE );
1551
1527
}
1552
1528
1553
- return {{tensor_ptr }};
1529
+ return {{tensorPtr }};
1554
1530
}
1555
1531
1556
1532
void setAttr (
0 commit comments