Skip to content

Commit 6e17571

Browse files
authored
Fix issue that the generated context cache model inputs/outputs order is not guaranteed (#19195)
Fix issue that the generated context cache model inputs/outputs order is not guaranteed ### Description Currently, QNN EP generate the context cache model in Compile() method which only get access to the partitioned graph. And the inputs/outputs order for the partitioned graph is not guaranteed. And EP doesn't have the view of the input user model. Have to move the context cache model generation to a higher level in GraphPartitioner which has the view of the partitioned model. This is also a break down of PR for multi-partition support. #18865
1 parent a3ecb63 commit 6e17571

13 files changed

+210
-26
lines changed

include/onnxruntime/core/framework/execution_provider.h

+9
Original file line numberDiff line numberDiff line change
@@ -326,6 +326,15 @@ class IExecutionProvider {
326326
*/
327327
virtual std::vector<AllocatorPtr> CreatePreferredAllocators() { return std::vector<AllocatorPtr>(); };
328328

329+
/**
330+
* Get the array of pointers for EPContext nodes
331+
* EP needs to implement this if has the requirement to generate the context cache model. Otherwise leave it.
332+
* Default return an empty vector if not provided by the Execution Provider
333+
*/
334+
virtual const InlinedVector<const Node*> GetEpContextNodes() const {
335+
return InlinedVector<const Node*>();
336+
}
337+
329338
private:
330339
const std::string type_;
331340

include/onnxruntime/core/session/onnxruntime_session_options_config_keys.h

+1-1
Original file line numberDiff line numberDiff line change
@@ -236,7 +236,7 @@ static const char* const kOrtSessionOptionsOptimizedModelExternalInitializersFil
236236
static const char* const kOrtSessionOptionsOptimizedModelExternalInitializersMinSizeInBytes =
237237
"session.optimized_model_external_initializers_min_size_in_bytes";
238238

239-
// Enable EP context feature to dump the partitioned graph which include the EP context into Onnx file.
239+
// Enable EP context feature to dump the partitioned graph which includes the EP context into Onnx file.
240240
// The dumped Onnx model with EP context can be used for future inference to avoid the EP graph partitioning/compile overhead.
241241
// "0": disable. (default)
242242
// "1": enable.

onnxruntime/core/framework/graph_partitioner.cc

+105
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@
1616
#include "core/graph/function_utils.h"
1717
#include "core/graph/graph_viewer.h"
1818
#include "core/graph/model.h"
19+
#include "core/session/onnxruntime_session_options_config_keys.h"
1920

2021
// uncomment this line to count non-CUDA ops in ONNX domain
2122
// #define COUNT_NON_CUDA_OPS
@@ -634,6 +635,100 @@ static Status InlineFunctionsAOTImpl(const ExecutionProviders& execution_provide
634635
return Status::OK();
635636
}
636637

638+
static Status CreateEpContextModel(const ExecutionProviders& execution_providers,
639+
const Graph& graph,
640+
const std::string& ep_context_path,
641+
const logging::Logger& logger) {
642+
InlinedVector<const Node*> all_ep_context_nodes;
643+
for (const auto& ep : execution_providers) {
644+
const InlinedVector<const Node*> ep_context_nodes = ep->GetEpContextNodes();
645+
all_ep_context_nodes.insert(all_ep_context_nodes.begin(), ep_context_nodes.begin(), ep_context_nodes.end());
646+
}
647+
648+
auto get_ep_context_node = [&all_ep_context_nodes](const std::string& node_name) -> std::pair<bool, const Node*> {
649+
for (auto& node : all_ep_context_nodes) {
650+
if (node_name == node->Name()) {
651+
return std::make_pair(true, node);
652+
}
653+
}
654+
return std::make_pair(false, static_cast<const Node*>(nullptr));
655+
};
656+
657+
onnxruntime::PathString context_cache_path;
658+
PathString model_pathstring = graph.ModelPath().ToPathString();
659+
if (all_ep_context_nodes.size() > 0) {
660+
if (!ep_context_path.empty()) {
661+
context_cache_path = ToPathString(ep_context_path);
662+
} else if (!model_pathstring.empty()) {
663+
context_cache_path = model_pathstring + ToPathString("_ctx.onnx");
664+
}
665+
666+
{
667+
#ifdef _WIN32
668+
std::wifstream fs(context_cache_path);
669+
#else
670+
std::ifstream fs(context_cache_path);
671+
#endif
672+
ORT_RETURN_IF(fs.good(), "Failed to generate EP context model since the file exist already.");
673+
}
674+
675+
Model ep_context_model(graph.Name(), false, ModelMetaData(), PathString(), IOnnxRuntimeOpSchemaRegistryList(),
676+
graph.DomainToVersionMap(), {}, logger);
677+
auto& ep_graph = ep_context_model.MainGraph();
678+
ep_graph.SetDescription(graph.Description());
679+
680+
// Set inputs outputs explicitly to make sure the order is same as the user model.
681+
auto inputs = graph.GetInputs();
682+
auto outputs = graph.GetOutputs();
683+
684+
InlinedVector<const NodeArg*> ep_graph_inputs;
685+
ep_graph_inputs.reserve(inputs.size());
686+
for (auto& input : inputs) {
687+
auto input_arg = graph.GetNodeArg(input->Name());
688+
auto& ep_graph_input_arg = ep_graph.GetOrCreateNodeArg(input_arg->Name(), input_arg->TypeAsProto());
689+
ep_graph_inputs.push_back(&ep_graph_input_arg);
690+
}
691+
692+
InlinedVector<const NodeArg*> ep_graph_outputs;
693+
ep_graph_outputs.reserve(outputs.size());
694+
for (auto& output : outputs) {
695+
auto output_arg = graph.GetNodeArg(output->Name());
696+
auto& ep_graph_output_arg = ep_graph.GetOrCreateNodeArg(output_arg->Name(), output_arg->TypeAsProto());
697+
ep_graph_outputs.push_back(&ep_graph_output_arg);
698+
}
699+
700+
ep_graph.SetInputs(ep_graph_inputs);
701+
ep_graph.SetOutputs(ep_graph_outputs);
702+
703+
for (const auto& node : graph.Nodes()) {
704+
// the fused node and EPContext node has same node name
705+
auto ep_context_node = get_ep_context_node(node.Name());
706+
// Use EpContext node created by the EPs if name matched, otherwise use node from original model
707+
if (ep_context_node.first) {
708+
ep_graph.AddNode(*ep_context_node.second);
709+
} else {
710+
ep_graph.AddNode(node);
711+
}
712+
}
713+
714+
// handle initializers
715+
for (const auto& input : graph.GetInputsIncludingInitializers()) {
716+
const ONNX_NAMESPACE::TensorProto* initializer = nullptr;
717+
if (graph.GetInitializedTensor(input->Name(), initializer)) {
718+
// There initializer could have duplicates so make sure we only add once
719+
const ONNX_NAMESPACE::TensorProto* subgraph_initializer = nullptr;
720+
if (!ep_graph.GetInitializedTensor(input->Name(), subgraph_initializer)) {
721+
ep_graph.AddInitializedTensor(*initializer);
722+
}
723+
}
724+
}
725+
726+
ORT_RETURN_IF_ERROR(Model::Save(ep_context_model, context_cache_path));
727+
}
728+
729+
return Status::OK();
730+
}
731+
637732
static Status PartitionOnnxFormatModel(const PartitionParams& partition_params, GraphPartitioner::Mode mode,
638733
const ExecutionProviders& execution_providers,
639734
KernelRegistryManager& kernel_registry_manager) {
@@ -840,6 +935,8 @@ Status GraphPartitioner::InlineFunctionsAOT(Model& model,
840935

841936
Status GraphPartitioner::Partition(Graph& graph, FuncManager& func_mgr,
842937
const layout_transformation::TransformLayoutFunction& transform_layout_function,
938+
const ConfigOptions& config_options,
939+
const logging::Logger& logger,
843940
Mode mode,
844941
const layout_transformation::DebugGraphFn& debug_graph_fn) const {
845942
// It is a greedy partitioning algorithm per provider preferences user provided when calling ONNX RUNTIME right now.
@@ -886,7 +983,15 @@ Status GraphPartitioner::Partition(Graph& graph, FuncManager& func_mgr,
886983
#if !defined(ORT_MINIMAL_BUILD)
887984
ORT_RETURN_IF_ERROR(PartitionOnnxFormatModel(partition_params, mode,
888985
providers_, kernel_registry_mgr_));
986+
987+
bool ep_context_enabled = config_options.GetConfigOrDefault(kOrtSessionOptionEpContextEnable, "0") == "1";
988+
std::string ep_context_path = config_options.GetConfigOrDefault(kOrtSessionOptionEpContextFilePath, "");
989+
if (ep_context_enabled) {
990+
ORT_RETURN_IF_ERROR(CreateEpContextModel(providers_, graph, ep_context_path, logger));
991+
}
889992
#else
993+
ORT_UNUSED_PARAMETER(config_options);
994+
ORT_UNUSED_PARAMETER(logger);
890995
return ORT_MAKE_STATUS(ONNXRUNTIME, FAIL, "ONNX models are not supported in this build.");
891996
#endif //! defined(ORT_MINIMAL_BUILD)
892997
} else {

onnxruntime/core/framework/graph_partitioner.h

+3
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@ namespace onnxruntime {
1313
class ExecutionProviders;
1414
class KernelRegistryManager;
1515
class Model;
16+
struct ConfigOptions;
1617

1718
class GraphPartitioner {
1819
public:
@@ -31,6 +32,8 @@ class GraphPartitioner {
3132
// Run partitioning.
3233
Status Partition(Graph& graph, FuncManager& func_mgr,
3334
const layout_transformation::TransformLayoutFunction& transform_layout_function,
35+
const ConfigOptions& config_options,
36+
const logging::Logger& logger,
3437
Mode mode = Mode::kNormal,
3538
const layout_transformation::DebugGraphFn& debug_graph_fn = {}) const;
3639

onnxruntime/core/providers/qnn/builder/onnx_ctx_model_helper.cc

+3-10
Original file line numberDiff line numberDiff line change
@@ -230,8 +230,7 @@ Status ValidateWithContextFile(const onnxruntime::PathString& context_cache_path
230230
return Status::OK();
231231
}
232232

233-
Status GenerateCtxCacheOnnxModel(const std::string model_name,
234-
const std::string model_description,
233+
Status GenerateCtxCacheOnnxModel(Model* model,
235234
unsigned char* buffer,
236235
uint64_t buffer_size,
237236
const std::string& sdk_build_version,
@@ -240,11 +239,7 @@ Status GenerateCtxCacheOnnxModel(const std::string model_name,
240239
const onnxruntime::PathString& context_cache_path,
241240
bool qnn_context_embed_mode,
242241
const logging::Logger& logger) {
243-
std::unordered_map<std::string, int> domain_to_version = {{kOnnxDomain, 11}, {kMSDomain, 1}};
244-
Model model(model_name, false, ModelMetaData(), PathString(), IOnnxRuntimeOpSchemaRegistryList(),
245-
domain_to_version, {}, logger);
246-
auto& graph = model.MainGraph();
247-
graph.SetDescription(model_description);
242+
auto& graph = model->MainGraph();
248243

249244
using namespace ONNX_NAMESPACE;
250245
int index = 0;
@@ -270,7 +265,7 @@ Status GenerateCtxCacheOnnxModel(const std::string model_name,
270265
nullptr,
271266
kMSDomain);
272267

273-
// Only dump the context buffer once since all QNN graph are in one single context
268+
// Only dump the context buffer once since all QNN graphs are in one single context
274269
if (0 == index) {
275270
if (qnn_context_embed_mode) {
276271
std::string cache_payload(buffer, buffer + buffer_size);
@@ -296,8 +291,6 @@ Status GenerateCtxCacheOnnxModel(const std::string model_name,
296291
ep_node.AddAttribute(SOURCE, kQnnExecutionProvider);
297292
++index;
298293
}
299-
ORT_RETURN_IF_ERROR(graph.Resolve());
300-
ORT_RETURN_IF_ERROR(Model::Save(model, context_cache_path));
301294

302295
return Status::OK();
303296
}

onnxruntime/core/providers/qnn/builder/onnx_ctx_model_helper.h

+1-2
Original file line numberDiff line numberDiff line change
@@ -73,8 +73,7 @@ Status GetMetadataFromEpContextModel(const onnxruntime::PathString& ctx_onnx_mod
7373
std::string& cache_source,
7474
const logging::Logger& logger);
7575

76-
Status GenerateCtxCacheOnnxModel(const std::string model_name,
77-
const std::string model_description,
76+
Status GenerateCtxCacheOnnxModel(Model* model,
7877
unsigned char* buffer,
7978
uint64_t buffer_size,
8079
const std::string& sdk_build_version,

onnxruntime/core/providers/qnn/qnn_execution_provider.cc

+14-2
Original file line numberDiff line numberDiff line change
@@ -613,8 +613,8 @@ Status QNNExecutionProvider::Compile(const std::vector<FusedNodeAndGraph>& fused
613613
ORT_RETURN_IF(fused_nodes_and_graphs.size() != 1, "Only support single partition for context cache feature.");
614614
uint64_t buffer_size(0);
615615
auto context_buffer = qnn_backend_manager_->GetContextBinaryBuffer(buffer_size);
616-
ORT_RETURN_IF_ERROR(qnn::GenerateCtxCacheOnnxModel(model_name,
617-
model_description,
616+
qnn_ep_context_model_ = std::make_unique<Model>("qnn_ep_context_model", false, logger);
617+
ORT_RETURN_IF_ERROR(qnn::GenerateCtxCacheOnnxModel(qnn_ep_context_model_.get(),
618618
context_buffer.get(),
619619
buffer_size,
620620
qnn_backend_manager_->GetSdkVersion(),
@@ -626,4 +626,16 @@ Status QNNExecutionProvider::Compile(const std::vector<FusedNodeAndGraph>& fused
626626
}
627627
return Status::OK();
628628
}
629+
630+
const InlinedVector<const Node*> QNNExecutionProvider::GetEpContextNodes() const {
631+
InlinedVector<const Node*> ep_context_nodes;
632+
if (qnn_ep_context_model_) {
633+
const auto& graph = qnn_ep_context_model_->MainGraph();
634+
for (const auto& node : graph.Nodes()) {
635+
ep_context_nodes.push_back(graph.GetNode(node.Index()));
636+
}
637+
}
638+
639+
return ep_context_nodes;
640+
}
629641
} // namespace onnxruntime

onnxruntime/core/providers/qnn/qnn_execution_provider.h

+4
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@
99
#include "core/providers/qnn/builder/qnn_backend_manager.h"
1010
#include "core/providers/qnn/builder/qnn_model.h"
1111
#include "core/providers/qnn/builder/qnn_graph_configs_helper.h"
12+
#include "core/graph/model.h"
1213

1314
namespace onnxruntime {
1415

@@ -35,6 +36,8 @@ class QNNExecutionProvider : public IExecutionProvider {
3536

3637
DataLayout GetPreferredLayout() const override;
3738

39+
const InlinedVector<const Node*> GetEpContextNodes() const override;
40+
3841
private:
3942
bool IsNodeSupported(qnn::QnnModelWrapper& qnn_model_wrapper, const NodeUnit& node_unit,
4043
std::unordered_map<const NodeUnit*, bool>& node_unit_supported_result,
@@ -66,6 +69,7 @@ class QNNExecutionProvider : public IExecutionProvider {
6669
bool disable_cpu_ep_fallback_ = false; // True if CPU EP fallback has been disabled for this session.
6770
bool qnn_context_embed_mode_ = true;
6871
int32_t vtcm_size_in_mb_ = 0;
72+
std::unique_ptr<onnxruntime::Model> qnn_ep_context_model_;
6973
};
7074

7175
} // namespace onnxruntime

onnxruntime/core/session/inference_session.cc

+7-2
Original file line numberDiff line numberDiff line change
@@ -1164,6 +1164,7 @@ common::Status InferenceSession::TransformGraph(onnxruntime::Graph& graph, bool
11641164

11651165
// Do partitioning based on execution providers' capabilities.
11661166
ORT_RETURN_IF_ERROR_SESSIONID_(partitioner.Partition(graph, session_state_->GetMutableFuncMgr(), transform_layout_fn,
1167+
session_options_.config_options, *session_logger_,
11671168
mode, debug_graph_fn));
11681169

11691170
// apply Level2 and higher transformers.
@@ -1458,7 +1459,9 @@ namespace {
14581459
Status PartitionOrtFormatModel(onnxruntime::Graph& graph,
14591460
const ExecutionProviders& providers,
14601461
KernelRegistryManager& kernel_registry_manager,
1461-
SessionState& session_state) {
1462+
SessionState& session_state,
1463+
const ConfigOptions& config_options,
1464+
const logging::Logger& logger) {
14621465
layout_transformation::TransformLayoutFunction transform_layout_fn = nullptr;
14631466

14641467
#if !defined(ORT_MINIMAL_BUILD) || defined(ORT_EXTENDED_MINIMAL_BUILD)
@@ -1479,6 +1482,8 @@ Status PartitionOrtFormatModel(onnxruntime::Graph& graph,
14791482
ORT_RETURN_IF_ERROR(partitioner.Partition(graph,
14801483
session_state.GetMutableFuncMgr(),
14811484
transform_layout_fn,
1485+
config_options,
1486+
logger,
14821487
GraphPartitioner::Mode::kOrtFormatLoad));
14831488

14841489
return Status::OK();
@@ -1833,7 +1838,7 @@ common::Status InferenceSession::Initialize() {
18331838
#endif // !defined(ORT_MINIMAL_BUILD)
18341839
} else {
18351840
ORT_RETURN_IF_ERROR_SESSIONID_(PartitionOrtFormatModel(graph, execution_providers_, kernel_registry_manager_,
1836-
*session_state_));
1841+
*session_state_, session_options_.config_options, *session_logger_));
18371842

18381843
#if !defined(ORT_MINIMAL_BUILD) || defined(ORT_EXTENDED_MINIMAL_BUILD)
18391844
const auto& cpu_ep = *execution_providers_.Get(onnxruntime::kCpuExecutionProvider);

onnxruntime/test/framework/session_state_test.cc

+16-9
Original file line numberDiff line numberDiff line change
@@ -171,13 +171,16 @@ TEST_P(SessionStateTestP, TestInitializerProcessing) {
171171

172172
GraphPartitioner partitioner(krm, execution_providers);
173173
ASSERT_STATUS_OK(
174-
partitioner.Partition(graph, session_state.GetMutableFuncMgr(),
175-
[](Graph& graph, bool& modified, const IExecutionProvider& execution_provider,
176-
const layout_transformation::DebugGraphFn& debug_graph_fn) -> Status {
177-
AllocatorPtr cpu_allocator = std::make_shared<CPUAllocator>();
178-
return layout_transformation::TransformLayoutForEP(
179-
graph, modified, execution_provider, std::move(cpu_allocator), debug_graph_fn);
180-
}));
174+
partitioner.Partition(
175+
graph, session_state.GetMutableFuncMgr(),
176+
[](Graph& graph, bool& modified, const IExecutionProvider& execution_provider,
177+
const layout_transformation::DebugGraphFn& debug_graph_fn) -> Status {
178+
AllocatorPtr cpu_allocator = std::make_shared<CPUAllocator>();
179+
return layout_transformation::TransformLayoutForEP(
180+
graph, modified, execution_provider, std::move(cpu_allocator), debug_graph_fn);
181+
},
182+
sess_options.config_options,
183+
DefaultLoggingManager().DefaultLogger()));
181184

182185
ASSERT_STATUS_OK(session_state.FinalizeSessionState(oss.str(), krm));
183186

@@ -257,7 +260,9 @@ TEST(SessionStateTest, TestInitializerMemoryAllocatedUsingNonArenaMemory) {
257260
const layout_transformation::DebugGraphFn& debug_graph_fn) -> Status {
258261
return layout_transformation::TransformLayoutForEP(graph, modified, execution_provider,
259262
cpu_allocator, debug_graph_fn);
260-
}));
263+
},
264+
sess_options.config_options,
265+
DefaultLoggingManager().DefaultLogger()));
261266

262267
ASSERT_STATUS_OK(session_state.FinalizeSessionState(oss.str(), krm));
263268

@@ -314,7 +319,9 @@ TEST(SessionStateTest, TestInitializerMemoryAllocatedUsingNonArenaMemory) {
314319
const layout_transformation::DebugGraphFn& debug_graph_fn) -> Status {
315320
return layout_transformation::TransformLayoutForEP(
316321
graph, modified, execution_provider, cpu_allocator, debug_graph_fn);
317-
}));
322+
},
323+
sess_options.config_options,
324+
DefaultLoggingManager().DefaultLogger()));
318325

319326
// Finalize the session state
320327
ASSERT_STATUS_OK(session_state.FinalizeSessionState(oss.str(), krm));

0 commit comments

Comments
 (0)