@@ -202,10 +202,13 @@ def __init__(
202
202
torch_tensorrt .runtime .get_cudagraphs_mode ()
203
203
)
204
204
205
- self .engine_is_dds = engine_is_dds
205
+ self .cudagraphs_enabled = False
206
206
self .pre_allocated_outputs : List [torch .Tensor ] = []
207
207
self .use_pre_allocated_outputs = False
208
+
209
+ self .engine_is_dds = engine_is_dds
208
210
self .output_allocator : Optional [DynamicOutputAllocator ] = None
211
+ self .use_output_allocator_outputs = False
209
212
210
213
if self .serialized_engine is not None and not self .settings .lazy_engine_init :
211
214
self .setup_engine ()
@@ -401,6 +404,9 @@ def create_output_tensors(self) -> List[torch.Tensor]:
401
404
def set_pre_allocated_outputs (self , enable : bool ) -> None :
402
405
self .use_pre_allocated_outputs = enable
403
406
407
+ def set_output_allocator_outputs (self , enable : bool ) -> None :
408
+ self .use_output_allocator_outputs = enable
409
+
404
410
def create_output_allocator (self ) -> None :
405
411
if self .output_allocator is None :
406
412
output_dtypes_dict = {}
@@ -410,15 +416,14 @@ def create_output_allocator(self) -> None:
410
416
411
417
def forward (self , * inputs : torch .Tensor ) -> torch .Tensor | Tuple [torch .Tensor , ...]:
412
418
413
- def run_cuda_graph () -> torch .Tensor | Tuple [torch .Tensor , ...]:
414
- cudagraphs_enabled = torch_tensorrt .runtime .get_cudagraphs_mode ()
419
+ def run_standard_execution () -> torch .Tensor | Tuple [torch .Tensor , ...]:
415
420
shape_changed = self .validate_input_shapes (inputs )
416
421
(
417
422
need_cudagraphs_record ,
418
423
can_use_pre_allocated_outputs ,
419
424
need_cudagraphs_reset ,
420
425
) = self .runtime_states .set_runtime_states (
421
- cudagraphs_enabled , self .use_pre_allocated_outputs , shape_changed
426
+ self . cudagraphs_enabled , self .use_pre_allocated_outputs , shape_changed
422
427
)
423
428
424
429
if need_cudagraphs_reset and self .cudagraph :
@@ -441,7 +446,7 @@ def run_cuda_graph() -> torch.Tensor | Tuple[torch.Tensor, ...]:
441
446
), f"Wrong number of inputs, expect { len (self .input_names )} get { len (contiguous_inputs )} ."
442
447
443
448
self .setup_input_tensors (
444
- contiguous_inputs , cudagraphs_enabled , need_cudagraphs_record
449
+ contiguous_inputs , self . cudagraphs_enabled , need_cudagraphs_record
445
450
)
446
451
447
452
if shape_changed :
@@ -477,7 +482,7 @@ def run_cuda_graph() -> torch.Tensor | Tuple[torch.Tensor, ...]:
477
482
if need_cudagraphs_record :
478
483
self ._output_buffers [o ] = outputs [o ].clone ()
479
484
480
- if cudagraphs_enabled :
485
+ if self . cudagraphs_enabled :
481
486
self .context .set_tensor_address (
482
487
output_name , self ._output_buffers [o ].data_ptr ()
483
488
)
@@ -503,7 +508,7 @@ def run_cuda_graph() -> torch.Tensor | Tuple[torch.Tensor, ...]:
503
508
self ._engine_stream .wait_stream (self ._caller_stream )
504
509
505
510
with torch .cuda .stream (self ._engine_stream ):
506
- if cudagraphs_enabled :
511
+ if self . cudagraphs_enabled :
507
512
if need_cudagraphs_record :
508
513
self .cudagraph = torch .cuda .CUDAGraph ()
509
514
@@ -535,7 +540,7 @@ def run_cuda_graph() -> torch.Tensor | Tuple[torch.Tensor, ...]:
535
540
if self .use_pre_allocated_outputs :
536
541
self .pre_allocated_outputs = self .create_output_tensors ()
537
542
538
- if cudagraphs_enabled :
543
+ if self . cudagraphs_enabled :
539
544
for idx , o in enumerate (outputs ):
540
545
o .copy_ (self ._output_buffers [idx ])
541
546
@@ -545,7 +550,9 @@ def run_cuda_graph() -> torch.Tensor | Tuple[torch.Tensor, ...]:
545
550
return outputs
546
551
547
552
def run_output_allocator () -> torch .Tensor | Tuple [torch .Tensor , ...]:
548
- torch_tensorrt .runtime .set_cudagraphs_mode (False )
553
+ assert (
554
+ not torch_tensorrt .runtime .get_cudagraphs_mode ()
555
+ ), "CUDA Graphs are not compatible with OutputAllocator."
549
556
with (
550
557
torch .autograd .profiler .record_function (
551
558
"PythonTorchTensorRTModule:ProcessInputs"
@@ -625,6 +632,8 @@ def run_output_allocator() -> torch.Tensor | Tuple[torch.Tensor, ...]:
625
632
626
633
return outputs
627
634
635
+ self .cudagraphs_enabled = torch_tensorrt .runtime .get_cudagraphs_mode ()
636
+
628
637
# Run forward function
629
638
contiguous_inputs : List [torch .Tensor ] = [
630
639
(i .contiguous () if isinstance (i , torch .Tensor ) else torch .tensor (i ).cuda ())
@@ -670,9 +679,26 @@ def run_output_allocator() -> torch.Tensor | Tuple[torch.Tensor, ...]:
670
679
logger .warning (f"Moved all input Tensors to cuda:{ device_id } " )
671
680
672
681
if self .engine_is_dds :
682
+ if self .cudagraphs_enabled :
683
+ raise RuntimeError (
684
+ "The module is Data-Dependent Shape (DDS). It has to be handled by OutputAllocator which is not compatible with CUDA Graphs. Please disable CUDA Graphs."
685
+ )
686
+ logger .debug (
687
+ "The module is Data-Dependent Shape (DDS). Using output allocator."
688
+ )
673
689
return run_output_allocator ()
674
690
else :
675
- return run_cuda_graph ()
691
+ if self .cudagraphs_enabled and self .use_output_allocator_outputs :
692
+ raise RuntimeError (
693
+ "Both CUDA Graphs and OutputAllocator are enabled. Please disable either one."
694
+ )
695
+ if self .use_output_allocator_outputs :
696
+ logger .debug ("Using output allocator." )
697
+ return run_output_allocator ()
698
+ logger .debug (
699
+ f"Using standard execution with cudagraphs={ self .cudagraphs_enabled } ."
700
+ )
701
+ return run_standard_execution ()
676
702
677
703
def enable_profiling (self , profiler : "trt.IProfiler" = None ) -> None :
678
704
"""
0 commit comments