5
5
import timm
6
6
from transformers import BertModel , BertTokenizer , BertConfig
7
7
import os
8
- import sys
8
+ import json
9
9
10
10
torch .hub ._validate_not_a_forked_repo = lambda a , b , c : True
11
11
12
12
torch_version = torch .__version__
13
- snapshot_file = 'model_snapshot.txt'
14
- skip_download = False
15
-
16
- # If model repository already setup
17
- if os .path .exists (snapshot_file ):
18
- with open (snapshot_file , 'r' ) as f :
19
- model_version = f .read ()
20
- if model_version == torch_version :
21
- skip_download = True
22
-
23
- # In case of existing model repository, skip the download
24
- if skip_download :
25
- print ('Skipping re-download of model repository' )
26
- sys .exit ()
27
- else :
28
- with open (snapshot_file , 'w' ) as f :
29
- f .write (torch_version )
13
+
14
+ # Downloads all model files again if manifest file is not present
15
+ MANIFEST_FILE = 'model_manifest.json'
30
16
31
17
models = {
32
18
"alexnet" : {
92
78
}
93
79
}
94
80
95
- # Download sample models
96
- for n , m in models .items ():
97
- print ("Downloading {}" .format (n ))
98
- m ["model" ] = m ["model" ].eval ().cuda ()
99
- x = torch .ones ((1 , 3 , 300 , 300 )).cuda ()
100
- if m ["path" ] == "both" or m ["path" ] == "trace" :
101
- trace_model = torch .jit .trace (m ["model" ], [x ])
102
- torch .jit .save (trace_model , n + '_traced.jit.pt' )
103
- if m ["path" ] == "both" or m ["path" ] == "script" :
104
- script_model = torch .jit .script (m ["model" ])
105
- torch .jit .save (script_model , n + '_scripted.jit.pt' )
106
-
107
81
108
82
# Sample Pool Model (for testing plugin serialization)
109
83
class Pool (nn .Module ):
@@ -114,14 +88,6 @@ def __init__(self):
114
88
def forward (self , x ):
115
89
return F .adaptive_avg_pool2d (x , (5 , 5 ))
116
90
117
-
118
- model = Pool ().eval ().cuda ()
119
- x = torch .ones ([1 , 3 , 10 , 10 ]).cuda ()
120
-
121
- trace_model = torch .jit .trace (model , x )
122
- torch .jit .save (trace_model , "pooling_traced.jit.pt" )
123
-
124
-
125
91
# Sample Nested Module (for module-level fallback testing)
126
92
class ModuleFallbackSub (nn .Module ):
127
93
@@ -133,7 +99,6 @@ def __init__(self):
133
99
def forward (self , x ):
134
100
return self .relu (self .conv (x ))
135
101
136
-
137
102
class ModuleFallbackMain (nn .Module ):
138
103
139
104
def __init__ (self ):
@@ -145,12 +110,6 @@ def __init__(self):
145
110
def forward (self , x ):
146
111
return self .relu (self .conv (self .layer1 (x )))
147
112
148
-
149
- module_fallback_model = ModuleFallbackMain ().eval ().cuda ()
150
- module_fallback_script_model = torch .jit .script (module_fallback_model )
151
- torch .jit .save (module_fallback_script_model , "module_fallback_scripted.jit.pt" )
152
-
153
-
154
113
# Sample Looping Modules (for loop fallback testing)
155
114
class LoopFallbackEval (nn .Module ):
156
115
@@ -163,7 +122,6 @@ def forward(self, x):
163
122
add_list = torch .cat ((add_list , torch .tensor ([x .shape [1 ]]).to (x .device )), 0 )
164
123
return x + add_list
165
124
166
-
167
125
class LoopFallbackNoEval (nn .Module ):
168
126
169
127
def __init__ (self ):
@@ -174,15 +132,6 @@ def forward(self, x):
174
132
x = x + torch .ones_like (x )
175
133
return x
176
134
177
-
178
- loop_fallback_eval_model = LoopFallbackEval ().eval ().cuda ()
179
- loop_fallback_eval_script_model = torch .jit .script (loop_fallback_eval_model )
180
- torch .jit .save (loop_fallback_eval_script_model , "loop_fallback_eval_scripted.jit.pt" )
181
- loop_fallback_no_eval_model = LoopFallbackNoEval ().eval ().cuda ()
182
- loop_fallback_no_eval_script_model = torch .jit .script (loop_fallback_no_eval_model )
183
- torch .jit .save (loop_fallback_no_eval_script_model , "loop_fallback_no_eval_scripted.jit.pt" )
184
-
185
-
186
135
# Sample Conditional Model (for testing partitioning and fallback in conditionals)
187
136
class FallbackIf (torch .nn .Module ):
188
137
@@ -207,34 +156,152 @@ def forward(self, x):
207
156
x = self .conv1 (x )
208
157
return x
209
158
210
-
211
- conditional_model = FallbackIf ().eval ().cuda ()
212
- conditional_script_model = torch .jit .script (conditional_model )
213
- torch .jit .save (conditional_script_model , "conditional_scripted.jit.pt" )
214
-
215
- enc = BertTokenizer .from_pretrained ("bert-base-uncased" )
216
- text = "[CLS] Who was Jim Henson ? [SEP] Jim Henson was a puppeteer [SEP]"
217
- tokenized_text = enc .tokenize (text )
218
- masked_index = 8
219
- tokenized_text [masked_index ] = "[MASK]"
220
- indexed_tokens = enc .convert_tokens_to_ids (tokenized_text )
221
- segments_ids = [0 , 0 , 0 , 0 , 0 , 0 , 0 , 1 , 1 , 1 , 1 , 1 , 1 , 1 ]
222
- tokens_tensor = torch .tensor ([indexed_tokens ])
223
- segments_tensors = torch .tensor ([segments_ids ])
224
- dummy_input = [tokens_tensor , segments_tensors ]
225
-
226
- config = BertConfig (
227
- vocab_size_or_config_json_file = 32000 ,
228
- hidden_size = 768 ,
229
- num_hidden_layers = 12 ,
230
- num_attention_heads = 12 ,
231
- intermediate_size = 3072 ,
232
- torchscript = True ,
233
- )
234
-
235
- model = BertModel (config )
236
- model .eval ()
237
- model = BertModel .from_pretrained ("bert-base-uncased" , torchscript = True )
238
-
239
- traced_model = torch .jit .trace (model , [tokens_tensor , segments_tensors ])
240
- torch .jit .save (traced_model , "bert_base_uncased_traced.jit.pt" )
159
+ class ModelManifest :
160
+ def __init__ (self ):
161
+ self .version_matches = False
162
+ if not os .path .exists (MANIFEST_FILE ) or os .stat (MANIFEST_FILE ).st_size == 0 :
163
+ self .manifest = {}
164
+ self .manifest .update ({'version' : torch_version })
165
+ else :
166
+ with open (MANIFEST_FILE , 'r' ) as f :
167
+ self .manifest = json .load (f )
168
+ if self .manifest ['version' ] == torch_version :
169
+ self .version_matches = True
170
+ else :
171
+ print ("Torch version: {} mismatches with manifest's version: {}. Re-downloading all models" .format (torch_version , self .manifest ['version' ]))
172
+ self .manifest ["version" ] = torch_version
173
+
174
+
175
+ def download (self , models ):
176
+ if self .version_matches :
177
+ for n , m in models .items ():
178
+ scripted_filename = n + "_scripted.jit.pt"
179
+ traced_filename = n + "_traced.jit.pt"
180
+ if (m ["path" ] == "both" and os .path .exists (scripted_filename ) and os .path .exists (traced_filename )) or \
181
+ (m ["path" ] == "script" and os .path .exists (scripted_filename )) or \
182
+ (m ["path" ] == "trace" and os .path .exists (traced_filename )):
183
+ print ("Skipping {} " .format (n ))
184
+ continue
185
+ self .get (n , m )
186
+ else :
187
+ for n , m in models .items ():
188
+ self .get (n , m )
189
+
190
+ def write (self , manifest_record ):
191
+ with open (MANIFEST_FILE , 'r+' ) as f :
192
+ data = f .read ()
193
+ f .seek (0 )
194
+ record = json .dumps (manifest_record )
195
+ f .write (record )
196
+ f .truncate ()
197
+
198
+ def get_manifest (self ):
199
+ return self .manifest
200
+
201
+ def if_version_matches (self ):
202
+ return self .version_matches
203
+
204
+ def get (self , n , m ):
205
+ print ("Downloading {}" .format (n ))
206
+ m ["model" ] = m ["model" ].eval ().cuda ()
207
+ traced_filename = n + '_traced.jit.pt'
208
+ script_filename = n + '_scripted.jit.pt'
209
+
210
+ x = torch .ones ((1 , 3 , 300 , 300 )).cuda ()
211
+ if m ["path" ] == "both" or m ["path" ] == "trace" :
212
+ trace_model = torch .jit .trace (m ["model" ], [x ])
213
+ torch .jit .save (trace_model , traced_filename )
214
+ if m ["path" ] == "both" or m ["path" ] == "script" :
215
+ script_model = torch .jit .script (m ["model" ])
216
+ torch .jit .save (script_model , script_filename )
217
+
218
+ self .manifest .update ({n : [traced_filename , script_filename ]})
219
+
220
+ def export_model (model , model_name , version_matches ):
221
+ if version_matches and os .path .exists (model_name ):
222
+ print ("Skipping model {}" .format (model_name ))
223
+ else :
224
+ print ("Saving model {}" .format (model_name ))
225
+ torch .jit .save (model , model_name )
226
+
227
+
228
+ def generate_custom_models (manifest , matches = False ):
229
+ # Pool
230
+ model = Pool ().eval ().cuda ()
231
+ x = torch .ones ([1 , 3 , 10 , 10 ]).cuda ()
232
+
233
+ trace_model = torch .jit .trace (model , x )
234
+ traced_pool_name = "pooling_traced.jit.pt"
235
+ export_model (trace_model , traced_pool_name , matches )
236
+ manifest .update ({"torchtrt_pooling" : [traced_pool_name ]})
237
+
238
+ # Module fallback
239
+ module_fallback_model = ModuleFallbackMain ().eval ().cuda ()
240
+ module_fallback_script_model = torch .jit .script (module_fallback_model )
241
+ scripted_module_fallback_name = "module_fallback_scripted.jit.pt"
242
+ export_model (module_fallback_script_model , scripted_module_fallback_name , matches )
243
+ manifest .update ({"torchtrt_module_fallback" : [scripted_module_fallback_name ]})
244
+
245
+ # Loop Fallback
246
+ loop_fallback_eval_model = LoopFallbackEval ().eval ().cuda ()
247
+ loop_fallback_eval_script_model = torch .jit .script (loop_fallback_eval_model )
248
+ scripted_loop_fallback_name = "loop_fallback_eval_scripted.jit.pt"
249
+ export_model (loop_fallback_eval_script_model , scripted_loop_fallback_name , matches )
250
+
251
+ loop_fallback_no_eval_model = LoopFallbackNoEval ().eval ().cuda ()
252
+ loop_fallback_no_eval_script_model = torch .jit .script (loop_fallback_no_eval_model )
253
+ scripted_loop_fallback_no_eval_name = "loop_fallback_no_eval_scripted.jit.pt"
254
+ export_model (loop_fallback_no_eval_script_model , scripted_loop_fallback_no_eval_name , matches )
255
+ manifest .update ({"torchtrt_loop_fallback_no_eval" : [scripted_loop_fallback_name , scripted_loop_fallback_no_eval_name ]})
256
+
257
+ # Conditional
258
+ conditional_model = FallbackIf ().eval ().cuda ()
259
+ conditional_script_model = torch .jit .script (conditional_model )
260
+ scripted_conditional_name = "conditional_scripted.jit.pt"
261
+ export_model (conditional_script_model , scripted_conditional_name , matches )
262
+ manifest .update ({"torchtrt_conditional" : [scripted_conditional_name ]})
263
+
264
+ # BERT model
265
+ enc = BertTokenizer .from_pretrained ("bert-base-uncased" )
266
+ text = "[CLS] Who was Jim Henson ? [SEP] Jim Henson was a puppeteer [SEP]"
267
+ tokenized_text = enc .tokenize (text )
268
+ masked_index = 8
269
+ tokenized_text [masked_index ] = "[MASK]"
270
+ indexed_tokens = enc .convert_tokens_to_ids (tokenized_text )
271
+ segments_ids = [0 , 0 , 0 , 0 , 0 , 0 , 0 , 1 , 1 , 1 , 1 , 1 , 1 , 1 ]
272
+ tokens_tensor = torch .tensor ([indexed_tokens ])
273
+ segments_tensors = torch .tensor ([segments_ids ])
274
+ dummy_input = [tokens_tensor , segments_tensors ]
275
+
276
+ config = BertConfig (
277
+ vocab_size_or_config_json_file = 32000 ,
278
+ hidden_size = 768 ,
279
+ num_hidden_layers = 12 ,
280
+ num_attention_heads = 12 ,
281
+ intermediate_size = 3072 ,
282
+ torchscript = True ,
283
+ )
284
+
285
+ model = BertModel (config )
286
+ model .eval ()
287
+ model = BertModel .from_pretrained ("bert-base-uncased" , torchscript = True )
288
+
289
+ traced_bert_uncased_name = "bert_case_uncased_traced.jit.pt"
290
+ traced_model = torch .jit .trace (model , [tokens_tensor , segments_tensors ])
291
+ export_model (traced_model , traced_bert_uncased_name , matches )
292
+ manifest .update ({"torchtrt_bert_case_uncased" : [traced_bert_uncased_name ]})
293
+
294
+
295
+ manifest = ModelManifest ()
296
+
297
+ # Download the models
298
+ manifest .download (models )
299
+
300
+ # Manifest generated from the model repository
301
+ manifest_record = manifest .get_manifest ()
302
+
303
+ # Save model
304
+ generate_custom_models (manifest_record , manifest .if_version_matches ())
305
+
306
+ # Update the manifest file
307
+ manifest .write (manifest_record )
0 commit comments