Skip to content

Commit 9846f23

Browse files
committed
use enforce_tensor_types decorator to cast type
1 parent c0cfeb8 commit 9846f23

File tree

1 file changed

+2
-6
lines changed
  • py/torch_tensorrt/dynamo/conversion/impl

1 file changed

+2
-6
lines changed

py/torch_tensorrt/dynamo/conversion/impl/shuffle.py

+2-6
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,8 @@
11
from typing import Optional, Sequence, Union
22

3-
import numpy as np
4-
import torch
53
from torch.fx.node import Target
64
from torch_tensorrt.dynamo.conversion._ConversionContext import ConversionContext
7-
from torch_tensorrt.dynamo.conversion.converter_utils import SourceIR, get_trt_tensor
5+
from torch_tensorrt.dynamo.conversion.converter_utils import SourceIR
86
from torch_tensorrt.fx.converters.converter_utils import set_layer_name
97
from torch_tensorrt.fx.types import TRTTensor
108

@@ -14,11 +12,9 @@ def reshape(
1412
target: Union[Target, str],
1513
source_ir: Optional[SourceIR],
1614
name: str,
17-
input: Sequence[Union[TRTTensor, torch.Tensor, np.ndarray]],
15+
input: TRTTensor,
1816
shape: Sequence[int],
1917
) -> TRTTensor:
20-
if not isinstance(input, TRTTensor):
21-
input = get_trt_tensor(ctx, input, f"{name}_input")
2218
layer = ctx.net.add_shuffle(input)
2319
layer.reshape_dims = tuple(shape)
2420
set_layer_name(layer, target, name, source_ir)

0 commit comments

Comments
 (0)