5
5
6
6
# Use system installed Python packages
7
7
PYT_PATH = '/opt/conda/lib/python3.8/site-packages' if not 'PYT_PATH' in os .environ else os .environ ["PYT_PATH" ]
8
+ print (f"Using python path { PYT_PATH } " )
8
9
9
10
# Set the root directory to the directory of the noxfile unless the user wants to
10
11
# TOP_DIR
11
12
TOP_DIR = os .path .dirname (os .path .realpath (__file__ )) if not 'TOP_DIR' in os .environ else os .environ ["TOP_DIR" ]
13
+ print (f"Test root directory { TOP_DIR } " )
12
14
13
15
# Set the USE_CXX11=1 to use cxx11_abi
14
16
USE_CXX11 = 0 if not 'USE_CXX11' in os .environ else os .environ ["USE_CXX11" ]
17
+ if USE_CXX11 :
18
+ print ("Using cxx11 abi" )
15
19
16
20
# Set the USE_HOST_DEPS=1 to use host dependencies for tests
17
21
USE_HOST_DEPS = 0 if not 'USE_HOST_DEPS' in os .environ else os .environ ["USE_HOST_DEPS" ]
22
+ if USE_HOST_DEPS :
23
+ print ("Using dependencies from host python" )
18
24
19
25
SUPPORTED_PYTHON_VERSIONS = ["3.7" , "3.8" , "3.9" , "3.10" ]
20
26
@@ -58,6 +64,12 @@ def download_datasets(session):
58
64
59
65
def train_model (session ):
60
66
session .chdir (os .path .join (TOP_DIR , 'examples/int8/training/vgg16' ))
67
+ session .install ("-r" , "requirements.txt" )
68
+ if os .path .exists ('vgg16_ckpts/ckpt_epoch25.pth' ):
69
+ session .run_always ('python' ,
70
+ 'export_ckpt.py' ,
71
+ 'vgg16_ckpts/ckpt_epoch25.pth' )
72
+ return
61
73
if USE_HOST_DEPS :
62
74
session .run_always ('python' ,
63
75
'main.py' ,
@@ -140,14 +152,14 @@ def run_base_tests(session):
140
152
print ("Running basic tests" )
141
153
session .chdir (os .path .join (TOP_DIR , 'tests/py' ))
142
154
tests = [
143
- "test_api.py " ,
144
- "test_to_backend_api.py" ,
155
+ "api " ,
156
+ "integrations/ test_to_backend_api.py" ,
145
157
]
146
158
for test in tests :
147
159
if USE_HOST_DEPS :
148
- session .run_always ('python ' , test , env = {'PYTHONPATH' : PYT_PATH })
160
+ session .run_always ('pytest ' , test , env = {'PYTHONPATH' : PYT_PATH })
149
161
else :
150
- session .run_always ("python " , test )
162
+ session .run_always ("pytest " , test )
151
163
152
164
def run_accuracy_tests (session ):
153
165
print ("Running accuracy tests" )
@@ -169,23 +181,23 @@ def copy_model(session):
169
181
session .run_always ('cp' ,
170
182
'-rpf' ,
171
183
os .path .join (TOP_DIR , src_file ),
172
- os .path .join (TOP_DIR , str ('tests/py /' ) + file_name ),
184
+ os .path .join (TOP_DIR , str ('tests/modules /' ) + file_name ),
173
185
external = True )
174
186
175
187
def run_int8_accuracy_tests (session ):
176
188
print ("Running accuracy tests" )
177
189
copy_model (session )
178
190
session .chdir (os .path .join (TOP_DIR , 'tests/py' ))
179
191
tests = [
180
- "test_ptq_dataloader_calibrator .py" ,
181
- "test_ptq_to_backend .py" ,
182
- "test_qat_trt_accuracy.py " ,
192
+ "ptq/test_ptq_to_backend .py" ,
193
+ "ptq/test_ptq_dataloader_calibrator .py" ,
194
+ "qat/ " ,
183
195
]
184
196
for test in tests :
185
197
if USE_HOST_DEPS :
186
- session .run_always ('python ' , test , env = {'PYTHONPATH' : PYT_PATH })
198
+ session .run_always ('pytest ' , test , env = {'PYTHONPATH' : PYT_PATH })
187
199
else :
188
- session .run_always ("python " , test )
200
+ session .run_always ("pytest " , test )
189
201
190
202
def run_trt_compatibility_tests (session ):
191
203
print ("Running TensorRT compatibility tests" )
@@ -197,9 +209,9 @@ def run_trt_compatibility_tests(session):
197
209
]
198
210
for test in tests :
199
211
if USE_HOST_DEPS :
200
- session .run_always ('python ' , test , env = {'PYTHONPATH' : PYT_PATH })
212
+ session .run_always ('pytest ' , test , env = {'PYTHONPATH' : PYT_PATH })
201
213
else :
202
- session .run_always ("python " , test )
214
+ session .run_always ("pytest " , test )
203
215
204
216
def run_dla_tests (session ):
205
217
print ("Running DLA tests" )
@@ -209,9 +221,9 @@ def run_dla_tests(session):
209
221
]
210
222
for test in tests :
211
223
if USE_HOST_DEPS :
212
- session .run_always ('python ' , test , env = {'PYTHONPATH' : PYT_PATH })
224
+ session .run_always ('pytest ' , test , env = {'PYTHONPATH' : PYT_PATH })
213
225
else :
214
- session .run_always ("python " , test )
226
+ session .run_always ("pytest " , test )
215
227
216
228
def run_multi_gpu_tests (session ):
217
229
print ("Running multi GPU tests" )
@@ -221,9 +233,9 @@ def run_multi_gpu_tests(session):
221
233
]
222
234
for test in tests :
223
235
if USE_HOST_DEPS :
224
- session .run_always ('python ' , test , env = {'PYTHONPATH' : PYT_PATH })
236
+ session .run_always ('pytest ' , test , env = {'PYTHONPATH' : PYT_PATH })
225
237
else :
226
- session .run_always ("python " , test )
238
+ session .run_always ("pytest " , test )
227
239
228
240
def run_l0_api_tests (session ):
229
241
if not USE_HOST_DEPS :
@@ -245,7 +257,6 @@ def run_l1_accuracy_tests(session):
245
257
if not USE_HOST_DEPS :
246
258
install_deps (session )
247
259
install_torch_trt (session )
248
- download_models (session )
249
260
download_datasets (session )
250
261
train_model (session )
251
262
run_accuracy_tests (session )
@@ -255,7 +266,6 @@ def run_l1_int8_accuracy_tests(session):
255
266
if not USE_HOST_DEPS :
256
267
install_deps (session )
257
268
install_torch_trt (session )
258
- download_models (session )
259
269
download_datasets (session )
260
270
train_model (session )
261
271
finetune_model (session )
@@ -313,4 +323,8 @@ def l2_multi_gpu_tests(session):
313
323
@nox .session (python = SUPPORTED_PYTHON_VERSIONS , reuse_venv = True )
314
324
def download_test_models (session ):
315
325
"""Grab all the models needed for testing"""
326
+ try :
327
+ import torch
328
+ except ModuleNotFoundError :
329
+ install_deps (session )
316
330
download_models (session )
0 commit comments