Skip to content

Commit f3402de

Browse files
authored
[TensorRT EP] Enhance EP context configs in session options and provider options (#19154)
Several changes: 1. To align with other EPs' setting of EP context configs in session options, for example [QNN EP](#18877), EP context configs for TRT EP can be configured through: 1. Session Options: `ep.context_enable`, `ep.context_file_path` and `ep.context_embed_mode` 2. Provider Options: `trt_dump_ep_context_model`, `trt_ep_context_file_path` and `trt_dump_ep_context_embed_mode` 3. Above setting has 1:1 mapping and provider options has higher priority over session options. ``` Please note that there are rules for using following context model related provider options: 1. In the case of dumping the context model and loading the context model, for security reason, TRT EP doesn't allow the "ep_cache_context" node attribute of EP context node to be the absolute path or relative path that is outside of context model directory. It means engine cache needs to be in the same directory or sub-directory of context model. 2. In the case of dumping the context model, the engine cache path will be changed to the relative path of context model directory. For example: If "trt_dump_ep_context_model" is enabled and "trt_engine_cache_enable" is enabled, if "trt_ep_context_file_path" is "./context_model_dir", - if "trt_engine_cache_path" is "" -> the engine cache will be saved to "./context_model_dir" - if "trt_engine_cache_path" is "engine_dir" -> the engine cache will be saved to "./context_model_dir/engine_dir" ``` 2. User can decide the naming of the dumped "EP context" model by using `trt_ep_context_file_path`, please see GetCtxModelPath() for more details. 3. Added suggested comments from #18217
1 parent c8ce839 commit f3402de

12 files changed

+624
-162
lines changed

include/onnxruntime/core/providers/tensorrt/tensorrt_provider_options.h

+24-4
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,8 @@
1111
/// User can only get the instance of OrtTensorRTProviderOptionsV2 via CreateTensorRTProviderOptions.
1212
/// </summary>
1313
struct OrtTensorRTProviderOptionsV2 {
14+
OrtTensorRTProviderOptionsV2& operator=(const OrtTensorRTProviderOptionsV2& other); // copy assignment operator
15+
1416
int device_id{0}; // cuda device id.
1517
int has_user_compute_stream{0}; // indicator of user specified CUDA compute stream.
1618
void* user_compute_stream{nullptr}; // user specified CUDA compute stream.
@@ -46,8 +48,26 @@ struct OrtTensorRTProviderOptionsV2 {
4648
const char* trt_profile_max_shapes{nullptr}; // Specify the range of the input shapes to build the engine with
4749
const char* trt_profile_opt_shapes{nullptr}; // Specify the range of the input shapes to build the engine with
4850
int trt_cuda_graph_enable{0}; // Enable CUDA graph in ORT TRT
49-
int trt_dump_ep_context_model{0}; // Dump EP context node model
50-
int trt_ep_context_embed_mode{0}; // Specify EP context embed mode. Default 0 = context is engine cache path, 1 = context is engine binary data
51-
int trt_ep_context_compute_capability_enable{1}; // Add GPU compute capability as an EP context node's attribute
52-
const char* trt_engine_cache_prefix{nullptr}; // specify engine cache prefix
51+
52+
/*
53+
* Please note that there are rules for using following context model related provider options:
54+
*
55+
* 1. In the case of dumping the context model and loading the context model,
56+
* for security reason, TRT EP doesn't allow the "ep_cache_context" node attribute of EP context node to be
57+
* the absolute path or relative path that is outside of context model directory.
58+
* It means engine cache needs to be in the same directory or sub-directory of context model.
59+
*
60+
* 2. In the case of dumping the context model, the engine cache path will be changed to the relative path of context model directory.
61+
* For example:
62+
* If "trt_dump_ep_context_model" is enabled and "trt_engine_cache_enable" is enabled,
63+
* if "trt_ep_context_file_path" is "./context_model_dir",
64+
* - if "trt_engine_cache_path" is "" -> the engine cache will be saved to "./context_model_dir"
65+
* - if "trt_engine_cache_path" is "engine_dir" -> the engine cache will be saved to "./context_model_dir/engine_dir"
66+
*
67+
*/
68+
int trt_dump_ep_context_model{0}; // Dump EP context node model
69+
const char* trt_ep_context_file_path{nullptr}; // Specify file name to dump EP context node model. Can be a path or a file name or a file name with path.
70+
int trt_ep_context_embed_mode{0}; // Specify EP context embed mode. Default 0 = context is engine cache path, 1 = context is engine binary data
71+
72+
const char* trt_engine_cache_prefix{nullptr}; // specify engine cache prefix
5373
};

onnxruntime/core/providers/tensorrt/onnx_ctx_model_helper.cc

+155-56
Original file line numberDiff line numberDiff line change
@@ -38,13 +38,6 @@ const onnxruntime::Path& GetModelPath(const GraphViewer& graph_viewer) {
3838
return main_graph.ModelPath();
3939
}
4040

41-
std::filesystem::path LocateEngineRelativeToPath(std::string engine_cache_path, const onnxruntime::Path& path) {
42-
std::filesystem::path base_path(path.ToPathString());
43-
std::filesystem::path parent_path = base_path.parent_path();
44-
std::filesystem::path engine_path = parent_path.append(engine_cache_path);
45-
return engine_path;
46-
}
47-
4841
/*
4942
* Update ep_cache_context attribute of the EP context node with the given engine binary data
5043
*/
@@ -69,14 +62,13 @@ void UpdateCtxNodeModelEngineContext(ONNX_NAMESPACE::ModelProto* model_proto,
6962
/*
7063
* Create "EP context node" model where engine information is embedded
7164
*/
72-
ONNX_NAMESPACE::ModelProto* CreateCtxNodeModel(const GraphViewer& graph_viewer,
73-
const std::string engine_cache_path,
74-
char* engine_data,
75-
size_t size,
76-
const int64_t embed_mode,
77-
bool compute_capability_enable,
78-
std::string compute_capability,
79-
const logging::Logger* logger) {
65+
ONNX_NAMESPACE::ModelProto* CreateCtxModel(const GraphViewer& graph_viewer,
66+
const std::string engine_cache_path,
67+
char* engine_data,
68+
size_t size,
69+
const int64_t embed_mode,
70+
std::string compute_capability,
71+
const logging::Logger* logger) {
8072
auto model_build = graph_viewer.CreateModel(*logger);
8173
auto& graph_build = model_build->MainGraph();
8274

@@ -107,21 +99,20 @@ ONNX_NAMESPACE::ModelProto* CreateCtxNodeModel(const GraphViewer& graph_viewer,
10799
engine_data_str.assign(engine_data, size);
108100
}
109101
attr_1->set_s(engine_data_str);
102+
LOGS_DEFAULT(WARNING) << EPCONTEXT_WARNING;
110103
} else {
111104
attr_1->set_s(engine_cache_path);
112105
}
106+
attr_2->set_name(COMPUTE_CAPABILITY);
107+
attr_2->set_type(onnx::AttributeProto_AttributeType_STRING);
108+
attr_2->set_s(compute_capability);
109+
113110
auto node_attributes = ONNX_NAMESPACE::NodeAttributes::Create();
114-
int num_attributes = compute_capability_enable ? 3 : 2;
111+
int num_attributes = 3;
115112
node_attributes->reserve(num_attributes);
116113
node_attributes->emplace(EMBED_MODE, *attr_0);
117114
node_attributes->emplace(EP_CACHE_CONTEXT, *attr_1);
118-
119-
if (compute_capability_enable) {
120-
attr_2->set_name(COMPUTE_CAPABILITY);
121-
attr_2->set_type(onnx::AttributeProto_AttributeType_STRING);
122-
attr_2->set_s(compute_capability);
123-
node_attributes->emplace(COMPUTE_CAPABILITY, *attr_2);
124-
}
115+
node_attributes->emplace(COMPUTE_CAPABILITY, *attr_2);
125116

126117
// Create EP context node
127118
graph_build.AddNode(EPCONTEXT_OP, EPCONTEXT_OP, "", inputs, outputs, node_attributes.get(), EPCONTEXT_OP_DOMAIN);
@@ -138,14 +129,111 @@ ONNX_NAMESPACE::ModelProto* CreateCtxNodeModel(const GraphViewer& graph_viewer,
138129
}
139130

140131
/*
141-
* Dump "EP context node" model
132+
* Return the directory where the ep context model locates
133+
*/
134+
std::filesystem::path GetPathOrParentPathOfCtxModel(const std::string& ep_context_file_path) {
135+
if (ep_context_file_path.empty()) {
136+
return std::filesystem::path();
137+
}
138+
std::filesystem::path ctx_path(ep_context_file_path);
139+
if (std::filesystem::is_directory(ep_context_file_path)) {
140+
return ctx_path;
141+
} else {
142+
return ctx_path.parent_path();
143+
}
144+
}
145+
146+
/*
147+
* Get "EP context" model path.
148+
*
149+
* Function logic:
150+
* If ep_context_file_path is provided,
151+
* - If ep_context_file_path is a file, return "ep_context_file_path".
152+
* - If ep_context_file_path is a directory, return "ep_context_file_path/original_model_name_ctx.onnx".
153+
* If ep_context_file_path is not provided,
154+
* - Return "original_model_name_ctx.onnx".
155+
*
156+
* TRT EP has rules about context model path and engine cache path (see tensorrt_execution_provider.cc):
157+
* - If dump_ep_context_model_ and engine_cache_enabled_ is enabled, TRT EP will dump context model and save engine cache
158+
* to the same directory provided by ep_context_file_path_. (i.e. engine_cache_path_ = ep_context_file_path_)
159+
*
160+
* Example 1:
161+
* ep_context_file_path = "/home/user/ep_context_model_directory"
162+
* original_model_path = "model.onnx"
163+
* => return "/home/user/ep_context_model_folder/model_ctx.onnx"
164+
*
165+
* Example 2:
166+
* ep_context_file_path = "my_ctx_model.onnx"
167+
* original_model_path = "model.onnx"
168+
* => return "my_ctx_model.onnx"
169+
*
170+
* Example 3:
171+
* ep_context_file_path = "/home/user2/ep_context_model_directory/my_ctx_model.onnx"
172+
* original_model_path = "model.onnx"
173+
* => return "/home/user2/ep_context_model_directory/my_ctx_model.onnx"
174+
*
175+
*/
176+
std::string GetCtxModelPath(const std::string& ep_context_file_path,
177+
const std::string& original_model_path) {
178+
std::string ctx_model_path;
179+
180+
if (!ep_context_file_path.empty() && !std::filesystem::is_directory(ep_context_file_path)) {
181+
ctx_model_path = ep_context_file_path;
182+
} else {
183+
std::filesystem::path model_path = original_model_path;
184+
std::filesystem::path model_name_stem = model_path.stem(); // model_name.onnx -> model_name
185+
std::string ctx_model_name = model_name_stem.string() + "_ctx.onnx";
186+
187+
if (std::filesystem::is_directory(ep_context_file_path)) {
188+
std::filesystem::path model_directory = ep_context_file_path;
189+
ctx_model_path = model_directory.append(ctx_model_name).string();
190+
} else {
191+
ctx_model_path = ctx_model_name;
192+
}
193+
}
194+
return ctx_model_path;
195+
}
196+
197+
/*
198+
* Dump "EP context" model
142199
*
143200
*/
144-
void DumpCtxNodeModel(ONNX_NAMESPACE::ModelProto* model_proto,
145-
const std::string engine_cache_path) {
146-
std::fstream dump(engine_cache_path + "_wrapper.onnx", std::ios::out | std::ios::trunc | std::ios::binary);
201+
void DumpCtxModel(ONNX_NAMESPACE::ModelProto* model_proto,
202+
const std::string& ctx_model_path) {
203+
std::fstream dump(ctx_model_path, std::ios::out | std::ios::trunc | std::ios::binary);
147204
model_proto->SerializeToOstream(dump);
148-
LOGS_DEFAULT(VERBOSE) << "[TensorRT EP] Serialized " + engine_cache_path + "_wrapper.onnx";
205+
LOGS_DEFAULT(VERBOSE) << "[TensorRT EP] Dumped " + ctx_model_path;
206+
}
207+
208+
bool IsAbsolutePath(std::string& path_string) {
209+
#ifdef _WIN32
210+
onnxruntime::PathString ort_path_string = onnxruntime::ToPathString(path_string);
211+
auto path = std::filesystem::path(ort_path_string.c_str());
212+
return path.is_absolute();
213+
#else
214+
if (!path_string.empty() && path_string[0] == '/') {
215+
return true;
216+
}
217+
return false;
218+
#endif
219+
}
220+
221+
// Like "../file_path"
222+
bool IsRelativePathToParentPath(std::string& path_string) {
223+
#ifdef _WIN32
224+
onnxruntime::PathString ort_path_string = onnxruntime::ToPathString(path_string);
225+
auto path = std::filesystem::path(ort_path_string.c_str());
226+
auto relative_path = path.lexically_normal().make_preferred().wstring();
227+
if (relative_path.find(L"..", 0) != std::string::npos) {
228+
return true;
229+
}
230+
return false;
231+
#else
232+
if (!path_string.empty() && path_string.find("..", 0) != std::string::npos) {
233+
return true;
234+
}
235+
return false;
236+
#endif
149237
}
150238

151239
Status TensorRTCacheModelHandler::GetEpContextFromGraph(const GraphViewer& graph_viewer) {
@@ -157,7 +245,7 @@ Status TensorRTCacheModelHandler::GetEpContextFromGraph(const GraphViewer& graph
157245

158246
const int64_t embed_mode = attrs.at(EMBED_MODE).i();
159247
if (embed_mode) {
160-
// Get engine from byte stream
248+
// Get engine from byte stream.
161249
const std::string& context_binary = attrs.at(EP_CACHE_CONTEXT).s();
162250
*(trt_engine_) = std::unique_ptr<nvinfer1::ICudaEngine>(trt_runtime_->deserializeCudaEngine(const_cast<char*>(context_binary.c_str()),
163251
static_cast<size_t>(context_binary.length())));
@@ -167,19 +255,41 @@ Status TensorRTCacheModelHandler::GetEpContextFromGraph(const GraphViewer& graph
167255
"TensorRT EP could not deserialize engine from binary data");
168256
}
169257
} else {
170-
// Get engine from cache file
171-
std::ifstream engine_file(engine_cache_path_.string(), std::ios::binary | std::ios::in);
258+
// Get engine from cache file.
259+
std::string cache_path = attrs.at(EP_CACHE_CONTEXT).s();
260+
261+
// For security purpose, in the case of running context model, TRT EP won't allow
262+
// engine cache path to be the relative path like "../file_path" or the absolute path.
263+
// It only allows the engine cache to be in the same directory or sub directory of the context model.
264+
if (IsAbsolutePath(cache_path)) {
265+
return ORT_MAKE_STATUS(ONNXRUNTIME, EP_FAIL, "For security purpose, the ep_cache_context attribute should be set with a relative path, but it is an absolute path: " + cache_path);
266+
}
267+
if (IsRelativePathToParentPath(cache_path)) {
268+
return ORT_MAKE_STATUS(ONNXRUNTIME, EP_FAIL, "The file path in ep_cache_context attribute has '..'. For security purpose, it's not allowed to point outside the directory.");
269+
}
270+
271+
// The engine cache and context model (current model) should be in the same directory
272+
std::filesystem::path ctx_model_dir(GetPathOrParentPathOfCtxModel(ep_context_model_path_));
273+
auto engine_cache_path = ctx_model_dir.append(cache_path);
274+
275+
if (!std::filesystem::exists(engine_cache_path)) {
276+
return ORT_MAKE_STATUS(ONNXRUNTIME, EP_FAIL,
277+
"TensorRT EP can't find engine cache: " + engine_cache_path.string() +
278+
". Please make sure engine cache is in the same directory or sub-directory of context model.");
279+
}
280+
281+
std::ifstream engine_file(engine_cache_path.string(), std::ios::binary | std::ios::in);
172282
engine_file.seekg(0, std::ios::end);
173283
size_t engine_size = engine_file.tellg();
174284
engine_file.seekg(0, std::ios::beg);
175285
std::unique_ptr<char[]> engine_buf{new char[engine_size]};
176286
engine_file.read((char*)engine_buf.get(), engine_size);
177287
*(trt_engine_) = std::unique_ptr<nvinfer1::ICudaEngine>(trt_runtime_->deserializeCudaEngine(engine_buf.get(), engine_size));
178-
LOGS_DEFAULT(VERBOSE) << "[TensorRT EP] DeSerialized " + engine_cache_path_.string();
179288
if (!(*trt_engine_)) {
180289
return ORT_MAKE_STATUS(ONNXRUNTIME, EP_FAIL,
181-
"TensorRT EP could not deserialize engine from cache: " + engine_cache_path_.string());
290+
"TensorRT EP could not deserialize engine from cache: " + engine_cache_path.string());
182291
}
292+
LOGS_DEFAULT(VERBOSE) << "[TensorRT EP] DeSerialized " + engine_cache_path.string();
183293
}
184294
return Status::OK();
185295
}
@@ -193,37 +303,26 @@ bool TensorRTCacheModelHandler::ValidateEPCtxNode(const GraphViewer& graph_viewe
193303
auto node = graph_viewer.GetNode(0);
194304
auto& attrs = node->GetAttributes();
195305

196-
// Check hardware_architecture(compute_capability) if it's present as an attribute
306+
// Show the warning if compute capability is not matched
197307
if (attrs.count(COMPUTE_CAPABILITY) > 0) {
198308
std::string model_compute_capability = attrs.at(COMPUTE_CAPABILITY).s();
199309
if (model_compute_capability != compute_capability_) {
200-
LOGS_DEFAULT(ERROR) << "The compute capability of the engine cache doesn't match with the GPU's compute capability";
201-
LOGS_DEFAULT(ERROR) << "The compute capability of the engine cache: " << model_compute_capability;
202-
LOGS_DEFAULT(ERROR) << "The compute capability of the GPU: " << compute_capability_;
203-
return false;
310+
LOGS_DEFAULT(WARNING) << "[TensorRT EP] Engine was compiled for a different compatibility level and might not work or perform suboptimal";
311+
LOGS_DEFAULT(WARNING) << "[TensorRT EP] The compute capability of the engine: " << model_compute_capability;
312+
LOGS_DEFAULT(WARNING) << "[TensorRT EP] The compute capability of the GPU: " << compute_capability_;
204313
}
205314
}
206315

207316
// "embed_mode" attr and "ep_cache_context" attr should be present
208-
if (attrs.count(EMBED_MODE) > 0 && attrs.count(EP_CACHE_CONTEXT) > 0) {
209-
// ep_cache_context: payload of the execution provider context if embed_mode=1, or path to the context file if embed_mode=0
210-
const int64_t embed_mode = attrs.at(EMBED_MODE).i();
211-
212-
// engine cache path
213-
if (embed_mode == 0) {
214-
// First assume engine cache path is relatvie to model path,
215-
// If not, then assume the engine cache path is an absolute path.
216-
engine_cache_path_ = LocateEngineRelativeToPath(attrs.at(EP_CACHE_CONTEXT).s(), GetModelPath(graph_viewer));
217-
auto default_engine_cache_path_ = engine_cache_path_;
218-
if (!std::filesystem::exists(engine_cache_path_)) {
219-
engine_cache_path_.assign(attrs.at(EP_CACHE_CONTEXT).s());
220-
if (!std::filesystem::exists(engine_cache_path_)) {
221-
LOGS_DEFAULT(ERROR) << "Can't find " << default_engine_cache_path_.string() << " or " << engine_cache_path_.string() << " TensorRT engine";
222-
return false;
223-
}
224-
}
225-
}
317+
assert(attrs.count(EMBED_MODE) > 0);
318+
assert(attrs.count(EP_CACHE_CONTEXT) > 0);
319+
320+
const int64_t embed_mode = attrs.at(EMBED_MODE).i();
321+
if (embed_mode == 1) {
322+
// engine binary data
323+
LOGS_DEFAULT(WARNING) << EPCONTEXT_WARNING;
226324
}
325+
227326
return true;
228327
}
229328
} // namespace onnxruntime

0 commit comments

Comments
 (0)