Skip to content

Commit 08b9853

Browse files
committed
refactor: Run trace or script of model only if version mismatches or file doesn't exists
Signed-off-by: Anurag Dixit <[email protected]>
1 parent 2362b7f commit 08b9853

File tree

1 file changed

+50
-52
lines changed

1 file changed

+50
-52
lines changed

tests/modules/hub.py

+50-52
Original file line numberDiff line numberDiff line change
@@ -217,78 +217,76 @@ def get(self, n, m):
217217

218218
self.manifest.update({n : [traced_filename, script_filename]})
219219

220-
def export_model(model, model_name, version_matches):
221-
if version_matches and os.path.exists(model_name):
222-
print("Skipping model {}".format(model_name))
223-
else:
224-
print("Saving model {}".format(model_name))
225-
torch.jit.save(model, model_name)
226-
227-
228-
def generate_custom_models(manifest, matches = False):
220+
def generate_custom_models(manifest, version_matches = False):
229221
# Pool
230-
model = Pool().eval().cuda()
231-
x = torch.ones([1, 3, 10, 10]).cuda()
232-
233-
trace_model = torch.jit.trace(model, x)
234222
traced_pool_name = "pooling_traced.jit.pt"
235-
export_model(trace_model, traced_pool_name, matches)
223+
if not (version_matches and os.path.exists(traced_pool_name)):
224+
model = Pool().eval().cuda()
225+
x = torch.ones([1, 3, 10, 10]).cuda()
226+
227+
trace_model = torch.jit.trace(model, x)
228+
torch.jit.save(trace_model, traced_pool_name)
236229
manifest.update({"torchtrt_pooling": [traced_pool_name]})
237230

238231
# Module fallback
239-
module_fallback_model = ModuleFallbackMain().eval().cuda()
240-
module_fallback_script_model = torch.jit.script(module_fallback_model)
241232
scripted_module_fallback_name = "module_fallback_scripted.jit.pt"
242-
export_model(module_fallback_script_model, scripted_module_fallback_name, matches)
233+
if not (version_matches and os.path.exists(scripted_module_fallback_name)):
234+
module_fallback_model = ModuleFallbackMain().eval().cuda()
235+
module_fallback_script_model = torch.jit.script(module_fallback_model)
236+
torch.jit.save(module_fallback_script_model, scripted_module_fallback_name)
243237
manifest.update({"torchtrt_module_fallback": [scripted_module_fallback_name]})
244238

245239
# Loop Fallback
246-
loop_fallback_eval_model = LoopFallbackEval().eval().cuda()
247-
loop_fallback_eval_script_model = torch.jit.script(loop_fallback_eval_model)
248240
scripted_loop_fallback_name = "loop_fallback_eval_scripted.jit.pt"
249-
export_model(loop_fallback_eval_script_model, scripted_loop_fallback_name, matches)
241+
if not (version_matches and os.path.exists(scripted_loop_fallback_name)):
242+
loop_fallback_eval_model = LoopFallbackEval().eval().cuda()
243+
loop_fallback_eval_script_model = torch.jit.script(loop_fallback_eval_model)
244+
torch.jit.save(loop_fallback_eval_script_model, scripted_loop_fallback_name)
250245

251-
loop_fallback_no_eval_model = LoopFallbackNoEval().eval().cuda()
252-
loop_fallback_no_eval_script_model = torch.jit.script(loop_fallback_no_eval_model)
253246
scripted_loop_fallback_no_eval_name = "loop_fallback_no_eval_scripted.jit.pt"
254-
export_model(loop_fallback_no_eval_script_model, scripted_loop_fallback_no_eval_name, matches)
247+
if not (version_matches and os.path.exists(scripted_loop_fallback_name)):
248+
loop_fallback_no_eval_model = LoopFallbackNoEval().eval().cuda()
249+
loop_fallback_no_eval_script_model = torch.jit.script(loop_fallback_no_eval_model)
250+
torch.jit.save(loop_fallback_no_eval_script_model, scripted_loop_fallback_no_eval_name)
255251
manifest.update({"torchtrt_loop_fallback_no_eval": [scripted_loop_fallback_name, scripted_loop_fallback_no_eval_name]})
256252

257253
# Conditional
258-
conditional_model = FallbackIf().eval().cuda()
259-
conditional_script_model = torch.jit.script(conditional_model)
260254
scripted_conditional_name = "conditional_scripted.jit.pt"
261-
export_model(conditional_script_model, scripted_conditional_name, matches)
255+
if not (version_matches and os.path.exists(scripted_conditional_name)):
256+
conditional_model = FallbackIf().eval().cuda()
257+
conditional_script_model = torch.jit.script(conditional_model)
258+
torch.jit.save(conditional_script_model, scripted_conditional_name)
262259
manifest.update({"torchtrt_conditional": [scripted_conditional_name]})
263260

264261
# BERT model
265-
enc = BertTokenizer.from_pretrained("bert-base-uncased")
266-
text = "[CLS] Who was Jim Henson ? [SEP] Jim Henson was a puppeteer [SEP]"
267-
tokenized_text = enc.tokenize(text)
268-
masked_index = 8
269-
tokenized_text[masked_index] = "[MASK]"
270-
indexed_tokens = enc.convert_tokens_to_ids(tokenized_text)
271-
segments_ids = [0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1]
272-
tokens_tensor = torch.tensor([indexed_tokens])
273-
segments_tensors = torch.tensor([segments_ids])
274-
dummy_input = [tokens_tensor, segments_tensors]
275-
276-
config = BertConfig(
277-
vocab_size_or_config_json_file=32000,
278-
hidden_size=768,
279-
num_hidden_layers=12,
280-
num_attention_heads=12,
281-
intermediate_size=3072,
282-
torchscript=True,
283-
)
284-
285-
model = BertModel(config)
286-
model.eval()
287-
model = BertModel.from_pretrained("bert-base-uncased", torchscript=True)
288-
289262
traced_bert_uncased_name = "bert_case_uncased_traced.jit.pt"
290-
traced_model = torch.jit.trace(model, [tokens_tensor, segments_tensors])
291-
export_model(traced_model, traced_bert_uncased_name, matches)
263+
if not (version_matches and os.path.exists(traced_bert_uncased_name)):
264+
enc = BertTokenizer.from_pretrained("bert-base-uncased")
265+
text = "[CLS] Who was Jim Henson ? [SEP] Jim Henson was a puppeteer [SEP]"
266+
tokenized_text = enc.tokenize(text)
267+
masked_index = 8
268+
tokenized_text[masked_index] = "[MASK]"
269+
indexed_tokens = enc.convert_tokens_to_ids(tokenized_text)
270+
segments_ids = [0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1]
271+
tokens_tensor = torch.tensor([indexed_tokens])
272+
segments_tensors = torch.tensor([segments_ids])
273+
dummy_input = [tokens_tensor, segments_tensors]
274+
275+
config = BertConfig(
276+
vocab_size_or_config_json_file=32000,
277+
hidden_size=768,
278+
num_hidden_layers=12,
279+
num_attention_heads=12,
280+
intermediate_size=3072,
281+
torchscript=True,
282+
)
283+
284+
model = BertModel(config)
285+
model.eval()
286+
model = BertModel.from_pretrained("bert-base-uncased", torchscript=True)
287+
288+
traced_model = torch.jit.trace(model, [tokens_tensor, segments_tensors])
289+
torch.jit.save(traced_model, traced_bert_uncased_name)
292290
manifest.update({"torchtrt_bert_case_uncased" : [traced_bert_uncased_name]})
293291

294292

0 commit comments

Comments
 (0)