1
+ import torch
2
+ import torch .nn as nn
3
+ from transformers import BertModel , BertTokenizer , BertConfig
4
+ import torch .nn .functional as F
5
+
6
+ # Sample Pool Model (for testing plugin serialization)
7
+ class Pool (nn .Module ):
8
+
9
+ def __init__ (self ):
10
+ super (Pool , self ).__init__ ()
11
+
12
+ def forward (self , x ):
13
+ return F .adaptive_avg_pool2d (x , (5 , 5 ))
14
+
15
+
16
+ # Sample Nested Module (for module-level fallback testing)
17
+ class ModuleFallbackSub (nn .Module ):
18
+
19
+ def __init__ (self ):
20
+ super (ModuleFallbackSub , self ).__init__ ()
21
+ self .conv = nn .Conv2d (1 , 3 , 3 )
22
+ self .relu = nn .ReLU ()
23
+
24
+ def forward (self , x ):
25
+ return self .relu (self .conv (x ))
26
+
27
+
28
+ class ModuleFallbackMain (nn .Module ):
29
+
30
+ def __init__ (self ):
31
+ super (ModuleFallbackMain , self ).__init__ ()
32
+ self .layer1 = ModuleFallbackSub ()
33
+ self .conv = nn .Conv2d (3 , 6 , 3 )
34
+ self .relu = nn .ReLU ()
35
+
36
+ def forward (self , x ):
37
+ return self .relu (self .conv (self .layer1 (x )))
38
+
39
+
40
+ # Sample Looping Modules (for loop fallback testing)
41
+ class LoopFallbackEval (nn .Module ):
42
+
43
+ def __init__ (self ):
44
+ super (LoopFallbackEval , self ).__init__ ()
45
+
46
+ def forward (self , x ):
47
+ add_list = torch .empty (0 ).to (x .device )
48
+ for i in range (x .shape [1 ]):
49
+ add_list = torch .cat ((add_list , torch .tensor ([x .shape [1 ]]).to (x .device )), 0 )
50
+ return x + add_list
51
+
52
+
53
+ class LoopFallbackNoEval (nn .Module ):
54
+
55
+ def __init__ (self ):
56
+ super (LoopFallbackNoEval , self ).__init__ ()
57
+
58
+ def forward (self , x ):
59
+ for _ in range (x .shape [1 ]):
60
+ x = x + torch .ones_like (x )
61
+ return x
62
+
63
+
64
+ # Sample Conditional Model (for testing partitioning and fallback in conditionals)
65
+ class FallbackIf (torch .nn .Module ):
66
+
67
+ def __init__ (self ):
68
+ super (FallbackIf , self ).__init__ ()
69
+ self .relu1 = torch .nn .ReLU ()
70
+ self .conv1 = torch .nn .Conv2d (3 , 32 , 3 , 1 , 1 )
71
+ self .log_sig = torch .nn .LogSigmoid ()
72
+ self .conv2 = torch .nn .Conv2d (32 , 32 , 3 , 1 , 1 )
73
+ self .conv3 = torch .nn .Conv2d (32 , 3 , 3 , 1 , 1 )
74
+
75
+ def forward (self , x ):
76
+ x = self .relu1 (x )
77
+ x_first = x [0 ][0 ][0 ][0 ].item ()
78
+ if x_first > 0 :
79
+ x = self .conv1 (x )
80
+ x1 = self .log_sig (x )
81
+ x2 = self .conv2 (x )
82
+ x = self .conv3 (x1 + x2 )
83
+ else :
84
+ x = self .log_sig (x )
85
+ x = self .conv1 (x )
86
+ return x
87
+
88
+
0 commit comments