@@ -217,78 +217,76 @@ def get(self, n, m):
217
217
218
218
self .manifest .update ({n : [traced_filename , script_filename ]})
219
219
220
- def export_model (model , model_name , version_matches ):
221
- if version_matches and os .path .exists (model_name ):
222
- print ("Skipping model {}" .format (model_name ))
223
- else :
224
- print ("Saving model {}" .format (model_name ))
225
- torch .jit .save (model , model_name )
226
-
227
-
228
- def generate_custom_models (manifest , matches = False ):
220
+ def generate_custom_models (manifest , version_matches = False ):
229
221
# Pool
230
- model = Pool ().eval ().cuda ()
231
- x = torch .ones ([1 , 3 , 10 , 10 ]).cuda ()
232
-
233
- trace_model = torch .jit .trace (model , x )
234
222
traced_pool_name = "pooling_traced.jit.pt"
235
- export_model (trace_model , traced_pool_name , matches )
223
+ if not (version_matches and os .path .exists (traced_pool_name )):
224
+ model = Pool ().eval ().cuda ()
225
+ x = torch .ones ([1 , 3 , 10 , 10 ]).cuda ()
226
+
227
+ trace_model = torch .jit .trace (model , x )
228
+ torch .jit .save (trace_model , traced_pool_name )
236
229
manifest .update ({"torchtrt_pooling" : [traced_pool_name ]})
237
230
238
231
# Module fallback
239
- module_fallback_model = ModuleFallbackMain ().eval ().cuda ()
240
- module_fallback_script_model = torch .jit .script (module_fallback_model )
241
232
scripted_module_fallback_name = "module_fallback_scripted.jit.pt"
242
- export_model (module_fallback_script_model , scripted_module_fallback_name , matches )
233
+ if not (version_matches and os .path .exists (scripted_module_fallback_name )):
234
+ module_fallback_model = ModuleFallbackMain ().eval ().cuda ()
235
+ module_fallback_script_model = torch .jit .script (module_fallback_model )
236
+ torch .jit .save (module_fallback_script_model , scripted_module_fallback_name )
243
237
manifest .update ({"torchtrt_module_fallback" : [scripted_module_fallback_name ]})
244
238
245
239
# Loop Fallback
246
- loop_fallback_eval_model = LoopFallbackEval ().eval ().cuda ()
247
- loop_fallback_eval_script_model = torch .jit .script (loop_fallback_eval_model )
248
240
scripted_loop_fallback_name = "loop_fallback_eval_scripted.jit.pt"
249
- export_model (loop_fallback_eval_script_model , scripted_loop_fallback_name , matches )
241
+ if not (version_matches and os .path .exists (scripted_loop_fallback_name )):
242
+ loop_fallback_eval_model = LoopFallbackEval ().eval ().cuda ()
243
+ loop_fallback_eval_script_model = torch .jit .script (loop_fallback_eval_model )
244
+ torch .jit .save (loop_fallback_eval_script_model , scripted_loop_fallback_name )
250
245
251
- loop_fallback_no_eval_model = LoopFallbackNoEval ().eval ().cuda ()
252
- loop_fallback_no_eval_script_model = torch .jit .script (loop_fallback_no_eval_model )
253
246
scripted_loop_fallback_no_eval_name = "loop_fallback_no_eval_scripted.jit.pt"
254
- export_model (loop_fallback_no_eval_script_model , scripted_loop_fallback_no_eval_name , matches )
247
+ if not (version_matches and os .path .exists (scripted_loop_fallback_name )):
248
+ loop_fallback_no_eval_model = LoopFallbackNoEval ().eval ().cuda ()
249
+ loop_fallback_no_eval_script_model = torch .jit .script (loop_fallback_no_eval_model )
250
+ torch .jit .save (loop_fallback_no_eval_script_model , scripted_loop_fallback_no_eval_name )
255
251
manifest .update ({"torchtrt_loop_fallback_no_eval" : [scripted_loop_fallback_name , scripted_loop_fallback_no_eval_name ]})
256
252
257
253
# Conditional
258
- conditional_model = FallbackIf ().eval ().cuda ()
259
- conditional_script_model = torch .jit .script (conditional_model )
260
254
scripted_conditional_name = "conditional_scripted.jit.pt"
261
- export_model (conditional_script_model , scripted_conditional_name , matches )
255
+ if not (version_matches and os .path .exists (scripted_conditional_name )):
256
+ conditional_model = FallbackIf ().eval ().cuda ()
257
+ conditional_script_model = torch .jit .script (conditional_model )
258
+ torch .jit .save (conditional_script_model , scripted_conditional_name )
262
259
manifest .update ({"torchtrt_conditional" : [scripted_conditional_name ]})
263
260
264
261
# BERT model
265
- enc = BertTokenizer .from_pretrained ("bert-base-uncased" )
266
- text = "[CLS] Who was Jim Henson ? [SEP] Jim Henson was a puppeteer [SEP]"
267
- tokenized_text = enc .tokenize (text )
268
- masked_index = 8
269
- tokenized_text [masked_index ] = "[MASK]"
270
- indexed_tokens = enc .convert_tokens_to_ids (tokenized_text )
271
- segments_ids = [0 , 0 , 0 , 0 , 0 , 0 , 0 , 1 , 1 , 1 , 1 , 1 , 1 , 1 ]
272
- tokens_tensor = torch .tensor ([indexed_tokens ])
273
- segments_tensors = torch .tensor ([segments_ids ])
274
- dummy_input = [tokens_tensor , segments_tensors ]
275
-
276
- config = BertConfig (
277
- vocab_size_or_config_json_file = 32000 ,
278
- hidden_size = 768 ,
279
- num_hidden_layers = 12 ,
280
- num_attention_heads = 12 ,
281
- intermediate_size = 3072 ,
282
- torchscript = True ,
283
- )
284
-
285
- model = BertModel (config )
286
- model .eval ()
287
- model = BertModel .from_pretrained ("bert-base-uncased" , torchscript = True )
288
-
289
262
traced_bert_uncased_name = "bert_case_uncased_traced.jit.pt"
290
- traced_model = torch .jit .trace (model , [tokens_tensor , segments_tensors ])
291
- export_model (traced_model , traced_bert_uncased_name , matches )
263
+ if not (version_matches and os .path .exists (traced_bert_uncased_name )):
264
+ enc = BertTokenizer .from_pretrained ("bert-base-uncased" )
265
+ text = "[CLS] Who was Jim Henson ? [SEP] Jim Henson was a puppeteer [SEP]"
266
+ tokenized_text = enc .tokenize (text )
267
+ masked_index = 8
268
+ tokenized_text [masked_index ] = "[MASK]"
269
+ indexed_tokens = enc .convert_tokens_to_ids (tokenized_text )
270
+ segments_ids = [0 , 0 , 0 , 0 , 0 , 0 , 0 , 1 , 1 , 1 , 1 , 1 , 1 , 1 ]
271
+ tokens_tensor = torch .tensor ([indexed_tokens ])
272
+ segments_tensors = torch .tensor ([segments_ids ])
273
+ dummy_input = [tokens_tensor , segments_tensors ]
274
+
275
+ config = BertConfig (
276
+ vocab_size_or_config_json_file = 32000 ,
277
+ hidden_size = 768 ,
278
+ num_hidden_layers = 12 ,
279
+ num_attention_heads = 12 ,
280
+ intermediate_size = 3072 ,
281
+ torchscript = True ,
282
+ )
283
+
284
+ model = BertModel (config )
285
+ model .eval ()
286
+ model = BertModel .from_pretrained ("bert-base-uncased" , torchscript = True )
287
+
288
+ traced_model = torch .jit .trace (model , [tokens_tensor , segments_tensors ])
289
+ torch .jit .save (traced_model , traced_bert_uncased_name )
292
290
manifest .update ({"torchtrt_bert_case_uncased" : [traced_bert_uncased_name ]})
293
291
294
292
0 commit comments