-
Notifications
You must be signed in to change notification settings - Fork 27
/
Copy pathbuildData.lua
117 lines (93 loc) · 3.28 KB
/
buildData.lua
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
JSON = (loadfile "model/JSON.lua")()
function new_batch(bs, sl)
local batch = {}
batch = {}
batch.ids = torch.zeros(bs)
batch.x = torch.ones(opt.max_code_length, bs)
batch.mask = torch.zeros(opt.max_code_length, bs)
batch.fmask = torch.zeros(opt.max_code_length, bs)
batch.y = torch.ones(sl + 1, bs)
batch.xsizes = torch.ones(bs)
batch.maxX = 0
batch.maxY = 0
batch.code= {}
return batch
end
function get_data(filename, vocab, bs, dont_skip)
local dataFile = io.open(filename, 'r')
local data = JSON:decode(dataFile:read())
dataFile:close()
local count = 0
local dataset = {}
dataset.size = #data
dataset.batches = {}
dataset.batch_size = bs
local currBatch = nil
for i = 1, #data do
if count % bs == 0 then
if currBatch ~= nil then
table.insert(dataset.batches, currBatch)
end
currBatch = new_batch(bs, opt.max_nl_length)
count = 0
end
count = count + 1
currBatch.ids[count] = data[i].id
currBatch.code[count] = data[i].code
currBatch.xsizes[count] = data[i].code_sizes
local apparentXSize = math.min(#data[i].code_num, opt.max_code_length)
local apparentYSize = math.min(#data[i].nl_num, opt.max_nl_length)
if apparentXSize > currBatch.maxX then
currBatch.maxX = apparentXSize
end
if (apparentYSize + 1) > currBatch.maxY then
currBatch.maxY = apparentYSize + 1
end
for j = 1, apparentXSize do
currBatch.x[j][count] = data[i].code_num[j]
currBatch.fmask[j][count] = 1
end
currBatch.mask[apparentXSize][count] = 1
for j = 1, apparentYSize do
currBatch.y[j + 1][count] = data[i].nl_num[j]
end
end
if currBatch ~= nil then
table.insert(dataset.batches, currBatch)
end
print('Total size = ' .. dataset.size)
dataset.max_code_length = opt.max_code_length
dataset.max_nl_length = opt.max_nl_length
return dataset
end
function main()
local cmd = torch.CmdLine()
cmd:option('-max_nl_length', 100, 'length')
cmd:option('-max_code_length', 100, 'length')
cmd:option('-batch_size', 100, 'length')
cmd:option('-language', 'python', 'python')
cmd:option('-dataset', '', 'name of the file/dataset')
cmd:text()
opt = cmd:parse(arg)
local working_dir = os.getenv("WORK_DIR") .. '/preprocessing'
local dataset = ""
if string.match(opt.dataset, ",") then
dataset = opt.dataset:gsub(",", "_")
else
dataset = opt.dataset
end
local vocabFileName = working_dir .. "/" .. dataset .. '.' .. opt.language .. '.vocab.json'
local vocabFile = io.open(vocabFileName, 'r')
local vocab = JSON:decode(vocabFile:read())
vocabFile:close()
torch.save(working_dir .. '/'.. dataset .. '.' .. opt.language .. '.vocab.data' , vocab)
local trainFileName = working_dir .. '/'.. dataset .. '.' .. opt.language .. '.train'
torch.save(trainFileName .. '.data', get_data(trainFileName .. '.json', vocab, opt.batch_size, false))
local validFileName = working_dir .. '/'.. dataset .. '.' .. opt.language .. '.valid'
torch.save(validFileName .. '.data', get_data(validFileName .. '.json', vocab, opt.batch_size, false))
-- only test data has "dont_skip" set to true
-- test data set to batch size = 1
local testFileName = working_dir .. '/'.. dataset .. '.' .. opt.language .. '.test'
torch.save(testFileName .. '.data', get_data(testFileName .. '.json', vocab, 1, true))
end
main()