Skip to content

Commit 2f419a8

Browse files
committed
feat: Optimize hub.py download
Signed-off-by: Anurag Dixit <[email protected]>
1 parent de33b04 commit 2f419a8

File tree

2 files changed

+22
-1
lines changed

2 files changed

+22
-1
lines changed

.gitignore

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -57,4 +57,4 @@ examples/int8/qat/qat
5757
examples/int8/training/vgg16/data/*
5858
examples/int8/datasets/data/*
5959
env/**/*
60-
bazel-Torch-TensorRT-Preview
60+
model_snapshot.txt

tests/modules/hub.py

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,9 +4,30 @@
44
import torchvision.models as models
55
import timm
66
from transformers import BertModel, BertTokenizer, BertConfig
7+
import os
8+
import sys
79

810
torch.hub._validate_not_a_forked_repo = lambda a, b, c: True
911

12+
torch_version = torch.__version__
13+
snapshot_file = 'model_snapshot.txt'
14+
skip_download = False
15+
16+
# If model repository already setup
17+
if os.path.exists(snapshot_file):
18+
with open(snapshot_file, 'r') as f:
19+
model_version = f.read()
20+
if model_version == torch_version:
21+
skip_download = True
22+
23+
# In case of existing model repository, skip the download
24+
if skip_download:
25+
print('Skipping re-download of model repository')
26+
sys.exit()
27+
else:
28+
with open(snapshot_file, 'w') as f:
29+
f.write(torch_version)
30+
1031
models = {
1132
"alexnet": {
1233
"model": models.alexnet(pretrained=True),

0 commit comments

Comments
 (0)