1
1
from typing import Optional
2
2
3
+ import numpy as np
3
4
import tensorrt as trt
4
5
import torch
5
6
from torch .fx .node import Target
@@ -23,16 +24,6 @@ def where(
23
24
other : TRTTensor ,
24
25
condition : TRTTensor ,
25
26
) -> TRTTensor :
26
- input_dim = len (tuple (input .shape ))
27
- other_dim = len (tuple (other .shape ))
28
- condition_dim = len (tuple (condition .shape ))
29
-
30
- if type (input ) != TRTTensor :
31
- assert type (input ) is torch .Tensor , f"value { input } is not torch.Tensor!"
32
-
33
- if type (other ) != TRTTensor :
34
- assert type (other ) is torch .Tensor , f"value { other } is not torch.Tensor!"
35
-
36
27
if not (broadcastable (input , other )):
37
28
assert "The two torch tensors should be broadcastable"
38
29
@@ -49,33 +40,37 @@ def where(
49
40
x_shape = list (input .shape )
50
41
y_shape = list (other .shape )
51
42
condition_shape = list (condition .shape )
43
+
52
44
output_shape = list (torch .broadcast_shapes (condition_shape , x_shape , y_shape ))
53
45
54
46
# expand shape
55
- if type (condition ) != TRTTensor :
56
- assert condition .dtype == torch .bool , "condition dtype is not bool"
47
+ if not isinstance (condition , TRTTensor ) :
48
+ assert condition .dtype in ( torch .bool , np . bool_ ) , "condition dtype is not bool"
57
49
if condition_shape != output_shape :
58
- condition .expand (output_shape )
59
- condition = condition .to (torch .int32 )
60
- condition_const = get_trt_tensor (ctx , condition , f"{ name } _condition" )
61
- condition_layer = ctx .net .add_identity (condition_const )
62
- condition_layer .set_output_type (0 , trt .bool )
63
- set_layer_name (condition_layer , target , f"{ name } _condition" )
64
- condition_val = condition_layer .get_output (0 )
50
+ condition = (
51
+ condition .expand (output_shape )
52
+ if isinstance (condition , torch .Tensor )
53
+ else np .broadcast_to (condition , output_shape )
54
+ )
55
+ condition_val = get_trt_tensor (ctx , condition , f"{ name } _condition" )
65
56
else :
66
57
assert condition .dtype == trt .bool , "mask dtype is not bool!"
67
- if len ( condition_shape ) != condition_dim :
58
+ if condition_shape != output_shape :
68
59
condition_val = expand (
69
60
ctx , target , source_ir , f"{ name } _expand" , condition , output_shape
70
61
)
71
62
else :
72
63
condition_val = condition
73
64
74
- if type (input ) != TRTTensor :
65
+ if not isinstance (input , TRTTensor ) :
75
66
if x_shape != output_shape :
76
67
# special case where 1 element in input
77
68
if len (input .shape ) == 0 :
78
- input = input .unsqueeze (0 )
69
+ input = (
70
+ input .unsqueeze (0 )
71
+ if isinstance (input , torch .Tensor )
72
+ else np .expand_dims (input , axis = 0 )
73
+ )
79
74
input = input .expand (output_shape )
80
75
x_val = get_trt_tensor (ctx , input , f"{ name } _x" )
81
76
else :
@@ -85,11 +80,15 @@ def where(
85
80
ctx , target , source_ir , f"{ name } _x_expand" , input , output_shape
86
81
)
87
82
88
- if type (other ) != TRTTensor :
83
+ if not isinstance (other , TRTTensor ) :
89
84
if y_shape != output_shape :
90
85
# special case where 1 element in other
91
86
if len (other .shape ) == 0 :
92
- other = other .unsqueeze (0 )
87
+ other = (
88
+ other .unsqueeze (0 )
89
+ if isinstance (other , torch .Tensor )
90
+ else np .expand_dims (other , axis = 0 )
91
+ )
93
92
other = other .expand (output_shape )
94
93
y_val = get_trt_tensor (ctx , other , f"{ name } _y" )
95
94
else :
0 commit comments