Skip to content

Commit 062e4d6

Browse files
authored
feat: add args for profiling engine caching (#3329)
1 parent 695ffd9 commit 062e4d6

File tree

1 file changed

+34
-2
lines changed

1 file changed

+34
-2
lines changed

tools/perf/perf_run.py

+34-2
Original file line numberDiff line numberDiff line change
@@ -255,6 +255,13 @@ def run_dynamo(model, input_tensors, params, precision, batch_size):
255255
min_block_size=params.get("min_block_size", 1),
256256
debug=False,
257257
truncate_long_and_double=params.get("truncate", False),
258+
immutable_weights=params.get("immutable_weights", True),
259+
strip_engine_weights=params.get("strip_engine_weights", False),
260+
refit_identical_engine_weights=params.get(
261+
"refit_identical_engine_weights", False
262+
),
263+
cache_built_engines=params.get("cache_built_engines", False),
264+
reuse_cached_engines=params.get("reuse_cached_engines", False),
258265
)
259266
end_compile = timeit.default_timer()
260267
compile_time_s = end_compile - start_compile
@@ -585,6 +592,31 @@ def run(
585592
type=str,
586593
help="Path of the output file where performance summary is written.",
587594
)
595+
arg_parser.add_argument(
596+
"--immutable_weights",
597+
action="store_true",
598+
help="Build non-refittable engines. This is useful for some layers that are not refittable. If this argument is set to true, `strip_engine_weights` and `refit_identical_engine_weights` will be ignored.",
599+
)
600+
arg_parser.add_argument(
601+
"--strip_engine_weights",
602+
action="store_true",
603+
help="Strip engine weights from the serialized engine. This is useful when the engine is to be deployed in an environment where the weights are not required.",
604+
)
605+
arg_parser.add_argument(
606+
"--refit_identical_engine_weights",
607+
action="store_true",
608+
help="Refit engines with identical weights. This is useful when the same model is compiled multiple times with different inputs and the weights are the same. This will save time by reusing the same engine for different inputs.",
609+
)
610+
arg_parser.add_argument(
611+
"--cache_built_engines",
612+
action="store_true",
613+
help="Whether to save the compiled TRT engines to storage.",
614+
)
615+
arg_parser.add_argument(
616+
"--reuse_cached_engines",
617+
action="store_true",
618+
help="Whether to load the compiled TRT engines from storage.",
619+
)
588620
args = arg_parser.parse_args()
589621

590622
# Create random input tensor of certain size
@@ -605,9 +637,9 @@ def run(
605637
# Load PyTorch Model, if provided
606638
if len(model_name_torch) > 0 and os.path.exists(model_name_torch):
607639
print("Loading user provided torch model: ", model_name_torch)
608-
model_torch = torch.load(model_name_torch).eval()
640+
model_torch = torch.load(model_name_torch).cuda().eval()
609641
elif model_name_torch in BENCHMARK_MODELS:
610-
model_torch = BENCHMARK_MODELS[model_name_torch]["model"].eval()
642+
model_torch = BENCHMARK_MODELS[model_name_torch]["model"].cuda().eval()
611643

612644
# If neither model type was provided
613645
if (model is None) and (model_torch is None):

0 commit comments

Comments
 (0)