Skip to content

Commit 391a4c0

Browse files
committed
refactor! : Update default workspace size based on platforms.
BREAKING CHANGE: This commit sets the default workspace size to 1GB for GPU platforms and 256MB for Jetson Nano/TX1 platforms whose compute capability is < 6. Signed-off-by: Dheeraj Peri <[email protected]> Signed-off-by: Dheeraj Peri <[email protected]> Signed-off-by: Dheeraj Peri <[email protected]> Signed-off-by: Dheeraj Peri <[email protected]> Signed-off-by: Dheeraj Peri <[email protected]>
1 parent a1180ce commit 391a4c0

File tree

2 files changed

+16
-2
lines changed

2 files changed

+16
-2
lines changed

Diff for: core/compiler.cpp

+15-1
Original file line numberDiff line numberDiff line change
@@ -347,6 +347,21 @@ torch::jit::script::Module CompileGraph(const torch::jit::script::Module& mod, C
347347
if (cfg.partition_info.enabled) {
348348
return CompileGraphWithFallback(mod, cfg);
349349
}
350+
auto device_spec = cfg.convert_info.engine_settings.device;
351+
352+
// GPU default WS size : 1 GB
353+
// Set WS = 256 Mb for Jetson nano/TX1 like platforms whose compute capability is 5.X.
354+
auto workspace_size = cfg.convert_info.engine_settings.workspace_size;
355+
cudaDeviceProp device_prop;
356+
cudaGetDeviceProperties(&device_prop, device_spec.gpu_id);
357+
if (workspace_size == 0) {
358+
if (device_prop.major < 6) {
359+
cfg.convert_info.engine_settings.workspace_size = 256 * (1 << 20);
360+
} else {
361+
cfg.convert_info.engine_settings.workspace_size = 1 << 30;
362+
}
363+
}
364+
350365
// TODO: Should be doing a functional transform but need PR #31978
351366
// [jit] More robust mangling
352367
// torch::jit::script::Module new_mod = mod.clone();
@@ -357,7 +372,6 @@ torch::jit::script::Module CompileGraph(const torch::jit::script::Module& mod, C
357372
if (method.name().compare("forward") == 0) {
358373
auto engine = ConvertGraphToTRTEngine(mod, method.name(), cfg);
359374
auto new_g = std::make_shared<torch::jit::Graph>();
360-
auto device_spec = cfg.convert_info.engine_settings.device;
361375
auto cuda_device = runtime::CudaDevice(device_spec.gpu_id, device_spec.device_type);
362376
AddEngineToGraph(new_mod, new_g, engine, cuda_device);
363377
auto new_method = new_mod._ivalue()->compilation_unit()->create_function(method.name(), new_g);

Diff for: core/conversion/conversionctx/ConversionCtx.cpp

+1-1
Original file line numberDiff line numberDiff line change
@@ -58,7 +58,7 @@ ConversionCtx::ConversionCtx(BuilderSettings build_settings)
5858
net = make_trt(
5959
builder->createNetworkV2(1U << static_cast<uint32_t>(nvinfer1::NetworkDefinitionCreationFlag::kEXPLICIT_BATCH)));
6060

61-
LOG_DEBUG(build_settings);
61+
LOG_INFO(settings);
6262
cfg = make_trt(builder->createBuilderConfig());
6363

6464
for (auto p = settings.enabled_precisions.begin(); p != settings.enabled_precisions.end(); ++p) {

0 commit comments

Comments
 (0)