Skip to content

Commit ad04cf9

Browse files
committed
implement with ctx manager
1 parent af62bf1 commit ad04cf9

10 files changed

+586
-14
lines changed

py/torch_tensorrt/dynamo/conversion/aten_ops_converters.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -3584,7 +3584,7 @@ def aten_ops_full(
35843584
)
35853585

35863586

3587-
@dynamo_tensorrt_converter(torch.ops.aten.nonzero.default)
3587+
@dynamo_tensorrt_converter(torch.ops.aten.nonzero.default, supports_dynamic_shapes=True)
35883588
def aten_ops_nonzero(
35893589
ctx: ConversionContext,
35903590
target: Target,

py/torch_tensorrt/dynamo/lowering/passes/_aten_lowering_pass.py

+4
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,8 @@
1313
from .remove_assert_nodes import remove_assert_nodes
1414
from .remove_detach import remove_detach
1515
from .remove_input_alias_fixing_clones import remove_input_alias_fixing_clones
16+
from .remove_num_users_is_0_nodes import remove_num_users_is_0_nodes
17+
from .remove_sym_size_and_constrain_nodes import remove_sym_size_and_constrain_nodes
1618
from .repair_input_as_output import repair_input_as_output
1719
from .replace_max_pool_with_indices import replace_max_pool_with_indices
1820
from .view_to_reshape import view_to_reshape
@@ -29,6 +31,8 @@
2931
view_to_reshape,
3032
remove_assert_nodes,
3133
accumulate_fp32_matmul,
34+
remove_sym_size_and_constrain_nodes,
35+
remove_num_users_is_0_nodes,
3236
]
3337
)
3438

Original file line numberDiff line numberDiff line change
@@ -0,0 +1,27 @@
1+
import logging
2+
3+
import torch
4+
from torch_tensorrt.dynamo._settings import CompilationSettings
5+
from torch_tensorrt.dynamo.lowering.passes.pass_utils import (
6+
clean_up_graph_after_modifications,
7+
)
8+
9+
logger = logging.getLogger(__name__)
10+
11+
12+
def remove_num_users_is_0_nodes(
13+
gm: torch.fx.GraphModule, settings: CompilationSettings
14+
) -> torch.fx.GraphModule:
15+
"""Remove ops that [num_users=0] in the graph"""
16+
output_node = list(gm.graph.nodes)[-1]
17+
18+
for node in gm.graph.nodes:
19+
if node != output_node and len(node.users) == 0:
20+
node_input = node.all_input_nodes[0]
21+
node.replace_all_uses_with(node_input)
22+
gm.graph.erase_node(node)
23+
gm = clean_up_graph_after_modifications(gm)
24+
25+
logger.debug(f"Removed ops that [num_users=0] nodes:\n{gm.graph}")
26+
27+
return gm
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,34 @@
1+
import logging
2+
3+
import torch
4+
from torch_tensorrt.dynamo._settings import CompilationSettings
5+
from torch_tensorrt.dynamo.lowering.passes.pass_utils import (
6+
clean_up_graph_after_modifications,
7+
)
8+
9+
logger = logging.getLogger(__name__)
10+
11+
12+
def remove_sym_size_and_constrain_nodes(
13+
gm: torch.fx.GraphModule, settings: CompilationSettings
14+
) -> torch.fx.GraphModule:
15+
"""Remove aten.sym_size.int and aten.sym_constrain_range_for_size.default ops in the graph"""
16+
count = 0
17+
for node in gm.graph.nodes:
18+
if node.op == "call_function" and (
19+
node.target == torch.ops.aten.sym_size.int
20+
or node.target == torch.ops.aten.sym_constrain_range_for_size.default
21+
):
22+
node_input = node.all_input_nodes[0]
23+
node.replace_all_uses_with(node_input)
24+
gm.graph.erase_node(node)
25+
count += 1
26+
27+
if count > 0:
28+
gm = clean_up_graph_after_modifications(gm)
29+
30+
logger.debug(
31+
f"Removed {count} aten.sym_size.int or aten.sym_constrain_range_for_size.default nodes:\n{gm.graph}"
32+
)
33+
34+
return gm

py/torch_tensorrt/dynamo/runtime/_CudaGraphsTorchTensorRTModule.py

+9
Original file line numberDiff line numberDiff line change
@@ -32,6 +32,7 @@ def __init__(
3232
self._input_buffers: List[torch.Tensor] = []
3333
self._output_buffers: List[torch.Tensor] = []
3434
self.cudagraph: Optional[torch.cuda.CUDAGraph] = None
35+
self.use_output_allocator_outputs = False
3536
self.shape_key: Optional[str] = None
3637
self._caller_stream: Optional[torch.cuda.Stream] = None
3738
self._engine_stream: Optional[torch.cuda.Stream] = None
@@ -73,8 +74,16 @@ def __del__(self) -> None:
7374
if self.cudagraph:
7475
self.cudagraph.reset()
7576

77+
def set_output_allocator_outputs(self, enable: bool) -> None:
78+
self.use_output_allocator_outputs = enable
79+
7680
def forward(self, *inputs: torch.Tensor) -> torch.Tensor | Tuple[torch.Tensor, ...]:
7781
cudagraphs_enabled = torch_tensorrt.runtime.get_whole_cudagraphs_mode()
82+
if cudagraphs_enabled and self.use_output_allocator_outputs:
83+
raise RuntimeError(
84+
"There are non-TRT submodules in the module. OutputAllocator is not compatible with modules with non-TRT submodules."
85+
)
86+
7887
if cudagraphs_enabled:
7988
shape_changed = self.validate_input_shapes(inputs)
8089
need_cudagraphs_record = shape_changed or self.is_weight_streaming_set

py/torch_tensorrt/dynamo/runtime/_PythonTorchTensorRTModule.py

+36-10
Original file line numberDiff line numberDiff line change
@@ -202,10 +202,13 @@ def __init__(
202202
torch_tensorrt.runtime.get_cudagraphs_mode()
203203
)
204204

205-
self.engine_is_dds = engine_is_dds
205+
self.cudagraphs_enabled = False
206206
self.pre_allocated_outputs: List[torch.Tensor] = []
207207
self.use_pre_allocated_outputs = False
208+
209+
self.engine_is_dds = engine_is_dds
208210
self.output_allocator: Optional[DynamicOutputAllocator] = None
211+
self.use_output_allocator_outputs = False
209212

210213
if self.serialized_engine is not None and not self.settings.lazy_engine_init:
211214
self.setup_engine()
@@ -401,6 +404,9 @@ def create_output_tensors(self) -> List[torch.Tensor]:
401404
def set_pre_allocated_outputs(self, enable: bool) -> None:
402405
self.use_pre_allocated_outputs = enable
403406

407+
def set_output_allocator_outputs(self, enable: bool) -> None:
408+
self.use_output_allocator_outputs = enable
409+
404410
def create_output_allocator(self) -> None:
405411
if self.output_allocator is None:
406412
output_dtypes_dict = {}
@@ -410,15 +416,14 @@ def create_output_allocator(self) -> None:
410416

411417
def forward(self, *inputs: torch.Tensor) -> torch.Tensor | Tuple[torch.Tensor, ...]:
412418

413-
def run_cuda_graph() -> torch.Tensor | Tuple[torch.Tensor, ...]:
414-
cudagraphs_enabled = torch_tensorrt.runtime.get_cudagraphs_mode()
419+
def run_standard_execution() -> torch.Tensor | Tuple[torch.Tensor, ...]:
415420
shape_changed = self.validate_input_shapes(inputs)
416421
(
417422
need_cudagraphs_record,
418423
can_use_pre_allocated_outputs,
419424
need_cudagraphs_reset,
420425
) = self.runtime_states.set_runtime_states(
421-
cudagraphs_enabled, self.use_pre_allocated_outputs, shape_changed
426+
self.cudagraphs_enabled, self.use_pre_allocated_outputs, shape_changed
422427
)
423428

424429
if need_cudagraphs_reset and self.cudagraph:
@@ -441,7 +446,7 @@ def run_cuda_graph() -> torch.Tensor | Tuple[torch.Tensor, ...]:
441446
), f"Wrong number of inputs, expect {len(self.input_names)} get {len(contiguous_inputs)}."
442447

443448
self.setup_input_tensors(
444-
contiguous_inputs, cudagraphs_enabled, need_cudagraphs_record
449+
contiguous_inputs, self.cudagraphs_enabled, need_cudagraphs_record
445450
)
446451

447452
if shape_changed:
@@ -477,7 +482,7 @@ def run_cuda_graph() -> torch.Tensor | Tuple[torch.Tensor, ...]:
477482
if need_cudagraphs_record:
478483
self._output_buffers[o] = outputs[o].clone()
479484

480-
if cudagraphs_enabled:
485+
if self.cudagraphs_enabled:
481486
self.context.set_tensor_address(
482487
output_name, self._output_buffers[o].data_ptr()
483488
)
@@ -503,7 +508,7 @@ def run_cuda_graph() -> torch.Tensor | Tuple[torch.Tensor, ...]:
503508
self._engine_stream.wait_stream(self._caller_stream)
504509

505510
with torch.cuda.stream(self._engine_stream):
506-
if cudagraphs_enabled:
511+
if self.cudagraphs_enabled:
507512
if need_cudagraphs_record:
508513
self.cudagraph = torch.cuda.CUDAGraph()
509514

@@ -535,7 +540,7 @@ def run_cuda_graph() -> torch.Tensor | Tuple[torch.Tensor, ...]:
535540
if self.use_pre_allocated_outputs:
536541
self.pre_allocated_outputs = self.create_output_tensors()
537542

538-
if cudagraphs_enabled:
543+
if self.cudagraphs_enabled:
539544
for idx, o in enumerate(outputs):
540545
o.copy_(self._output_buffers[idx])
541546

@@ -545,7 +550,9 @@ def run_cuda_graph() -> torch.Tensor | Tuple[torch.Tensor, ...]:
545550
return outputs
546551

547552
def run_output_allocator() -> torch.Tensor | Tuple[torch.Tensor, ...]:
548-
torch_tensorrt.runtime.set_cudagraphs_mode(False)
553+
assert (
554+
not torch_tensorrt.runtime.get_cudagraphs_mode()
555+
), "CUDA Graphs are not compatible with OutputAllocator."
549556
with (
550557
torch.autograd.profiler.record_function(
551558
"PythonTorchTensorRTModule:ProcessInputs"
@@ -625,6 +632,8 @@ def run_output_allocator() -> torch.Tensor | Tuple[torch.Tensor, ...]:
625632

626633
return outputs
627634

635+
self.cudagraphs_enabled = torch_tensorrt.runtime.get_cudagraphs_mode()
636+
628637
# Run forward function
629638
contiguous_inputs: List[torch.Tensor] = [
630639
(i.contiguous() if isinstance(i, torch.Tensor) else torch.tensor(i).cuda())
@@ -670,9 +679,26 @@ def run_output_allocator() -> torch.Tensor | Tuple[torch.Tensor, ...]:
670679
logger.warning(f"Moved all input Tensors to cuda:{device_id}")
671680

672681
if self.engine_is_dds:
682+
if self.cudagraphs_enabled:
683+
raise RuntimeError(
684+
"The module is Data-Dependent Shape (DDS). It has to be handled by OutputAllocator which is not compatible with CUDA Graphs. Please disable CUDA Graphs."
685+
)
686+
logger.debug(
687+
"The module is Data-Dependent Shape (DDS). Using output allocator."
688+
)
673689
return run_output_allocator()
674690
else:
675-
return run_cuda_graph()
691+
if self.cudagraphs_enabled and self.use_output_allocator_outputs:
692+
raise RuntimeError(
693+
"Both CUDA Graphs and OutputAllocator are enabled. Please disable either one."
694+
)
695+
if self.use_output_allocator_outputs:
696+
logger.debug("Using output allocator.")
697+
return run_output_allocator()
698+
logger.debug(
699+
f"Using standard execution with cudagraphs={self.cudagraphs_enabled}."
700+
)
701+
return run_standard_execution()
676702

677703
def enable_profiling(self, profiler: "trt.IProfiler" = None) -> None:
678704
"""

py/torch_tensorrt/runtime/__init__.py

+1
Original file line numberDiff line numberDiff line change
@@ -9,5 +9,6 @@
99
set_cudagraphs_mode,
1010
)
1111
from torch_tensorrt.runtime._multi_device_safe_mode import set_multi_device_safe_mode
12+
from torch_tensorrt.runtime._output_allocator import enable_output_allocator
1213
from torch_tensorrt.runtime._pre_allocated_outputs import enable_pre_allocated_outputs
1314
from torch_tensorrt.runtime._weight_streaming import weight_streaming
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,51 @@
1+
import logging
2+
from typing import Any, Union
3+
4+
import torch
5+
from torch_tensorrt.dynamo.runtime import PythonTorchTensorRTModule, TorchTensorRTModule
6+
from torch_tensorrt.dynamo.runtime._CudaGraphsTorchTensorRTModule import (
7+
CudaGraphsTorchTensorRTModule,
8+
)
9+
10+
logger = logging.getLogger(__name__)
11+
12+
13+
class _OutputAllocatorContextManager(object):
14+
"""
15+
Helper class to set up output_allocator
16+
"""
17+
18+
def __init__(
19+
self, module: Union[torch.fx.GraphModule, CudaGraphsTorchTensorRTModule]
20+
) -> None:
21+
if isinstance(module, CudaGraphsTorchTensorRTModule):
22+
rt_mods = [module]
23+
else:
24+
rt_mods = []
25+
26+
for name, rt_mod in module.named_children():
27+
if "_run_on_acc" in name and isinstance(
28+
rt_mod, (PythonTorchTensorRTModule, TorchTensorRTModule)
29+
):
30+
rt_mods.append(rt_mod)
31+
32+
self.rt_mods = rt_mods
33+
34+
def set_output_allocator_output(self, enable: bool) -> None:
35+
for mod in self.rt_mods:
36+
mod.set_output_allocator_outputs(enable)
37+
38+
def __enter__(self) -> "_OutputAllocatorContextManager":
39+
# Enable output_allocator for TRT submodules
40+
self.set_output_allocator_output(True)
41+
return self
42+
43+
def __exit__(self, *args: Any) -> None:
44+
# Disable output_allocator
45+
self.set_output_allocator_output(False)
46+
47+
48+
def enable_output_allocator(
49+
module: torch.fx.GraphModule,
50+
) -> _OutputAllocatorContextManager:
51+
return _OutputAllocatorContextManager(module)

tests/py/dynamo/conversion/test_nonzero_aten.py

+48-3
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,7 @@ class TestNonZeroConverter(DispatchTestCase):
1717
((2, 3, 4, 5), torch.float),
1818
]
1919
)
20-
def test_non_zero(self, input_shape, dtype):
20+
def test_nonzero_dds(self, input_shape, dtype):
2121
class NonZero(nn.Module):
2222
# This is a DDS network
2323
def forward(self, input):
@@ -39,7 +39,7 @@ def forward(self, input):
3939
((2, 3, 4, 5), torch.float),
4040
]
4141
)
42-
def test_non_zero(self, input_shape, dtype):
42+
def test_nonzero_non_dds(self, input_shape, dtype):
4343
class NonZero(nn.Module):
4444
# This is a static network
4545
def forward(self, input):
@@ -78,7 +78,7 @@ def forward(self, input):
7878
),
7979
]
8080
)
81-
def test_nonzero_dynamic_shape(self, _, min_shape, opt_shape, max_shape, dtype):
81+
def test_nonzero_dynamic_shape_dds(self, _, min_shape, opt_shape, max_shape, dtype):
8282
class NonZero(nn.Module):
8383
def forward(self, input):
8484
return torch.ops.aten.nonzero.default(input)
@@ -94,6 +94,51 @@ def forward(self, input):
9494

9595
self.run_test_with_dynamic_shape(NonZero(), input_specs)
9696

97+
@parameterized.expand(
98+
[
99+
(
100+
"1d",
101+
(1,),
102+
(10,),
103+
(100,),
104+
torch.int32,
105+
),
106+
(
107+
"2d",
108+
(1, 2),
109+
(5, 10),
110+
(20, 40),
111+
torch.float16,
112+
),
113+
(
114+
"3d",
115+
(1, 2, 3),
116+
(5, 10, 20),
117+
(30, 40, 50),
118+
torch.float,
119+
),
120+
]
121+
)
122+
def test_nonzero_dynamic_shape_non_dds(
123+
self, _, min_shape, opt_shape, max_shape, dtype
124+
):
125+
class NonZero(nn.Module):
126+
def forward(self, input):
127+
out = torch.ops.aten.nonzero.default(input)
128+
out = torch.ops.aten.sum.dim_IntList(out, 0)
129+
return out
130+
131+
input_specs = [
132+
Input(
133+
min_shape=min_shape,
134+
opt_shape=opt_shape,
135+
max_shape=max_shape,
136+
dtype=dtype,
137+
),
138+
]
139+
140+
self.run_test_with_dynamic_shape(NonZero(), input_specs)
141+
97142

98143
if __name__ == "__main__":
99144
run_tests()

0 commit comments

Comments
 (0)