@@ -88,6 +88,7 @@ def __init__(self):
88
88
def forward (self , x ):
89
89
return F .adaptive_avg_pool2d (x , (5 , 5 ))
90
90
91
+
91
92
# Sample Nested Module (for module-level fallback testing)
92
93
class ModuleFallbackSub (nn .Module ):
93
94
@@ -99,6 +100,7 @@ def __init__(self):
99
100
def forward (self , x ):
100
101
return self .relu (self .conv (x ))
101
102
103
+
102
104
class ModuleFallbackMain (nn .Module ):
103
105
104
106
def __init__ (self ):
@@ -110,6 +112,7 @@ def __init__(self):
110
112
def forward (self , x ):
111
113
return self .relu (self .conv (self .layer1 (x )))
112
114
115
+
113
116
# Sample Looping Modules (for loop fallback testing)
114
117
class LoopFallbackEval (nn .Module ):
115
118
@@ -122,6 +125,7 @@ def forward(self, x):
122
125
add_list = torch .cat ((add_list , torch .tensor ([x .shape [1 ]]).to (x .device )), 0 )
123
126
return x + add_list
124
127
128
+
125
129
class LoopFallbackNoEval (nn .Module ):
126
130
127
131
def __init__ (self ):
@@ -132,6 +136,7 @@ def forward(self, x):
132
136
x = x + torch .ones_like (x )
133
137
return x
134
138
139
+
135
140
# Sample Conditional Model (for testing partitioning and fallback in conditionals)
136
141
class FallbackIf (torch .nn .Module ):
137
142
@@ -156,21 +161,23 @@ def forward(self, x):
156
161
x = self .conv1 (x )
157
162
return x
158
163
164
+
159
165
class ModelManifest :
166
+
160
167
def __init__ (self ):
161
168
self .version_matches = False
162
169
if not os .path .exists (MANIFEST_FILE ) or os .stat (MANIFEST_FILE ).st_size == 0 :
163
170
self .manifest = {}
164
- self .manifest .update ({'version' : torch_version })
171
+ self .manifest .update ({'version' : torch_version })
165
172
else :
166
173
with open (MANIFEST_FILE , 'r' ) as f :
167
174
self .manifest = json .load (f )
168
175
if self .manifest ['version' ] == torch_version :
169
176
self .version_matches = True
170
177
else :
171
- print ("Torch version: {} mismatches with manifest's version: {}. Re-downloading all models" .format (torch_version , self .manifest ['version' ]))
178
+ print ("Torch version: {} mismatches with manifest's version: {}. Re-downloading all models" .format (
179
+ torch_version , self .manifest ['version' ]))
172
180
self .manifest ["version" ] = torch_version
173
-
174
181
175
182
def download (self , models ):
176
183
if self .version_matches :
@@ -194,13 +201,13 @@ def write(self, manifest_record):
194
201
record = json .dumps (manifest_record )
195
202
f .write (record )
196
203
f .truncate ()
197
-
204
+
198
205
def get_manifest (self ):
199
206
return self .manifest
200
-
207
+
201
208
def if_version_matches (self ):
202
209
return self .version_matches
203
-
210
+
204
211
def get (self , n , m ):
205
212
print ("Downloading {}" .format (n ))
206
213
m ["model" ] = m ["model" ].eval ().cuda ()
@@ -214,10 +221,11 @@ def get(self, n, m):
214
221
if m ["path" ] == "both" or m ["path" ] == "script" :
215
222
script_model = torch .jit .script (m ["model" ])
216
223
torch .jit .save (script_model , script_filename )
217
-
218
- self .manifest .update ({n : [traced_filename , script_filename ]})
219
224
220
- def generate_custom_models (manifest , version_matches = False ):
225
+ self .manifest .update ({n : [traced_filename , script_filename ]})
226
+
227
+
228
+ def generate_custom_models (manifest , version_matches = False ):
221
229
# Pool
222
230
traced_pool_name = "pooling_traced.jit.pt"
223
231
if not (version_matches and os .path .exists (traced_pool_name )):
@@ -248,7 +256,8 @@ def generate_custom_models(manifest, version_matches = False):
248
256
loop_fallback_no_eval_model = LoopFallbackNoEval ().eval ().cuda ()
249
257
loop_fallback_no_eval_script_model = torch .jit .script (loop_fallback_no_eval_model )
250
258
torch .jit .save (loop_fallback_no_eval_script_model , scripted_loop_fallback_no_eval_name )
251
- manifest .update ({"torchtrt_loop_fallback_no_eval" : [scripted_loop_fallback_name , scripted_loop_fallback_no_eval_name ]})
259
+ manifest .update (
260
+ {"torchtrt_loop_fallback_no_eval" : [scripted_loop_fallback_name , scripted_loop_fallback_no_eval_name ]})
252
261
253
262
# Conditional
254
263
scripted_conditional_name = "conditional_scripted.jit.pt"
@@ -287,7 +296,7 @@ def generate_custom_models(manifest, version_matches = False):
287
296
288
297
traced_model = torch .jit .trace (model , [tokens_tensor , segments_tensors ])
289
298
torch .jit .save (traced_model , traced_bert_uncased_name )
290
- manifest .update ({"torchtrt_bert_case_uncased" : [traced_bert_uncased_name ]})
299
+ manifest .update ({"torchtrt_bert_case_uncased" : [traced_bert_uncased_name ]})
291
300
292
301
293
302
manifest = ModelManifest ()
@@ -302,4 +311,4 @@ def generate_custom_models(manifest, version_matches = False):
302
311
generate_custom_models (manifest_record , manifest .if_version_matches ())
303
312
304
313
# Update the manifest file
305
- manifest .write (manifest_record )
314
+ manifest .write (manifest_record )
0 commit comments