18
18
from torch .fx .immutable_collections import immutable_list
19
19
from torch .fx .node import Argument , Target
20
20
21
- from ..utils import get_dynamic_dims , torch_dtype_from_trt , torch_dtype_to_trt
21
+ from ..utils import get_dynamic_dims , unified_dtype_converter , Frameworks
22
22
23
23
from .converter_utils import * # noqa: F403
24
24
from torch_tensorrt .fx .passes .lower_basic_pass import (
@@ -400,7 +400,7 @@ def acc_ops_pad_with_slice_layer(
400
400
)
401
401
402
402
# cast value to TRTensor
403
- dt = torch_dtype_from_trt (input_val .dtype )
403
+ dt = unified_dtype_converter (input_val .dtype , Frameworks . TORCH )
404
404
value = 0 if value == None else value
405
405
value_const = get_trt_tensor (
406
406
network , torch .tensor ([value ], dtype = dt ), f"{ name } _value"
@@ -1561,7 +1561,7 @@ def acc_ops_to_dtype(
1561
1561
input_t = get_trt_tensor (network , input_val , f"{ name } _input_t" )
1562
1562
if input_dtype :
1563
1563
if isinstance (input_dtype , torch .dtype ):
1564
- input_dtype = torch_dtype_to_trt (input_dtype )
1564
+ input_dtype = unified_dtype_converter (input_dtype , Frameworks . TRT )
1565
1565
input_t = type_cast (network , target , f"{ name } _input" , input_t , input_dtype )
1566
1566
return input_t
1567
1567
@@ -1822,7 +1822,7 @@ def acc_ops_logical_xor(
1822
1822
# f"isinf received input {input_t} that is not part "
1823
1823
# "of the TensorRT region!"
1824
1824
# )
1825
- # tdtype = torch_dtype_from_trt (input_t.dtype)
1825
+ # tdtype = unified_dtype_converter (input_t.dtype, Frameworks.TORCH )
1826
1826
1827
1827
# inf_t = torch.ones(tuple(input_t.shape))
1828
1828
# inf_t = inf_t * float("inf")
@@ -1860,7 +1860,7 @@ def acc_ops_any(
1860
1860
1861
1861
if input_t .dtype in (trt .float32 , trt .float16 , trt .int32 ):
1862
1862
comp_t = torch .zeros (tuple ([* input_t .shape ])).to (
1863
- torch_dtype_from_trt (input_t .dtype )
1863
+ unified_dtype_converter (input_t .dtype , Frameworks . TORCH )
1864
1864
)
1865
1865
comp_t = get_trt_tensor (network , comp_t , f"{ name } _comp_t" )
1866
1866
kwargs_new = {"input" : input_t , "other" : comp_t }
@@ -2749,7 +2749,7 @@ def acc_ops_masked_fill_tensor(
2749
2749
if type (value_t ) is torch .Tensor :
2750
2750
value_t = value_t .cpu ().numpy ()
2751
2751
# cast to input type
2752
- input_dtype = torch_dtype_from_trt (input_t .dtype )
2752
+ input_dtype = unified_dtype_converter (input_t .dtype , Frameworks . TORCH )
2753
2753
value_t = (torch .ones (shape ) * value_t ).to (input_dtype )
2754
2754
input_val = get_trt_tensor (network , input_t , f"{ name } _input" )
2755
2755
value_val = get_trt_tensor (network , value_t , f"{ name } _input" )
@@ -2883,7 +2883,11 @@ def add_clamp(network, input, val, op, name):
2883
2883
# clamping scalar
2884
2884
acc_ops_clamp_trt = get_trt_tensor (
2885
2885
network ,
2886
- squeeze_left (torch .tensor ([val ], dtype = torch_dtype_from_trt (input .dtype ))),
2886
+ squeeze_left (
2887
+ torch .tensor (
2888
+ [val ], dtype = unified_dtype_converter (input .dtype , Frameworks .TORCH )
2889
+ )
2890
+ ),
2887
2891
f"{ name } _clamp_{ val } " ,
2888
2892
)
2889
2893
else :
@@ -2892,7 +2896,8 @@ def add_clamp(network, input, val, op, name):
2892
2896
(
2893
2897
val
2894
2898
* torch .ones (
2895
- acc_ops_clamp_shape , dtype = torch_dtype_from_trt (input .dtype )
2899
+ acc_ops_clamp_shape ,
2900
+ dtype = unified_dtype_converter (input .dtype , Frameworks .TORCH ),
2896
2901
)
2897
2902
)
2898
2903
.cpu ()
@@ -3538,7 +3543,9 @@ def acc_ops_cumsum(
3538
3543
iterator = loop .add_iterator (input_val , dim , False )
3539
3544
data = iterator .get_output (0 )
3540
3545
new_dims = tuple (data .shape )
3541
- zero_tensor = torch .zeros (new_dims , dtype = trt_dtype_to_torch_dtype (input_val .dtype ))
3546
+ zero_tensor = torch .zeros (
3547
+ new_dims , dtype = unified_dtype_converter (input_val .dtype , Frameworks .TORCH )
3548
+ )
3542
3549
zero_tensor = network .add_constant (
3543
3550
zero_tensor .shape , to_numpy (zero_tensor )
3544
3551
).get_output (0 )
@@ -3689,7 +3696,7 @@ def acc_ops_new_ones(
3689
3696
dtype_val = kwargs .get ("dtype" )
3690
3697
if dtype_val is None :
3691
3698
dtype_val = input_val .dtype
3692
- dtype_val = torch_dtype_from_trt (dtype_val )
3699
+ dtype_val = unified_dtype_converter (dtype_val , Frameworks . TORCH )
3693
3700
3694
3701
device_val = kwargs .get ("device" )
3695
3702
assert (
@@ -3713,7 +3720,7 @@ def acc_ops_new_empty(
3713
3720
dtype_val = kwargs .get ("dtype" )
3714
3721
if dtype_val is None :
3715
3722
dtype_val = input_val .dtype
3716
- dtype_val = torch_dtype_from_trt (dtype_val )
3723
+ dtype_val = unified_dtype_converter (dtype_val , Frameworks . TORCH )
3717
3724
3718
3725
device_val = kwargs .get ("device" )
3719
3726
assert (
0 commit comments