Skip to content

feat: Add validators for dynamic shapes in converter registration #2796

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 13 commits into from
May 16, 2024
Merged
20 changes: 19 additions & 1 deletion py/torch_tensorrt/dynamo/conversion/_ConverterRegistry.py
Original file line number Diff line number Diff line change
Expand Up @@ -79,6 +79,7 @@ class ConverterSupport:

converter_implementation: ConverterImplSignature
capability_validator: Callable[[Node], bool] = field(default=lambda node: True)
dynamic: bool = False


# Dictionary representing Dynamo aten-only converters
Expand All @@ -88,9 +89,11 @@ class ConverterSupport:

def dynamo_tensorrt_converter(
key: Target,
*,
enabled: bool = True,
capability_validator: Optional[Callable[[Node], bool]] = None,
priority: ConverterPriority = ConverterPriority.STANDARD,
dynamic: bool = False,
) -> Callable[[ConverterImplSignature], ConverterImplSignature]:
"""Decorator for Dynamo TensorRT Converter

Expand All @@ -116,14 +119,17 @@ def register_converter(converter: ConverterImplSignature) -> ConverterImplSignat

# If no capability_validator function is specified, use the default function - always return true
if capability_validator is None:
converter_support = ConverterSupport(converter_implementation=converter)
converter_support = ConverterSupport(
converter_implementation=converter, dynamic=dynamic
)
else:
assert callable(
capability_validator
), "Argument checking function must be callable"
converter_support = ConverterSupport(
converter_implementation=converter,
capability_validator=capability_validator,
dynamic=dynamic,
)

# OpOverloadPackets are only valid if they have a single overload, or
Expand Down Expand Up @@ -323,6 +329,18 @@ def __getitem__(

if isinstance(converters, (list, tuple)):
for candidate in converters:
# TODO: Importing this here avoids circular import issue. One potential fix is moving this function into _ConverterRegistry file.
from torch_tensorrt.dynamo.conversion.converter_utils import (
dynamic_unsupported,
)

has_static_inputs = dynamic_unsupported(node)
# If there are dynamic inputs but the converter doesn't support it explicitly, throw a warning.
if not has_static_inputs and not candidate.dynamic:
logger.warning(
f"The converter for node {node.target} received dynamic shaped inputs but the static version of the converter is being used. Please report this issue at https://github.com/pytorch/TensorRT/issues"
)

if candidate.capability_validator(node):
return (
candidate.converter_implementation,
Expand Down
6 changes: 4 additions & 2 deletions py/torch_tensorrt/dynamo/conversion/aten_ops_converters.py
Original file line number Diff line number Diff line change
Expand Up @@ -358,7 +358,7 @@ def aten_ops_grid(
)


@dynamo_tensorrt_converter(torch.ops.aten.relu.default)
@dynamo_tensorrt_converter(torch.ops.aten.relu.default, dynamic=True)
def aten_ops_relu(
ctx: ConversionContext,
target: Target,
Expand Down Expand Up @@ -2080,7 +2080,9 @@ def conv_param_validator(conv_node: Node) -> bool:


@dynamo_tensorrt_converter(
torch.ops.aten.convolution.default, capability_validator=conv_param_validator
torch.ops.aten.convolution.default,
capability_validator=conv_param_validator,
dynamic=True,
)
@enforce_tensor_types(
{
Expand Down
13 changes: 11 additions & 2 deletions py/torch_tensorrt/dynamo/conversion/converter_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -82,9 +82,18 @@ def _dynamic_unsupported(

def _is_subnode_dynamic(subnode: torch.fx.Node) -> bool:
"""Checks if a node itself has Dynamic properties"""
return getattr(
_has_symbolic_sizes_strides = getattr(
subnode.meta["val"], "_has_symbolic_sizes_strides", False
) or isinstance(subnode.meta["val"], (SymFloat, SymInt, SymBool))
)

is_shape_dynamic = False
if "val" in subnode.meta:
shape = subnode.meta["val"].size()
is_shape_dynamic = any(
isinstance(dim, (SymFloat, SymInt, SymBool)) for dim in shape
)

return _has_symbolic_sizes_strides or is_shape_dynamic

# Check node value itself
if arg_positions_to_check is None and _is_subnode_dynamic(node):
Expand Down
Loading