Skip to content

Commit 6d149bc

Browse files
committed
chore: Applying lint
Signed-off-by: Anurag Dixit <[email protected]>
1 parent 20da2dc commit 6d149bc

File tree

1 file changed

+21
-12
lines changed

1 file changed

+21
-12
lines changed

tests/modules/hub.py

+21-12
Original file line numberDiff line numberDiff line change
@@ -88,6 +88,7 @@ def __init__(self):
8888
def forward(self, x):
8989
return F.adaptive_avg_pool2d(x, (5, 5))
9090

91+
9192
# Sample Nested Module (for module-level fallback testing)
9293
class ModuleFallbackSub(nn.Module):
9394

@@ -99,6 +100,7 @@ def __init__(self):
99100
def forward(self, x):
100101
return self.relu(self.conv(x))
101102

103+
102104
class ModuleFallbackMain(nn.Module):
103105

104106
def __init__(self):
@@ -110,6 +112,7 @@ def __init__(self):
110112
def forward(self, x):
111113
return self.relu(self.conv(self.layer1(x)))
112114

115+
113116
# Sample Looping Modules (for loop fallback testing)
114117
class LoopFallbackEval(nn.Module):
115118

@@ -122,6 +125,7 @@ def forward(self, x):
122125
add_list = torch.cat((add_list, torch.tensor([x.shape[1]]).to(x.device)), 0)
123126
return x + add_list
124127

128+
125129
class LoopFallbackNoEval(nn.Module):
126130

127131
def __init__(self):
@@ -132,6 +136,7 @@ def forward(self, x):
132136
x = x + torch.ones_like(x)
133137
return x
134138

139+
135140
# Sample Conditional Model (for testing partitioning and fallback in conditionals)
136141
class FallbackIf(torch.nn.Module):
137142

@@ -156,21 +161,23 @@ def forward(self, x):
156161
x = self.conv1(x)
157162
return x
158163

164+
159165
class ModelManifest:
166+
160167
def __init__(self):
161168
self.version_matches = False
162169
if not os.path.exists(MANIFEST_FILE) or os.stat(MANIFEST_FILE).st_size == 0:
163170
self.manifest = {}
164-
self.manifest.update({'version' : torch_version})
171+
self.manifest.update({'version': torch_version})
165172
else:
166173
with open(MANIFEST_FILE, 'r') as f:
167174
self.manifest = json.load(f)
168175
if self.manifest['version'] == torch_version:
169176
self.version_matches = True
170177
else:
171-
print("Torch version: {} mismatches with manifest's version: {}. Re-downloading all models".format(torch_version, self.manifest['version']))
178+
print("Torch version: {} mismatches with manifest's version: {}. Re-downloading all models".format(
179+
torch_version, self.manifest['version']))
172180
self.manifest["version"] = torch_version
173-
174181

175182
def download(self, models):
176183
if self.version_matches:
@@ -194,13 +201,13 @@ def write(self, manifest_record):
194201
record = json.dumps(manifest_record)
195202
f.write(record)
196203
f.truncate()
197-
204+
198205
def get_manifest(self):
199206
return self.manifest
200-
207+
201208
def if_version_matches(self):
202209
return self.version_matches
203-
210+
204211
def get(self, n, m):
205212
print("Downloading {}".format(n))
206213
m["model"] = m["model"].eval().cuda()
@@ -214,10 +221,11 @@ def get(self, n, m):
214221
if m["path"] == "both" or m["path"] == "script":
215222
script_model = torch.jit.script(m["model"])
216223
torch.jit.save(script_model, script_filename)
217-
218-
self.manifest.update({n : [traced_filename, script_filename]})
219224

220-
def generate_custom_models(manifest, version_matches = False):
225+
self.manifest.update({n: [traced_filename, script_filename]})
226+
227+
228+
def generate_custom_models(manifest, version_matches=False):
221229
# Pool
222230
traced_pool_name = "pooling_traced.jit.pt"
223231
if not (version_matches and os.path.exists(traced_pool_name)):
@@ -248,7 +256,8 @@ def generate_custom_models(manifest, version_matches = False):
248256
loop_fallback_no_eval_model = LoopFallbackNoEval().eval().cuda()
249257
loop_fallback_no_eval_script_model = torch.jit.script(loop_fallback_no_eval_model)
250258
torch.jit.save(loop_fallback_no_eval_script_model, scripted_loop_fallback_no_eval_name)
251-
manifest.update({"torchtrt_loop_fallback_no_eval": [scripted_loop_fallback_name, scripted_loop_fallback_no_eval_name]})
259+
manifest.update(
260+
{"torchtrt_loop_fallback_no_eval": [scripted_loop_fallback_name, scripted_loop_fallback_no_eval_name]})
252261

253262
# Conditional
254263
scripted_conditional_name = "conditional_scripted.jit.pt"
@@ -287,7 +296,7 @@ def generate_custom_models(manifest, version_matches = False):
287296

288297
traced_model = torch.jit.trace(model, [tokens_tensor, segments_tensors])
289298
torch.jit.save(traced_model, traced_bert_uncased_name)
290-
manifest.update({"torchtrt_bert_case_uncased" : [traced_bert_uncased_name]})
299+
manifest.update({"torchtrt_bert_case_uncased": [traced_bert_uncased_name]})
291300

292301

293302
manifest = ModelManifest()
@@ -302,4 +311,4 @@ def generate_custom_models(manifest, version_matches = False):
302311
generate_custom_models(manifest_record, manifest.if_version_matches())
303312

304313
# Update the manifest file
305-
manifest.write(manifest_record)
314+
manifest.write(manifest_record)

0 commit comments

Comments
 (0)