|
1 |
| -local SpatialPyramidPooling, parent = torch.class('inn.SpatialPyramidPooling', 'nn.Concat') |
| 1 | +local SpatialPyramidPooling, parent = torch.class('inn.SpatialPyramidPooling', 'nn.Module') |
2 | 2 |
|
3 | 3 | -- allows nn.SpatialPyramidPooling({4,4},{3,3}) or {{4,4},{3,3}}
|
4 | 4 | function SpatialPyramidPooling:__init(...)
|
5 |
| - parent.__init(self, 1) |
| 5 | + parent.__init(self) |
6 | 6 | local pyr = {...}
|
7 | 7 | self.pyr = torch.type(pyr[1][1]) == 'table' and pyr[1] or pyr
|
| 8 | + self._modules = nn.Concat(1) |
8 | 9 | for k,v in ipairs(self.pyr) do
|
9 |
| - parent.add(self, nn.Sequential() |
| 10 | + self._modules:add(nn.Sequential() |
10 | 11 | :add(nn.SpatialAdaptiveMaxPooling(v[1], v[2]))
|
11 | 12 | :add(nn.View(-1):setNumInputDims(3))
|
12 | 13 | :add(nn.Contiguous())
|
|
16 | 17 |
|
17 | 18 | function SpatialPyramidPooling:updateOutput(input)
|
18 | 19 | assert(input:dim() == 4 or input:dim() == 3, 'unsupported dimensionality')
|
19 |
| - self.dimension = input:dim() - 2 |
20 |
| - return parent.updateOutput(self, input) |
| 20 | + self._modules.dimension = input:dim() - 2 |
| 21 | + self.output = self._modules:updateOutput(input) |
| 22 | + return self.output |
| 23 | +end |
| 24 | + |
| 25 | +function SpatialPyramidPooling:updateGradInput(input, gradOutput) |
| 26 | + assert(input:dim() == 4 or input:dim() == 3, 'unsupported dimensionality') |
| 27 | + self._modules.dimension = input:dim() - 2 |
| 28 | + self.gradInput = self._modules:updateGradInput(input, gradOutput) |
| 29 | + return self.gradInput |
21 | 30 | end
|
22 | 31 |
|
0 commit comments