Skip to content

Commit 6150ff8

Browse files
Update by commit
1 parent 18e42e3 commit 6150ff8

File tree

3 files changed

+9
-9
lines changed

3 files changed

+9
-9
lines changed

rnn_decoder.py

+5-5
Original file line numberDiff line numberDiff line change
@@ -32,11 +32,11 @@ def __init__(self, x_wti, y_wti):
3232

3333
def forward(self, xw, yi, mask):
3434

35-
x = self.embed(None, None, yi)
35+
h = self.embed(None, None, yi)
3636

3737
if ATTN:
38-
x = torch.cat([x, self.h], 2) # input feeding
39-
h, self.H = self.rnn(x, self.H)
38+
h = torch.cat([h, self.h], 2) # input feeding
39+
h, self.H = self.rnn(h, self.H)
4040
self.attn(self.M, h, mask)
4141
self.h = self.Wc(torch.cat([self.attn.V, h], 2)).tanh()
4242
h = self.Wo(self.h).squeeze(1) # [B, V]
@@ -46,8 +46,8 @@ def forward(self, xw, yi, mask):
4646
_M = self.M[:, :-1] # remove EOS token [B, L' = L - 1]
4747
self.attn(self.M, self.h, mask) # attentive read
4848
self.copy.attn(_M) # selective read
49-
x = torch.cat([x, self.attn.V, self.copy.R], 2)
50-
self.h, self.H = self.rnn(x, self.H)
49+
h = torch.cat([h, self.attn.V, self.copy.R], 2)
50+
self.h, self.H = self.rnn(h, self.H)
5151
g = self.Wo(self.h).squeeze(1) # generation scores [B, V]
5252
c = self.copy.score(_M, self.h, mask) # copy scores [B, L']
5353
yo = self.copy.mix(xw, g, c) # [B, V']

rnn_encoder.py

+3-3
Original file line numberDiff line numberDiff line change
@@ -34,9 +34,9 @@ def forward(self, xc, xw, lens):
3434
b = len(lens)
3535
s = self.init_state(b)
3636

37-
x = self.embed(b, xc, xw)
38-
x = nn.utils.rnn.pack_padded_sequence(x, lens.cpu(), batch_first = True)
39-
h, s = self.rnn(x, s)
37+
h = self.embed(b, xc, xw)
38+
h = nn.utils.rnn.pack_padded_sequence(h, lens, batch_first = True)
39+
h, s = self.rnn(h, s)
4040
h, _ = nn.utils.rnn.pad_packed_sequence(h, batch_first = True)
4141

4242
return h, s

utils.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -82,7 +82,7 @@ def save_loss(filename, epoch, loss_array):
8282
def maskset(x):
8383

8484
mask = x.eq(PAD_IDX)
85-
lens = x.size(1) - mask.sum(1) # x.ne(PAD_IDX).sum(1)
85+
lens = (x.size(1) - mask.sum(1)).tolist() # x.ne(PAD_IDX).sum(1)
8686

8787
return mask, lens
8888

0 commit comments

Comments
 (0)