Skip to content
This repository was archived by the owner on Aug 31, 2021. It is now read-only.

Commit 833a384

Browse files
author
s9xie
committed
Initial commit
0 parents  commit 833a384

21 files changed

+1994
-0
lines changed

.gitignore

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,4 @@
1+
gen/
2+
libnccl.so
3+
model_best.t7
4+
checkpoints

CONTRIBUTING.md

Lines changed: 28 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,28 @@
1+
# Contributing to ResNeXt
2+
We want to make contributing to this project as easy and transparent as
3+
possible.
4+
5+
6+
## Pull Requests
7+
We actively welcome your pull requests.
8+
9+
1. Fork the repo and create your branch from `master`.
10+
2. If you haven't already, complete the Contributor License Agreement ("CLA").
11+
12+
## Contributor License Agreement ("CLA")
13+
In order to accept your pull request, we need you to submit a CLA. You only need
14+
to do this once to work on any of Facebook's open source projects.
15+
16+
Complete your CLA here: <https://code.facebook.com/cla>
17+
18+
## Issues
19+
We use GitHub issues to track public bugs. Please ensure your description is
20+
clear and has sufficient instructions to be able to reproduce the issue.
21+
22+
## Coding Style
23+
* 3 spaces for indentation rather than tabs
24+
* 80 character line length
25+
26+
## License
27+
By contributing to ResNeXt, you agree that your contributions will be licensed
28+
under its [BSD license](https://github.com/facebookresearch/ResNeXt/blob/master/LICENSE).

LICENSE

Lines changed: 30 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,30 @@
1+
BSD License
2+
3+
For ResNeXt software
4+
5+
Copyright (c) 2017, Facebook, Inc. All rights reserved.
6+
7+
Redistribution and use in source and binary forms, with or without modification,
8+
are permitted provided that the following conditions are met:
9+
10+
* Redistributions of source code must retain the above copyright notice, this
11+
list of conditions and the following disclaimer.
12+
13+
* Redistributions in binary form must reproduce the above copyright notice,
14+
this list of conditions and the following disclaimer in the documentation
15+
and/or other materials provided with the distribution.
16+
17+
* Neither the name Facebook nor the names of its contributors may be used to
18+
endorse or promote products derived from this software without specific
19+
prior written permission.
20+
21+
THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND
22+
ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED
23+
WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
24+
DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE FOR
25+
ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES
26+
(INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES;
27+
LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON
28+
ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
29+
(INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
30+
SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.

PATENTS

Lines changed: 33 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,33 @@
1+
Additional Grant of Patent Rights Version 2
2+
3+
"Software" means the ResNeXt software distributed by Facebook, Inc.
4+
5+
Facebook, Inc. ("Facebook") hereby grants to each recipient of the Software
6+
("you") a perpetual, worldwide, royalty-free, non-exclusive, irrevocable
7+
(subject to the termination provision below) license under any Necessary
8+
Claims, to make, have made, use, sell, offer to sell, import, and otherwise
9+
transfer the Software. For avoidance of doubt, no license is granted under
10+
Facebook’s rights in any patent claims that are infringed by (i) modifications
11+
to the Software made by you or any third party or (ii) the Software in
12+
combination with any software or other technology.
13+
14+
The license granted hereunder will terminate, automatically and without notice,
15+
if you (or any of your subsidiaries, corporate affiliates or agents) initiate
16+
directly or indirectly, or take a direct financial interest in, any Patent
17+
Assertion: (i) against Facebook or any of its subsidiaries or corporate
18+
affiliates, (ii) against any party if such Patent Assertion arises in whole or
19+
in part from any software, technology, product or service of Facebook or any of
20+
its subsidiaries or corporate affiliates, or (iii) against any party relating
21+
to the Software. Notwithstanding the foregoing, if Facebook or any of its
22+
subsidiaries or corporate affiliates files a lawsuit alleging patent
23+
infringement against you in the first instance, and you respond by filing a
24+
patent infringement counterclaim in that lawsuit against that party that is
25+
unrelated to the Software, the license granted hereunder will not terminate
26+
under section (i) of this paragraph due to such counterclaim.
27+
28+
A "Necessary Claim" is a claim of a patent owned by Facebook that is
29+
necessarily infringed by the Software standing alone.
30+
31+
A "Patent Assertion" is any lawsuit or other action alleging direct, indirect,
32+
or contributory infringement or inducement to infringe any patent, including a
33+
cross-claim or counterclaim.

README.md

Lines changed: 74 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,74 @@
1+
# Introduction
2+
This repository contains a [Torch](http://torch.ch) implementation for both the [ResNeXt](https://arxiv.org/abs/1611.05431) algorithm for image classification. The code is based on [fb.resnet.torch] (https://github.com/facebook/fb.resnet.torch).
3+
4+
[ResNeXt](https://arxiv.org/abs/1611.05431) is a simple, highly modularized network architecture for image classification. Our network is constructed by repeating a building block that aggregates a set of transformations with the same topology. Our simple design results in a homogeneous, multi-branch architecture that has only a few hyper-parameters to set. This strategy exposes a new dimension, which we call “cardinality” (the size of the set of transformations), as an essential factor in addition to the dimensions of depth and width.
5+
6+
7+
![teaser](http://vcl.ucsd.edu/~sxie/teaser.png)
8+
##### Figure: Training curves on ImageNet-1K. (Left): ResNet/ResNeXt-50 with the same complexity (~4.1 billion FLOPs, ~25 million parameters); (Right): ResNet/ResNeXt-101 with the same complexity (~7.8 billion FLOPs, ~44 million parameters).
9+
-----
10+
11+
If you use ResNeXt in your research, please cite the paper:
12+
```
13+
@article{Xie2016,
14+
title={Aggregated Residual Transformations for Deep Neural Networks},
15+
author={Saining Xie and Ross Girshick and Piotr Dollár and Zhuowen Tu and Kaiming He},
16+
journal={arXiv preprint arXiv:1611.05431},
17+
year={2016}
18+
}
19+
```
20+
21+
# Requirements and Dependencies
22+
See the fb.resnet.torch [installation instructions](https://github.com/facebook/fb.resnet.torch/blob/master/INSTALL.md) for a step-by-step guide.
23+
- Install [Torch](http://torch.ch/docs/getting-started.html) on a machine with CUDA GPU
24+
- Install [cuDNN v4 or v5](https://developer.nvidia.com/cudnn) and the Torch [cuDNN bindings](https://github.com/soumith/cudnn.torch/tree/R4)
25+
- Download the [ImageNet](http://image-net.org/download-images) dataset and [move validation images](https://github.com/facebook/fb.resnet.torch/blob/master/INSTALL.md#download-the-imagenet-dataset) to labeled subfolders
26+
27+
## Training
28+
29+
Please follow [fb.resnet.torch] (https://github.com/facebook/fb.resnet.torch) for the general usage of the code, including [how](https://github.com/facebook/fb.resnet.torch/tree/master/pretrained) to use pretrained ResNeXt models for your own task.
30+
31+
There are two new hyperparameters need to be specified to determine the bottleneck template:
32+
33+
**-baseWidth** and **-cardinality**
34+
35+
###1x Complexity Configurations Reference Table
36+
| baseWidth | cardinality |
37+
|---------- | ----------- |
38+
| 64 | 1 |
39+
| 40 | 2 |
40+
| 24 | 4 |
41+
| 14 | 8 |
42+
| 4 | 32 |
43+
44+
45+
To train ResNeXt-50 (32x4d) on 8 GPUs for ImageNet:
46+
```bash
47+
th main.lua -dataset imagenet -bottleneckType resnext_C -depth 50 -baseWidth 4 -cardinality 32 -batchSize 256 -nGPU 8 -nThreads 8 -shareGradInput true -data [imagenet-folder]
48+
```
49+
50+
To reproduce CIFAR results (e.g. ResNeXt 16x64d for cifar10) on 8 GPUs:
51+
```bash
52+
th main.lua -dataset cifar10 -bottleneckType resnext_C -depth 29 -baseWidth 64 -cardinality 16 -weightDecay 5e-4 -batchSize 128 -nGPU 8 -nThreads 8 -shareGradInput true
53+
```
54+
To get comparable results using 2/4 GPUs, you should change the batch size and the corresponding learning rate:
55+
```bash
56+
th main.lua -dataset cifar10 -bottleneckType resnext_C -depth 29 -baseWidth 64 -cardinality 16 -weightDecay 5e-4 -batchSize 64 -nGPU 4 -LR 0.05 -nThreads 8 -shareGradInput true
57+
th main.lua -dataset cifar10 -bottleneckType resnext_C -depth 29 -baseWidth 64 -cardinality 16 -weightDecay 5e-4 -batchSize 32 -nGPU 2 -LR 0.025 -nThreads 8 -shareGradInput true
58+
```
59+
Note: CIFAR datasets will be automatically downloaded and processed for the first time. We found that better CIFAR test acurracy can be achieved using a (on 8 GPUs) batch size of 128.
60+
61+
# ImageNet Pretrained Models
62+
ImageNet pretrained models are licensed under CC BY-NC 4.0.
63+
64+
[![CC BY-NC 4.0](https://i.creativecommons.org/l/by-nc/4.0/88x31.png)](https://creativecommons.org/licenses/by-nc/4.0/)
65+
66+
###Single-crop (224x224) validation error rate
67+
| Network | GFLOPS | Top-1 Error | Download |
68+
| ------------------- | ------ | ----------- | ------------|
69+
| ResNet-50 (1x64d) | ~4.1 | 23.9 | [Original ResNet-50](https://github.com/facebook/fb.resnet.torch/tree/master/pretrained) |
70+
| ResNeXt-50 (32x4d) | ~4.1 | 22.2 | [Download (191MB)](https://s3.amazonaws.com/resnext/imagenet_models/resnext_50_32x4d.t7) |
71+
| ResNet-101 (1x64d) | ~7.8 | 22.0 | [Original ResNet-101](https://github.com/facebook/fb.resnet.torch/tree/master/pretrained) |
72+
| ResNeXt-101 (32x4d) | ~7.8 | 21.2 | [Download (338MB)] (https://s3.amazonaws.com/resnext/imagenet_models/resnext_101_32x4d.t7) |
73+
| ResNeXt-101 (64x4d) | ~15.6 | 20.4 | [Download (638MB)](https://s3.amazonaws.com/resnext/imagenet_models/resnext_101_64x4d.t7) |
74+

checkpoints.lua

Lines changed: 69 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,69 @@
1+
--
2+
-- Copyright (c) 2017, Facebook, Inc.
3+
-- All rights reserved.
4+
--
5+
-- This source code is licensed under the BSD-style license found in the
6+
-- LICENSE file in the root directory of this source tree. An additional grant
7+
-- of patent rights can be found in the PATENTS file in the same directory.
8+
--
9+
local checkpoint = {}
10+
11+
local function deepCopy(tbl)
12+
-- creates a copy of a network with new modules and the same tensors
13+
local copy = {}
14+
for k, v in pairs(tbl) do
15+
if type(v) == 'table' then
16+
copy[k] = deepCopy(v)
17+
else
18+
copy[k] = v
19+
end
20+
end
21+
if torch.typename(tbl) then
22+
torch.setmetatable(copy, torch.typename(tbl))
23+
end
24+
return copy
25+
end
26+
27+
function checkpoint.latest(opt)
28+
if opt.resume == 'none' then
29+
return nil
30+
end
31+
32+
local latestPath = paths.concat(opt.resume, 'latest.t7')
33+
if not paths.filep(latestPath) then
34+
return nil
35+
end
36+
37+
print('=> Loading checkpoint ' .. latestPath)
38+
local latest = torch.load(latestPath)
39+
local optimState = torch.load(paths.concat(opt.resume, latest.optimFile))
40+
41+
return latest, optimState
42+
end
43+
44+
function checkpoint.save(epoch, model, optimState, isBestModel, opt)
45+
-- don't save the DataParallelTable for easier loading on other machines
46+
if torch.type(model) == 'nn.DataParallelTable' then
47+
model = model:get(1)
48+
end
49+
50+
-- create a clean copy on the CPU without modifying the original network
51+
model = deepCopy(model):float():clearState()
52+
53+
local modelFile = 'model_' .. epoch .. '.t7'
54+
local optimFile = 'optimState_' .. epoch .. '.t7'
55+
56+
torch.save(paths.concat(opt.save, modelFile), model)
57+
torch.save(paths.concat(opt.save, optimFile), optimState)
58+
torch.save(paths.concat(opt.save, 'latest.t7'), {
59+
epoch = epoch,
60+
modelFile = modelFile,
61+
optimFile = optimFile,
62+
})
63+
64+
if isBestModel then
65+
torch.save(paths.concat(opt.save, 'model_best.t7'), model)
66+
end
67+
end
68+
69+
return checkpoint

dataloader.lua

Lines changed: 127 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,127 @@
1+
--
2+
-- Copyright (c) 2017, Facebook, Inc.
3+
-- All rights reserved.
4+
--
5+
-- This source code is licensed under the BSD-style license found in the
6+
-- LICENSE file in the root directory of this source tree. An additional grant
7+
-- of patent rights can be found in the PATENTS file in the same directory.
8+
--
9+
-- Multi-threaded data loader
10+
--
11+
12+
local datasets = require 'datasets/init'
13+
local Threads = require 'threads'
14+
Threads.serialization('threads.sharedserialize')
15+
16+
local M = {}
17+
local DataLoader = torch.class('resnet.DataLoader', M)
18+
19+
function DataLoader.create(opt)
20+
-- The train and val loader
21+
local loaders = {}
22+
23+
for i, split in ipairs{'train', 'val'} do
24+
local dataset = datasets.create(opt, split)
25+
loaders[i] = M.DataLoader(dataset, opt, split)
26+
end
27+
28+
return table.unpack(loaders)
29+
end
30+
31+
function DataLoader:__init(dataset, opt, split)
32+
local manualSeed = opt.manualSeed
33+
local function init()
34+
require('datasets/' .. opt.dataset)
35+
end
36+
local function main(idx)
37+
if manualSeed ~= 0 then
38+
torch.manualSeed(manualSeed + idx)
39+
end
40+
torch.setnumthreads(1)
41+
_G.dataset = dataset
42+
_G.preprocess = dataset:preprocess()
43+
return dataset:size()
44+
end
45+
46+
local threads, sizes = Threads(opt.nThreads, init, main)
47+
self.nCrops = (split == 'val' and opt.tenCrop) and 10 or 1
48+
self.threads = threads
49+
self.__size = sizes[1][1]
50+
self.batchSize = math.floor(opt.batchSize / self.nCrops)
51+
local function getCPUType(tensorType)
52+
if tensorType == 'torch.CudaHalfTensor' then
53+
return 'HalfTensor'
54+
elseif tensorType == 'torch.CudaDoubleTensor' then
55+
return 'DoubleTensor'
56+
else
57+
return 'FloatTensor'
58+
end
59+
end
60+
self.cpuType = getCPUType(opt.tensorType)
61+
end
62+
63+
function DataLoader:size()
64+
return math.ceil(self.__size / self.batchSize)
65+
end
66+
67+
function DataLoader:run()
68+
local threads = self.threads
69+
local size, batchSize = self.__size, self.batchSize
70+
local perm = torch.randperm(size)
71+
72+
local idx, sample = 1, nil
73+
local function enqueue()
74+
while idx <= size and threads:acceptsjob() do
75+
local indices = perm:narrow(1, idx, math.min(batchSize, size - idx + 1))
76+
threads:addjob(
77+
function(indices, nCrops, cpuType)
78+
local sz = indices:size(1)
79+
local batch, imageSize
80+
local target = torch.IntTensor(sz)
81+
for i, idx in ipairs(indices:totable()) do
82+
local sample = _G.dataset:get(idx)
83+
local input = _G.preprocess(sample.input)
84+
if not batch then
85+
imageSize = input:size():totable()
86+
if nCrops > 1 then table.remove(imageSize, 1) end
87+
batch = torch[cpuType](sz, nCrops, table.unpack(imageSize))
88+
end
89+
batch[i]:copy(input)
90+
target[i] = sample.target
91+
end
92+
collectgarbage()
93+
return {
94+
input = batch:view(sz * nCrops, table.unpack(imageSize)),
95+
target = target,
96+
}
97+
end,
98+
function(_sample_)
99+
sample = _sample_
100+
end,
101+
indices,
102+
self.nCrops,
103+
self.cpuType
104+
)
105+
idx = idx + batchSize
106+
end
107+
end
108+
109+
local n = 0
110+
local function loop()
111+
enqueue()
112+
if not threads:hasjob() then
113+
return nil
114+
end
115+
threads:dojob()
116+
if threads:haserror() then
117+
threads:synchronize()
118+
end
119+
enqueue()
120+
n = n + 1
121+
return n, sample
122+
end
123+
124+
return loop
125+
end
126+
127+
return M.DataLoader

0 commit comments

Comments
 (0)