Skip to content

Commit 930321e

Browse files
committed
fix: Workspace defaults for other apis and centralize cuda api use
Signed-off-by: Naren Dasan <[email protected]> Signed-off-by: Naren Dasan <[email protected]>
1 parent 832b1c7 commit 930321e

File tree

1 file changed

+21
-11
lines changed

1 file changed

+21
-11
lines changed

Diff for: core/compiler.cpp

+21-11
Original file line numberDiff line numberDiff line change
@@ -341,6 +341,14 @@ void MapInputsAndDetermineDTypes(
341341
}
342342
}
343343

344+
uint64_t GetRecommendedWorkspaceSize(const runtime::CudaDevice& device) {
345+
if (device.major < 6) {
346+
return 256 * (1 << 20);
347+
} else {
348+
return 1 << 30;
349+
}
350+
}
351+
344352
std::string ConvertGraphToTRTEngine(const torch::jit::script::Module& mod, std::string method_name, CompileSpec cfg) {
345353
// Go through Lowering to simplify graph and extract weight parameters
346354
auto graph_and_parameters = lowering::Lower(mod, method_name, cfg.lower_info);
@@ -354,6 +362,16 @@ std::string ConvertGraphToTRTEngine(const torch::jit::script::Module& mod, std::
354362
// Infer the type of an input from the weights of the calculation
355363
auto first_use_types = ir::get_block_first_calc_dtypes_opt(g->block());
356364

365+
// GPU default WS size : 1 GB
366+
// Set WS = 256 Mb for Jetson nano/TX1 like platforms whose compute capability is 5.X.
367+
auto workspace_size = cfg.convert_info.engine_settings.workspace_size;
368+
auto device_spec = cfg.convert_info.engine_settings.device;
369+
auto cuda_device = runtime::CudaDevice(device_spec.gpu_id, device_spec.device_type);
370+
if (workspace_size == 0) {
371+
cfg.convert_info.engine_settings.workspace_size = GetRecommendedWorkspaceSize(cuda_device);
372+
}
373+
374+
357375
MapInputsAndDetermineDTypes(cfg, g, static_params, first_use_types);
358376

359377
auto engine = conversion::ConvertBlockToEngine(g->block(), cfg.convert_info, static_params);
@@ -364,19 +382,13 @@ std::string ConvertGraphToTRTEngine(const torch::jit::script::Module& mod, std::
364382
torch::jit::Module CompileGraph(const torch::jit::Module& mod, CompileSpec cfg) {
365383
torch::jit::Module new_mod(mod._ivalue()->name() + "_trt");
366384

367-
auto device_spec = cfg.convert_info.engine_settings.device;
368-
369385
// GPU default WS size : 1 GB
370386
// Set WS = 256 Mb for Jetson nano/TX1 like platforms whose compute capability is 5.X.
371387
auto workspace_size = cfg.convert_info.engine_settings.workspace_size;
372-
cudaDeviceProp device_prop;
373-
cudaGetDeviceProperties(&device_prop, device_spec.gpu_id);
388+
auto device_spec = cfg.convert_info.engine_settings.device;
389+
auto cuda_device = runtime::CudaDevice(device_spec.gpu_id, device_spec.device_type);
374390
if (workspace_size == 0) {
375-
if (device_prop.major < 6) {
376-
cfg.convert_info.engine_settings.workspace_size = 256 * (1 << 20);
377-
} else {
378-
cfg.convert_info.engine_settings.workspace_size = 1 << 30;
379-
}
391+
cfg.convert_info.engine_settings.workspace_size = GetRecommendedWorkspaceSize(cuda_device);
380392
}
381393

382394
for (const torch::jit::Method& method : mod.get_methods()) {
@@ -420,8 +432,6 @@ torch::jit::Module CompileGraph(const torch::jit::Module& mod, CompileSpec cfg)
420432
conversion::VerifyConverterSupportForBlock(g->block()),
421433
"Not all operations in graph are supported by the compiler");
422434
auto engine = conversion::ConvertBlockToEngine(g->block(), cfg.convert_info, static_params);
423-
auto device_spec = cfg.convert_info.engine_settings.device;
424-
auto cuda_device = runtime::CudaDevice(device_spec.gpu_id, device_spec.device_type);
425435
AddEngineToGraph(new_mod, new_g, engine, cuda_device);
426436
}
427437
auto new_method = new_mod._ivalue()->compilation_unit()->create_function(method.name(), new_g);

0 commit comments

Comments
 (0)