Skip to content

Commit ab295ee

Browse files
committed
Revert back SPP and LRN to heritate from Module
1 parent 4c2dc1c commit ab295ee

File tree

2 files changed

+27
-8
lines changed

2 files changed

+27
-8
lines changed

SpatialPyramidPooling.lua

Lines changed: 14 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,13 @@
1-
local SpatialPyramidPooling, parent = torch.class('inn.SpatialPyramidPooling', 'nn.Concat')
1+
local SpatialPyramidPooling, parent = torch.class('inn.SpatialPyramidPooling', 'nn.Module')
22

33
-- allows nn.SpatialPyramidPooling({4,4},{3,3}) or {{4,4},{3,3}}
44
function SpatialPyramidPooling:__init(...)
5-
parent.__init(self, 1)
5+
parent.__init(self)
66
local pyr = {...}
77
self.pyr = torch.type(pyr[1][1]) == 'table' and pyr[1] or pyr
8+
self._modules = nn.Concat(1)
89
for k,v in ipairs(self.pyr) do
9-
parent.add(self, nn.Sequential()
10+
self._modules:add(nn.Sequential()
1011
:add(nn.SpatialAdaptiveMaxPooling(v[1], v[2]))
1112
:add(nn.View(-1):setNumInputDims(3))
1213
:add(nn.Contiguous())
@@ -16,7 +17,15 @@ end
1617

1718
function SpatialPyramidPooling:updateOutput(input)
1819
assert(input:dim() == 4 or input:dim() == 3, 'unsupported dimensionality')
19-
self.dimension = input:dim() - 2
20-
return parent.updateOutput(self, input)
20+
self._modules.dimension = input:dim() - 2
21+
self.output = self._modules:updateOutput(input)
22+
return self.output
23+
end
24+
25+
function SpatialPyramidPooling:updateGradInput(input, gradOutput)
26+
assert(input:dim() == 4 or input:dim() == 3, 'unsupported dimensionality')
27+
self._modules.dimension = input:dim() - 2
28+
self.gradInput = self._modules:updateGradInput(input, gradOutput)
29+
return self.gradInput
2130
end
2231

SpatialSameResponseNormalization.lua

Lines changed: 13 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
local SpatialSameResponseNormalization, parent = torch.class('inn.SpatialSameResponseNormalization', 'nn.Sequential')
1+
local SpatialSameResponseNormalization, parent = torch.class('inn.SpatialSameResponseNormalization', 'nn.Module')
22

33
function SpatialSameResponseNormalization:__init(size, alpha, beta)
44
parent.__init(self)
@@ -21,7 +21,17 @@ function SpatialSameResponseNormalization:__init(size, alpha, beta)
2121
:add(numerator)
2222
:add(denominator)
2323

24-
self:add(divide)
25-
self:add(nn.CDivTable())
24+
self._modules = nn.Sequential()
25+
self._modules:add(divide)
26+
self._modules:add(nn.CDivTable())
2627
end
2728

29+
function SpatialSameResponseNormalization:updateOutput(input)
30+
self.output = self._modules:updateOutput(input)
31+
return self.output
32+
end
33+
34+
function SpatialSameResponseNormalization:updateGradInput(input, gradOutput)
35+
self.gradInput = self._modules:updateGradInput(input, gradOutput)
36+
return self.gradInput
37+
end

0 commit comments

Comments
 (0)