Skip to content

Commit eed420a

Browse files
committed
resolve comments
1 parent 0987146 commit eed420a

File tree

8 files changed

+39
-29
lines changed

8 files changed

+39
-29
lines changed

py/torch_tensorrt/dynamo/_engine_cache.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -118,7 +118,7 @@ def pack(
118118
input_specs (Sequence[Input]): input specs of TRT engine
119119
compilation_settings (CompilationSettings): compilation settings of TRT engine
120120
weight_name_map (Optional[Dict[Any, Any]]): weight name map for refitting
121-
requires_output_allocator (bool): whether the engine requires output allocator
121+
requires_output_allocator (bool): Boolean flag indicating if the converter creates operators which require an Output Allocator to run (e.g. data dependent operators)
122122
Returns:
123123
bytes: packed blob
124124
"""

py/torch_tensorrt/dynamo/conversion/_ConversionContext.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@ class ConversionContext:
1111
Args:
1212
net: TensorRT Network being built
1313
compilation_settings: Settings selected by the user for compilation
14-
requires_output_allocator: Whether the network requires output allocator
14+
requires_output_allocator: Boolean flag indicating if the converter creates operators which require an Output Allocator to run (e.g. data dependent operators)
1515
"""
1616

1717
net: TRTNetwork

py/torch_tensorrt/dynamo/conversion/_ConverterRegistry.py

+12-6
Original file line numberDiff line numberDiff line change
@@ -80,7 +80,7 @@ class ConverterSupport:
8080
whether that node can be supported by its companion converter. Note that
8181
this function must not modify the node or its graph
8282
supports_dynamic_shapes: Boolean flag indicating if the converter has support for dynamic inputs.
83-
requires_output_allocator: Boolean flag indicating if the converter requires to run in output allocator.
83+
requires_output_allocator: Boolean flag indicating if the converter creates operators which require an Output Allocator to run (e.g. data dependent operators).
8484
"""
8585

8686
converter_implementation: ConverterImplSignature
@@ -215,7 +215,7 @@ def dynamo_tensorrt_converter(
215215
priority: Converter's level of priority relative to other converters with the
216216
same target
217217
supports_dynamic_shapes: Boolean flag indicating if the converter has support for dynamic shapes.
218-
requires_output_allocator: Boolean flag indicating if the converter requires to run in output allocator.
218+
requires_output_allocator: Boolean flag indicating if the converter creates operators which require an Output Allocator to run (e.g. data dependent operators).
219219
Returns:
220220
The converter being decorated
221221
"""
@@ -410,7 +410,7 @@ def __getitem_without_validation__(
410410
def __getitem__(
411411
self, node: Node
412412
) -> Tuple[
413-
Any, CallingConvention, bool
413+
Any, CallingConvention, Dict[str, bool]
414414
]: # TODO: Narrow to ConverterImplSignature this when we can remove FX converters
415415
"""Get the first-found validated converter in any registry
416416
@@ -468,7 +468,10 @@ def __getitem__(
468468
return (
469469
candidate.converter_implementation,
470470
calling_convention,
471-
candidate.requires_output_allocator,
471+
{
472+
"supports_dynamic_shapes": candidate.supports_dynamic_shapes,
473+
"requires_output_allocator": candidate.requires_output_allocator,
474+
},
472475
)
473476
else:
474477
logger.debug(
@@ -481,7 +484,10 @@ def __getitem__(
481484
return (
482485
converters,
483486
calling_convention,
484-
False,
487+
{
488+
"supports_dynamic_shapes": False,
489+
"requires_output_allocator": False,
490+
},
485491
)
486492

487493
raise KeyError(
@@ -506,7 +512,7 @@ def get_unvalidated(
506512
def get(
507513
self, node: Node, value: Optional[ConverterImplSignature] = None
508514
) -> Union[
509-
Any, Tuple[Any, CallingConvention, bool]
515+
Any, Tuple[Any, CallingConvention, Dict[str, bool]]
510516
]: # TODO: Narrow to ConverterImplSignature this when we can remove FX converters
511517
"""Get validated converter for input node with a default return"""
512518
try:

py/torch_tensorrt/dynamo/conversion/_TRTInterpreter.py

+4-4
Original file line numberDiff line numberDiff line change
@@ -835,7 +835,7 @@ def call_module(
835835
f"Conversion of module of type {submod_type} not currently supported!"
836836
)
837837

838-
converter, calling_convention, requires_output_allocator = converter_packet
838+
converter, calling_convention, _ = converter_packet
839839

840840
assert self._cur_node_name is not None
841841

@@ -852,8 +852,8 @@ def call_function(self, target: str, args: Any, kwargs: Any) -> Any:
852852
f"Conversion of function {torch.typename(target)} not currently supported!"
853853
)
854854

855-
converter, calling_convention, requires_output_allocator = converter_packet
856-
if requires_output_allocator:
855+
converter, calling_convention, converter_info = converter_packet
856+
if converter_info.get("requires_output_allocator", False):
857857
self.ctx.requires_output_allocator = True
858858
_LOGGER.debug(f"{target} requires output allocator")
859859

@@ -885,7 +885,7 @@ def call_method(self, target: str, args: Any, kwargs: Any) -> Any:
885885
raise UnsupportedOperatorException(
886886
f"Conversion of method {target} not currently supported!"
887887
)
888-
converter, calling_convention, requires_output_allocator = converter_packet
888+
converter, calling_convention, _ = converter_packet
889889

890890
if calling_convention is CallingConvention.LEGACY:
891891
return converter(self.ctx.net, target, args, kwargs, self._cur_node_name)

py/torch_tensorrt/dynamo/lowering/passes/remove_num_users_is_0_nodes.py

+3-4
Original file line numberDiff line numberDiff line change
@@ -13,16 +13,15 @@ def remove_num_users_is_0_nodes(
1313
gm: torch.fx.GraphModule, settings: CompilationSettings
1414
) -> torch.fx.GraphModule:
1515
"""Remove ops that [num_users=0] in the graph"""
16-
output_node = list(gm.graph.nodes)[-1]
16+
nodes = list(gm.graph.nodes)
17+
output_node = nodes[-1]
1718

18-
for node in gm.graph.nodes:
19+
for node in nodes[::-1]:
1920
if (
2021
node != output_node
2122
and len(node.users) == 0
2223
and len(node.all_input_nodes) > 0
2324
):
24-
node_input = node.all_input_nodes[0]
25-
node.replace_all_uses_with(node_input)
2625
gm.graph.erase_node(node)
2726
gm = clean_up_graph_after_modifications(gm)
2827

py/torch_tensorrt/dynamo/runtime/_PythonTorchTensorRTModule.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -141,7 +141,7 @@ def __init__(
141141
name (str): Name for module
142142
settings (torch_tensorrt.dynamo.CompilationSettings): Settings used to compile engine, assumes engine was built with default compilation settings if object not passed
143143
weight_name_map (dict): Mapping of engine weight name to state_dict weight name
144-
requires_output_allocator (bool): Whether the engine requires an output allocator
144+
requires_output_allocator (bool): Boolean flag indicating if the converter creates operators which require an Output Allocator to run (e.g. data dependent operators)
145145
146146
Example:
147147

py/torch_tensorrt/dynamo/runtime/_TorchTensorRTModule.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -98,7 +98,7 @@ def __init__(
9898
name (str): Name for module
9999
settings (torch_tensorrt.dynamo.CompilationSettings): Settings used to compile engine, assumes engine was built with default compilation settings if object not passed
100100
weight_name_map (dict): Mapping of engine weight name to state_dict weight name
101-
requires_output_allocator (bool): Whether the engine requires an output allocator
101+
requires_output_allocator (bool): Boolean flag indicating if the converter creates operators which require an Output Allocator to run (e.g. data dependent operators)
102102
103103
Example:
104104

py/torch_tensorrt/runtime/_cudagraphs.py

+16-11
Original file line numberDiff line numberDiff line change
@@ -87,17 +87,22 @@ def __enter__(self) -> torch.nn.Module:
8787
elif "_run_on_gpu" in name:
8888
num_torch_module += 1
8989

90-
if num_torch_module > 0 and not disable_cudagraphs:
91-
# Set whole cudagraphs mode and returns wrapped module
92-
_PY_RT_CUDAGRAPHS = CudaGraphsMode.WHOLE_GRAPH_CUDAGRAPHS
93-
# Set new mode for C++
94-
if torch_tensorrt.ENABLED_FEATURES.torch_tensorrt_runtime:
95-
torch.ops.tensorrt.set_cudagraphs_mode(_PY_RT_CUDAGRAPHS)
96-
97-
logger.debug(
98-
"Found pytorch subgraphs in module, wrapping module in CudaGraphsTorchTensorRTModule"
99-
)
100-
return CudaGraphsTorchTensorRTModule(self.compiled_module)
90+
if num_torch_module > 0:
91+
if disable_cudagraphs:
92+
raise RuntimeError(
93+
"There are converters that require Output Allocator. Please disable CUDA Graphs."
94+
)
95+
else:
96+
# Set whole cudagraphs mode and returns wrapped module
97+
_PY_RT_CUDAGRAPHS = CudaGraphsMode.WHOLE_GRAPH_CUDAGRAPHS
98+
# Set new mode for C++
99+
if torch_tensorrt.ENABLED_FEATURES.torch_tensorrt_runtime:
100+
torch.ops.tensorrt.set_cudagraphs_mode(_PY_RT_CUDAGRAPHS)
101+
102+
logger.debug(
103+
"Found pytorch subgraphs in module, wrapping module in CudaGraphsTorchTensorRTModule"
104+
)
105+
return CudaGraphsTorchTensorRTModule(self.compiled_module)
101106
else:
102107
if num_trt_module > 0:
103108
logger.debug("No graph breaks detected, using runtime cudagraphs mode")

0 commit comments

Comments
 (0)