|
1 | 1 | import functools
|
2 | 2 | import logging
|
3 | 3 | import re
|
4 |
| -from typing import Any, List, Optional, Tuple |
| 4 | +from typing import Any, List, Optional, Tuple, Union |
5 | 5 |
|
| 6 | +import numpy as np |
6 | 7 | import tensorrt as trt
|
7 | 8 | import torch
|
8 | 9 | from torch.fx.node import Target
|
9 | 10 | from torch_tensorrt.fx.converters.converter_utils import (
|
10 | 11 | Frameworks,
|
11 | 12 | get_axes_for_reduce_op,
|
| 13 | + to_numpy, |
12 | 14 | unified_dtype_converter,
|
13 | 15 | )
|
14 | 16 | from torch_tensorrt.fx.types import TRTDataType, TRTNetwork, TRTTensor
|
@@ -185,11 +187,85 @@ def extend_attr_to_tuple(
|
185 | 187 |
|
186 | 188 | if isinstance(val, list):
|
187 | 189 | val = tuple(val)
|
188 |
| - return val |
| 190 | + |
| 191 | + if isinstance(val, tuple): |
| 192 | + return val |
| 193 | + else: |
| 194 | + raise AssertionError(f"Could not extend attribute {val}") |
189 | 195 |
|
190 | 196 |
|
191 |
| -def cast_int_or_float_to_bool(network: TRTNetwork, name: str, tensor: TRTTensor): |
| 197 | +def cast_int_or_float_to_bool( |
| 198 | + network: TRTNetwork, name: str, tensor: TRTTensor |
| 199 | +) -> TRTTensor: |
192 | 200 | if tensor.dtype != trt.bool:
|
193 | 201 | return cast_trt_tensor(network, tensor, trt.bool, name)
|
194 | 202 |
|
195 | 203 | return tensor
|
| 204 | + |
| 205 | + |
| 206 | +def create_constant( |
| 207 | + network: TRTNetwork, |
| 208 | + value: Union[int, float, np.ndarray, torch.Tensor], |
| 209 | + name: str, |
| 210 | + dtype: Optional[Union[torch.dtype, np.dtype, TRTDataType]], |
| 211 | +) -> TRTTensor: |
| 212 | + """ |
| 213 | + Add a TensorRT constant layer whose value is `value` to `network`. |
| 214 | + Args: |
| 215 | + network (TRTNetwork): A TensorRT network to which we want to add |
| 216 | + a constant layer. |
| 217 | + value (Union[int, float, np.ndarray, torch.Tensor]): A literal value, Numpy array, |
| 218 | + or a PyTorch tensor that will be used as value of the added TensorRT Constant layer. |
| 219 | + name (str): Name of the added TensorRT Constant layer. |
| 220 | + dtype (Optional[Union[torch.dtype, np.dtype, TRTDataType]]): |
| 221 | + If a dtype is given, we will convert the type of the given `value` to this dtype. |
| 222 | + Returns: |
| 223 | + A TensorRT ITensor that represents the given value. |
| 224 | + """ |
| 225 | + constant = network.add_constant( |
| 226 | + (1,) if isinstance(value, (int, float)) else value.shape, |
| 227 | + to_numpy(value, dtype).copy(), |
| 228 | + ) |
| 229 | + constant.name = name |
| 230 | + return constant.get_output(0) |
| 231 | + |
| 232 | + |
| 233 | +def get_trt_tensor( |
| 234 | + network: TRTNetwork, |
| 235 | + input_val: Any, |
| 236 | + name: str, |
| 237 | + dtype: Optional[Union[torch.dtype, np.dtype, TRTDataType]] = None, |
| 238 | +) -> TRTTensor: |
| 239 | + """ |
| 240 | + Given a value of random type, we try to convert it to a TensorRT ITensor. |
| 241 | + An runtime error is raised if we're not able to do that. |
| 242 | + Args: |
| 243 | + network (TRTNetwork): A TensorRT network. If we want to |
| 244 | + add a TensorRT Constant layer, we will add it to this network. |
| 245 | + input_val (Any): An value that we want to convert to a TensorRT ITensor. |
| 246 | + name (str): The name of the created TensorRT Constant layer if there's |
| 247 | + one. |
| 248 | + dtype (Optional[Union[torch.dtype, np.dtype, TRTDataType]]): |
| 249 | + If dtype is provided, the given value will be converted to this dtype. |
| 250 | + Returns: |
| 251 | + A TensorRT ITensor that represents the given value. |
| 252 | + """ |
| 253 | + # TRT can not add constant for bool type. We do a work around to 1) cast it to int and 2)cast to bool later |
| 254 | + # This is useful for logical operations which require input to be bool type |
| 255 | + if isinstance(input_val, bool): |
| 256 | + input_val = int(input_val) |
| 257 | + elif isinstance(input_val, torch.Tensor) and ( |
| 258 | + input_val.dtype == torch.bool or input_val.dtype == torch.int64 |
| 259 | + ): |
| 260 | + input_val = input_val.to(torch.int32) |
| 261 | + elif isinstance(input_val, np.ndarray) and ( |
| 262 | + input_val.dtype == np.bool_ or input_val.dtype == np.int64 |
| 263 | + ): |
| 264 | + input_val = input_val.astype(np.int32) |
| 265 | + |
| 266 | + if isinstance(input_val, (torch.Tensor, np.ndarray, int, float)): |
| 267 | + return create_constant(network, input_val, name, dtype) |
| 268 | + elif isinstance(input_val, TRTTensor): |
| 269 | + return input_val |
| 270 | + else: |
| 271 | + raise AssertionError(f"Cannot convert {input_val} to TRT constant") |
0 commit comments