From 802c07acf83d138711a54f0b108c6e6936c3c60d Mon Sep 17 00:00:00 2001 From: Kevin Chen Date: Mon, 11 Jan 2021 11:52:35 -0800 Subject: [PATCH] Update APIs with model_path for ONNXRT Signed-off-by: Kevin Chen --- ModelImporter.cpp | 14 ++++++++++++-- ModelImporter.hpp | 4 ++-- NvOnnxParser.h | 11 +++++++---- OnnxAttrs.hpp | 2 +- onnx2trt_utils.cpp | 4 ++-- 5 files changed, 24 insertions(+), 11 deletions(-) diff --git a/ModelImporter.cpp b/ModelImporter.cpp index eab7ce85..b241d213 100644 --- a/ModelImporter.cpp +++ b/ModelImporter.cpp @@ -293,7 +293,8 @@ Status deserialize_onnx_model(int fd, bool is_serialized_as_text, ::ONNX_NAMESPA } bool ModelImporter::supportsModel( - void const* serialized_onnx_model, size_t serialized_onnx_model_size, SubGraphCollection_t& sub_graph_collection) + void const* serialized_onnx_model, size_t serialized_onnx_model_size, SubGraphCollection_t& sub_graph_collection, + const char* model_path) { ::ONNX_NAMESPACE::ModelProto model; @@ -307,6 +308,11 @@ bool ModelImporter::supportsModel( return false; } + if (model_path) + { + _importer_ctx.setOnnxFileLocation(model_path); + } + bool allSupported{true}; // Parse the graph and see if we hit any parsing errors @@ -454,8 +460,12 @@ bool ModelImporter::parseWithWeightDescriptors(void const* serialized_onnx_model return true; } -bool ModelImporter::parse(void const* serialized_onnx_model, size_t serialized_onnx_model_size) +bool ModelImporter::parse(void const* serialized_onnx_model, size_t serialized_onnx_model_size, const char* model_path) { + if (model_path) + { + _importer_ctx.setOnnxFileLocation(model_path); + } return this->parseWithWeightDescriptors(serialized_onnx_model, serialized_onnx_model_size, 0, nullptr); } diff --git a/ModelImporter.hpp b/ModelImporter.hpp index 699dc693..d961266b 100644 --- a/ModelImporter.hpp +++ b/ModelImporter.hpp @@ -56,9 +56,9 @@ class ModelImporter : public nvonnxparser::IParser } bool parseWithWeightDescriptors(void const* serialized_onnx_model, size_t serialized_onnx_model_size, uint32_t weight_count, onnxTensorDescriptorV1 const* weight_descriptors) override; - bool parse(void const* serialized_onnx_model, size_t serialized_onnx_model_size) override; + bool parse(void const* serialized_onnx_model, size_t serialized_onnx_model_size, const char* model_path = nullptr) override; bool supportsModel(void const* serialized_onnx_model, size_t serialized_onnx_model_size, - SubGraphCollection_t& sub_graph_collection) override; + SubGraphCollection_t& sub_graph_collection, const char* model_path = nullptr) override; bool supportsOperator(const char* op_name) const override; void destroy() override diff --git a/NvOnnxParser.h b/NvOnnxParser.h index c7f618c7..48a768f6 100644 --- a/NvOnnxParser.h +++ b/NvOnnxParser.h @@ -132,13 +132,14 @@ class IParser * To obtain a better diagnostic, use the parseFromFile method below. * * \param serialized_onnx_model Pointer to the serialized ONNX model - * \param serialized_onnx_model_size Size of the serialized ONNX model - * in bytes + * \param serialized_onnx_model_size Size of the serialized ONNX model in bytes + * \param model_path Absolute path to the model file for loading external weights if required * \return true if the model was parsed successfully * \see getNbErrors() getError() */ virtual bool parse(void const* serialized_onnx_model, - size_t serialized_onnx_model_size) + size_t serialized_onnx_model_size, + const char* model_path = nullptr) = 0; /** \brief Parse an onnx model file, can be a binary protobuf or a text onnx model @@ -158,11 +159,13 @@ class IParser * \param serialized_onnx_model_size Size of the serialized ONNX model * in bytes * \param sub_graph_collection Container to hold supported subgraphs + * \param model_path Absolute path to the model file for loading external weights if required * \return true if the model is supported */ virtual bool supportsModel(void const* serialized_onnx_model, size_t serialized_onnx_model_size, - SubGraphCollection_t& sub_graph_collection) + SubGraphCollection_t& sub_graph_collection, + const char* model_path = nullptr) = 0; /** \brief Parse a serialized ONNX model into the TensorRT network diff --git a/OnnxAttrs.hpp b/OnnxAttrs.hpp index 69358ea6..ada3aa94 100644 --- a/OnnxAttrs.hpp +++ b/OnnxAttrs.hpp @@ -61,7 +61,7 @@ class OnnxAttrs return _attrs.at(key); } - const ::ONNX_NAMESPACE::AttributeProto::AttributeType type(const std::string& key) const + ::ONNX_NAMESPACE::AttributeProto::AttributeType type(const std::string& key) const { return this->at(key)->type(); } diff --git a/onnx2trt_utils.cpp b/onnx2trt_utils.cpp index 491ee95e..572061ea 100644 --- a/onnx2trt_utils.cpp +++ b/onnx2trt_utils.cpp @@ -1346,10 +1346,10 @@ bool parseExternalWeights(IImporterContext* ctx, std::string file, std::string p relPathFile.seekg(offset, std::ios::beg); int weightsBufSize = length == 0 ? fileSize : length; weightsBuf.resize(weightsBufSize); - LOG_VERBOSE("Reading weights from external file: " << file); + LOG_VERBOSE("Reading weights from external file: " << path); if (!relPathFile.read(weightsBuf.data(), weightsBuf.size())) { - LOG_ERROR("Failed to read weights from external file: " << file); + LOG_ERROR("Failed to read weights from external file: " << path); return false; } size = weightsBuf.size();