Skip to content

Commit 5374283

Browse files
pranavm-nvidiakevinch-nv
authored andcommitted
Fixes batchnorm importer in non-4D/5D cases (#569)
Signed-off-by: pranavm <[email protected]> Signed-off-by: Kevin Chen <[email protected]>
1 parent 96064cb commit 5374283

File tree

3 files changed

+42
-90
lines changed

3 files changed

+42
-90
lines changed

builtin_op_importers.cpp

Lines changed: 11 additions & 36 deletions
Original file line numberDiff line numberDiff line change
@@ -186,8 +186,9 @@ NodeImportResult batchnormFallback(
186186
nvinfer1::ITensor* mean = &convertToTensor(inputs.at(3), ctx);
187187
nvinfer1::ITensor* variance = &convertToTensor(inputs.at(4), ctx);
188188

189-
const bool hasCDimension = rank > 1;
190-
if (hasCDimension)
189+
// Reshape batchnorm weights from [C] to [N, C, ...]
190+
const bool needsExpandDims = rank > 1;
191+
if (needsExpandDims)
191192
{
192193
std::vector<int> axes(rank - 1);
193194
axes[0] = 0;
@@ -223,7 +224,7 @@ NodeImportResult batchnormFallback(
223224
->getOutput(0),
224225
*bias, eOp::kSUM);
225226

226-
ctx->registerLayer(layer, node.name());
227+
ctx->registerLayer(layer, getNodeName(node));
227228

228229
RETURN_FIRST_OUTPUT(layer);
229230
}
@@ -254,25 +255,13 @@ DEFINE_BUILTIN_OP_IMPORTER(BatchNormalization)
254255
OnnxAttrs attrs(node, ctx);
255256
float eps = attrs.get<float>("epsilon", 1e-5f);
256257

257-
nvinfer1::Dims dims = tensorPtr->getDimensions();
258-
259-
bool needToExpandDims = (dims.nbDims == 3);
260-
if (needToExpandDims)
261-
{
262-
// Expand spatial dims from 1D to 2D
263-
std::vector<int> axes{3};
264-
tensorPtr = unsqueezeTensor(ctx, node, *tensorPtr, axes);
265-
ASSERT(tensorPtr, ErrorCode::kUNSUPPORTED_NODE);
266-
dims = tensorPtr->getDimensions();
267-
}
268-
269258
// Number of channels is equal to the length of scale_weights.
270259
int nchan = scale_weights.shape.d[0];
271260
nvinfer1::Dims weights_shape{1, {nchan}};
272-
ASSERT(scale_weights.shape == weights_shape, ErrorCode::kINVALID_NODE);
273-
ASSERT(bias_weights.shape == weights_shape, ErrorCode::kINVALID_NODE);
274-
ASSERT(mean_weights.shape == weights_shape, ErrorCode::kINVALID_NODE);
275-
ASSERT(variance_weights.shape == weights_shape, ErrorCode::kINVALID_NODE);
261+
ASSERT((scale_weights.shape == weights_shape) && "The shape of input scale must be (C)", ErrorCode::kINVALID_NODE);
262+
ASSERT((bias_weights.shape == weights_shape) && "The shape of input bias must be (C)", ErrorCode::kINVALID_NODE);
263+
ASSERT((mean_weights.shape == weights_shape) && "The shape of input mean must be (C)", ErrorCode::kINVALID_NODE);
264+
ASSERT((variance_weights.shape == weights_shape) && "The shape of input var must be (C)", ErrorCode::kINVALID_NODE);
276265
auto combined_scale_weights = ctx->createTempWeights(scale_weights.type, scale_weights.shape);
277266
auto combined_bias_weights = ctx->createTempWeights(bias_weights.type, bias_weights.shape);
278267
size_t nweight = nchan;
@@ -289,23 +278,9 @@ DEFINE_BUILTIN_OP_IMPORTER(BatchNormalization)
289278
combined_bias_ref = bias - mean * combined_scale_ref;
290279
}
291280

292-
// If dimensions were not expanded return the output of the scale operation
293-
if (!needToExpandDims)
294-
{
295-
return scaleHelper(
296-
ctx, node, *tensorPtr, nvinfer1::ScaleMode::kCHANNEL, combined_bias_weights, combined_scale_weights, {}, bias_weights.getName(), scale_weights.getName());
297-
}
298-
else
299-
{
300-
auto scaledResult = scaleHelper(
301-
ctx, node, *tensorPtr, nvinfer1::ScaleMode::kCHANNEL, combined_bias_weights, combined_scale_weights, {}, bias_weights.getName(), scale_weights.getName());
302-
// Squeeze spatial dims back to 1D
303-
tensorPtr = &convertToTensor(scaledResult.value().at(0), ctx);
304-
std::vector<int> axes{3};
305-
tensorPtr = squeezeTensor(ctx, node, *tensorPtr, axes);
306-
ASSERT(tensorPtr, ErrorCode::kUNSUPPORTED_NODE);
307-
return {{tensorPtr}};
308-
}
281+
return scaleHelper(ctx, node, *tensorPtr, nvinfer1::ScaleMode::kCHANNEL, combined_bias_weights,
282+
combined_scale_weights, ShapedWeights::empty(scale_weights.type), bias_weights.getName(),
283+
scale_weights.getName());
309284
}
310285

311286
DEFINE_BUILTIN_OP_IMPORTER(Cast)

onnx2trt_utils.cpp

Lines changed: 28 additions & 52 deletions
Original file line numberDiff line numberDiff line change
@@ -347,13 +347,13 @@ bool convertOnnxWeights(
347347
{
348348
continue;
349349
}
350-
else
350+
else
351351
{
352352
LOG_ERROR("Key value of: " << keyName << " was not expected!");
353353
return false;
354354
}
355355
}
356-
356+
357357
// Buffer to hold the data read from the file
358358
std::vector<char> dataBuf;
359359
// Will update dataBuf and nbytes by reference.
@@ -1315,7 +1315,7 @@ nvinfer1::Dims insertDimension(const nvinfer1::Dims& dims, const int axis, const
13151315
bool parseExternalWeights(IImporterContext* ctx, std::string file, std::string path, int offset, int length,
13161316
std::vector<char>& weightsBuf, size_t& size)
13171317
{
1318-
// The weight paths in the ONNX model are relative paths to the main ONNX file.
1318+
// The weight paths in the ONNX model are relative paths to the main ONNX file.
13191319
#ifdef _MSC_VER
13201320
size_t slash = path.rfind("\\");
13211321
#else
@@ -1486,71 +1486,47 @@ nvinfer1::ITensor* reshapeTensor(IImporterContext* ctx, nvinfer1::ITensor& tenso
14861486
return layer->getOutput(0);
14871487
}
14881488

1489-
NodeImportResult scaleHelper(IImporterContext* ctx, const ::ONNX_NAMESPACE::NodeProto& node, nvinfer1::ITensor& tensor_, nvinfer1::ScaleMode mode,
1490-
nvinfer1::Weights shift, nvinfer1::Weights scale, nvinfer1::Weights power, std::string shiftName, std::string scaleName)
1489+
NodeImportResult scaleHelper(IImporterContext* ctx, const ::ONNX_NAMESPACE::NodeProto& node, nvinfer1::ITensor& tensor_,
1490+
nvinfer1::ScaleMode mode, const nvinfer1::Weights& shift, const nvinfer1::Weights& scale,
1491+
const nvinfer1::Weights& power, const char* shiftName, const char* scaleName)
14911492
{
1492-
nvinfer1::ITensor* tensor_ptr = &tensor_;
1493-
nvinfer1::Dims dims = tensor_ptr->getDimensions();
1493+
nvinfer1::ITensor* tensorPtr = &tensor_;
1494+
const ShapeTensor origShape = shapeOf(*tensorPtr);
14941495

14951496
// TensorRT scale layers support 4D(NCHW) or 5D(NCDHW) input.
1496-
// For input other than 4D or 5D will be expanded to 4D.
1497-
int expectedNbDims = 4;
1498-
bool needToExpandDims = (dims.nbDims != 4 && dims.nbDims != 5);
1499-
nvinfer1::Dims orig_shape = dims;
1500-
if (needToExpandDims)
1497+
// For input other than 4D or 5D will be expanded or squeezed to 4D.
1498+
bool needToReshape = (origShape.size() != 4 && origShape.size() != 5);
1499+
if (needToReshape)
15011500
{
1502-
// Expand or squash dims to 4D
1503-
nvinfer1::Dims new_shape = dims;
1504-
while (new_shape.nbDims < expectedNbDims)
1501+
if (origShape.size() < 4)
15051502
{
1506-
new_shape.d[new_shape.nbDims++] = 1;
1503+
std::vector<int> expandAxes(4 - origShape.size());
1504+
std::iota(expandAxes.begin(), expandAxes.end(), origShape.size());
1505+
tensorPtr = unsqueezeTensor(ctx, node, *tensorPtr, expandAxes);
15071506
}
1508-
while (new_shape.nbDims > expectedNbDims)
1507+
else
15091508
{
1510-
new_shape.d[3] *= new_shape.d[--new_shape.nbDims];
1509+
// Collapse trailing dimensions if origShape.size() > 5
1510+
const ShapeTensor collapsedDim = product(ctx, origShape, 3, origShape.size(), 1);
1511+
const ShapeTensor collapsedShape = concat(ctx, gather(ctx, origShape, iotaShapeVector(3)), collapsedDim);
1512+
tensorPtr = &reshape(ctx, *tensorPtr, collapsedShape);
15111513
}
1512-
tensor_ptr = reshapeTensor(ctx, *tensor_ptr, new_shape);
1513-
ASSERT(tensor_ptr, ErrorCode::kUNSUPPORTED_NODE);
1514-
dims = tensor_ptr->getDimensions();
15151514
}
15161515

1517-
ASSERT(dims.nbDims == 4 || dims.nbDims == 5, ErrorCode::kUNSUPPORTED_NODE);
1518-
1519-
// Fill in dtype for any unused (dummy) weights
1520-
nvinfer1::DataType* dtype_ptr = nullptr;
1521-
if (shift.count)
1522-
{
1523-
dtype_ptr = &shift.type;
1524-
}
1525-
if (scale.count)
1526-
{
1527-
ASSERT(!dtype_ptr || *dtype_ptr == scale.type, ErrorCode::kUNSUPPORTED_NODE);
1528-
dtype_ptr = &scale.type;
1529-
}
1530-
if (power.count)
1531-
{
1532-
ASSERT(!dtype_ptr || *dtype_ptr == power.type, ErrorCode::kUNSUPPORTED_NODE);
1533-
dtype_ptr = &power.type;
1534-
}
1535-
ASSERT(dtype_ptr, ErrorCode::kINTERNAL_ERROR);
1536-
shift.type = *dtype_ptr;
1537-
scale.type = *dtype_ptr;
1538-
power.type = *dtype_ptr;
1539-
auto* layer = ctx->network()->addScaleNd(*tensor_ptr, mode, shift, scale, power, 1);
1540-
ASSERT(layer, ErrorCode::kUNSUPPORTED_NODE);
1516+
auto* layer = ctx->network()->addScaleNd(*tensorPtr, mode, shift, scale, power, 1);
1517+
ASSERT(layer && "Failed to add a Scale layer.", ErrorCode::kUNSUPPORTED_NODE);
15411518
// Register layer name, and shift and scale weight names for the refit map.
15421519
ctx->registerLayer(layer, getNodeName(node));
1543-
ctx->insertRefitMap(shiftName, getNodeName(node), nvinfer1::WeightsRole::kSHIFT);
1544-
ctx->insertRefitMap(scaleName, getNodeName(node), nvinfer1::WeightsRole::kSCALE);
1545-
tensor_ptr = layer->getOutput(0);
15461520

1547-
if (needToExpandDims)
1521+
tensorPtr = layer->getOutput(0);
1522+
1523+
if (needToReshape)
15481524
{
1549-
tensor_ptr = reshapeTensor(ctx, *tensor_ptr, orig_shape);
1550-
ASSERT(tensor_ptr, ErrorCode::kUNSUPPORTED_NODE);
1525+
tensorPtr = &reshape(ctx, *tensorPtr, origShape);
1526+
ASSERT(tensorPtr && "Failed to reshape tensor.", ErrorCode::kUNSUPPORTED_NODE);
15511527
}
15521528

1553-
return {{tensor_ptr}};
1529+
return {{tensorPtr}};
15541530
}
15551531

15561532
void setAttr(

onnx2trt_utils.hpp

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -286,8 +286,9 @@ NodeImportResult reduceTensor(IImporterContext* ctx, ::ONNX_NAMESPACE::NodeProto
286286
nvinfer1::ITensor* reshapeTensor(IImporterContext* ctx, nvinfer1::ITensor& tensor, nvinfer1::Dims shape);
287287

288288
// Helper function to map attributes to a TRT scale layer
289-
NodeImportResult scaleHelper(IImporterContext* ctx, const ::ONNX_NAMESPACE::NodeProto& node, nvinfer1::ITensor& tensor_, nvinfer1::ScaleMode mode,
290-
nvinfer1::Weights shift, nvinfer1::Weights scale, nvinfer1::Weights power, std::string shiftName, std::string scaleName);
289+
NodeImportResult scaleHelper(IImporterContext* ctx, const ::ONNX_NAMESPACE::NodeProto& node, nvinfer1::ITensor& tensor_,
290+
nvinfer1::ScaleMode mode, const nvinfer1::Weights& shift, const nvinfer1::Weights& scale,
291+
const nvinfer1::Weights& power, const char* shiftName, const char* scaleName);
291292

292293
// Helper function to set an ONNX attribute
293294
void setAttr(

0 commit comments

Comments
 (0)