@@ -38,7 +38,7 @@ def FeedForward(dim, hidden_dim):
38
38
)
39
39
40
40
class Attention (Module ):
41
- def __init__ (self , dim , heads = 8 , dim_head = 64 ):
41
+ def __init__ (self , dim , heads = 8 , dim_head = 64 , learned_value_residual_mix = False ):
42
42
super ().__init__ ()
43
43
inner_dim = dim_head * heads
44
44
self .heads = heads
@@ -50,14 +50,21 @@ def __init__(self, dim, heads = 8, dim_head = 64):
50
50
self .to_qkv = nn .Linear (dim , inner_dim * 3 , bias = False )
51
51
self .to_out = nn .Linear (inner_dim , dim , bias = False )
52
52
53
+ self .to_residual_mix = nn .Sequential (
54
+ nn .Linear (dim , heads ),
55
+ nn .Sigmoid (),
56
+ Rearrange ('b n h -> b h n 1' )
57
+ ) if learned_value_residual_mix else (lambda _ : 0.5 )
58
+
53
59
def forward (self , x , value_residual = None ):
54
60
x = self .norm (x )
55
61
56
62
qkv = self .to_qkv (x ).chunk (3 , dim = - 1 )
57
63
q , k , v = map (lambda t : rearrange (t , 'b n (h d) -> b h n d' , h = self .heads ), qkv )
58
64
59
65
if exists (value_residual ):
60
- v = 0.5 * (v + value_residual )
66
+ mix = self .to_residual_mix (x )
67
+ v = v * mix + value_residual * (1. - mix )
61
68
62
69
dots = torch .matmul (q , k .transpose (- 1 , - 2 )) * self .scale
63
70
@@ -73,9 +80,10 @@ def __init__(self, dim, depth, heads, dim_head, mlp_dim):
73
80
super ().__init__ ()
74
81
self .norm = nn .LayerNorm (dim )
75
82
self .layers = ModuleList ([])
76
- for _ in range (depth ):
83
+ for i in range (depth ):
84
+ is_first = i == 0
77
85
self .layers .append (ModuleList ([
78
- Attention (dim , heads = heads , dim_head = dim_head ),
86
+ Attention (dim , heads = heads , dim_head = dim_head , learned_value_residual_mix = not is_first ),
79
87
FeedForward (dim , mlp_dim )
80
88
]))
81
89
def forward (self , x ):
0 commit comments