1
1
from typing import Optional , cast
2
2
import math
3
+ import numpy as np
3
4
4
5
from torch .fx .node import Target
5
6
@@ -25,12 +26,6 @@ def slice_op(
25
26
stop : int ,
26
27
step : int ,
27
28
) -> TRTTensor :
28
- if not isinstance (input , TRTTensor ):
29
- raise RuntimeError (
30
- f"slice_tensor received input { input } that is not part "
31
- "of the TensorRT region!"
32
- )
33
-
34
29
ranks = len (input .shape ) + (1 if network .has_implicit_batch_dimension else 0 )
35
30
dim = get_positive_dim (cast (int , dim ), ranks )
36
31
dynamic_shape = has_dynamic_shape (input .shape )
@@ -49,6 +44,22 @@ def slice_op(
49
44
if stop_int == 2 ** 63 - 1 :
50
45
stop_int = input .shape [dim ]
51
46
step_int = cast (int , step )
47
+
48
+ if isinstance (input , np .ndarray ):
49
+ tensor_to_freeze = np .take (
50
+ input , np .arange (start_int , stop_int , step_int ), axis = dim
51
+ )
52
+ # TODO: Fix naming for constant tensors
53
+ frozen_trt_tensor = get_trt_tensor (network , tensor_to_freeze , name )
54
+ return frozen_trt_tensor
55
+
56
+ if not isinstance (input , TRTTensor ):
57
+ raise RuntimeError (
58
+ f"slice_tensor received input { input } that is not part "
59
+ "of the TensorRT region!"
60
+ )
61
+
62
+ # TRT Input Formatting
52
63
start = [0 ] * len (input .shape )
53
64
start [dim ] = start_int
54
65
stride = [1 ] * len (start )
0 commit comments