@@ -32,11 +32,11 @@ def __init__(self, x_wti, y_wti):
32
32
33
33
def forward (self , xw , yi , mask ):
34
34
35
- x = self .embed (None , None , yi )
35
+ h = self .embed (None , None , yi )
36
36
37
37
if ATTN :
38
- x = torch .cat ([x , self .h ], 2 ) # input feeding
39
- h , self .H = self .rnn (x , self .H )
38
+ h = torch .cat ([h , self .h ], 2 ) # input feeding
39
+ h , self .H = self .rnn (h , self .H )
40
40
self .attn (self .M , h , mask )
41
41
self .h = self .Wc (torch .cat ([self .attn .V , h ], 2 )).tanh ()
42
42
h = self .Wo (self .h ).squeeze (1 ) # [B, V]
@@ -46,8 +46,8 @@ def forward(self, xw, yi, mask):
46
46
_M = self .M [:, :- 1 ] # remove EOS token [B, L' = L - 1]
47
47
self .attn (self .M , self .h , mask ) # attentive read
48
48
self .copy .attn (_M ) # selective read
49
- x = torch .cat ([x , self .attn .V , self .copy .R ], 2 )
50
- self .h , self .H = self .rnn (x , self .H )
49
+ h = torch .cat ([h , self .attn .V , self .copy .R ], 2 )
50
+ self .h , self .H = self .rnn (h , self .H )
51
51
g = self .Wo (self .h ).squeeze (1 ) # generation scores [B, V]
52
52
c = self .copy .score (_M , self .h , mask ) # copy scores [B, L']
53
53
yo = self .copy .mix (xw , g , c ) # [B, V']
0 commit comments