Skip to content

feat: support amax dynamo converter #2241

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 8 commits into from
Aug 25, 2023
Merged
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
25 changes: 25 additions & 0 deletions py/torch_tensorrt/dynamo/conversion/aten_ops_converters.py
Original file line number Diff line number Diff line change
Expand Up @@ -420,3 +420,28 @@ def aten_ops_clone(
name,
args[0],
)


@dynamo_tensorrt_converter(torch.ops.aten.amax.default)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Based on this schema, it seems that the dimension (dim) can also be not present. For instance, the following model creates the subsequent graph:

class argmax(torch.nn.Module):
   def forward(self, x):
        return torch.argmax(x)

"""
graph():
    %arg0_1 : [num_users=1] = placeholder[target=arg0_1]
    %argmax : [num_users=1] = call_function[target=torch.ops.aten.argmax.default](args = (%arg0_1,), kwargs = {})
    return (argmax,)
"""

If we cannot support this case, you can add a capability_validator function to this decorator, which will note that case as unsupported. Roughly, that could be something like:

def amax_param_validator(amax_node: Node) -> bool:
    return len(amax_node.args) >= 2

@dynamo_tensorrt_converter(
    torch.ops.aten.amax.default, capability_validator=amax_param_validator
)  # type: ignore[misc]

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@gs-olive Thank you so much George! That's a good catch. I fixed and learned a lot from these details recently 😃

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

No problem - glad to hear it!

def aten_ops_amax(
network: TRTNetwork,
target: Target,
args: Tuple[Argument, ...],
kwargs: Dict[str, Argument],
name: str,
) -> Union[TRTTensor, Sequence[TRTTensor]]:
input_val = args[0]
if (isinstance(input_val, TRTTensor)) and (
input_val.dtype == trt.int8 or input_val.dtype == trt.int32
):
input_val = cast_trt_tensor(network, input_val, trt.float32, name)

return impl.reduce.amax(
network,
target,
SourceIR.ATEN,
name,
input_val,
args[1],
args_bounds_check(args, 2, replacement=False),
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@gs-olive can this check be done in a validator?

Copy link
Collaborator

@gs-olive gs-olive Aug 22, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This check can be done in a validator, but in this context it would make the most sense for it to be done here, since we can support cases where this argument is both present and absent.

)
7 changes: 7 additions & 0 deletions py/torch_tensorrt/dynamo/conversion/converter_utils.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import functools
import logging
import re
from typing import List, Optional
Expand All @@ -7,6 +8,7 @@
from torch.fx.node import Target
from torch_tensorrt.fx.converters.converter_utils import (
Frameworks,
get_axes_for_reduce_op,
unified_dtype_converter,
)
from torch_tensorrt.fx.types import TRTDataType, TRTNetwork, TRTTensor
Expand Down Expand Up @@ -157,3 +159,8 @@ def broadcastable(
if not (a_shape[i] == b_shape[i] or a_shape[i] == 1 or b_shape[i] == 1):
return False
return True


get_axes_for_reduce_op = functools.partial(
get_axes_for_reduce_op, has_implicit_batch_dimension=False
)
1 change: 1 addition & 0 deletions py/torch_tensorrt/dynamo/conversion/impl/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
matmul,
normalization,
permutation,
reduce,
select,
shape,
slice,
Expand Down
36 changes: 36 additions & 0 deletions py/torch_tensorrt/dynamo/conversion/impl/reduce.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,36 @@
from typing import Any, Optional, Tuple, Union

import tensorrt as trt
from torch.fx.node import Target
from torch_tensorrt.dynamo._SourceIR import SourceIR
from torch_tensorrt.dynamo.conversion.converter_utils import get_axes_for_reduce_op
from torch_tensorrt.fx.converters.converter_utils import set_layer_name
from torch_tensorrt.fx.types import TRTNetwork, TRTTensor


def amax(
network: TRTNetwork,
target: Target,
source_ir: Optional[SourceIR],
name: str,
input: TRTTensor,
dim: Union[int, Tuple[int]],
keep_dims: Optional[bool] = False,
out: Optional[Any] = None,
) -> TRTTensor:
if not isinstance(input, TRTTensor):
raise RuntimeError(
f"amax received input {input} that is not part of the TensorRT region!"
)

if dim is None:
raise ValueError("amax requires specifying dimension(s) (dim).")

layer = network.add_reduce(
input,
trt.ReduceOperation.MAX,
axes=get_axes_for_reduce_op(dim),
keep_dims=keep_dims,
)
set_layer_name(layer, target, name)
return layer.get_output(0)
93 changes: 93 additions & 0 deletions tests/py/dynamo/converters/test_amax_aten.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,93 @@
import torch
import torch.nn as nn
from harness import DispatchTestCase
from parameterized import parameterized
from torch.testing._internal.common_utils import run_tests


class TestAmaxConverter(DispatchTestCase):
@parameterized.expand(
[
((3, 2, 4), 1, True),
((2, 3, 4, 5), 3, True),
((2, 3, 4, 5), 2, False),
((6, 7, 5, 4, 5), 4, False),
]
)
def test_amax_dim_int_int(self, input_shape, dim, keep_dims, dtype):
class Amax(nn.Module):
def forward(self, x):
return torch.amax(x, dim=dim, keepdim=keep_dims)

inputs = [torch.randn(*input_shape, dtype=dtype)]
self.run_test(
Amax(),
inputs,
expected_ops={torch.ops.aten.amax.default},
)

@parameterized.expand(
[
((3, 2, 4), [1], True),
((2, 1, 4, 5), [0, 3], True),
((2, 3, 4, 5), [0, 1, 2, 3], False),
((6, 7, 5, 4, 5), [1, 3, 4], False),
]
)
def test_amax_dim_tuple_int(self, input_shape, dim, keep_dims, dtype):
class Amax(nn.Module):
def forward(self, x):
return torch.amax(x, dim=dim, keepdim=keep_dims)

inputs = [torch.randn(*input_shape, dtype=dtype)]
self.run_test(
Amax(),
inputs,
expected_ops={torch.ops.aten.amax.default},
)

@parameterized.expand(
[
((3, 2, 4), 1, True, torch.int, 0, 5),
((2, 3, 4, 5), 3, True, torch.int, -10, 10),
((2, 3, 4, 5), 2, False, torch.int32, -5, 0),
((6, 7, 5, 4, 5), 4, False, torch.int32, -5, 5),
]
)
def test_amax_dim_int_int(self, input_shape, dim, keep_dims, dtype, low, high):
class Amax(nn.Module):
def forward(self, x):
return torch.amax(x, dim=dim, keepdim=keep_dims)

inputs = [torch.randint(low, high, input_shape, dtype=dtype)]
self.run_test(
Amax(),
inputs,
expected_ops={torch.ops.aten.amax.default},
check_dtype=False,
)

@parameterized.expand(
[
((3, 2, 4), [1], True, torch.int, 0, 5),
((2, 1, 4, 5), [0, 3], True, torch.int, -10, 10),
((2, 3, 4, 5), [0, 1, 2, 3], False, torch.int32, -5, 0),
((6, 7, 5, 4, 5), [1, 3, 4], False, torch.int32, -5, 5),
]
)
def test_amax_dim_tuple_int(self, input_shape, dim, keep_dims, dtype, low, high):
class Amax(nn.Module):
def forward(self, x):
return torch.amax(x, dim=dim, keepdim=keep_dims)

inputs = [torch.randint(low, high, input_shape, dtype=dtype)]
self.run_test(
Amax(),
inputs,
expected_ops={torch.ops.aten.amax.default},
check_dtype=False,
)


if __name__ == "__main__":
run_tests()