diff --git a/cmake/onnxruntime_providers.cmake b/cmake/onnxruntime_providers.cmake index cce7eded7840d..465305d9903d3 100644 --- a/cmake/onnxruntime_providers.cmake +++ b/cmake/onnxruntime_providers.cmake @@ -730,6 +730,8 @@ if (onnxruntime_USE_TENSORRT) "${ONNXRUNTIME_ROOT}/core/providers/shared_library/*.cc" "${ONNXRUNTIME_ROOT}/core/providers/cuda/cuda_stream_handle.h" "${ONNXRUNTIME_ROOT}/core/providers/cuda/cuda_stream_handle.cc" + "${ONNXRUNTIME_ROOT}/core/providers/cuda/cuda_graph.h" + "${ONNXRUNTIME_ROOT}/core/providers/cuda/cuda_graph.cc" ) source_group(TREE ${ONNXRUNTIME_ROOT}/core FILES ${onnxruntime_providers_tensorrt_cc_srcs}) diff --git a/include/onnxruntime/core/providers/tensorrt/tensorrt_provider_options.h b/include/onnxruntime/core/providers/tensorrt/tensorrt_provider_options.h index 600e255bcdf9f..e7d0f9f03ade9 100644 --- a/include/onnxruntime/core/providers/tensorrt/tensorrt_provider_options.h +++ b/include/onnxruntime/core/providers/tensorrt/tensorrt_provider_options.h @@ -44,4 +44,5 @@ struct OrtTensorRTProviderOptionsV2 { const char* trt_profile_min_shapes; // Specify the range of the input shapes to build the engine with const char* trt_profile_max_shapes; // Specify the range of the input shapes to build the engine with const char* trt_profile_opt_shapes; // Specify the range of the input shapes to build the engine with + int trt_cuda_graph_enable; // Enable CUDA graph in ORT TRT }; diff --git a/onnxruntime/core/providers/tensorrt/tensorrt_execution_provider.cc b/onnxruntime/core/providers/tensorrt/tensorrt_execution_provider.cc index 8af8eb327f79c..5ba95b1440627 100644 --- a/onnxruntime/core/providers/tensorrt/tensorrt_execution_provider.cc +++ b/onnxruntime/core/providers/tensorrt/tensorrt_execution_provider.cc @@ -703,6 +703,7 @@ TensorrtExecutionProvider::TensorrtExecutionProvider(const TensorrtExecutionProv profile_min_shapes = info.profile_min_shapes; profile_max_shapes = info.profile_max_shapes; profile_opt_shapes = info.profile_opt_shapes; + cuda_graph_enable_ = info.cuda_graph_enable; } else { try { const std::string max_partition_iterations_env = onnxruntime::GetEnvironmentVar(tensorrt_env_vars::kMaxPartitionIterations); @@ -842,6 +843,11 @@ TensorrtExecutionProvider::TensorrtExecutionProvider(const TensorrtExecutionProv profile_min_shapes = onnxruntime::GetEnvironmentVar(tensorrt_env_vars::kProfilesMinShapes); profile_max_shapes = onnxruntime::GetEnvironmentVar(tensorrt_env_vars::kProfilesMaxShapes); profile_opt_shapes = onnxruntime::GetEnvironmentVar(tensorrt_env_vars::kProfilesOptShapes); + + const std::string cuda_graph_enable_env = onnxruntime::GetEnvironmentVar(tensorrt_env_vars::kCudaGraphEnable); + if (!cuda_graph_enable_env.empty()) { + cuda_graph_enable_ = (std::stoi(cuda_graph_enable_env) == 0 ? false : true); + } } catch (const std::invalid_argument& ex) { LOGS_DEFAULT(WARNING) << "[TensorRT EP] Invalid Argument (from environment variables): " << ex.what(); } catch (const std::out_of_range& ex) { @@ -895,6 +901,10 @@ TensorrtExecutionProvider::TensorrtExecutionProvider(const TensorrtExecutionProv int8_calibration_cache_available_ = !int8_calibration_cache_name_.empty(); } + if (cuda_graph_enable_) { + cuda_graph_ = std::make_unique(); + } + /* * Parse explicit min/max/opt profile shapes from provider options. * @@ -968,7 +978,8 @@ TensorrtExecutionProvider::TensorrtExecutionProvider(const TensorrtExecutionProv << ", trt_tactic_sources: " << tactic_sources_ << ", trt_profile_min_shapes: " << profile_min_shapes << ", trt_profile_max_shapes: " << profile_max_shapes - << ", trt_profile_opt_shapes: " << profile_opt_shapes; + << ", trt_profile_opt_shapes: " << profile_opt_shapes + << ", trt_cuda_graph_enable: " << cuda_graph_enable_; } TensorrtExecutionProvider::~TensorrtExecutionProvider() { @@ -982,6 +993,43 @@ TensorrtExecutionProvider::~TensorrtExecutionProvider() { ReleaseTensorRTCustomOpDomainList(info_.custom_op_domain_list); } +bool TensorrtExecutionProvider::IsGraphCaptureEnabled() const { + return cuda_graph_enable_; +} + +bool TensorrtExecutionProvider::IsGraphCaptureAllowed() const { + return regular_run_count_before_graph_capture_ >= min_num_runs_before_cuda_graph_capture_; +} + +void TensorrtExecutionProvider::CaptureBegin() { + cuda_graph_->Reset(); + cuda_graph_->CaptureBegin(); +} + +void TensorrtExecutionProvider::CaptureEnd() { + cuda_graph_->CaptureEnd(); + is_graph_captured_ = true; +} + +bool TensorrtExecutionProvider::IsGraphCaptured() const { + return is_graph_captured_; +} + +Status TensorrtExecutionProvider::ReplayGraph() { + ORT_ENFORCE(IsGraphCaptured()); + // Please note that CUDAGraph::Replay() is not thread safe. + // ORT TRT calls ReplayGraph() in compute_func() where synchromization is enforced due to lock_guard(), + // therefore calling CUDAGraph::Replay() here is guaranteed to be thread safe. + return cuda_graph_->Replay(); +} + +void TensorrtExecutionProvider::IncrementRegularRunCountBeforeGraphCapture() { + // Please note that this function is not thread safe. + // ORT TRT calls this function in compute_func() where synchronization is enforced due to lock_guard(), + // therefore following increment is guaranteed to be thread safe. + ++regular_run_count_before_graph_capture_; +} + std::vector TensorrtExecutionProvider::CreatePreferredAllocators() { AllocatorCreationInfo default_memory_info( [](OrtDevice::DeviceId device_id) { return CreateCUDAAllocator(device_id, onnxruntime::CUDA); }, device_id_); @@ -999,6 +1047,10 @@ std::unique_ptr TensorrtExecutionProvider::GetDataTransfer() cons return onnxruntime::CreateGPUDataTransfer(); } +Status TensorrtExecutionProvider::OnRunStart() { + return Status::OK(); +} + Status TensorrtExecutionProvider::OnRunEnd(bool sync_stream) { if (sync_stream && external_stream_) { CUDA_RETURN_IF_ERROR(cudaStreamSynchronize(stream_)); @@ -2737,6 +2789,15 @@ common::Status TensorrtExecutionProvider::Compile(const std::vectorsetDeviceMemory(IAllocator::MakeUniquePtrFromOrtAllocator(alloc, *max_context_mem_size_ptr).get()); } + // Start CUDA graph capture. + // Note: The reason we don't put graph capture in OnRunStart() like CUDA EP does is because + // current ORT TRT doesn't get cuda stream until compute time and graph capture requires cuda stream. + if (cuda_graph_enable_ && IsGraphCaptureAllowed() && !IsGraphCaptured()) { + LOGS_DEFAULT(INFO) << "Capturing the cuda graph for this model"; + cuda_graph_->SetStream(stream); + CaptureBegin(); + } + // Run TRT inference if (!trt_context->enqueueV2(&buffers[0], stream, nullptr)) { return ORT_MAKE_STATUS(ONNXRUNTIME, FAIL, "TensorRT EP execution context enqueue failed."); @@ -2764,6 +2825,23 @@ common::Status TensorrtExecutionProvider::Compile(const std::vector& fused_nodes_and_graphs, std::vector& node_compute_funcs) override; + Status OnRunStart() override; Status OnRunEnd(bool sync_stream) override; ProviderOptions GetProviderOptions() const override { @@ -167,6 +171,10 @@ class TensorrtExecutionProvider : public IExecutionProvider { std::vector CreatePreferredAllocators() override; + bool IsGraphCaptureEnabled() const override; + bool IsGraphCaptured() const override; + Status ReplayGraph() override; + private: TensorrtExecutionProviderInfo info_; bool external_stream_ = false; @@ -204,6 +212,12 @@ class TensorrtExecutionProvider : public IExecutionProvider { bool timing_cache_enable_ = false; bool force_timing_cache_match_ = false; bool detailed_build_log_ = false; + bool cuda_graph_enable_ = false; + + std::unique_ptr cuda_graph_; // ORT TRT only supports CUDA graph when whole model is supported by TRT, so simply maintaining a CUDAGraph pointer is enough (no need to maintain one CUDAGraph pointer per TRT subgraph) + bool is_graph_captured_ = false; + int regular_run_count_before_graph_capture_ = 0; + const int min_num_runs_before_cuda_graph_capture_ = 1; // required min regular runs before graph capture for the necessary memory allocations. std::unordered_set control_flow_op_set_ = {"If", "Loop", "Scan"}; std::unordered_map> parsers_; @@ -254,5 +268,10 @@ class TensorrtExecutionProvider : public IExecutionProvider { /**Check whether all the nodes of subgraph are supported*/ bool IsSubGraphFullySupported(SubGraphCollection_t supported_nodes_vector, const int number_of_ort_nodes) const; + + bool IsGraphCaptureAllowed() const; + void CaptureBegin(); + void CaptureEnd(); + void IncrementRegularRunCountBeforeGraphCapture(); }; } // namespace onnxruntime diff --git a/onnxruntime/core/providers/tensorrt/tensorrt_execution_provider_info.cc b/onnxruntime/core/providers/tensorrt/tensorrt_execution_provider_info.cc index 217ded637dcfb..d47af40b4c7a2 100644 --- a/onnxruntime/core/providers/tensorrt/tensorrt_execution_provider_info.cc +++ b/onnxruntime/core/providers/tensorrt/tensorrt_execution_provider_info.cc @@ -43,6 +43,7 @@ constexpr const char* kExtraPluginLibPaths = "trt_extra_plugin_lib_paths"; constexpr const char* kProfilesMinShapes = "trt_profile_min_shapes"; constexpr const char* kProfilesMaxShapes = "trt_profile_max_shapes"; constexpr const char* kProfilesOptShapes = "trt_profile_opt_shapes"; +constexpr const char* kCudaGraphEnable = "trt_cuda_graph_enable"; } // namespace provider_option_names } // namespace tensorrt @@ -91,6 +92,7 @@ TensorrtExecutionProviderInfo TensorrtExecutionProviderInfo::FromProviderOptions .AddAssignmentToReference(tensorrt::provider_option_names::kProfilesMinShapes, info.profile_min_shapes) .AddAssignmentToReference(tensorrt::provider_option_names::kProfilesMaxShapes, info.profile_max_shapes) .AddAssignmentToReference(tensorrt::provider_option_names::kProfilesOptShapes, info.profile_opt_shapes) + .AddAssignmentToReference(tensorrt::provider_option_names::kCudaGraphEnable, info.cuda_graph_enable) .Parse(options)); // add new provider option here. return info; @@ -129,6 +131,7 @@ ProviderOptions TensorrtExecutionProviderInfo::ToProviderOptions(const TensorrtE {tensorrt::provider_option_names::kProfilesMinShapes, MakeStringWithClassicLocale(info.profile_min_shapes)}, {tensorrt::provider_option_names::kProfilesMaxShapes, MakeStringWithClassicLocale(info.profile_max_shapes)}, {tensorrt::provider_option_names::kProfilesOptShapes, MakeStringWithClassicLocale(info.profile_opt_shapes)}, + {tensorrt::provider_option_names::kCudaGraphEnable, MakeStringWithClassicLocale(info.cuda_graph_enable)}, }; return options; } @@ -175,6 +178,7 @@ ProviderOptions TensorrtExecutionProviderInfo::ToProviderOptions(const OrtTensor {tensorrt::provider_option_names::kProfilesMinShapes, kProfilesMinShapes_}, {tensorrt::provider_option_names::kProfilesMaxShapes, kProfilesMaxShapes_}, {tensorrt::provider_option_names::kProfilesOptShapes, kProfilesOptShapes_}, + {tensorrt::provider_option_names::kCudaGraphEnable, MakeStringWithClassicLocale(info.trt_cuda_graph_enable)}, }; return options; } diff --git a/onnxruntime/core/providers/tensorrt/tensorrt_execution_provider_info.h b/onnxruntime/core/providers/tensorrt/tensorrt_execution_provider_info.h index 7235bb6940f9c..4fb9837e1c040 100644 --- a/onnxruntime/core/providers/tensorrt/tensorrt_execution_provider_info.h +++ b/onnxruntime/core/providers/tensorrt/tensorrt_execution_provider_info.h @@ -49,6 +49,7 @@ struct TensorrtExecutionProviderInfo { std::string profile_min_shapes{""}; std::string profile_max_shapes{""}; std::string profile_opt_shapes{""}; + bool cuda_graph_enable{false}; static TensorrtExecutionProviderInfo FromProviderOptions(const ProviderOptions& options); static ProviderOptions ToProviderOptions(const TensorrtExecutionProviderInfo& info); diff --git a/onnxruntime/core/providers/tensorrt/tensorrt_provider_factory.cc b/onnxruntime/core/providers/tensorrt/tensorrt_provider_factory.cc index 5a1b662078e90..ed5cca93f74d4 100644 --- a/onnxruntime/core/providers/tensorrt/tensorrt_provider_factory.cc +++ b/onnxruntime/core/providers/tensorrt/tensorrt_provider_factory.cc @@ -94,6 +94,7 @@ struct Tensorrt_Provider : Provider { info.profile_min_shapes = options.trt_profile_min_shapes == nullptr ? "" : options.trt_profile_min_shapes; info.profile_max_shapes = options.trt_profile_max_shapes == nullptr ? "" : options.trt_profile_max_shapes; info.profile_opt_shapes = options.trt_profile_opt_shapes == nullptr ? "" : options.trt_profile_opt_shapes; + info.cuda_graph_enable = options.trt_cuda_graph_enable != 0; common::Status status = CreateTensorRTCustomOpDomainList(info); if (!status.IsOK()) { @@ -229,6 +230,8 @@ struct Tensorrt_Provider : Provider { dest[str_size] = '\0'; trt_options.trt_profile_opt_shapes = (const char*)dest; } + + trt_options.trt_cuda_graph_enable = internal_options.cuda_graph_enable; } ProviderOptions GetProviderOptions(const void* provider_options) override { diff --git a/onnxruntime/core/session/inference_session.cc b/onnxruntime/core/session/inference_session.cc index 18dc911c7236f..99308e0df6e3b 100644 --- a/onnxruntime/core/session/inference_session.cc +++ b/onnxruntime/core/session/inference_session.cc @@ -177,6 +177,18 @@ std::pair AreAllComputeNodesAssignedToCudaEp(const Graph& graph) { return std::make_pair(true, static_cast(shape_nodes.size())); } +bool AreAllNodesInMainGraphAssignedToOneEp(const Graph& graph, ProviderType provider) { + for (const auto& node : graph.Nodes()) { + const auto& node_provider = node.GetExecutionProviderType(); + + if (node_provider.empty() || node_provider != provider) { + return false; + } + } + + return true; +} + bool HasControlflowNodes(const Graph& graph) { for (const auto& node : graph.Nodes()) { if (node.ContainsSubgraph()) { @@ -1554,51 +1566,82 @@ common::Status InferenceSession::Initialize() { // now that all the transforms are done, call Resolve on the main graph. this will recurse into the subgraphs. ORT_RETURN_IF_ERROR_SESSIONID_(graph.Resolve()); - // Currently only the CUDA EP is considered. + // Currently CUDA graph is only considered by CUDA EP and TRT EP. + // + // Check for CUDA EP: // If the CUDA EP is part of the providers list for this session AND // The CUDA EP is configured to do a graph capture AND - // All the graph nodes have been assigned to the CUDA EP, + // All the "compute" graph nodes have been assigned to the CUDA EP, // Then the CUDA EP is cached for triggering a ReplayGraph() in Run(). - auto* cuda_ep = execution_providers_.Get(onnxruntime::kCudaExecutionProvider); - if (cuda_ep && cuda_ep->IsGraphCaptureEnabled()) { - if (HasControlflowNodes(graph)) { - LOGS(*session_logger_, ERROR) << "This session cannot use the CUDA Graph feature as requested by the user " - << " as the model has control flow nodes which can't be supported by CUDA Graphs."; - - // Return error status as we don't want the session initialization to complete successfully - // if the user has requested usage of CUDA Graph feature and we cannot honor that. - ORT_RETURN_IF_ERROR_SESSIONID_( - ORT_MAKE_STATUS(ONNXRUNTIME, FAIL, - "This session cannot use the CUDA Graph feature as requested by the user " - " as the model has control flow nodes which can't be supported by CUDA Graphs.")); - } + // + // Check for TRT EP: + // If the TRT EP is part of the providers list for this session AND + // The TRT EP is configured to do a graph capture AND + // All the graph nodes have been assigned to the TRT EP, + // Then the TRT EP is cached for triggering a ReplayGraph() in Run(). + std::vector cuda_graph_support_ep_list = {onnxruntime::kTensorrtExecutionProvider, onnxruntime::kCudaExecutionProvider}; + + for (auto& it : cuda_graph_support_ep_list) { + auto* target_ep = execution_providers_.Get(it); + + if (target_ep && target_ep->IsGraphCaptureEnabled()) { + if (HasControlflowNodes(graph)) { + LOGS(*session_logger_, ERROR) << "This session cannot use the CUDA Graph feature as requested by the user " + << "as the model has control flow nodes which can't be supported by CUDA Graphs."; + + // Return error status as we don't want the session initialization to complete successfully + // if the user has requested usage of CUDA Graph feature and we cannot honor that. + ORT_RETURN_IF_ERROR_SESSIONID_( + ORT_MAKE_STATUS(ONNXRUNTIME, FAIL, + "This session cannot use the CUDA Graph feature as requested by the user " + "as the model has control flow nodes which can't be supported by CUDA Graphs.")); + } - auto res = AreAllComputeNodesAssignedToCudaEp(graph); + if (strcmp(target_ep->Type().c_str(), onnxruntime::kCudaExecutionProvider) == 0) { + auto res = AreAllComputeNodesAssignedToCudaEp(graph); - if (!res.first) { - LOGS(*session_logger_, ERROR) << "This session cannot use the CUDA Graph feature as requested by the user " - << " as all compute graph nodes have not been partitioned to the CUDA EP."; + if (!res.first) { + LOGS(*session_logger_, ERROR) << "This session cannot use the CUDA Graph feature as requested by the user " + << " as all compute graph nodes have not been partitioned to the CUDA EP."; - // Return error status as we don't want the session initialization to complete successfully - // if the user has requested usage of CUDA Graph feature and we cannot honor that. - ORT_RETURN_IF_ERROR_SESSIONID_( - ORT_MAKE_STATUS(ONNXRUNTIME, FAIL, - "This session cannot use the CUDA Graph feature as requested by the user " - " as all compute graph nodes have not been partitioned to the CUDA EP.")); - } + // Return error status as we don't want the session initialization to complete successfully + // if the user has requested usage of CUDA Graph feature and we cannot honor that. + ORT_RETURN_IF_ERROR_SESSIONID_( + ORT_MAKE_STATUS(ONNXRUNTIME, FAIL, + "This session cannot use the CUDA Graph feature as requested by the user " + " as all compute graph nodes have not been partitioned to the CUDA EP.")); + } - if (res.second > 0) { - LOGS(*session_logger_, WARNING) << "This model has shape massaging nodes that will execute on CPU. " - << "Use the CUDA Graph feature with caution. " - << "As long as the intermediate shapes produced in the model " - << "using the representative input used to capture the CUDA graph, " - << "will match the shapes produced in the model for other inputs " - << "of the same shape as the representative input (common case), " - << "it is safe to use the CUDA Graph feature."; - } + if (res.second > 0) { + LOGS(*session_logger_, WARNING) << "This model has shape massaging nodes that will execute on CPU. " + << "Use the CUDA Graph feature with caution. " + << "As long as the intermediate shapes produced in the model " + << "using the representative input used to capture the CUDA graph, " + << "will match the shapes produced in the model for other inputs " + << "of the same shape as the representative input (common case), " + << "it is safe to use the CUDA Graph feature."; + } + } else { + // Following code path is for TRT EP currently. + if (!AreAllNodesInMainGraphAssignedToOneEp(graph, target_ep->Type())) { + LOGS(*session_logger_, ERROR) << "This session cannot use the CUDA Graph feature as requested by the user " + << "as all the graph nodes have not been assigned to " + << target_ep->Type(); + + // Return error status as we don't want the session initialization to complete successfully + // if the user has requested usage of CUDA Graph feature and we cannot honor that. + ORT_RETURN_IF_ERROR_SESSIONID_( + ORT_MAKE_STATUS(ONNXRUNTIME, FAIL, + "This session cannot use the CUDA Graph feature as requested by the user " + "as all the graph nodes have not been assigned to " + + target_ep->Type())); + } + } - LOGS(*session_logger_, INFO) << "This session will use the CUDA Graph feature as requested by the user."; - cached_execution_provider_for_graph_replay_.SetExecutionProvider(cuda_ep); + LOGS(*session_logger_, INFO) << "This session will use the CUDA Graph feature as requested by the user."; + cached_execution_provider_for_graph_replay_.SetExecutionProvider(target_ep); + break; // Make sure only one ep can run CUDA graph. + } } const bool disable_cpu_ep_fallback = session_options_.config_options.GetConfigOrDefault( diff --git a/onnxruntime/core/session/provider_bridge_ort.cc b/onnxruntime/core/session/provider_bridge_ort.cc index c79de8105c039..c8508b1de6e96 100644 --- a/onnxruntime/core/session/provider_bridge_ort.cc +++ b/onnxruntime/core/session/provider_bridge_ort.cc @@ -1371,6 +1371,7 @@ OrtTensorRTProviderOptionsV2 OrtTensorRTProviderOptionsToOrtTensorRTProviderOpti trt_options_converted.trt_profile_min_shapes = ""; trt_options_converted.trt_profile_max_shapes = ""; trt_options_converted.trt_profile_opt_shapes = ""; + trt_options_converted.trt_cuda_graph_enable = 0; return trt_options_converted; } @@ -1727,6 +1728,7 @@ ORT_API_STATUS_IMPL(OrtApis::CreateTensorRTProviderOptions, _Outptr_ OrtTensorRT options->trt_profile_min_shapes = nullptr; options->trt_profile_max_shapes = nullptr; options->trt_profile_opt_shapes = nullptr; + options->trt_cuda_graph_enable = false; *out = options.release(); return nullptr; #else diff --git a/onnxruntime/python/onnxruntime_pybind_state.cc b/onnxruntime/python/onnxruntime_pybind_state.cc index 2c6b4c8db3ff6..69eac91501cd5 100644 --- a/onnxruntime/python/onnxruntime_pybind_state.cc +++ b/onnxruntime/python/onnxruntime_pybind_state.cc @@ -384,7 +384,7 @@ std::unique_ptr CreateExecutionProviderInstance( nullptr, nullptr, nullptr, - nullptr}; + 0}; for (auto option : it->second) { if (option.first == "device_id") { if (!option.second.empty()) { @@ -602,6 +602,14 @@ std::unique_ptr CreateExecutionProviderInstance( } else { ORT_THROW("[ERROR] [TensorRT] The value for the key 'trt_profile_opt_shapes' should be a string of 'input1:dim1xdimd2...,input2:dim1xdim2...,...'.\n"); } + } else if (option.first == "trt_cuda_graph_enable") { + if (option.second == "True" || option.second == "true") { + params.trt_cuda_graph_enable = true; + } else if (option.second == "False" || option.second == "false") { + params.trt_cuda_graph_enable = false; + } else { + ORT_THROW("[ERROR] [TensorRT] The value for the key 'trt_cuda_graph_enable' should be 'True' or 'False'. Default value is 'False'.\n"); + } } else { ORT_THROW("Invalid TensorRT EP option: ", option.first); } diff --git a/onnxruntime/test/perftest/ort_test_session.cc b/onnxruntime/test/perftest/ort_test_session.cc index 62129a1f4d444..ace59982633b7 100644 --- a/onnxruntime/test/perftest/ort_test_session.cc +++ b/onnxruntime/test/perftest/ort_test_session.cc @@ -133,6 +133,7 @@ OnnxRuntimeTestSession::OnnxRuntimeTestSession(Ort::Env& env, std::random_device std::string trt_profile_min_shapes = ""; std::string trt_profile_max_shapes = ""; std::string trt_profile_opt_shapes = ""; + bool trt_cuda_graph_enable = false; #ifdef _MSC_VER std::string ov_string = ToUTF8String(performance_test_config.run_config.ep_runtime_config_string); @@ -362,8 +363,16 @@ OnnxRuntimeTestSession::OnnxRuntimeTestSession(Ort::Env& env, std::random_device } else { ORT_THROW("[ERROR] [TensorRT] The value for the key 'trt_profile_opt_shapes' should be a non-empty string.\n"); } + } else if (key == "trt_cuda_graph_enable") { + if (value == "true" || value == "True") { + trt_cuda_graph_enable = true; + } else if (value == "false" || value == "False") { + trt_cuda_graph_enable = false; + } else { + ORT_THROW("[ERROR] [TensorRT] The value for the key 'trt_cuda_graph_enable' should be a boolean i.e. true or false. Default value is false.\n"); + } } else { - ORT_THROW("[ERROR] [TensorRT] wrong key type entered. Choose from the following runtime key options that are available for TensorRT. ['device_id', 'trt_max_partition_iterations', 'trt_min_subgraph_size', 'trt_max_workspace_size', 'trt_fp16_enable', 'trt_int8_enable', 'trt_int8_calibration_table_name', 'trt_int8_use_native_calibration_table', 'trt_dla_enable', 'trt_dla_core', 'trt_dump_subgraphs', 'trt_engine_cache_enable', 'trt_engine_cache_path', 'trt_engine_decryption_enable', 'trt_engine_decryption_lib_path', 'trt_force_sequential_engine_build', 'trt_context_memory_sharing_enable', 'trt_layer_norm_fp32_fallback', 'trt_timing_cache_enable', 'trt_force_timing_cache', 'trt_detailed_build_log', 'trt_build_heuristics_enable', 'trt_sparsity_enable', 'trt_builder_optimization_level', 'trt_auxiliary_streams', 'trt_tactic_sources', 'trt_extra_plugin_lib_paths', 'trt_profile_min_shapes', 'trt_profile_max_shapes', 'trt_profile_opt_shapes'] \n"); + ORT_THROW("[ERROR] [TensorRT] wrong key type entered. Choose from the following runtime key options that are available for TensorRT. ['device_id', 'trt_max_partition_iterations', 'trt_min_subgraph_size', 'trt_max_workspace_size', 'trt_fp16_enable', 'trt_int8_enable', 'trt_int8_calibration_table_name', 'trt_int8_use_native_calibration_table', 'trt_dla_enable', 'trt_dla_core', 'trt_dump_subgraphs', 'trt_engine_cache_enable', 'trt_engine_cache_path', 'trt_engine_decryption_enable', 'trt_engine_decryption_lib_path', 'trt_force_sequential_engine_build', 'trt_context_memory_sharing_enable', 'trt_layer_norm_fp32_fallback', 'trt_timing_cache_enable', 'trt_force_timing_cache', 'trt_detailed_build_log', 'trt_build_heuristics_enable', 'trt_sparsity_enable', 'trt_builder_optimization_level', 'trt_auxiliary_streams', 'trt_tactic_sources', 'trt_extra_plugin_lib_paths', 'trt_profile_min_shapes', 'trt_profile_max_shapes', 'trt_profile_opt_shapes', 'trt_cuda_graph_enable'] \n"); } } OrtTensorRTProviderOptionsV2 tensorrt_options; @@ -399,6 +408,7 @@ OnnxRuntimeTestSession::OnnxRuntimeTestSession(Ort::Env& env, std::random_device tensorrt_options.trt_profile_min_shapes = trt_profile_min_shapes.c_str(); tensorrt_options.trt_profile_max_shapes = trt_profile_max_shapes.c_str(); tensorrt_options.trt_profile_opt_shapes = trt_profile_opt_shapes.c_str(); + tensorrt_options.trt_cuda_graph_enable = trt_cuda_graph_enable; session_options.AppendExecutionProvider_TensorRT_V2(tensorrt_options); diff --git a/onnxruntime/test/providers/cpu/model_tests.cc b/onnxruntime/test/providers/cpu/model_tests.cc index 80b48c56b0ff6..ff9137ec0535d 100644 --- a/onnxruntime/test/providers/cpu/model_tests.cc +++ b/onnxruntime/test/providers/cpu/model_tests.cc @@ -710,7 +710,7 @@ TEST_P(ModelTest, Run) { OrtTensorRTProviderOptionsV2 params{0, 0, nullptr, 1000, 1, 1 << 30, 1, // enable fp16 0, nullptr, 0, 0, 0, 0, 0, nullptr, 0, nullptr, 0, 0, 0, 0, 0, 0, 0, 0, - 3, -1, nullptr, nullptr, nullptr, nullptr, nullptr}; + 3, -1, nullptr, nullptr, nullptr, nullptr, nullptr, 0}; ortso.AppendExecutionProvider_TensorRT_V2(params); } else { diff --git a/onnxruntime/test/providers/tensorrt/tensorrt_basic_test.cc b/onnxruntime/test/providers/tensorrt/tensorrt_basic_test.cc index 32bd8f6ec144b..1601eeff6438e 100644 --- a/onnxruntime/test/providers/tensorrt/tensorrt_basic_test.cc +++ b/onnxruntime/test/providers/tensorrt/tensorrt_basic_test.cc @@ -164,7 +164,8 @@ void RunWithOneSessionSingleThreadInference(std::string model_name, std::string nullptr, nullptr, nullptr, - nullptr}; + nullptr, + 0}; params.trt_engine_cache_enable = 1; std::unique_ptr execution_provider = TensorrtExecutionProviderWithOptions(¶ms); @@ -246,7 +247,8 @@ void RunWithOneSessionMultiThreadsInference(std::string model_name, std::string nullptr, nullptr, nullptr, - nullptr}; + nullptr, + 0}; params.trt_engine_cache_enable = 1; std::unique_ptr execution_provider = TensorrtExecutionProviderWithOptions(¶ms); @@ -399,7 +401,8 @@ TEST(TensorrtExecutionProviderTest, TRTPluginsCustomOpTest) { nullptr, nullptr, nullptr, - nullptr}; + nullptr, + 0}; std::unique_ptr execution_provider = TensorrtExecutionProviderWithOptions(¶ms); EXPECT_TRUE(session_object.RegisterExecutionProvider(std::move(execution_provider)).IsOK()); @@ -493,7 +496,8 @@ TEST_P(TensorrtExecutionProviderCacheTest, Run) { nullptr, nullptr, nullptr, - nullptr}; + nullptr, + 0}; if (cache_type.compare("engine") == 0) { /* Following code block tests the functionality of engine and optimization profile of ORT TRT, including: diff --git a/onnxruntime/test/python/onnxruntime_test_python_cudagraph.py b/onnxruntime/test/python/onnxruntime_test_python_cudagraph.py index 30e299863f4f4..a322ebe93ad44 100644 --- a/onnxruntime/test/python/onnxruntime_test_python_cudagraph.py +++ b/onnxruntime/test/python/onnxruntime_test_python_cudagraph.py @@ -61,7 +61,7 @@ def get_output(self, output_name: str): class TestInferenceSessionWithCudaGraph(unittest.TestCase): - def testOrtValueUpdateInPlace(self): # noqa: N802 + def test_ort_value_update_in_place(self): x0 = np.array([[1.0, 2.0], [3.0, 4.0], [5.0, 6.0]], dtype=np.float32) ortvalue_cpu = onnxrt.OrtValue.ortvalue_from_numpy(x0) np.testing.assert_allclose(x0, ortvalue_cpu.numpy()) @@ -77,45 +77,52 @@ def testOrtValueUpdateInPlace(self): # noqa: N802 ortvalue_gpu.update_inplace(x1) np.testing.assert_allclose(x1, ortvalue_gpu.numpy()) - def testRunModelWithCudaGraph(self): # noqa: N802 - if "CUDAExecutionProvider" in onnxrt.get_available_providers(): + def test_select_ep_to_run_cuda_graph(self): + if "TensorrtExecutionProvider" in onnxrt.get_available_providers(): + providers = [("TensorrtExecutionProvider", {"trt_cuda_graph_enable": True})] + self.run_model_with_cuda_graph(providers) + elif "CUDAExecutionProvider" in onnxrt.get_available_providers(): providers = [("CUDAExecutionProvider", {"enable_cuda_graph": True})] - INPUT_SIZE = 1280 # noqa: N806 - x = np.array([[1.0, 2.0], [3.0, 4.0], [5.0, 6.0]] * INPUT_SIZE, dtype=np.float32) - y = np.array([[0.0], [0.0], [0.0]] * INPUT_SIZE, dtype=np.float32) - x_ortvalue = onnxrt.OrtValue.ortvalue_from_numpy(x, "cuda", 0) - y_ortvalue = onnxrt.OrtValue.ortvalue_from_numpy(y, "cuda", 0) - - session = onnxrt.InferenceSession(get_name("matmul_2.onnx"), providers=providers) - io_binding = session.io_binding() - - # Bind the input and output - io_binding.bind_ortvalue_input("X", x_ortvalue) - io_binding.bind_ortvalue_output("Y", y_ortvalue) - - # One regular run for the necessary memory allocation and cuda graph capturing - session.run_with_iobinding(io_binding) - expected_y = np.array([[5.0], [11.0], [17.0]] * INPUT_SIZE, dtype=np.float32) - np.testing.assert_allclose(expected_y, y_ortvalue.numpy(), rtol=1e-05, atol=1e-05) - - # After capturing, CUDA graph replay happens from this Run onwards - session.run_with_iobinding(io_binding) - np.testing.assert_allclose(expected_y, y_ortvalue.numpy(), rtol=1e-05, atol=1e-05) - - # Update input and then replay CUDA graph - x_ortvalue.update_inplace( - np.array( - [[10.0, 20.0], [30.0, 40.0], [50.0, 60.0]] * INPUT_SIZE, - dtype=np.float32, - ) - ) - session.run_with_iobinding(io_binding) - np.testing.assert_allclose( - np.array([[50.0], [110.0], [170.0]] * INPUT_SIZE, dtype=np.float32), - y_ortvalue.numpy(), - rtol=1e-05, - atol=1e-05, + self.run_model_with_cuda_graph(providers) + + def run_model_with_cuda_graph(self, providers): + INPUT_SIZE = 1280 # noqa: N806 + x = np.array([[1.0, 2.0], [3.0, 4.0], [5.0, 6.0]] * INPUT_SIZE, dtype=np.float32) + y = np.array([[0.0], [0.0], [0.0]] * INPUT_SIZE, dtype=np.float32) + x_ortvalue = onnxrt.OrtValue.ortvalue_from_numpy(x, "cuda", 0) + y_ortvalue = onnxrt.OrtValue.ortvalue_from_numpy(y, "cuda", 0) + + onnxrt.set_default_logger_severity(0) + session = onnxrt.InferenceSession(get_name("matmul_2.onnx"), providers=providers) + io_binding = session.io_binding() + + # Bind the input and output + io_binding.bind_ortvalue_input("X", x_ortvalue) + io_binding.bind_ortvalue_output("Y", y_ortvalue) + + # One regular run for the necessary memory allocation and cuda graph capturing + session.run_with_iobinding(io_binding) + expected_y = np.array([[5.0], [11.0], [17.0]] * INPUT_SIZE, dtype=np.float32) + np.testing.assert_allclose(expected_y, y_ortvalue.numpy(), rtol=1e-05, atol=1e-05) + + # After capturing, CUDA graph replay happens from this Run onwards + session.run_with_iobinding(io_binding) + np.testing.assert_allclose(expected_y, y_ortvalue.numpy(), rtol=1e-05, atol=1e-05) + + # Update input and then replay CUDA graph + x_ortvalue.update_inplace( + np.array( + [[10.0, 20.0], [30.0, 40.0], [50.0, 60.0]] * INPUT_SIZE, + dtype=np.float32, ) + ) + session.run_with_iobinding(io_binding) + np.testing.assert_allclose( + np.array([[50.0], [110.0], [170.0]] * INPUT_SIZE, dtype=np.float32), + y_ortvalue.numpy(), + rtol=1e-05, + atol=1e-05, + ) def testArenaWithCudaGraph(self): # noqa: N802 if "CUDAExecutionProvider" in onnxrt.get_available_providers(): diff --git a/onnxruntime/test/shared_lib/test_inference.cc b/onnxruntime/test/shared_lib/test_inference.cc index bb438fd0edb55..ce55625c4b3eb 100644 --- a/onnxruntime/test/shared_lib/test_inference.cc +++ b/onnxruntime/test/shared_lib/test_inference.cc @@ -1744,10 +1744,25 @@ TEST(CApiTest, io_binding_cuda) { } #endif -#if defined(USE_CUDA) +#if defined(USE_CUDA) || defined(USE_TENSORRT) TEST(CApiTest, basic_cuda_graph) { const auto& api = Ort::GetApi(); + Ort::SessionOptions session_options; +#if defined(USE_TENSORRT) + // Enable cuda graph in TRT provider option. + OrtTensorRTProviderOptionsV2* trt_options; + ASSERT_TRUE(api.CreateTensorRTProviderOptions(&trt_options) == nullptr); + std::unique_ptr + rel_trt_options(trt_options, api.ReleaseTensorRTProviderOptions); + std::vector keys{"trt_cuda_graph_enable"}; + std::vector values{"1"}; + ASSERT_TRUE(api.UpdateTensorRTProviderOptions(rel_trt_options.get(), keys.data(), values.data(), keys.size()) == nullptr); + + ASSERT_TRUE(api.SessionOptionsAppendExecutionProvider_TensorRT_V2( + static_cast(session_options), + rel_trt_options.get()) == nullptr); +#else // Enable cuda graph in cuda provider option. OrtCUDAProviderOptionsV2* cuda_options = nullptr; ASSERT_TRUE(api.CreateCUDAProviderOptions(&cuda_options) == nullptr); @@ -1757,10 +1772,10 @@ TEST(CApiTest, basic_cuda_graph) { std::vector values{"1"}; ASSERT_TRUE(api.UpdateCUDAProviderOptions(rel_cuda_options.get(), keys.data(), values.data(), 1) == nullptr); - Ort::SessionOptions session_options; ASSERT_TRUE(api.SessionOptionsAppendExecutionProvider_CUDA_V2( static_cast(session_options), rel_cuda_options.get()) == nullptr); +#endif Ort::Session session(*ort_env, MODEL_URI, session_options); Ort::MemoryInfo info_cuda("Cuda", OrtAllocatorType::OrtArenaAllocator, 0, OrtMemTypeDefault);