Skip to content

Commit 141239c

Browse files
committed
fix value residual
1 parent 0b5c9b4 commit 141239c

File tree

2 files changed

+2
-2
lines changed

2 files changed

+2
-2
lines changed

setup.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@
66
setup(
77
name = 'vit-pytorch',
88
packages = find_packages(exclude=['examples']),
9-
version = '1.8.6',
9+
version = '1.8.7',
1010
license='MIT',
1111
description = 'Vision Transformer (ViT) - Pytorch',
1212
long_description=long_description,

vit_pytorch/simple_vit_with_value_residual.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -57,7 +57,7 @@ def forward(self, x, value_residual = None):
5757
q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> b h n d', h = self.heads), qkv)
5858

5959
if exists(value_residual):
60-
v = v + value_residual
60+
v = 0.5 * (v + value_residual)
6161

6262
dots = torch.matmul(q, k.transpose(-1, -2)) * self.scale
6363

0 commit comments

Comments
 (0)