Skip to content

Commit 45d044e

Browse files
committed
feat: Added manifest to track the downloaded files
Signed-off-by: Anurag Dixit <[email protected]>
1 parent 2f419a8 commit 45d044e

File tree

2 files changed

+154
-87
lines changed

2 files changed

+154
-87
lines changed

.gitignore

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -57,4 +57,4 @@ examples/int8/qat/qat
5757
examples/int8/training/vgg16/data/*
5858
examples/int8/datasets/data/*
5959
env/**/*
60-
model_snapshot.txt
60+
model_manifest.json

tests/modules/hub.py

Lines changed: 153 additions & 86 deletions
Original file line numberDiff line numberDiff line change
@@ -5,28 +5,14 @@
55
import timm
66
from transformers import BertModel, BertTokenizer, BertConfig
77
import os
8-
import sys
8+
import json
99

1010
torch.hub._validate_not_a_forked_repo = lambda a, b, c: True
1111

1212
torch_version = torch.__version__
13-
snapshot_file = 'model_snapshot.txt'
14-
skip_download = False
15-
16-
# If model repository already setup
17-
if os.path.exists(snapshot_file):
18-
with open(snapshot_file, 'r') as f:
19-
model_version = f.read()
20-
if model_version == torch_version:
21-
skip_download = True
22-
23-
# In case of existing model repository, skip the download
24-
if skip_download:
25-
print('Skipping re-download of model repository')
26-
sys.exit()
27-
else:
28-
with open(snapshot_file, 'w') as f:
29-
f.write(torch_version)
13+
14+
# Downloads all model files again if manifest file is not present
15+
MANIFEST_FILE = 'model_manifest.json'
3016

3117
models = {
3218
"alexnet": {
@@ -92,18 +78,6 @@
9278
}
9379
}
9480

95-
# Download sample models
96-
for n, m in models.items():
97-
print("Downloading {}".format(n))
98-
m["model"] = m["model"].eval().cuda()
99-
x = torch.ones((1, 3, 300, 300)).cuda()
100-
if m["path"] == "both" or m["path"] == "trace":
101-
trace_model = torch.jit.trace(m["model"], [x])
102-
torch.jit.save(trace_model, n + '_traced.jit.pt')
103-
if m["path"] == "both" or m["path"] == "script":
104-
script_model = torch.jit.script(m["model"])
105-
torch.jit.save(script_model, n + '_scripted.jit.pt')
106-
10781

10882
# Sample Pool Model (for testing plugin serialization)
10983
class Pool(nn.Module):
@@ -114,14 +88,6 @@ def __init__(self):
11488
def forward(self, x):
11589
return F.adaptive_avg_pool2d(x, (5, 5))
11690

117-
118-
model = Pool().eval().cuda()
119-
x = torch.ones([1, 3, 10, 10]).cuda()
120-
121-
trace_model = torch.jit.trace(model, x)
122-
torch.jit.save(trace_model, "pooling_traced.jit.pt")
123-
124-
12591
# Sample Nested Module (for module-level fallback testing)
12692
class ModuleFallbackSub(nn.Module):
12793

@@ -133,7 +99,6 @@ def __init__(self):
13399
def forward(self, x):
134100
return self.relu(self.conv(x))
135101

136-
137102
class ModuleFallbackMain(nn.Module):
138103

139104
def __init__(self):
@@ -145,12 +110,6 @@ def __init__(self):
145110
def forward(self, x):
146111
return self.relu(self.conv(self.layer1(x)))
147112

148-
149-
module_fallback_model = ModuleFallbackMain().eval().cuda()
150-
module_fallback_script_model = torch.jit.script(module_fallback_model)
151-
torch.jit.save(module_fallback_script_model, "module_fallback_scripted.jit.pt")
152-
153-
154113
# Sample Looping Modules (for loop fallback testing)
155114
class LoopFallbackEval(nn.Module):
156115

@@ -163,7 +122,6 @@ def forward(self, x):
163122
add_list = torch.cat((add_list, torch.tensor([x.shape[1]]).to(x.device)), 0)
164123
return x + add_list
165124

166-
167125
class LoopFallbackNoEval(nn.Module):
168126

169127
def __init__(self):
@@ -174,15 +132,6 @@ def forward(self, x):
174132
x = x + torch.ones_like(x)
175133
return x
176134

177-
178-
loop_fallback_eval_model = LoopFallbackEval().eval().cuda()
179-
loop_fallback_eval_script_model = torch.jit.script(loop_fallback_eval_model)
180-
torch.jit.save(loop_fallback_eval_script_model, "loop_fallback_eval_scripted.jit.pt")
181-
loop_fallback_no_eval_model = LoopFallbackNoEval().eval().cuda()
182-
loop_fallback_no_eval_script_model = torch.jit.script(loop_fallback_no_eval_model)
183-
torch.jit.save(loop_fallback_no_eval_script_model, "loop_fallback_no_eval_scripted.jit.pt")
184-
185-
186135
# Sample Conditional Model (for testing partitioning and fallback in conditionals)
187136
class FallbackIf(torch.nn.Module):
188137

@@ -207,34 +156,152 @@ def forward(self, x):
207156
x = self.conv1(x)
208157
return x
209158

210-
211-
conditional_model = FallbackIf().eval().cuda()
212-
conditional_script_model = torch.jit.script(conditional_model)
213-
torch.jit.save(conditional_script_model, "conditional_scripted.jit.pt")
214-
215-
enc = BertTokenizer.from_pretrained("bert-base-uncased")
216-
text = "[CLS] Who was Jim Henson ? [SEP] Jim Henson was a puppeteer [SEP]"
217-
tokenized_text = enc.tokenize(text)
218-
masked_index = 8
219-
tokenized_text[masked_index] = "[MASK]"
220-
indexed_tokens = enc.convert_tokens_to_ids(tokenized_text)
221-
segments_ids = [0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1]
222-
tokens_tensor = torch.tensor([indexed_tokens])
223-
segments_tensors = torch.tensor([segments_ids])
224-
dummy_input = [tokens_tensor, segments_tensors]
225-
226-
config = BertConfig(
227-
vocab_size_or_config_json_file=32000,
228-
hidden_size=768,
229-
num_hidden_layers=12,
230-
num_attention_heads=12,
231-
intermediate_size=3072,
232-
torchscript=True,
233-
)
234-
235-
model = BertModel(config)
236-
model.eval()
237-
model = BertModel.from_pretrained("bert-base-uncased", torchscript=True)
238-
239-
traced_model = torch.jit.trace(model, [tokens_tensor, segments_tensors])
240-
torch.jit.save(traced_model, "bert_base_uncased_traced.jit.pt")
159+
class ModelManifest:
160+
def __init__(self):
161+
self.version_matches = False
162+
if not os.path.exists(MANIFEST_FILE) or os.stat(MANIFEST_FILE).st_size == 0:
163+
self.manifest = {}
164+
self.manifest.update({'version' : torch_version})
165+
else:
166+
with open(MANIFEST_FILE, 'r') as f:
167+
self.manifest = json.load(f)
168+
if self.manifest['version'] == torch_version:
169+
self.version_matches = True
170+
else:
171+
print("Torch version: {} mismatches with manifest's version: {}. Re-downloading all models".format(torch_version, self.manifest['version']))
172+
self.manifest["version"] = torch_version
173+
174+
175+
def download(self, models):
176+
if self.version_matches:
177+
for n, m in models.items():
178+
scripted_filename = n + "_scripted.jit.pt"
179+
traced_filename = n + "_traced.jit.pt"
180+
if (m["path"] == "both" and os.path.exists(scripted_filename) and os.path.exists(traced_filename)) or \
181+
(m["path"] == "script" and os.path.exists(scripted_filename)) or \
182+
(m["path"] == "trace" and os.path.exists(traced_filename)):
183+
print("Skipping {} ".format(n))
184+
continue
185+
self.get(n, m)
186+
else:
187+
for n, m in models.items():
188+
self.get(n, m)
189+
190+
def write(self, manifest_record):
191+
with open(MANIFEST_FILE, 'r+') as f:
192+
data = f.read()
193+
f.seek(0)
194+
record = json.dumps(manifest_record)
195+
f.write(record)
196+
f.truncate()
197+
198+
def get_manifest(self):
199+
return self.manifest
200+
201+
def if_version_matches(self):
202+
return self.version_matches
203+
204+
def get(self, n, m):
205+
print("Downloading {}".format(n))
206+
m["model"] = m["model"].eval().cuda()
207+
traced_filename = n + '_traced.jit.pt'
208+
script_filename = n + '_scripted.jit.pt'
209+
210+
x = torch.ones((1, 3, 300, 300)).cuda()
211+
if m["path"] == "both" or m["path"] == "trace":
212+
trace_model = torch.jit.trace(m["model"], [x])
213+
torch.jit.save(trace_model, traced_filename)
214+
if m["path"] == "both" or m["path"] == "script":
215+
script_model = torch.jit.script(m["model"])
216+
torch.jit.save(script_model, script_filename)
217+
218+
self.manifest.update({n : [traced_filename, script_filename]})
219+
220+
def export_model(model, model_name, version_matches):
221+
if version_matches and os.path.exists(model_name):
222+
print("Skipping model {}".format(model_name))
223+
else:
224+
print("Saving model {}".format(model_name))
225+
torch.jit.save(model, model_name)
226+
227+
228+
def generate_custom_models(manifest, matches = False):
229+
# Pool
230+
model = Pool().eval().cuda()
231+
x = torch.ones([1, 3, 10, 10]).cuda()
232+
233+
trace_model = torch.jit.trace(model, x)
234+
traced_pool_name = "pooling_traced.jit.pt"
235+
export_model(trace_model, traced_pool_name, matches)
236+
manifest.update({"torchtrt_pooling": [traced_pool_name]})
237+
238+
# Module fallback
239+
module_fallback_model = ModuleFallbackMain().eval().cuda()
240+
module_fallback_script_model = torch.jit.script(module_fallback_model)
241+
scripted_module_fallback_name = "module_fallback_scripted.jit.pt"
242+
export_model(module_fallback_script_model, scripted_module_fallback_name, matches)
243+
manifest.update({"torchtrt_module_fallback": [scripted_module_fallback_name]})
244+
245+
# Loop Fallback
246+
loop_fallback_eval_model = LoopFallbackEval().eval().cuda()
247+
loop_fallback_eval_script_model = torch.jit.script(loop_fallback_eval_model)
248+
scripted_loop_fallback_name = "loop_fallback_eval_scripted.jit.pt"
249+
export_model(loop_fallback_eval_script_model, scripted_loop_fallback_name, matches)
250+
251+
loop_fallback_no_eval_model = LoopFallbackNoEval().eval().cuda()
252+
loop_fallback_no_eval_script_model = torch.jit.script(loop_fallback_no_eval_model)
253+
scripted_loop_fallback_no_eval_name = "loop_fallback_no_eval_scripted.jit.pt"
254+
export_model(loop_fallback_no_eval_script_model, scripted_loop_fallback_no_eval_name, matches)
255+
manifest.update({"torchtrt_loop_fallback_no_eval": [scripted_loop_fallback_name, scripted_loop_fallback_no_eval_name]})
256+
257+
# Conditional
258+
conditional_model = FallbackIf().eval().cuda()
259+
conditional_script_model = torch.jit.script(conditional_model)
260+
scripted_conditional_name = "conditional_scripted.jit.pt"
261+
export_model(conditional_script_model, scripted_conditional_name, matches)
262+
manifest.update({"torchtrt_conditional": [scripted_conditional_name]})
263+
264+
# BERT model
265+
enc = BertTokenizer.from_pretrained("bert-base-uncased")
266+
text = "[CLS] Who was Jim Henson ? [SEP] Jim Henson was a puppeteer [SEP]"
267+
tokenized_text = enc.tokenize(text)
268+
masked_index = 8
269+
tokenized_text[masked_index] = "[MASK]"
270+
indexed_tokens = enc.convert_tokens_to_ids(tokenized_text)
271+
segments_ids = [0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1]
272+
tokens_tensor = torch.tensor([indexed_tokens])
273+
segments_tensors = torch.tensor([segments_ids])
274+
dummy_input = [tokens_tensor, segments_tensors]
275+
276+
config = BertConfig(
277+
vocab_size_or_config_json_file=32000,
278+
hidden_size=768,
279+
num_hidden_layers=12,
280+
num_attention_heads=12,
281+
intermediate_size=3072,
282+
torchscript=True,
283+
)
284+
285+
model = BertModel(config)
286+
model.eval()
287+
model = BertModel.from_pretrained("bert-base-uncased", torchscript=True)
288+
289+
traced_bert_uncased_name = "bert_case_uncased_traced.jit.pt"
290+
traced_model = torch.jit.trace(model, [tokens_tensor, segments_tensors])
291+
export_model(traced_model, traced_bert_uncased_name, matches)
292+
manifest.update({"torchtrt_bert_case_uncased" : [traced_bert_uncased_name]})
293+
294+
295+
manifest = ModelManifest()
296+
297+
# Download the models
298+
manifest.download(models)
299+
300+
# Manifest generated from the model repository
301+
manifest_record = manifest.get_manifest()
302+
303+
# Save model
304+
generate_custom_models(manifest_record, manifest.if_version_matches())
305+
306+
# Update the manifest file
307+
manifest.write(manifest_record)

0 commit comments

Comments
 (0)