Skip to content

Commit 2e1764a

Browse files
committed
refactor: Segregated custom model from pre-trained models
Signed-off-by: Anurag Dixit <[email protected]>
1 parent d9384fe commit 2e1764a

File tree

2 files changed

+211
-226
lines changed

2 files changed

+211
-226
lines changed

tests/modules/custom_models.py

+88
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,88 @@
1+
import torch
2+
import torch.nn as nn
3+
from transformers import BertModel, BertTokenizer, BertConfig
4+
import torch.nn.functional as F
5+
6+
# Sample Pool Model (for testing plugin serialization)
7+
class Pool(nn.Module):
8+
9+
def __init__(self):
10+
super(Pool, self).__init__()
11+
12+
def forward(self, x):
13+
return F.adaptive_avg_pool2d(x, (5, 5))
14+
15+
16+
# Sample Nested Module (for module-level fallback testing)
17+
class ModuleFallbackSub(nn.Module):
18+
19+
def __init__(self):
20+
super(ModuleFallbackSub, self).__init__()
21+
self.conv = nn.Conv2d(1, 3, 3)
22+
self.relu = nn.ReLU()
23+
24+
def forward(self, x):
25+
return self.relu(self.conv(x))
26+
27+
28+
class ModuleFallbackMain(nn.Module):
29+
30+
def __init__(self):
31+
super(ModuleFallbackMain, self).__init__()
32+
self.layer1 = ModuleFallbackSub()
33+
self.conv = nn.Conv2d(3, 6, 3)
34+
self.relu = nn.ReLU()
35+
36+
def forward(self, x):
37+
return self.relu(self.conv(self.layer1(x)))
38+
39+
40+
# Sample Looping Modules (for loop fallback testing)
41+
class LoopFallbackEval(nn.Module):
42+
43+
def __init__(self):
44+
super(LoopFallbackEval, self).__init__()
45+
46+
def forward(self, x):
47+
add_list = torch.empty(0).to(x.device)
48+
for i in range(x.shape[1]):
49+
add_list = torch.cat((add_list, torch.tensor([x.shape[1]]).to(x.device)), 0)
50+
return x + add_list
51+
52+
53+
class LoopFallbackNoEval(nn.Module):
54+
55+
def __init__(self):
56+
super(LoopFallbackNoEval, self).__init__()
57+
58+
def forward(self, x):
59+
for _ in range(x.shape[1]):
60+
x = x + torch.ones_like(x)
61+
return x
62+
63+
64+
# Sample Conditional Model (for testing partitioning and fallback in conditionals)
65+
class FallbackIf(torch.nn.Module):
66+
67+
def __init__(self):
68+
super(FallbackIf, self).__init__()
69+
self.relu1 = torch.nn.ReLU()
70+
self.conv1 = torch.nn.Conv2d(3, 32, 3, 1, 1)
71+
self.log_sig = torch.nn.LogSigmoid()
72+
self.conv2 = torch.nn.Conv2d(32, 32, 3, 1, 1)
73+
self.conv3 = torch.nn.Conv2d(32, 3, 3, 1, 1)
74+
75+
def forward(self, x):
76+
x = self.relu1(x)
77+
x_first = x[0][0][0][0].item()
78+
if x_first > 0:
79+
x = self.conv1(x)
80+
x1 = self.log_sig(x)
81+
x2 = self.conv2(x)
82+
x = self.conv3(x1 + x2)
83+
else:
84+
x = self.log_sig(x)
85+
x = self.conv1(x)
86+
return x
87+
88+

0 commit comments

Comments
 (0)