|
1 |
| -from typing import Optional, cast |
| 1 | +from typing import Optional, cast, Union, Sequence |
| 2 | +import tensorrt as trt |
2 | 3 |
|
3 | 4 | import numpy as np
|
4 | 5 | from torch.fx.node import Target
|
|
10 | 11 | to_numpy,
|
11 | 12 | )
|
12 | 13 | from torch_tensorrt.fx.types import Shape, TRTNetwork, TRTTensor
|
| 14 | +from torch_tensorrt.fx.converters.converter_utils import set_layer_name |
13 | 15 |
|
14 | 16 |
|
15 | 17 | def select(
|
@@ -61,3 +63,43 @@ def select(
|
61 | 63 | if len(out.shape) != 1:
|
62 | 64 | layer = network.add_shuffle(out)
|
63 | 65 | return layer.get_output(0)
|
| 66 | + |
| 67 | + |
| 68 | +def select( |
| 69 | + network: TRTNetwork, |
| 70 | + target: Target, |
| 71 | + source_ir: Optional[SourceIR], |
| 72 | + name: str, |
| 73 | + input: TRTTensor, |
| 74 | + index: Union[TRTTensor, Sequence[TRTTensor]] |
| 75 | +) -> TRTTensor: |
| 76 | + adv_indx_indices = [] |
| 77 | + tensor_indices = [] |
| 78 | + |
| 79 | + for i in len(index): |
| 80 | + ind = index[i] |
| 81 | + #FIXME: check if the datatype for the indices needs to be casted to INT32 |
| 82 | + #TRTInterpretor should take care |
| 83 | + adv_indx_indices.append(i) |
| 84 | + tensor_indices.append(ind) |
| 85 | + |
| 86 | + if not tensor_indices: |
| 87 | + identity_layer = network.add_identity(input) |
| 88 | + identity_layer.set_output_type(0, trt.int32) |
| 89 | + set_layer_name(identity_layer, target, name + "_index_identity", source_ir) |
| 90 | + return identity_layer.get_output(0) |
| 91 | + elif (len(tensor_indices) == 1): |
| 92 | + indices_tensor = tensor_indices[0] |
| 93 | + gather_layer = network.add_gather(input, indices_tensor, adv_indx_indices[0]) |
| 94 | + set_layer_name(gather_layer, target, name + "_index_gather", source_ir) |
| 95 | + return gather_layer.get_output(0) |
| 96 | + else: |
| 97 | + input_shape = input.shape |
| 98 | + rank = len(input_shape) |
| 99 | + adv_indx_count = len(adv_indx_indices) |
| 100 | + input_shape_layer = network.add_shape(input) |
| 101 | + set_layer_name(input_shape_layer, target, name + "_index_shape", source_ir) |
| 102 | + input_shape_tensor = input_shape_layer.get_output(0) |
| 103 | + return input_shape_tensor.get_output(0) |
| 104 | + |
| 105 | + |
0 commit comments