|
16 | 16 | #include "core/graph/function_utils.h"
|
17 | 17 | #include "core/graph/graph_viewer.h"
|
18 | 18 | #include "core/graph/model.h"
|
| 19 | +#include "core/session/onnxruntime_session_options_config_keys.h" |
19 | 20 |
|
20 | 21 | // uncomment this line to count non-CUDA ops in ONNX domain
|
21 | 22 | // #define COUNT_NON_CUDA_OPS
|
@@ -634,6 +635,100 @@ static Status InlineFunctionsAOTImpl(const ExecutionProviders& execution_provide
|
634 | 635 | return Status::OK();
|
635 | 636 | }
|
636 | 637 |
|
| 638 | +static Status CreateEpContextModel(const ExecutionProviders& execution_providers, |
| 639 | + const Graph& graph, |
| 640 | + const std::string& ep_context_path, |
| 641 | + const logging::Logger& logger) { |
| 642 | + InlinedVector<const Node*> all_ep_context_nodes; |
| 643 | + for (const auto& ep : execution_providers) { |
| 644 | + const InlinedVector<const Node*> ep_context_nodes = ep->GetEpContextNodes(); |
| 645 | + all_ep_context_nodes.insert(all_ep_context_nodes.begin(), ep_context_nodes.begin(), ep_context_nodes.end()); |
| 646 | + } |
| 647 | + |
| 648 | + auto get_ep_context_node = [&all_ep_context_nodes](const std::string& node_name) -> std::pair<bool, const Node*> { |
| 649 | + for (auto& node : all_ep_context_nodes) { |
| 650 | + if (node_name == node->Name()) { |
| 651 | + return std::make_pair(true, node); |
| 652 | + } |
| 653 | + } |
| 654 | + return std::make_pair(false, static_cast<const Node*>(nullptr)); |
| 655 | + }; |
| 656 | + |
| 657 | + onnxruntime::PathString context_cache_path; |
| 658 | + PathString model_pathstring = graph.ModelPath().ToPathString(); |
| 659 | + if (all_ep_context_nodes.size() > 0) { |
| 660 | + if (!ep_context_path.empty()) { |
| 661 | + context_cache_path = ToPathString(ep_context_path); |
| 662 | + } else if (!model_pathstring.empty()) { |
| 663 | + context_cache_path = model_pathstring + ToPathString("_ctx.onnx"); |
| 664 | + } |
| 665 | + |
| 666 | + { |
| 667 | +#ifdef _WIN32 |
| 668 | + std::wifstream fs(context_cache_path); |
| 669 | +#else |
| 670 | + std::ifstream fs(context_cache_path); |
| 671 | +#endif |
| 672 | + ORT_RETURN_IF(fs.good(), "Failed to generate EP context model since the file exist already."); |
| 673 | + } |
| 674 | + |
| 675 | + Model ep_context_model(graph.Name(), false, ModelMetaData(), PathString(), IOnnxRuntimeOpSchemaRegistryList(), |
| 676 | + graph.DomainToVersionMap(), {}, logger); |
| 677 | + auto& ep_graph = ep_context_model.MainGraph(); |
| 678 | + ep_graph.SetDescription(graph.Description()); |
| 679 | + |
| 680 | + // Set inputs outputs explicitly to make sure the order is same as the user model. |
| 681 | + auto inputs = graph.GetInputs(); |
| 682 | + auto outputs = graph.GetOutputs(); |
| 683 | + |
| 684 | + InlinedVector<const NodeArg*> ep_graph_inputs; |
| 685 | + ep_graph_inputs.reserve(inputs.size()); |
| 686 | + for (auto& input : inputs) { |
| 687 | + auto input_arg = graph.GetNodeArg(input->Name()); |
| 688 | + auto& ep_graph_input_arg = ep_graph.GetOrCreateNodeArg(input_arg->Name(), input_arg->TypeAsProto()); |
| 689 | + ep_graph_inputs.push_back(&ep_graph_input_arg); |
| 690 | + } |
| 691 | + |
| 692 | + InlinedVector<const NodeArg*> ep_graph_outputs; |
| 693 | + ep_graph_outputs.reserve(outputs.size()); |
| 694 | + for (auto& output : outputs) { |
| 695 | + auto output_arg = graph.GetNodeArg(output->Name()); |
| 696 | + auto& ep_graph_output_arg = ep_graph.GetOrCreateNodeArg(output_arg->Name(), output_arg->TypeAsProto()); |
| 697 | + ep_graph_outputs.push_back(&ep_graph_output_arg); |
| 698 | + } |
| 699 | + |
| 700 | + ep_graph.SetInputs(ep_graph_inputs); |
| 701 | + ep_graph.SetOutputs(ep_graph_outputs); |
| 702 | + |
| 703 | + for (const auto& node : graph.Nodes()) { |
| 704 | + // the fused node and EPContext node has same node name |
| 705 | + auto ep_context_node = get_ep_context_node(node.Name()); |
| 706 | + // Use EpContext node created by the EPs if name matched, otherwise use node from original model |
| 707 | + if (ep_context_node.first) { |
| 708 | + ep_graph.AddNode(*ep_context_node.second); |
| 709 | + } else { |
| 710 | + ep_graph.AddNode(node); |
| 711 | + } |
| 712 | + } |
| 713 | + |
| 714 | + // handle initializers |
| 715 | + for (const auto& input : graph.GetInputsIncludingInitializers()) { |
| 716 | + const ONNX_NAMESPACE::TensorProto* initializer = nullptr; |
| 717 | + if (graph.GetInitializedTensor(input->Name(), initializer)) { |
| 718 | + // There initializer could have duplicates so make sure we only add once |
| 719 | + const ONNX_NAMESPACE::TensorProto* subgraph_initializer = nullptr; |
| 720 | + if (!ep_graph.GetInitializedTensor(input->Name(), subgraph_initializer)) { |
| 721 | + ep_graph.AddInitializedTensor(*initializer); |
| 722 | + } |
| 723 | + } |
| 724 | + } |
| 725 | + |
| 726 | + ORT_RETURN_IF_ERROR(Model::Save(ep_context_model, context_cache_path)); |
| 727 | + } |
| 728 | + |
| 729 | + return Status::OK(); |
| 730 | +} |
| 731 | + |
637 | 732 | static Status PartitionOnnxFormatModel(const PartitionParams& partition_params, GraphPartitioner::Mode mode,
|
638 | 733 | const ExecutionProviders& execution_providers,
|
639 | 734 | KernelRegistryManager& kernel_registry_manager) {
|
@@ -840,6 +935,8 @@ Status GraphPartitioner::InlineFunctionsAOT(Model& model,
|
840 | 935 |
|
841 | 936 | Status GraphPartitioner::Partition(Graph& graph, FuncManager& func_mgr,
|
842 | 937 | const layout_transformation::TransformLayoutFunction& transform_layout_function,
|
| 938 | + const ConfigOptions& config_options, |
| 939 | + const logging::Logger& logger, |
843 | 940 | Mode mode,
|
844 | 941 | const layout_transformation::DebugGraphFn& debug_graph_fn) const {
|
845 | 942 | // It is a greedy partitioning algorithm per provider preferences user provided when calling ONNX RUNTIME right now.
|
@@ -886,7 +983,15 @@ Status GraphPartitioner::Partition(Graph& graph, FuncManager& func_mgr,
|
886 | 983 | #if !defined(ORT_MINIMAL_BUILD)
|
887 | 984 | ORT_RETURN_IF_ERROR(PartitionOnnxFormatModel(partition_params, mode,
|
888 | 985 | providers_, kernel_registry_mgr_));
|
| 986 | + |
| 987 | + bool ep_context_enabled = config_options.GetConfigOrDefault(kOrtSessionOptionEpContextEnable, "0") == "1"; |
| 988 | + std::string ep_context_path = config_options.GetConfigOrDefault(kOrtSessionOptionEpContextFilePath, ""); |
| 989 | + if (ep_context_enabled) { |
| 990 | + ORT_RETURN_IF_ERROR(CreateEpContextModel(providers_, graph, ep_context_path, logger)); |
| 991 | + } |
889 | 992 | #else
|
| 993 | + ORT_UNUSED_PARAMETER(config_options); |
| 994 | + ORT_UNUSED_PARAMETER(logger); |
890 | 995 | return ORT_MAKE_STATUS(ONNXRUNTIME, FAIL, "ONNX models are not supported in this build.");
|
891 | 996 | #endif //! defined(ORT_MINIMAL_BUILD)
|
892 | 997 | } else {
|
|
0 commit comments