Skip to content

Commit 96e7811

Browse files
authored
ONNX-TensorRT 10.1 GA release (#975)
Signed-off-by: Akhil Goel <[email protected]>
1 parent 06adf44 commit 96e7811

21 files changed

+933
-398
lines changed

.gitmodules

+1-1
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
11
[submodule "third_party/onnx"]
22
path = third_party/onnx
33
url = https://github.com/onnx/onnx.git
4-
branch = rel-1.16.0
4+
branch = v1.16.0

CMakeLists.txt

+2-2
Original file line numberDiff line numberDiff line change
@@ -28,8 +28,8 @@ add_definitions("-DSOURCE_LENGTH=${SOURCE_LENGTH}")
2828
# Version information
2929
#--------------------------------------------------
3030
set(ONNX2TRT_MAJOR 10)
31-
set(ONNX2TRT_MINOR 0)
32-
set(ONNX2TRT_PATCH 1)
31+
set(ONNX2TRT_MINOR 1)
32+
set(ONNX2TRT_PATCH 0)
3333
set(ONNX2TRT_VERSION "${ONNX2TRT_MAJOR}.${ONNX2TRT_MINOR}.${ONNX2TRT_PATCH}" CACHE STRING "ONNX2TRT version")
3434

3535
#--------------------------------------------------

ModelImporter.cpp

+297-179
Large diffs are not rendered by default.

ModelImporter.hpp

+63-22
Original file line numberDiff line numberDiff line change
@@ -7,8 +7,10 @@
77
#include "ImporterContext.hpp"
88
#include "NvInferPlugin.h"
99
#include "NvOnnxParser.h"
10+
#include "errorHelpers.hpp"
1011
#include "onnxOpCheckers.hpp"
1112
#include "onnxOpImporters.hpp"
13+
#include <stdexcept>
1214

1315
namespace onnx2trt
1416
{
@@ -24,32 +26,49 @@ Status parseGraph(ImporterContext* ctx, ::ONNX_NAMESPACE::GraphProto const& grap
2426

2527
class ModelImporter : public nvonnxparser::IParser
2628
{
29+
using SubGraphSupport_t = std::pair<std::vector<int64_t>, bool>;
30+
using SubGraphSupportVector_t = std::vector<SubGraphSupport_t>;
31+
2732
protected:
2833
StringMap<NodeImporter> _op_importers;
29-
virtual Status importModel(::ONNX_NAMESPACE::ModelProto const& model);
34+
virtual Status importModel(::ONNX_NAMESPACE::ModelProto const& model) noexcept;
3035

3136
private:
3237
ImporterContext mImporterCtx;
3338
std::vector<std::string> mPluginLibraryList; // Array of strings containing plugin libs
3439
std::vector<char const*>
3540
mPluginLibraryListCStr; // Array of C-strings corresponding to the strings in mPluginLibraryList
3641
std::list<::ONNX_NAMESPACE::ModelProto> mONNXModels; // Needed for ownership of weights
42+
SubGraphSupportVector_t mSubGraphSupportVector;
3743
int mCurrentNode;
38-
std::vector<Status> mErrors;
39-
nvonnxparser::OnnxParserFlags mOnnxParserFlags{1U << static_cast<uint32_t>(nvonnxparser::OnnxParserFlag::kNATIVE_INSTANCENORM)}; // kNATIVE_INSTANCENORM is ON by default.
44+
mutable std::vector<Status> mErrors; // Marked as mutable so that errors could be reported from const functions
45+
nvonnxparser::OnnxParserFlags mOnnxParserFlags{
46+
1U << static_cast<uint32_t>(
47+
nvonnxparser::OnnxParserFlag::kNATIVE_INSTANCENORM)}; // kNATIVE_INSTANCENORM is ON by default.
48+
std::pair<bool, SubGraphSupportVector_t> doSupportsModel(
49+
void const* serialized_onnx_model, size_t serialized_onnx_model_size, char const* model_path = nullptr);
4050

4151
public:
42-
ModelImporter(nvinfer1::INetworkDefinition* network, nvinfer1::ILogger* logger)
52+
ModelImporter(nvinfer1::INetworkDefinition* network, nvinfer1::ILogger* logger) noexcept
4353
: _op_importers(getBuiltinOpImporterMap())
4454
, mImporterCtx(network, logger)
4555
{
4656
}
47-
bool parseWithWeightDescriptors(void const* serialized_onnx_model, size_t serialized_onnx_model_size) override;
48-
bool parse(void const* serialized_onnx_model, size_t serialized_onnx_model_size, const char* model_path = nullptr) override;
57+
bool parseWithWeightDescriptors(
58+
void const* serialized_onnx_model, size_t serialized_onnx_model_size) noexcept override;
59+
bool parse(void const* serialized_onnx_model, size_t serialized_onnx_model_size,
60+
const char* model_path = nullptr) noexcept override;
61+
4962
bool supportsModel(void const* serialized_onnx_model, size_t serialized_onnx_model_size,
50-
SubGraphCollection_t& sub_graph_collection, const char* model_path = nullptr) override;
63+
SubGraphCollection_t& sub_graph_collection, const char* model_path = nullptr) noexcept override;
64+
bool supportsModelV2(void const* serialized_onnx_model, size_t serialized_onnx_model_size,
65+
char const* model_path = nullptr) noexcept override;
66+
67+
int64_t getNbSubgraphs() noexcept override;
68+
bool isSubgraphSupported(int64_t const index) noexcept override;
69+
int64_t* getSubgraphNodes(int64_t const index, int64_t& subgraphLength) noexcept override;
5170

52-
bool supportsOperator(const char* op_name) const override;
71+
bool supportsOperator(const char* op_name) const noexcept override;
5372

5473
void setFlags(nvonnxparser::OnnxParserFlags onnxParserFlags) noexcept override
5574
{
@@ -62,44 +81,66 @@ class ModelImporter : public nvonnxparser::IParser
6281

6382
void clearFlag(nvonnxparser::OnnxParserFlag onnxParserFlag) noexcept override
6483
{
65-
mOnnxParserFlags &= ~(1U << static_cast<uint32_t>(onnxParserFlag));
84+
ONNXTRT_TRY
85+
{
86+
mOnnxParserFlags &= ~(1U << static_cast<uint32_t>(onnxParserFlag));
87+
}
88+
ONNXTRT_CATCH_RECORD
6689
}
6790

6891
void setFlag(nvonnxparser::OnnxParserFlag onnxParserFlag) noexcept override
6992
{
70-
mOnnxParserFlags |= 1U << static_cast<uint32_t>(onnxParserFlag);
93+
ONNXTRT_TRY
94+
{
95+
mOnnxParserFlags |= 1U << static_cast<uint32_t>(onnxParserFlag);
96+
}
97+
ONNXTRT_CATCH_RECORD
7198
}
7299

73100
bool getFlag(nvonnxparser::OnnxParserFlag onnxParserFlag) const noexcept override
74101
{
75-
auto flag = 1U << static_cast<uint32_t>(onnxParserFlag);
76-
return static_cast<bool>(mOnnxParserFlags & flag);
102+
ONNXTRT_TRY
103+
{
104+
auto flag = 1U << static_cast<uint32_t>(onnxParserFlag);
105+
return static_cast<bool>(mOnnxParserFlags & flag);
106+
}
107+
ONNXTRT_CATCH_RECORD
108+
return false;
77109
}
78110

79-
int32_t getNbErrors() const override
111+
int32_t getNbErrors() const noexcept override
80112
{
81113
return mErrors.size();
82114
}
83-
nvonnxparser::IParserError const* getError(int32_t index) const override
115+
nvonnxparser::IParserError const* getError(int32_t index) const noexcept override
84116
{
85-
assert(0 <= index && index < (int32_t) mErrors.size());
86-
return &mErrors[index];
117+
ONNXTRT_TRY
118+
{
119+
return &mErrors.at(index);
120+
}
121+
ONNXTRT_CATCH_RECORD
122+
return nullptr;
87123
}
88-
void clearErrors() override
124+
void clearErrors() noexcept override
89125
{
90126
mErrors.clear();
91127
}
92128

93-
nvinfer1::ITensor const* getLayerOutputTensor(char const* name, int64_t i)
129+
nvinfer1::ITensor const* getLayerOutputTensor(char const* name, int64_t i) noexcept override
94130
{
95-
if (!name)
131+
ONNXTRT_TRY
96132
{
97-
return nullptr;
133+
if (!name)
134+
{
135+
throw std::invalid_argument("name is a nullptr");
136+
}
137+
return mImporterCtx.findLayerOutputTensor(name, i);
98138
}
99-
return mImporterCtx.findLayerOutputTensor(name, i);
139+
ONNXTRT_CATCH_RECORD
140+
return nullptr;
100141
}
101142

102-
bool parseFromFile(char const* onnxModelFile, int32_t verbosity) override;
143+
bool parseFromFile(char const* onnxModelFile, int32_t verbosity) noexcept override;
103144

104145
virtual char const* const* getUsedVCPluginLibraries(int64_t& nbPluginLibs) const noexcept override;
105146
};

ModelRefitter.cpp

+42-32
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,7 @@ Status deserializeOnnxModelFile(char const* onnxModelFile, ::ONNX_NAMESPACE::Mod
2424
{
2525
// Define S_ISREG macro for Windows
2626
#if !defined(S_ISREG)
27-
#define S_ISREG(mode) (((mode) &S_IFMT) == S_IFREG)
27+
#define S_ISREG(mode) (((mode) & S_IFMT) == S_IFREG)
2828
#endif
2929

3030
struct stat sb;
@@ -393,52 +393,62 @@ Status ModelRefitter::refitOnnxScanNode(::ONNX_NAMESPACE::NodeProto const& node)
393393
bool ModelRefitter::refitFromBytes(
394394
void const* serializedOnnxModel, size_t serializedOnnxModelSize, char const* modelPath) noexcept
395395
{
396-
if (modelPath)
396+
ONNXTRT_TRY
397397
{
398-
// Keep track of the absolute path to the ONNX file.
399-
mWeightsContext.setOnnxFileLocation(modelPath);
400-
}
398+
if (modelPath)
399+
{
400+
// Keep track of the absolute path to the ONNX file.
401+
mWeightsContext.setOnnxFileLocation(modelPath);
402+
}
401403

402-
Status status
403-
= deserializeOnnxModel(serializedOnnxModel, serializedOnnxModelSize, &onnx_model);
404-
if (status.is_error())
405-
{
406-
mErrors.push_back(status);
407-
return false;
408-
}
404+
Status status = deserializeOnnxModel(serializedOnnxModel, serializedOnnxModelSize, &onnx_model);
405+
if (status.is_error())
406+
{
407+
mErrors.push_back(status);
408+
return false;
409+
}
409410

410-
refittableWeights = getRefittableWeights();
411-
status = refitOnnxWeights(onnx_model);
412-
if (status.is_error())
413-
{
414-
mErrors.push_back(status);
415-
return false;
411+
refittableWeights = getRefittableWeights();
412+
status = refitOnnxWeights(onnx_model);
413+
if (status.is_error())
414+
{
415+
mErrors.push_back(status);
416+
return false;
417+
}
418+
return true;
416419
}
417-
return true;
420+
ONNXTRT_CATCH_LOG(mLogger)
421+
return false;
418422
}
419423

420424
bool ModelRefitter::refitFromFile(char const* onnxModelFile) noexcept
421425
{
422-
// Keep track of the absolute path to the ONNX file.
423-
mWeightsContext.setOnnxFileLocation(onnxModelFile);
424-
425-
Status status = deserializeOnnxModelFile(onnxModelFile, onnx_model);
426-
if (status.is_error())
426+
ONNXTRT_TRY
427427
{
428-
mErrors.push_back(status);
429-
return false;
430-
}
428+
// Keep track of the absolute path to the ONNX file.
429+
mWeightsContext.setOnnxFileLocation(onnxModelFile);
431430

432-
refittableWeights = getRefittableWeights();
433-
if (!refittableWeights.empty())
434-
{
435-
status = refitOnnxWeights(onnx_model);
431+
Status status = deserializeOnnxModelFile(onnxModelFile, onnx_model);
436432
if (status.is_error())
437433
{
438434
mErrors.push_back(status);
439435
return false;
440436
}
437+
438+
refittableWeights = getRefittableWeights();
439+
if (!refittableWeights.empty())
440+
{
441+
status = refitOnnxWeights(onnx_model);
442+
if (status.is_error())
443+
{
444+
mErrors.push_back(status);
445+
return false;
446+
}
447+
}
448+
return true;
441449
}
442-
return true;
450+
ONNXTRT_CATCH_LOG(mLogger)
451+
452+
return false;
443453
}
444454
} // namespace onnx2trt

ModelRefitter.hpp

+8-3
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77
#include "NvInferRuntime.h"
88
#include "Status.hpp"
99
#include "WeightsContext.hpp"
10+
#include "errorHelpers.hpp"
1011
#include <onnx/onnx_pb.h>
1112
#include <string>
1213
#include <unordered_set>
@@ -51,7 +52,7 @@ class ModelRefitter : public nvonnxparser::IParserRefitter
5152
std::unordered_set<std::string> refittableWeights;
5253
std::unordered_set<std::string> refittedWeights;
5354

54-
std::vector<Status> mErrors;
55+
mutable std::vector<Status> mErrors;
5556

5657
std::unordered_set<std::string> getRefittableWeights();
5758

@@ -90,8 +91,12 @@ class ModelRefitter : public nvonnxparser::IParserRefitter
9091

9192
nvonnxparser::IParserError const* getError(int32_t index) const noexcept override
9293
{
93-
assert(0 <= index && index < (int32_t) mErrors.size());
94-
return &mErrors[index];
94+
ONNXTRT_TRY
95+
{
96+
return &mErrors.at(index);
97+
}
98+
ONNXTRT_CATCH_LOG(mLogger)
99+
return nullptr;
95100
}
96101

97102
void clearErrors() noexcept override

NvOnnxParser.cpp

+3-3
Original file line numberDiff line numberDiff line change
@@ -7,21 +7,21 @@
77
#include "ModelRefitter.hpp"
88
#include "NvInferRuntime.h"
99

10-
extern "C" void* createNvOnnxParser_INTERNAL(void* network_, void* logger_, int version)
10+
extern "C" void* createNvOnnxParser_INTERNAL(void* network_, void* logger_, int version) noexcept
1111
{
1212
auto network = static_cast<nvinfer1::INetworkDefinition*>(network_);
1313
auto logger = static_cast<nvinfer1::ILogger*>(logger_);
1414
return new onnx2trt::ModelImporter(network, logger);
1515
}
1616

17-
extern "C" void* createNvOnnxParserRefitter_INTERNAL(void* refitter_, void* logger_, int32_t version)
17+
extern "C" void* createNvOnnxParserRefitter_INTERNAL(void* refitter_, void* logger_, int32_t version) noexcept
1818
{
1919
auto refitter = static_cast<nvinfer1::IRefitter*>(refitter_);
2020
auto logger = static_cast<nvinfer1::ILogger*>(logger_);
2121
return new onnx2trt::ModelRefitter(refitter, logger);
2222
}
2323

24-
extern "C" int getNvOnnxParserVersion()
24+
extern "C" int getNvOnnxParserVersion() noexcept
2525
{
2626
return NV_ONNX_PARSER_VERSION;
2727
}

0 commit comments

Comments
 (0)