Skip to content

Commit 5af278e

Browse files
author
Nicholas Leonard
committed
removed nnx dependency
1 parent a501c2f commit 5af278e

File tree

6 files changed

+18
-93
lines changed

6 files changed

+18
-93
lines changed

SoftMaxForest.lua

-55
This file was deleted.

SoftMaxTree.lua

-26
This file was deleted.

SpatialGlimpse.lua

+2-1
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@
1414
local SpatialGlimpse, parent = torch.class("nn.SpatialGlimpse", "nn.Module")
1515

1616
function SpatialGlimpse:__init(size, depth, scale)
17-
require 'nnx'
17+
dpnn.require('nnx')
1818
if torch.type(size)=='table' then
1919
self.height = size[1]
2020
self.width = size[2]
@@ -42,6 +42,7 @@ end
4242
-- a bandwidth limited sensor which focuses on a location.
4343
-- locations index the x,y coord of the center of the output glimpse
4444
function SpatialGlimpse:updateOutput(inputTable)
45+
dpnn.require('nnx')
4546
assert(torch.type(inputTable) == 'table')
4647
assert(#inputTable >= 2)
4748
local input, location = unpack(inputTable)

SpatialUniformCrop.lua

+9-7
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
local SpatialUniformCrop, parent = torch.class("nn.SpatialUniformCrop", "nn.Module")
22

33
function SpatialUniformCrop:__init(oheight, owidth, scale)
4+
dpnn.require('nnx')
45
parent.__init(self)
56
self.scale = scale or nil
67
if self.scale ~= nil then
@@ -12,8 +13,9 @@ function SpatialUniformCrop:__init(oheight, owidth, scale)
1213
end
1314

1415
function SpatialUniformCrop:updateOutput(input)
16+
dpnn.require('nnx')
1517
input = self:toBatch(input, 3)
16-
18+
1719
self.output:resize(input:size(1), input:size(2), self.oheight, self.owidth)
1820
self.coord = self.coord or torch.IntTensor()
1921
self.coord:resize(input:size(1), 2)
@@ -22,7 +24,7 @@ function SpatialUniformCrop:updateOutput(input)
2224
self.scales = self.scales or torch.FloatTensor()
2325
self.scales:resize(input:size(1))
2426
end
25-
27+
2628
local iH, iW = input:size(3), input:size(4)
2729
if self.train ~= false then
2830
if self.scale ~= nil then
@@ -34,7 +36,7 @@ function SpatialUniformCrop:updateOutput(input)
3436

3537
local h = math.ceil(torch.uniform(1e-2, iH-soheight))
3638
local w = math.ceil(torch.uniform(1e-2, iW-sowidth))
37-
39+
3840
local ch = math.ceil(iH/2 - (iH-soheight)/2 + h)
3941
local cw = math.ceil(iW/2 - (iH-sowidth)/2 + w)
4042

@@ -70,15 +72,15 @@ function SpatialUniformCrop:updateOutput(input)
7072
local crop = input:narrow(3,h1,self.oheight):narrow(4,w1,self.owidth)
7173
self.output:copy(crop)
7274
end
73-
75+
7476
self.output = self:fromBatch(self.output, 1)
7577
return self.output
7678
end
7779

7880
function SpatialUniformCrop:updateGradInput(input, gradOutput)
7981
input = self:toBatch(input, 3)
8082
gradOutput = self:toBatch(gradOutput, 3)
81-
83+
8284
self.gradInput:resizeAs(input):zero()
8385
if self.scale ~= nil then
8486
local iH, iW = input:size(3), input:size(4)
@@ -88,7 +90,7 @@ function SpatialUniformCrop:updateGradInput(input, gradOutput)
8890
local sowidth = math.ceil(s*self.owidth)
8991

9092
local h, w = self.coord[{i,1}], self.coord[{i,2}]
91-
93+
9294
local ch = math.ceil(iH/2 - (iH-soheight)/2 + h)
9395
local cw = math.ceil(iW/2 - (iH-sowidth)/2 + w)
9496

@@ -108,7 +110,7 @@ function SpatialUniformCrop:updateGradInput(input, gradOutput)
108110
self.gradInput[i]:narrow(2,h1,self.oheight):narrow(3,w1,self.owidth):copy(gradOutput[i])
109111
end
110112
end
111-
113+
112114
self.gradInput = self:fromBatch(self.gradInput, 1)
113115
return self.gradInput
114116
end

init.lua

+5-3
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,5 @@
11
require 'torch'
22
require 'nn'
3-
require 'nnx'
43
local _ = require 'moses'
54

65
-- create global dpnn table
@@ -9,6 +8,11 @@ dpnn.version = 2
98

109
unpack = unpack or table.unpack -- lua 5.2 compat
1110

11+
function dpnn.require(packagename)
12+
assert(torch.type(packagename) == 'string')
13+
assert(pcall(function() require(packagename) end), "missing package "..packagename..": run 'luarocks install nnx'")
14+
end
15+
1216
-- for testing:
1317
require('dpnn.test')
1418

@@ -49,8 +53,6 @@ require('dpnn.CAddTensorTable')
4953
require('dpnn.ReverseTable')
5054
require('dpnn.Dictionary')
5155
require('dpnn.Inception')
52-
require('dpnn.SoftMaxTree')
53-
require('dpnn.SoftMaxForest')
5456
require('dpnn.Clip')
5557
require('dpnn.SpatialUniformCrop')
5658
require('dpnn.SpatialGlimpse')

test/test.lua

+2-1
Original file line numberDiff line numberDiff line change
@@ -898,6 +898,7 @@ end
898898

899899
function dpnntest.SpatialGlimpse()
900900
if not pcall(function() require "image" end) then return end -- needs the image package
901+
if not pcall(function() require "nnx" end) then return end -- needs the nnx package
901902
local batchSize = 1
902903
local inputSize = {2,8,8}
903904
local glimpseSize = 4
@@ -1118,12 +1119,12 @@ function dpnntest.SpatialGlimpse()
11181119
end
11191120

11201121
function dpnntest.SpatialGlimpse_backwardcompat()
1122+
if not pcall(function() require "nnx" end) then return end -- needs the nnx package
11211123
-- this is ugly, but I know this verson of the module works.
11221124
-- So we try to match the newer versions to it
11231125
local SG, parent = torch.class("nn.SG", "nn.Module")
11241126

11251127
function SG:__init(size, depth, scale)
1126-
require 'nnx'
11271128
self.size = size -- height == width
11281129
self.depth = depth or 3
11291130
self.scale = scale or 2

0 commit comments

Comments
 (0)