Skip to content

Commit 917342c

Browse files
Bigrams++
1 parent 9a10d25 commit 917342c

File tree

1 file changed

+25
-15
lines changed

1 file changed

+25
-15
lines changed

Bigrams.lua

Lines changed: 25 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
local Bigrams, parent = torch.class("nn.Bigrams", "nn.Module")
22

3-
--Function taken by torchx Aliasmultinomail.lua
3+
--Function taken by torchx Aliasmultinomial.lua
44
function Bigrams:setup(probs)
55
assert(probs:dim() == 1)
66
local K = probs:nElement()
@@ -91,34 +91,44 @@ function Bigrams:batchdraw(output, J, q)
9191
end
9292

9393

94-
function Bigrams:__init(bigrams, sample)
95-
self.Nsample = sample
96-
self.bigrams = bigrams
97-
self.q = {}
98-
self.J = {}
99-
for uniI, map in pairs(bigrams) do
100-
local J, q = self.setup(self, map['prob'])
101-
self.J[uniI] = J
102-
self.q[uniI] = q
103-
end
94+
function Bigrams:__init(bigrams, nsample)
95+
self.nsample = nsample
96+
self.bigrams = bigrams
97+
self.q = {}
98+
self.J = {}
99+
for uniI, map in pairs(bigrams) do
100+
local J, q = self.setup(self, map.prob)
101+
self.J[uniI] = J
102+
self.q[uniI] = q
103+
end
104104
end
105105

106106

107107
function Bigrams:updateOutput(input)
108108
assert(torch.type(input) == 'torch.LongTensor')
109109
local batchSize = input:size(1)
110-
self.output = torch.LongTensor(batchSize, self.Nsample)
110+
self.output = torch.type(self.output) == 'torch.LongTensor' and self.output or torch.LongTensor()
111+
self.output:resize(batchSize, self.nsample)
111112

112113
for i = 1, batchSize do
113114
self.batchdraw(self, self.output[i], self.J[input[i]], self.q[input[i]])
114-
--self.output[i]:apply(function(x) return self.bigrams[input[i]]['index'][x] end)
115115
end
116116

117117
return self.output
118118
end
119119

120120
function Bigrams:updateGradInput(input, gradOutput)
121-
return torch.LongTensor(input:size()):fill(0)
121+
self.gradInput = torch.type(self.gradInput) == 'torch.LongTensor' or torch.LongTensor()
122+
self.gradInput:resizeAs(input):fill(0)
123+
return self.gradInput
122124
end
123125

124-
126+
function Bigrams:statistics()
127+
local sum, count = 0, 0
128+
for uniI, map in pairs(self.bigrams) do
129+
sum = sum + map.prob:nElement()
130+
count = count + 1
131+
end
132+
local meansize = sum/count
133+
return meansize
134+
end

0 commit comments

Comments
 (0)