Skip to content

Graphgen improvements #10

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 7 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
82 changes: 75 additions & 7 deletions graphgen.lua
Original file line number Diff line number Diff line change
Expand Up @@ -102,6 +102,9 @@ local function generateGraph(net, input, opts)

local storageHash = {}
local nodes = {}
local trickyNodes = {}
local current_module = {__input=input}
local stack_visited_modules = {}

local g = graph.Graph()

Expand Down Expand Up @@ -158,6 +161,45 @@ local function generateGraph(net, input, opts)
end
end

local origTorchFuncs = {DoubleTensor={},FloatTensor={}}
-- also hack the cuda counter-parts if cutorch is loaded
if package.loaded.cutorch then
origTorchFuncs.CudaTensor = {}
end
-- list of functions to hack. seems that can't extend due to stack
-- overflow reasons
local hackableTorchFuncs = {'select','__index'}

-- we will temporarily overwrite torch functions to keep track
-- of all created tensors during the forward call. This will
-- allow us to handle some corner cases where the input tensor is
-- not part of the state of a module (i.e., it's not the output
-- of another module)
local function hackTorch()
for torchType, t in pairs(origTorchFuncs) do
for _, func in ipairs(hackableTorchFuncs) do
local oldFunc = torch[torchType][func]
t[func] = oldFunc
torch[torchType][func] = function(...)
local res = oldFunc(...)
if res then
-- heavy use of upvalues
trickyNodes[torch.pointer(res)] = {current_module, 'torch.'..func}
end
return res
end
end
end
end

local function unhackTorch()
for torchType, t in pairs(origTorchFuncs) do
for _, func in ipairs(hackableTorchFuncs) do
torch[torchType][func] = t[func]
end
end
end

-- create edge "from" -> "to", creating "to" on the way with "name"
-- the edges can be seen as linking modules, but in fact it links the output
-- tensor of each module
Expand All @@ -168,8 +210,19 @@ local function generateGraph(net, input, opts)

nodes[toPtr] = nodes[toPtr] or createNode(name,to)

assert(nodes[fromPtr], 'Parent node inexistant for module '.. name)

-- if "from" tensor is not present in "nodes" table, this means that
-- "from" is not the output of a module, and was created on the fly
-- during for example a slicing of a tensor. "trickyNodes" contains
-- all tensors that were generated on the fly
if not nodes[fromPtr] then
local trickyNode = trickyNodes[fromPtr]
assert(trickyNode, "Could't handle previous node to "..name)
local trickyNodeName = trickyNode[2]

local trickyParentFrom = trickyNode[1].__input
addEdge(trickyParentFrom,from,trickyNodeName)
end

-- insert edge
g:add(graph.Edge(nodes[fromPtr],nodes[toPtr]))
elseif torch.isTensor(from) then
Expand All @@ -188,6 +241,14 @@ local function generateGraph(net, input, opts)
local function apply_func(m)
local basefunc = m.updateOutput
m.updateOutput = function(self, input)
-- add input to self to help keep track of it
self.__input = input
-- keeps a stack of visited modules
table.insert(stack_visited_modules, current_module)
current_module = self
local output = basefunc(self, input)
current_module = table.remove(stack_visited_modules)
-- add edges to the graph according to the node type
if isSingleOperationModule(m) then
local name = tostring(m)
if m.inplace then -- handle it differently ?
Expand All @@ -199,7 +260,13 @@ local function generateGraph(net, input, opts)
-- those containers effectively do some computation, so they have their
-- place in the graph
for i,branch in ipairs(m.modules) do
local last_module = branch:get(branch:size())
local last_module
if branch.modules then
last_module = branch:get(branch:size())
else
last_module = branch
end

local out = last_module.output
local ptr = torch.pointer(out)

Expand All @@ -208,20 +275,20 @@ local function generateGraph(net, input, opts)
addEdge(out, self.output, torch.typename(m))
end
end
return basefunc(self, input)
return output
end
end

createBoundaryNode(input, 'Input')

-- fill the states from each tensor
net:forward(input)

hackTorch()
-- overwriting the standard functions to generate our graph
net:apply(apply_func)
-- generate the graph
net:forward(input)

unhackTorch()

if opts.addOutputNode then
-- add dummy output node and link the last module to it
local output = utils.recursiveClone(net.output)
Expand All @@ -245,6 +312,7 @@ local function generateGraph(net, input, opts)
-- clean up the modified function
net:apply(function(x)
x.updateOutput = nil
x.__input = nil
end)

return g
Expand Down
56 changes: 56 additions & 0 deletions models.lua
Original file line number Diff line number Diff line change
Expand Up @@ -72,6 +72,62 @@ models.siamese = function()
return m, input
end

models.siamese_parallel = function()
local fSize = {1, 32, 64}
local featuresOut = 128

local desc = nn.Sequential()
desc:add(nn.Reshape(1,64,64))
desc:add(nn.SpatialAveragePooling(2,2,2,2))
desc:add(nn.SpatialConvolution(fSize[1], fSize[2], 7,7))
desc:add(nn.ReLU())
desc:add(nn.SpatialMaxPooling(2,2,2,2))
desc:add(nn.SpatialConvolution(fSize[2], fSize[3], 6,6))
desc:add(nn.ReLU())
desc:add(nn.View(-1):setNumInputDims(3))
desc:add(nn.Linear(4096, 128))
desc:add(nn.Contiguous())

local siamese = nn.Parallel(2,2)
local siam = desc:clone()
desc:share(siam, 'weight', 'bias', 'gradWeight', 'gradBias')
siamese:add(desc)
siamese:add(siam)

local top = nn.Sequential()
top:add(nn.Linear(featuresOut*2, featuresOut*2))
top:add(nn.ReLU())
top:add(nn.Linear(featuresOut*2, 1))

local model = nn.Sequential():add(siamese):add(top)

local input = torch.rand(1,2,64,64)

return model, input
end

models.basic_parallel_middle = function()
local model = nn.Sequential():add(nn.Linear(2,2))
local prl = nn.Parallel(2,1)
prl:add(nn.Linear(2,2))
prl:add(nn.Linear(2,2))
model:add(prl)
local input = torch.rand(2,2)
return model, input
end

models.basic_splitTable = function()
local model = nn.Sequential():add(nn.Linear(2,2))
model:add(nn.SplitTable(2))
local prl = nn.ParallelTable()
prl:add(nn.ReLU())
prl:add(nn.Sigmoid())
model:add(prl)
model:add(nn.JoinTable(1))
local input = torch.rand(2,2)
return model, input
end

models.basic_concat = function()
local m = nn.Sequential()
local cat = nn.ConcatTable()
Expand Down