Skip to content

Commit 9cb983b

Browse files
committed
minor changes
1 parent 97f6e8f commit 9cb983b

File tree

6 files changed

+21
-28
lines changed

6 files changed

+21
-28
lines changed

code/checkpoints.lua

+3-6
Original file line numberDiff line numberDiff line change
@@ -14,14 +14,11 @@ function checkpoint.latest(opt)
1414
end
1515

1616
function checkpoint.load(opt)
17-
if opt.useCheckpoint == false then
18-
return nil
19-
end
2017
--print(opt.epochNumber)
21-
local epoch = opt.epochNumber
22-
if epoch == 0 then
18+
local epoch = opt.useCheckpoint
19+
if epoch == -1 then
2320
return nil
24-
elseif epoch < 0 then
21+
if epoch == 0 then
2522
-- finding the latest epoch, requiring 'latest.t7'
2623
return checkpoint.latest(opt)
2724
end

code/datasets/floorplan-representation.lua

+2-5
Original file line numberDiff line numberDiff line change
@@ -12,11 +12,8 @@ local FloorplanDataset = torch.class('FloorplanDataset', M)
1212
function FloorplanDataset:__init(imageInfo, opt, split)
1313
self.imageInfo = imageInfo[split]
1414
if split == 'train' then
15-
self.imageInfo.floorplanPaths = self.imageInfo.floorplanPaths:repeatTensor(opt.nRepetitionsPerEpochTrain, 1)
16-
self.imageInfo.representationPaths = self.imageInfo.representationPaths:repeatTensor(opt.nRepetitionsPerEpochTrain, 1)
17-
else
18-
self.imageInfo.floorplanPaths = self.imageInfo.floorplanPaths:repeatTensor(opt.nRepetitionsPerEpochTest, 1)
19-
self.imageInfo.representationPaths = self.imageInfo.representationPaths:repeatTensor(opt.nRepetitionsPerEpochTest, 1)
15+
self.imageInfo.floorplanPaths = self.imageInfo.floorplanPaths:repeatTensor(opt.checkpointEpochInterval, 1)
16+
self.imageInfo.representationPaths = self.imageInfo.representationPaths:repeatTensor(opt.checkpointEpochInterval, 1)
2017
end
2118
self.opt = opt
2219
self.split = split

code/evaluate.lua

+4-4
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,9 @@
11
package.path = '../util/lua/?.lua;' .. package.path
22
local fp_ut = require 'floorplan_utils'
33

4-
dataPath = '../data/'
4+
local dataPath = '../data/'
55

6-
local imageInfo = csvigo.load({path=dataPath + '/test.txt', mode="large", header=false, separator='\t'})
6+
local imageInfo = csvigo.load({path=dataPath .. '/test.txt', mode="large", header=false, separator='\t'})
77

88
local result = {}
99
for _, mode in pairs({'Wall Junction', 'Door', 'Object', 'Room'}) do
@@ -19,11 +19,11 @@ local filenames = {}
1919
--local finalExamples = {1, 2, 4, 5, 6}
2020
--for _, i in pairs(finalExamples) do
2121
local results = {}
22-
for k, v in pairs(photo_info) do
22+
for k, v in pairs(imageInfo) do
2323
local floorplanFilename = dataPath .. v[1]
2424
local representationFilename = dataPath .. v[2]
2525

26-
26+
2727
representationPrediction = fp_ut.invertFloorplan(floorplan, false)
2828
local singleResult = fp_ut.evaluateResult(floorplan:size(3), floorplan:size(2), representationTarget, representationPrediction, {pointDistanceThreshold = 0.02, doorDistanceThreshold = 0.02, iconIOUThreshold = 0.5, segmentIOUThreshold = 0.7}, result)
2929
table.insert(results, singleResult)

code/models/heatmap-segmentation.lua

+7-6
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,13 @@ local nn = require 'nn'
66

77
local function createModel(opt)
88

9+
if opt.loadModel ~= '' then
10+
local model = torch.load(opt.loadModel)
11+
model:cuda()
12+
print(model)
13+
return model
14+
end
15+
916
if opt.loadPoseEstimationModel ~= '' then
1017
local nOutput = 51 --This is slightly more than our final number of output channels. The actually used number of channels is 13 (wall corner) + 4 (opening corner) + 4 (icon corner) + 10 (opening/icon/empty segmentation) + 11 (wall/room segmentation)
1118

@@ -33,12 +40,6 @@ local function createModel(opt)
3340
return model
3441
end
3542

36-
if opt.loadModel ~= '' then
37-
local model = torch.load(opt.loadModel)
38-
model:cuda()
39-
print(model)
40-
return model
41-
end
4243
assert(false, 'Please specify either opt.loadPoseEstimationModel or opt.loadModel')
4344
end
4445

code/opts.lua

+4-6
Original file line numberDiff line numberDiff line change
@@ -37,10 +37,8 @@ function M.parse(arg)
3737
cmd:option('-maskProb', 0.5, 'The probability of training masks')
3838
cmd:option('-IOUThreshold', 0.95, 'The probability of training masks')
3939
------------- Training/testing options ----------
40-
cmd:option('-nEpochs', 100, 'Number of total epochs to run')
41-
cmd:option('-nRepetitionsPerEpochTrain', 10, 'Number of total epochs to run')
42-
cmd:option('-nRepetitionsPerEpochTest', 10, 'Number of total epochs to run')
43-
cmd:option('-epochNumber', -1, 'Manual epoch number (useful on restarts)')
40+
cmd:option('-nEpochs', 1000, 'Number of total epochs to run')
41+
cmd:option('-checkpointEpochInterval', 1, 'Number of epochs between two saved checkpoints')
4442
cmd:option('-batchSize', 16, 'mini-batch size (1 = pure stochastic)')
4543
cmd:option('-valOnly', false, 'Run on validation set only')
4644
cmd:option('-testOnly', false, 'Run on validation set only')
@@ -65,7 +63,7 @@ function M.parse(arg)
6563
cmd:option('-segmentationDim', 500, 'Input dimensions')
6664
cmd:option('-retrain', 'none', 'Path to model to retrain with')
6765
cmd:option('-optimState', 'none', 'Path to an optimState to reload from')
68-
cmd:option('-useCheckpoint', true, 'Load checkpoint or not')
66+
cmd:option('-useCheckpoint', 0, 'Load checkpoint or not')
6967
------------- Other model options ---------------
7068
cmd:option('-nClasses', 13, 'The number of classes in the dataset')
7169
cmd:option('-nSegmentationClasses', 26, 'Number of classes of segmentation')
@@ -87,7 +85,7 @@ function M.parse(arg)
8785
cmd:option('-resnetDepth', 34, 'ResNet depth')
8886
cmd:option('-weight_1', 3, 'Weight 1')
8987

90-
cmd:option('-loadModel', '../checkpoint/model_best.t7', 'load trained model')
88+
cmd:option('-loadModel', '', 'load trained model')
9189
cmd:option('-loadPoseEstimationModel', '../PoseEstimation/human_pose_mpii.t7', 'load pretrained model')
9290
------------- pre-process options ---------------
9391
cmd:option('-scaleProb', 0.5, 'The probability of using scaling instead of random cropping')

util/lua/ginit.lua

+1-1
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,7 @@ function ginit(opt)
2727
-- py = require 'fb.python'
2828

2929
--gp = require 'gpath'
30-
--ut = require 'utils'
30+
ut = require 'utils'
3131
fp_ut = require 'floorplan_utils'
3232
--nnut = require 'nn_utils'
3333
--wut = require 'www_utils'

0 commit comments

Comments
 (0)