Skip to content

Commit 293a03f

Browse files
committed
Update APIs with model_path for ONNXRT (onnx#621)
Signed-off-by: Kevin Chen <[email protected]>
1 parent 452c9d9 commit 293a03f

File tree

5 files changed

+24
-11
lines changed

5 files changed

+24
-11
lines changed

Diff for: ModelImporter.cpp

+12-2
Original file line numberDiff line numberDiff line change
@@ -293,7 +293,8 @@ Status deserialize_onnx_model(int fd, bool is_serialized_as_text, ::ONNX_NAMESPA
293293
}
294294

295295
bool ModelImporter::supportsModel(
296-
void const* serialized_onnx_model, size_t serialized_onnx_model_size, SubGraphCollection_t& sub_graph_collection)
296+
void const* serialized_onnx_model, size_t serialized_onnx_model_size, SubGraphCollection_t& sub_graph_collection,
297+
const char* model_path)
297298
{
298299

299300
::ONNX_NAMESPACE::ModelProto model;
@@ -307,6 +308,11 @@ bool ModelImporter::supportsModel(
307308
return false;
308309
}
309310

311+
if (model_path)
312+
{
313+
_importer_ctx.setOnnxFileLocation(model_path);
314+
}
315+
310316
bool allSupported{true};
311317

312318
// Parse the graph and see if we hit any parsing errors
@@ -454,8 +460,12 @@ bool ModelImporter::parseWithWeightDescriptors(void const* serialized_onnx_model
454460
return true;
455461
}
456462

457-
bool ModelImporter::parse(void const* serialized_onnx_model, size_t serialized_onnx_model_size)
463+
bool ModelImporter::parse(void const* serialized_onnx_model, size_t serialized_onnx_model_size, const char* model_path)
458464
{
465+
if (model_path)
466+
{
467+
_importer_ctx.setOnnxFileLocation(model_path);
468+
}
459469
return this->parseWithWeightDescriptors(serialized_onnx_model, serialized_onnx_model_size, 0, nullptr);
460470
}
461471

Diff for: ModelImporter.hpp

+2-2
Original file line numberDiff line numberDiff line change
@@ -56,9 +56,9 @@ class ModelImporter : public nvonnxparser::IParser
5656
}
5757
bool parseWithWeightDescriptors(void const* serialized_onnx_model, size_t serialized_onnx_model_size,
5858
uint32_t weight_count, onnxTensorDescriptorV1 const* weight_descriptors) override;
59-
bool parse(void const* serialized_onnx_model, size_t serialized_onnx_model_size) override;
59+
bool parse(void const* serialized_onnx_model, size_t serialized_onnx_model_size, const char* model_path = nullptr) override;
6060
bool supportsModel(void const* serialized_onnx_model, size_t serialized_onnx_model_size,
61-
SubGraphCollection_t& sub_graph_collection) override;
61+
SubGraphCollection_t& sub_graph_collection, const char* model_path = nullptr) override;
6262

6363
bool supportsOperator(const char* op_name) const override;
6464
void destroy() override

Diff for: NvOnnxParser.h

+7-4
Original file line numberDiff line numberDiff line change
@@ -132,13 +132,14 @@ class IParser
132132
* To obtain a better diagnostic, use the parseFromFile method below.
133133
*
134134
* \param serialized_onnx_model Pointer to the serialized ONNX model
135-
* \param serialized_onnx_model_size Size of the serialized ONNX model
136-
* in bytes
135+
* \param serialized_onnx_model_size Size of the serialized ONNX model in bytes
136+
* \param model_path Absolute path to the model file for loading external weights if required
137137
* \return true if the model was parsed successfully
138138
* \see getNbErrors() getError()
139139
*/
140140
virtual bool parse(void const* serialized_onnx_model,
141-
size_t serialized_onnx_model_size)
141+
size_t serialized_onnx_model_size,
142+
const char* model_path = nullptr)
142143
= 0;
143144

144145
/** \brief Parse an onnx model file, can be a binary protobuf or a text onnx model
@@ -158,11 +159,13 @@ class IParser
158159
* \param serialized_onnx_model_size Size of the serialized ONNX model
159160
* in bytes
160161
* \param sub_graph_collection Container to hold supported subgraphs
162+
* \param model_path Absolute path to the model file for loading external weights if required
161163
* \return true if the model is supported
162164
*/
163165
virtual bool supportsModel(void const* serialized_onnx_model,
164166
size_t serialized_onnx_model_size,
165-
SubGraphCollection_t& sub_graph_collection)
167+
SubGraphCollection_t& sub_graph_collection,
168+
const char* model_path = nullptr)
166169
= 0;
167170

168171
/** \brief Parse a serialized ONNX model into the TensorRT network

Diff for: OnnxAttrs.hpp

+1-1
Original file line numberDiff line numberDiff line change
@@ -61,7 +61,7 @@ class OnnxAttrs
6161
return _attrs.at(key);
6262
}
6363

64-
const ::ONNX_NAMESPACE::AttributeProto::AttributeType type(const std::string& key) const
64+
::ONNX_NAMESPACE::AttributeProto::AttributeType type(const std::string& key) const
6565
{
6666
return this->at(key)->type();
6767
}

Diff for: onnx2trt_utils.cpp

+2-2
Original file line numberDiff line numberDiff line change
@@ -1346,10 +1346,10 @@ bool parseExternalWeights(IImporterContext* ctx, std::string file, std::string p
13461346
relPathFile.seekg(offset, std::ios::beg);
13471347
int weightsBufSize = length == 0 ? fileSize : length;
13481348
weightsBuf.resize(weightsBufSize);
1349-
LOG_VERBOSE("Reading weights from external file: " << file);
1349+
LOG_VERBOSE("Reading weights from external file: " << path);
13501350
if (!relPathFile.read(weightsBuf.data(), weightsBuf.size()))
13511351
{
1352-
LOG_ERROR("Failed to read weights from external file: " << file);
1352+
LOG_ERROR("Failed to read weights from external file: " << path);
13531353
return false;
13541354
}
13551355
size = weightsBuf.size();

0 commit comments

Comments
 (0)