forked from huggingface/pytorch-image-models
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathtest_layers.py
121 lines (84 loc) · 2.91 KB
/
test_layers.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
import torch
import torch.nn as nn
from timm.layers import create_act_layer, set_layer_config, get_act_layer, get_act_fn
import importlib
import os
torch_backend = os.environ.get('TORCH_BACKEND')
if torch_backend is not None:
importlib.import_module(torch_backend)
torch_device = os.environ.get('TORCH_DEVICE', 'cpu')
class MLP(nn.Module):
def __init__(self, act_layer="relu", inplace=True):
super(MLP, self).__init__()
self.fc1 = nn.Linear(1000, 100)
self.act = create_act_layer(act_layer, inplace=inplace)
self.fc2 = nn.Linear(100, 10)
def forward(self, x):
x = self.fc1(x)
x = self.act(x)
x = self.fc2(x)
return x
def _run_act_layer_grad(act_type, inplace=True):
x = torch.rand(10, 1000) * 10
m = MLP(act_layer=act_type, inplace=inplace)
def _run(x, act_layer=''):
if act_layer:
# replace act layer if set
m.act = create_act_layer(act_layer, inplace=inplace)
out = m(x)
l = (out - 0).pow(2).sum()
return l
x = x.to(device=torch_device)
m.to(device=torch_device)
out_me = _run(x)
with set_layer_config(scriptable=True):
out_jit = _run(x, act_type)
assert torch.isclose(out_jit, out_me)
with set_layer_config(no_jit=True):
out_basic = _run(x, act_type)
assert torch.isclose(out_basic, out_jit)
def test_swish_grad():
for _ in range(100):
_run_act_layer_grad('swish')
def test_mish_grad():
for _ in range(100):
_run_act_layer_grad('mish')
def test_hard_sigmoid_grad():
for _ in range(100):
_run_act_layer_grad('hard_sigmoid', inplace=None)
def test_hard_swish_grad():
for _ in range(100):
_run_act_layer_grad('hard_swish')
def test_hard_mish_grad():
for _ in range(100):
_run_act_layer_grad('hard_mish')
def test_get_act_layer_empty_string():
# Empty string should return None
assert get_act_layer('') is None
def test_create_act_layer_inplace_error():
class NoInplaceAct(nn.Module):
def __init__(self):
super().__init__()
def forward(self, x):
return x
# Should recover when inplace arg causes TypeError
layer = create_act_layer(NoInplaceAct, inplace=True)
assert isinstance(layer, NoInplaceAct)
def test_create_act_layer_edge_cases():
# Test None input
assert create_act_layer(None) is None
# Test TypeError handling for inplace
class CustomAct(nn.Module):
def __init__(self, **kwargs):
super().__init__()
def forward(self, x):
return x
result = create_act_layer(CustomAct, inplace=True)
assert isinstance(result, CustomAct)
def test_get_act_fn_callable():
def custom_act(x):
return x
assert get_act_fn(custom_act) is custom_act
def test_get_act_fn_none():
assert get_act_fn(None) is None
assert get_act_fn('') is None