Skip to content

Commit cd734eb

Browse files
committed
add conv validator
1 parent 17003c4 commit cd734eb

File tree

1 file changed

+7
-8
lines changed

1 file changed

+7
-8
lines changed

py/torch_tensorrt/dynamo/conversion/aten_ops_converters.py

+7-8
Original file line numberDiff line numberDiff line change
@@ -814,21 +814,20 @@ def aten_ops_isinf(
814814
)
815815

816816

817-
@dynamo_tensorrt_converter(torch.ops.aten.convolution.default)
817+
def conv_param_validator(conv_node: Node) -> bool:
818+
return (not conv_node.args[6]) and (conv_node.args[7] in ([0], [0, 0], [0, 0, 0]))
819+
820+
821+
@dynamo_tensorrt_converter(
822+
torch.ops.aten.convolution.default, capability_validator=conv_param_validator
823+
)
818824
def aten_ops_convolution(
819825
network: TRTNetwork,
820826
target: Target,
821827
args: Tuple[Argument, ...],
822828
kwargs: Dict[str, Argument],
823829
name: str,
824830
) -> Union[TRTTensor, Sequence[TRTTensor]]:
825-
# we do not handle transposed.
826-
if args[6] is True:
827-
raise RuntimeError(f"Target {target} does not support `transposed=True` ")
828-
# we do not handle output_padding.
829-
if args[7] not in ([0], [0, 0], [0, 0, 0]):
830-
raise RuntimeError(f"Target {target} has non-0 output_padding")
831-
832831
return impl.conv.convNd(
833832
network,
834833
target,

0 commit comments

Comments
 (0)