|
1 | 1 | local Bigrams, parent = torch.class("nn.Bigrams", "nn.Module")
|
2 | 2 |
|
3 |
| ---Function taken by torchx Aliasmultinomail.lua |
| 3 | +--Function taken by torchx Aliasmultinomial.lua |
4 | 4 | function Bigrams:setup(probs)
|
5 | 5 | assert(probs:dim() == 1)
|
6 | 6 | local K = probs:nElement()
|
@@ -91,34 +91,44 @@ function Bigrams:batchdraw(output, J, q)
|
91 | 91 | end
|
92 | 92 |
|
93 | 93 |
|
94 |
| -function Bigrams:__init(bigrams, sample) |
95 |
| - self.Nsample = sample |
96 |
| - self.bigrams = bigrams |
97 |
| - self.q = {} |
98 |
| - self.J = {} |
99 |
| - for uniI, map in pairs(bigrams) do |
100 |
| - local J, q = self.setup(self, map['prob']) |
101 |
| - self.J[uniI] = J |
102 |
| - self.q[uniI] = q |
103 |
| - end |
| 94 | +function Bigrams:__init(bigrams, nsample) |
| 95 | + self.nsample = nsample |
| 96 | + self.bigrams = bigrams |
| 97 | + self.q = {} |
| 98 | + self.J = {} |
| 99 | + for uniI, map in pairs(bigrams) do |
| 100 | + local J, q = self.setup(self, map.prob) |
| 101 | + self.J[uniI] = J |
| 102 | + self.q[uniI] = q |
| 103 | + end |
104 | 104 | end
|
105 | 105 |
|
106 | 106 |
|
107 | 107 | function Bigrams:updateOutput(input)
|
108 | 108 | assert(torch.type(input) == 'torch.LongTensor')
|
109 | 109 | local batchSize = input:size(1)
|
110 |
| - self.output = torch.LongTensor(batchSize, self.Nsample) |
| 110 | + self.output = torch.type(self.output) == 'torch.LongTensor' and self.output or torch.LongTensor() |
| 111 | + self.output:resize(batchSize, self.nsample) |
111 | 112 |
|
112 | 113 | for i = 1, batchSize do
|
113 | 114 | self.batchdraw(self, self.output[i], self.J[input[i]], self.q[input[i]])
|
114 |
| - --self.output[i]:apply(function(x) return self.bigrams[input[i]]['index'][x] end) |
115 | 115 | end
|
116 | 116 |
|
117 | 117 | return self.output
|
118 | 118 | end
|
119 | 119 |
|
120 | 120 | function Bigrams:updateGradInput(input, gradOutput)
|
121 |
| - return torch.LongTensor(input:size()):fill(0) |
| 121 | + self.gradInput = torch.type(self.gradInput) == 'torch.LongTensor' or torch.LongTensor() |
| 122 | + self.gradInput:resizeAs(input):fill(0) |
| 123 | + return self.gradInput |
122 | 124 | end
|
123 | 125 |
|
124 |
| - |
| 126 | +function Bigrams:statistics() |
| 127 | + local sum, count = 0, 0 |
| 128 | + for uniI, map in pairs(self.bigrams) do |
| 129 | + sum = sum + map.prob:nElement() |
| 130 | + count = count + 1 |
| 131 | + end |
| 132 | + local meansize = sum/count |
| 133 | + return meansize |
| 134 | +end |
0 commit comments