Skip to content

Commit 56373c0

Browse files
committed
make value residual learned
1 parent 24196a3 commit 56373c0

File tree

2 files changed

+14
-6
lines changed

2 files changed

+14
-6
lines changed

setup.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -6,10 +6,10 @@
66
setup(
77
name = 'vit-pytorch',
88
packages = find_packages(exclude=['examples']),
9-
version = '1.8.8',
9+
version = '1.8.9',
1010
license='MIT',
1111
description = 'Vision Transformer (ViT) - Pytorch',
12-
long_description=long_description,
12+
long_description = long_description,
1313
long_description_content_type = 'text/markdown',
1414
author = 'Phil Wang',
1515
author_email = '[email protected]',

vit_pytorch/simple_vit_with_value_residual.py

+12-4
Original file line numberDiff line numberDiff line change
@@ -38,7 +38,7 @@ def FeedForward(dim, hidden_dim):
3838
)
3939

4040
class Attention(Module):
41-
def __init__(self, dim, heads = 8, dim_head = 64):
41+
def __init__(self, dim, heads = 8, dim_head = 64, learned_value_residual_mix = False):
4242
super().__init__()
4343
inner_dim = dim_head * heads
4444
self.heads = heads
@@ -50,14 +50,21 @@ def __init__(self, dim, heads = 8, dim_head = 64):
5050
self.to_qkv = nn.Linear(dim, inner_dim * 3, bias = False)
5151
self.to_out = nn.Linear(inner_dim, dim, bias = False)
5252

53+
self.to_residual_mix = nn.Sequential(
54+
nn.Linear(dim, heads),
55+
nn.Sigmoid(),
56+
Rearrange('b n h -> b h n 1')
57+
) if learned_value_residual_mix else (lambda _: 0.5)
58+
5359
def forward(self, x, value_residual = None):
5460
x = self.norm(x)
5561

5662
qkv = self.to_qkv(x).chunk(3, dim = -1)
5763
q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> b h n d', h = self.heads), qkv)
5864

5965
if exists(value_residual):
60-
v = 0.5 * (v + value_residual)
66+
mix = self.to_residual_mix(x)
67+
v = v * mix + value_residual * (1. - mix)
6168

6269
dots = torch.matmul(q, k.transpose(-1, -2)) * self.scale
6370

@@ -73,9 +80,10 @@ def __init__(self, dim, depth, heads, dim_head, mlp_dim):
7380
super().__init__()
7481
self.norm = nn.LayerNorm(dim)
7582
self.layers = ModuleList([])
76-
for _ in range(depth):
83+
for i in range(depth):
84+
is_first = i == 0
7785
self.layers.append(ModuleList([
78-
Attention(dim, heads = heads, dim_head = dim_head),
86+
Attention(dim, heads = heads, dim_head = dim_head, learned_value_residual_mix = not is_first),
7987
FeedForward(dim, mlp_dim)
8088
]))
8189
def forward(self, x):

0 commit comments

Comments
 (0)