Skip to content

Commit 2228031

Browse files
committed
fix: Move aten.neg test case
1 parent 7a4288e commit 2228031

File tree

2 files changed

+7
-1
lines changed

2 files changed

+7
-1
lines changed

py/torch_tensorrt/dynamo/conversion/impl/unary/ops.py

+5
Original file line numberDiff line numberDiff line change
@@ -393,6 +393,11 @@ def neg(
393393
name: str,
394394
input_val: TRTTensor,
395395
) -> TRTTensor:
396+
if (isinstance(input_val, TRTTensor)) and (
397+
input_val.dtype == trt.int8 or input_val.dtype == trt.int32
398+
):
399+
input_val = cast_trt_tensor(network, input_val, trt.float32, name)
400+
396401
return convert_unary(
397402
network, target, source_ir, name, trt.UnaryOperation.NEG, input_val
398403
)

tests/py/dynamo/converters/test_neg_aten.py renamed to tests/py/dynamo/conversion/test_neg_aten.py

+2-1
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,8 @@
33
from parameterized import parameterized
44
from torch.testing._internal.common_utils import run_tests
55
from torch_tensorrt import Input
6-
from torch_tensorrt.dynamo.test_utils import DispatchTestCase
6+
7+
from .harness import DispatchTestCase
78

89

910
class TestNegConverter(DispatchTestCase):

0 commit comments

Comments
 (0)