@@ -341,6 +341,14 @@ void MapInputsAndDetermineDTypes(
341
341
}
342
342
}
343
343
344
+ uint64_t GetRecommendedWorkspaceSize (const runtime::CudaDevice& device) {
345
+ if (device.major < 6 ) {
346
+ return 256 * (1 << 20 );
347
+ } else {
348
+ return 1 << 30 ;
349
+ }
350
+ }
351
+
344
352
std::string ConvertGraphToTRTEngine (const torch::jit::script::Module& mod, std::string method_name, CompileSpec cfg) {
345
353
// Go through Lowering to simplify graph and extract weight parameters
346
354
auto graph_and_parameters = lowering::Lower (mod, method_name, cfg.lower_info );
@@ -354,6 +362,16 @@ std::string ConvertGraphToTRTEngine(const torch::jit::script::Module& mod, std::
354
362
// Infer the type of an input from the weights of the calculation
355
363
auto first_use_types = ir::get_block_first_calc_dtypes_opt (g->block ());
356
364
365
+ // GPU default WS size : 1 GB
366
+ // Set WS = 256 Mb for Jetson nano/TX1 like platforms whose compute capability is 5.X.
367
+ auto workspace_size = cfg.convert_info .engine_settings .workspace_size ;
368
+ auto device_spec = cfg.convert_info .engine_settings .device ;
369
+ auto cuda_device = runtime::CudaDevice (device_spec.gpu_id , device_spec.device_type );
370
+ if (workspace_size == 0 ) {
371
+ cfg.convert_info .engine_settings .workspace_size = GetRecommendedWorkspaceSize (cuda_device);
372
+ }
373
+
374
+
357
375
MapInputsAndDetermineDTypes (cfg, g, static_params, first_use_types);
358
376
359
377
auto engine = conversion::ConvertBlockToEngine (g->block (), cfg.convert_info , static_params);
@@ -364,19 +382,13 @@ std::string ConvertGraphToTRTEngine(const torch::jit::script::Module& mod, std::
364
382
torch::jit::Module CompileGraph (const torch::jit::Module& mod, CompileSpec cfg) {
365
383
torch::jit::Module new_mod (mod._ivalue ()->name () + " _trt" );
366
384
367
- auto device_spec = cfg.convert_info .engine_settings .device ;
368
-
369
385
// GPU default WS size : 1 GB
370
386
// Set WS = 256 Mb for Jetson nano/TX1 like platforms whose compute capability is 5.X.
371
387
auto workspace_size = cfg.convert_info .engine_settings .workspace_size ;
372
- cudaDeviceProp device_prop ;
373
- cudaGetDeviceProperties (&device_prop , device_spec.gpu_id );
388
+ auto device_spec = cfg. convert_info . engine_settings . device ;
389
+ auto cuda_device = runtime::CudaDevice (device_spec. gpu_id , device_spec.device_type );
374
390
if (workspace_size == 0 ) {
375
- if (device_prop.major < 6 ) {
376
- cfg.convert_info .engine_settings .workspace_size = 256 * (1 << 20 );
377
- } else {
378
- cfg.convert_info .engine_settings .workspace_size = 1 << 30 ;
379
- }
391
+ cfg.convert_info .engine_settings .workspace_size = GetRecommendedWorkspaceSize (cuda_device);
380
392
}
381
393
382
394
for (const torch::jit::Method& method : mod.get_methods ()) {
@@ -420,8 +432,6 @@ torch::jit::Module CompileGraph(const torch::jit::Module& mod, CompileSpec cfg)
420
432
conversion::VerifyConverterSupportForBlock (g->block ()),
421
433
" Not all operations in graph are supported by the compiler" );
422
434
auto engine = conversion::ConvertBlockToEngine (g->block (), cfg.convert_info , static_params);
423
- auto device_spec = cfg.convert_info .engine_settings .device ;
424
- auto cuda_device = runtime::CudaDevice (device_spec.gpu_id , device_spec.device_type );
425
435
AddEngineToGraph (new_mod, new_g, engine, cuda_device);
426
436
}
427
437
auto new_method = new_mod._ivalue ()->compilation_unit ()->create_function (method.name (), new_g);
0 commit comments