Skip to content

Commit 887f644

Browse files
committed
Copy weights file to epctx output directory
1 parent a8527b9 commit 887f644

File tree

3 files changed

+31
-18
lines changed

3 files changed

+31
-18
lines changed

Diff for: onnxruntime/core/providers/openvino/backend_manager.cc

+8-5
Original file line numberDiff line numberDiff line change
@@ -85,17 +85,20 @@ BackendManager::BackendManager(SessionContext& session_context,
8585

8686
auto& sw = shared_context_.shared_weights;
8787
if (session_context_.so_share_ep_contexts) {
88-
std::filesystem::path weight_filename = session_context_.onnx_model_path_name.parent_path();
8988
if (sw.external_weight_filename.empty() && !sw.metadata.empty()) {
9089
// Reasonable assumption that all metadata entries have the same external file location
9190
sw.external_weight_filename = sw.metadata.begin()->second.location;
9291
}
93-
weight_filename /= sw.external_weight_filename;
94-
std::ifstream weight_file(weight_filename);
9592

93+
auto weight_path = session_context_.GetNewWeightsFilePath(sw.external_weight_filename);
94+
if (!std::filesystem::exists(weight_path)) {
95+
weight_path = session_context_.GetModelDirectory() / sw.external_weight_filename;
96+
}
97+
98+
std::ifstream weight_file(weight_path);
9699
if (weight_file) {
97100
if (!sw.mapped_weights) {
98-
sw.mapped_weights = std::make_unique<SharedContext::SharedWeights::WeightsFile>(weight_filename);
101+
sw.mapped_weights = std::make_unique<SharedContext::SharedWeights::WeightsFile>(weight_path);
99102
}
100103
backend_utils::CreateOVTensors(session_context_.device_type, sw.metadata, *sw.mapped_weights);
101104
}
@@ -324,7 +327,7 @@ static bool IsQDQGraph(const onnxruntime::GraphViewer& graph_viewer) {
324327
static void DumpOpenVINOEPModel([[maybe_unused]] const std::filesystem::path& onnx_model_path_name,
325328
[[maybe_unused]] ONNX_NAMESPACE::ModelProto* model_proto,
326329
[[maybe_unused]] const onnxruntime::Node& fused_node) {
327-
#ifdef NOT_RELEASE
330+
#ifdef NOT_RELEASE
328331
if (openvino_ep::backend_utils::IsDebugEnabled()) {
329332
auto model_name = onnx_model_path_name.empty() ? "unknown.onnx" : onnx_model_path_name.filename();
330333

Diff for: onnxruntime/core/providers/openvino/contexts.h

+14
Original file line numberDiff line numberDiff line change
@@ -118,6 +118,20 @@ struct SessionContext : ProviderInfo {
118118
mutable bool has_external_weights = false; // Value is set to mutable to modify from capability
119119
const std::vector<uint32_t> OpenVINO_Version = {OPENVINO_VERSION_MAJOR, OPENVINO_VERSION_MINOR};
120120
const std::string openvino_sdk_version = std::to_string(OPENVINO_VERSION_MAJOR) + "." + std::to_string(OPENVINO_VERSION_MINOR);
121+
122+
fs::path GetModelDirectory() const {
123+
return onnx_model_path_name.parent_path();
124+
}
125+
126+
fs::path GetEpContextOutputDirectory() const {
127+
return so_context_file_path.empty() ? GetModelDirectory() : so_context_file_path;
128+
}
129+
130+
fs::path GetNewWeightsFilePath(fs::path external_weights_filename) const {
131+
ORT_ENFORCE(!external_weights_filename.empty(), "External weights filename should not be empty.");
132+
// Otherwise, use the provided external weights filename.
133+
return GetEpContextOutputDirectory() / fs::path(external_weights_filename.filename().string() + "_weights.bin");
134+
}
121135
};
122136

123137
// Holds context specific to subgraph.

Diff for: onnxruntime/core/providers/openvino/openvino_execution_provider.cc

+9-13
Original file line numberDiff line numberDiff line change
@@ -102,12 +102,13 @@ common::Status OpenVINOExecutionProvider::Compile(
102102
graph_body_viewer_0.DomainToVersionMap().at(kOnnxDomain);
103103
}
104104

105+
const auto metadata_path = session_context_.GetEpContextOutputDirectory() / "metadata.bin";
106+
105107
// Temporary code to read metadata before it moves to the .bin
106108
auto& metadata = shared_context_->shared_weights.metadata;
107109
if (session_context_.so_share_ep_contexts && metadata.empty()) {
108110
// Metadata is always read from model location, this could be a source or epctx model
109-
fs::path metadata_filename = session_context_.onnx_model_path_name.parent_path() / "metadata.bin";
110-
std::ifstream file(metadata_filename, std::ios::binary);
111+
std::ifstream file(metadata_path, std::ios::binary);
111112
if (file) {
112113
file >> metadata;
113114
}
@@ -174,18 +175,13 @@ common::Status OpenVINOExecutionProvider::Compile(
174175
}
175176

176177
if (session_context_.so_share_ep_contexts) {
177-
fs::path metadata_filename;
178-
if (session_context_.so_context_file_path.empty()) {
179-
metadata_filename = session_context_.onnx_model_path_name.parent_path() / "metadata.bin";
180-
} else {
181-
metadata_filename = session_context_.so_context_file_path.parent_path() / "metadata.bin";
182-
}
178+
const auto& sw_path_filename = shared_context_->shared_weights.external_weight_filename;
179+
fs::path new_weights_file_path = session_context_.GetNewWeightsFilePath(sw_path_filename);
180+
fs::path original_weights_path = session_context_.GetModelDirectory() / sw_path_filename;
181+
182+
std::filesystem::copy_file(original_weights_path, new_weights_file_path, std::filesystem::copy_options::skip_existing);
183183

184-
// Metadata is generated only for shared contexts
185-
// If saving metadata then save it to the provided path or ose the original model path
186-
// Multiple calls to Compile() will update the metadata and for the last call
187-
// the resulting file will contain the aggregated content
188-
std::ofstream file(metadata_filename, std::ios::binary);
184+
std::ofstream file(metadata_path, std::ios::binary);
189185
if (file) {
190186
file << metadata;
191187
}

0 commit comments

Comments
 (0)