Skip to content

Commit 5238b45

Browse files
committed
Merge pull request #15 from szagoruyko/master
imagenet classification demo
2 parents 872ece5 + bf81cba commit 5238b45

File tree

2 files changed

+1086
-0
lines changed

2 files changed

+1086
-0
lines changed
+86
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,86 @@
1+
-- Imagenet classification with Torch7 demo
2+
require 'loadcaffe'
3+
require 'image'
4+
5+
-- Helper functions
6+
7+
-- Loads the mapping from net outputs to human readable labels
8+
function load_synset()
9+
local file = io.open 'synset_words.txt'
10+
local list = {}
11+
while true do
12+
local line = file:read()
13+
if not line then break end
14+
table.insert(list, string.sub(line,11))
15+
end
16+
return list
17+
end
18+
19+
20+
-- Converts an image from RGB to BGR format and subtracts mean
21+
function preprocess(im, img_mean)
22+
-- rescale the image
23+
local im3 = image.scale(im,227,227,'bilinear')*255
24+
-- RGB2BGR
25+
local im4 = im3:clone()
26+
im4[{1,{},{}}] = im3[{3,{},{}}]
27+
im4[{3,{},{}}] = im3[{1,{},{}}]
28+
29+
-- subtract imagenet mean
30+
return im4 - image.scale(img_mean, 227, 227, 'bilinear')
31+
end
32+
33+
34+
35+
-- Setting up networks and downloading stuff if needed
36+
proto_name = 'deploy.prototxt'
37+
model_name = 'bvlc_reference_caffenet.caffemodel'
38+
img_mean_name = 'ilsvrc_2012_mean.t7'
39+
image_name = 'Goldfish3.jpg'
40+
41+
prototxt_url = 'https://raw.githubusercontent.com/BVLC/caffe/master/models/bvlc_reference_caffenet/'..proto_name
42+
model_url = 'http://dl.caffe.berkeleyvision.org/'..model_name
43+
img_mean_url = 'https://www.dropbox.com/s/p33rheie3xjx6eu/'..img_mean_name
44+
image_url = 'http://upload.wikimedia.org/wikipedia/commons/e/e9/Goldfish3.jpg'
45+
46+
if not paths.filep(proto_name) then os.execute('wget '..prototxt_url) end
47+
if not paths.filep(model_name) then os.execute('wget '..model_url) end
48+
if not paths.filep(img_mean_name) then os.execute('wget '..img_mean_url) end
49+
if not paths.filep(image_name) then os.execute('wget '..image_url) end
50+
51+
52+
53+
print '==> Loading network'
54+
-- we'll use the fastest CUDA ConvNet implementation available, cuda-convnet2
55+
-- this loads the network in Caffe format and returns in Torch format, ready to use!
56+
net = loadcaffe.load(proto_name, model_name, 'ccn2')
57+
58+
-- as we want to classify, let's disable dropouts by enabling evaluation mode
59+
net:evaluate()
60+
61+
print '==> Loading synsets'
62+
synset_words = load_synset()
63+
64+
print '==> Loading image and imagenet mean'
65+
im = image.load(image_name)
66+
img_mean = torch.load(img_mean_name).img_mean:transpose(3,1)
67+
68+
print '==> Preprocessing'
69+
-- Have to resize and convert from RGB to BGR and subtract mean
70+
I = preprocess(im, img_mean)
71+
72+
-- cuda-convnet2 implementation support only batched routines, so
73+
-- we have to allocate memory for 32 inputs and then put crops to 10 of them.
74+
-- let's however use just one image for simplicity.
75+
-- note that for other networks that use cunn ore cudnn that might not be needed
76+
batch = torch.CudaTensor(32,3,227,227)
77+
batch[1]:copy(I)
78+
79+
print '==> Propagating through the network'
80+
net:forward(batch)
81+
82+
-- for the outputs of SoftMax layer sort them in decreasing order
83+
_,classes = net:get(25).output[{1,{}}]:float():sort(true)
84+
for i=1,5 do
85+
print('predicted class '..tostring(i)..': ', synset_words[classes[i]])
86+
end

0 commit comments

Comments
 (0)