@@ -46,6 +46,30 @@ def test_compile_script(self):
46
46
self .assertTrue (same < 2e-3 )
47
47
48
48
49
+ class TestPTtoTRTtoPT (ModelTestCase ):
50
+
51
+ def setUp (self ):
52
+ self .input = torch .randn ((1 , 3 , 224 , 224 )).to ("cuda" )
53
+ self .ts_model = torch .jit .script (self .model )
54
+
55
+ def test_pt_to_trt_to_pt (self ):
56
+ compile_spec = {
57
+ "input_shapes" : [self .input .shape ],
58
+ "device" : {
59
+ "device_type" : trtorch .DeviceType .GPU ,
60
+ "gpu_id" : 0 ,
61
+ "dla_core" : 0 ,
62
+ "allow_gpu_fallback" : False ,
63
+ "disable_tf32" : False
64
+ }
65
+ }
66
+
67
+ trt_engine = trtorch .convert_method_to_trt_engine (self .ts_model , "forward" , compile_spec )
68
+ trt_mod = trtorch .embed_engine_in_new_module (trt_engine )
69
+ same = (trt_mod (self .input ) - self .ts_model (self .input )).abs ().max ()
70
+ self .assertTrue (same < 2e-3 )
71
+
72
+
49
73
class TestCheckMethodOpSupport (unittest .TestCase ):
50
74
51
75
def setUp (self ):
@@ -59,13 +83,13 @@ def test_check_support(self):
59
83
class TestLoggingAPIs (unittest .TestCase ):
60
84
61
85
def test_logging_prefix (self ):
62
- new_prefix = "TEST "
86
+ new_prefix = "Python API Test: "
63
87
trtorch .logging .set_logging_prefix (new_prefix )
64
88
logging_prefix = trtorch .logging .get_logging_prefix ()
65
89
self .assertEqual (new_prefix , logging_prefix )
66
90
67
91
def test_reportable_log_level (self ):
68
- new_level = trtorch .logging .Level .Warning
92
+ new_level = trtorch .logging .Level .Error
69
93
trtorch .logging .set_reportable_log_level (new_level )
70
94
level = trtorch .logging .get_reportable_log_level ()
71
95
self .assertEqual (new_level , level )
@@ -78,10 +102,11 @@ def test_is_colored_output_on(self):
78
102
79
103
def test_suite ():
80
104
suite = unittest .TestSuite ()
105
+ suite .addTest (unittest .makeSuite (TestLoggingAPIs ))
81
106
suite .addTest (TestCompile .parametrize (TestCompile , model = models .resnet18 (pretrained = True )))
82
107
suite .addTest (TestCompile .parametrize (TestCompile , model = models .mobilenet_v2 (pretrained = True )))
108
+ suite .addTest (TestPTtoTRTtoPT .parametrize (TestPTtoTRTtoPT , model = models .mobilenet_v2 (pretrained = True )))
83
109
suite .addTest (unittest .makeSuite (TestCheckMethodOpSupport ))
84
- suite .addTest (unittest .makeSuite (TestLoggingAPIs ))
85
110
86
111
return suite
87
112
0 commit comments