-
Notifications
You must be signed in to change notification settings - Fork 3
/
Copy pathcsv2torch-datasets.lua
executable file
·88 lines (66 loc) · 1.9 KB
/
csv2torch-datasets.lua
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
#!/usr/bin/env torch
require 'torch'
require 'torch-env'
require 'dataset'
require 'dataset/TableDataset'
require 'Csv'
require 'util'
require 'util/arg'
require 'sys'
local function parse_arg(arg, initNLparams)
local dname, fname = sys.fpath()
local cmd = torch.CmdLine()
cmd:text('Options:')
cmd:option('-csv', '', 'input csv to convert')
cmd:option('-out', '', 'output torch-dataset filename')
cmd:option('-N', 20000, 'size of block to read at once')
return cmd:parse(arg)
end
local function get_csv(options)
local csv = Csv(options.csv, "r")
return csv
end
local function save(data, options)
torch.save(options.out, data)
end
local function get_N_samples(csv, options)
local data = torch.Tensor(options.N, options.nfeatures)
local i = 1
local line
while i <= options.N do
line = csv:read()
if not line then break end
for j = 1,options.nfeatures do
data[i][j] = tonumber(line[j])
end
i = i + 1
end
i = i-1
data = data:narrow(1, 1, i)
assert(not line, 'Skipped all but the first '..options.N..' samples! Increase buffer size to fit all your data, using the -N command line option.')
return data
end
local function get_dataset(csv, options)
-- get header
local header = csv:read()
options.nfeatures = #header
local data = get_N_samples(csv, options)
local class
-- separate labels from data
if header[1] == 'label' then
class = data:narrow(2, 1, 1)
data = data:narrow(2, 2, data:size(2)-1)
end
data = data:clone()
if class then
class = class:clone()
end
return dataset.TableDataset({data=data, class=class})
end
local function main()
local options = parse_arg(arg, true)
local csv = get_csv(options)
dset = get_dataset(csv, options)
save(dset, options)
end
main()