Skip to content

Commit dcf19cc

Browse files
authored
Merge pull request #1022 from pytorch/anuragd/optimize_model_hub
feat: Optimize hub.py download
2 parents 160fe4f + 176b907 commit dcf19cc

File tree

5 files changed

+233
-150
lines changed

5 files changed

+233
-150
lines changed

.gitignore

+2-1
Original file line numberDiff line numberDiff line change
@@ -57,7 +57,8 @@ examples/int8/qat/qat
5757
examples/int8/training/vgg16/data/*
5858
examples/int8/datasets/data/*
5959
env/**/*
60+
model_manifest.json
6061
bazel-Torch-TensorRT-Preview
6162
docsrc/src/
6263
bazel-TensorRT
63-
bazel-tensorrt
64+
bazel-tensorrt

py/setup.py

+1-2
Original file line numberDiff line numberDiff line change
@@ -242,8 +242,7 @@ def run(self):
242242
dir_path + "/../bazel-TRTorch/external/tensorrt/include",
243243
dir_path + "/../bazel-Torch-TensorRT/external/tensorrt/include",
244244
dir_path + "/../bazel-TensorRT/external/tensorrt/include",
245-
dir_path + "/../bazel-tensorrt/external/tensorrt/include",
246-
dir_path + "/../"
245+
dir_path + "/../bazel-tensorrt/external/tensorrt/include", dir_path + "/../"
247246
],
248247
extra_compile_args=[
249248
"-Wno-deprecated",

tests/core/lowering/test_module_fallback_passes.cpp

+2-2
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,7 @@ TEST(Lowering, NotateModuleForFallbackWorksCorrectly) {
1818
}
1919

2020
std::unordered_set<std::string> mods_to_mark;
21-
mods_to_mark.insert("ModuleFallbackSub");
21+
mods_to_mark.insert("custom_models.ModuleFallbackSub");
2222

2323
torch_tensorrt::core::lowering::passes::NotateModuleForFallback(mod, "", "forward", mods_to_mark);
2424

@@ -58,7 +58,7 @@ TEST(Lowering, MarkNodesForFallbackWorksCorrectly) {
5858
}
5959

6060
std::unordered_set<std::string> mods_to_mark;
61-
mods_to_mark.insert("ModuleFallbackSub");
61+
mods_to_mark.insert("custom_models.ModuleFallbackSub");
6262

6363
torch_tensorrt::core::lowering::passes::NotateModuleForFallback(mod, "", "forward", mods_to_mark);
6464
auto mod_ = torch::jit::freeze_module(mod);

tests/modules/custom_models.py

+113
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,113 @@
1+
import torch
2+
import torch.nn as nn
3+
from transformers import BertModel, BertTokenizer, BertConfig
4+
import torch.nn.functional as F
5+
6+
7+
# Sample Pool Model (for testing plugin serialization)
8+
class Pool(nn.Module):
9+
10+
def __init__(self):
11+
super(Pool, self).__init__()
12+
13+
def forward(self, x):
14+
return F.adaptive_avg_pool2d(x, (5, 5))
15+
16+
17+
# Sample Nested Module (for module-level fallback testing)
18+
class ModuleFallbackSub(nn.Module):
19+
20+
def __init__(self):
21+
super(ModuleFallbackSub, self).__init__()
22+
self.conv = nn.Conv2d(1, 3, 3)
23+
self.relu = nn.ReLU()
24+
25+
def forward(self, x):
26+
return self.relu(self.conv(x))
27+
28+
29+
class ModuleFallbackMain(nn.Module):
30+
31+
def __init__(self):
32+
super(ModuleFallbackMain, self).__init__()
33+
self.layer1 = ModuleFallbackSub()
34+
self.conv = nn.Conv2d(3, 6, 3)
35+
self.relu = nn.ReLU()
36+
37+
def forward(self, x):
38+
return self.relu(self.conv(self.layer1(x)))
39+
40+
41+
# Sample Looping Modules (for loop fallback testing)
42+
class LoopFallbackEval(nn.Module):
43+
44+
def __init__(self):
45+
super(LoopFallbackEval, self).__init__()
46+
47+
def forward(self, x):
48+
add_list = torch.empty(0).to(x.device)
49+
for i in range(x.shape[1]):
50+
add_list = torch.cat((add_list, torch.tensor([x.shape[1]]).to(x.device)), 0)
51+
return x + add_list
52+
53+
54+
class LoopFallbackNoEval(nn.Module):
55+
56+
def __init__(self):
57+
super(LoopFallbackNoEval, self).__init__()
58+
59+
def forward(self, x):
60+
for _ in range(x.shape[1]):
61+
x = x + torch.ones_like(x)
62+
return x
63+
64+
65+
# Sample Conditional Model (for testing partitioning and fallback in conditionals)
66+
class FallbackIf(torch.nn.Module):
67+
68+
def __init__(self):
69+
super(FallbackIf, self).__init__()
70+
self.relu1 = torch.nn.ReLU()
71+
self.conv1 = torch.nn.Conv2d(3, 32, 3, 1, 1)
72+
self.log_sig = torch.nn.LogSigmoid()
73+
self.conv2 = torch.nn.Conv2d(32, 32, 3, 1, 1)
74+
self.conv3 = torch.nn.Conv2d(32, 3, 3, 1, 1)
75+
76+
def forward(self, x):
77+
x = self.relu1(x)
78+
x_first = x[0][0][0][0].item()
79+
if x_first > 0:
80+
x = self.conv1(x)
81+
x1 = self.log_sig(x)
82+
x2 = self.conv2(x)
83+
x = self.conv3(x1 + x2)
84+
else:
85+
x = self.log_sig(x)
86+
x = self.conv1(x)
87+
return x
88+
89+
90+
def BertModule():
91+
model_name = "bert-base-uncased"
92+
enc = BertTokenizer.from_pretrained(model_name)
93+
text = "[CLS] Who was Jim Henson ? [SEP] Jim Henson was a puppeteer [SEP]"
94+
tokenized_text = enc.tokenize(text)
95+
masked_index = 8
96+
tokenized_text[masked_index] = "[MASK]"
97+
indexed_tokens = enc.convert_tokens_to_ids(tokenized_text)
98+
segments_ids = [0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1]
99+
tokens_tensor = torch.tensor([indexed_tokens])
100+
segments_tensors = torch.tensor([segments_ids])
101+
config = BertConfig(
102+
vocab_size_or_config_json_file=32000,
103+
hidden_size=768,
104+
num_hidden_layers=12,
105+
num_attention_heads=12,
106+
intermediate_size=3072,
107+
torchscript=True,
108+
)
109+
model = BertModel(config)
110+
model.eval()
111+
model = BertModel.from_pretrained(model_name, torchscript=True)
112+
traced_model = torch.jit.trace(model, [tokens_tensor, segments_tensors])
113+
return traced_model

0 commit comments

Comments
 (0)