Skip to content

Commit 4b71ed4

Browse files
committed
Move pretrained weights to Releases
1 parent 9b64d35 commit 4b71ed4

File tree

2 files changed

+33
-29
lines changed

2 files changed

+33
-29
lines changed

README.md

+6-7
Original file line numberDiff line numberDiff line change
@@ -37,7 +37,7 @@ This is an unofficial **PyTorch** implementation of **DeepLab v2** [[1](##refere
3737
</tr>
3838
<tr>
3939
<td rowspan="2"><strong>This repo</strong></td>
40-
<td rowspan="2"><a href='https://drive.google.com/file/d/1Cgbl3Q_tHPFPyqfx2hx-9FZYBSbG5Rhy/view?usp=sharing'>Download</a></td>
40+
<td rowspan="2"><a href="https://github.com/kazuto1011/deeplab-pytorch/releases/download/v1.0/deeplabv2_resnet101_msc-cocostuff10k-20000.pth">Download</a></td>
4141
<td></td>
4242
<td><strong>65.8</td>
4343
<td><strong>45.7</strong></td>
@@ -57,7 +57,7 @@ This is an unofficial **PyTorch** implementation of **DeepLab v2** [[1](##refere
5757
</td>
5858
<td rowspan="2">164k <i>val</i></td>
5959
<td rowspan="2"><strong>This repo</strong></td>
60-
<td rowspan="2"><a href='https://drive.google.com/file/d/18kR928yl9Hz4xxuxnYgg7Hpi36hM8J2d/view?usp=sharing'>Download</a> &Dagger;</td>
60+
<td rowspan="2"><a href="https://github.com/kazuto1011/deeplab-pytorch/releases/download/v1.0/deeplabv2_resnet101_msc-cocostuff164k-100000.pth">Download</a> &Dagger;</td>
6161
<td></td>
6262
<td>66.8</td>
6363
<td>51.2</td>
@@ -112,7 +112,7 @@ This is an unofficial **PyTorch** implementation of **DeepLab v2** [[1](##refere
112112
</tr>
113113
<tr>
114114
<td rowspan="2"><strong>This repo</strong></td>
115-
<td rowspan="2"><a href='https://drive.google.com/file/d/1FaW2Sp7Jj3eaoyZtbabM1IWZnuScN-u6/view?usp=sharing'>Download</a></td>
115+
<td rowspan="2"><a href="https://github.com/kazuto1011/deeplab-pytorch/releases/download/v1.0/deeplabv2_resnet101_msc-vocaug-20000.pth">Download</a></td>
116116
<td></td>
117117
<td>94.64</td>
118118
<td>86.50</td>
@@ -240,7 +240,7 @@ python demo.py single \
240240

241241
To run on a webcam:
242242

243-
```console
243+
```bash
244244
python demo.py live \
245245
--config-path configs/voc12.yaml \
246246
--model-path deeplabv2_resnet101_msc-vocaug-20000.pth
@@ -252,12 +252,11 @@ To run a CRF post-processing, add `--crf`. To run on a CPU, add `--cpu`.
252252

253253
### torch.hub
254254

255-
Model setup with 3 lines
255+
Model setup with two lines
256256

257257
```python
258258
import torch.hub
259-
model = torch.hub.load("kazuto1011/deeplab-pytorch", "deeplabv2_resnet101", n_classes=182)
260-
model.load_state_dict(torch.load("deeplabv2_resnet101_msc-cocostuff164k-100000.pth"))
259+
model = torch.hub.load("kazuto1011/deeplab-pytorch", "deeplabv2_resnet101", pretrained='cocostuff164k', n_classes=182)
261260
```
262261

263262
### Difference with Caffe version

hubconf.py

+27-22
Original file line numberDiff line numberDiff line change
@@ -7,36 +7,41 @@
77

88
from __future__ import print_function
99

10+
from torch.hub import load_state_dict_from_url
1011

11-
def deeplabv2_resnet101(pretrained=False, **kwargs):
12-
"""
13-
DeepLab v2 model with ResNet-101 backbone
14-
n_classes (int): the number of classes
15-
"""
12+
model_url_root = "https://github.com/kazuto1011/deeplab-pytorch/releases/download/v1.0/"
13+
model_dict = {
14+
"cocostuff10k": ("deeplabv2_resnet101_msc-cocostuff10k-20000.pth", 182),
15+
"cocostuff164k": ("deeplabv2_resnet101_msc-cocostuff164k-100000.pth", 182),
16+
"voc12": ("deeplabv2_resnet101_msc-vocaug-20000.pth", 21),
17+
}
1618

17-
if pretrained:
18-
raise NotImplementedError(
19-
"Please download from "
20-
"https://github.com/kazuto1011/deeplab-pytorch/tree/master#performance"
21-
)
19+
20+
def deeplabv2_resnet101(pretrained=None, n_classes=182, scales=None):
2221

2322
from libs.models.deeplabv2 import DeepLabV2
2423
from libs.models.msc import MSC
2524

26-
base = DeepLabV2(n_blocks=[3, 4, 23, 3], atrous_rates=[6, 12, 18, 24], **kwargs)
27-
model = MSC(base=base, scales=[0.5, 0.75])
25+
# Model parameters
26+
n_blocks = [3, 4, 23, 3]
27+
atrous_rates = [6, 12, 18, 24]
28+
if scales is None:
29+
scales = [0.5, 0.75]
2830

29-
return model
31+
base = DeepLabV2(n_classes=n_classes, n_blocks=n_blocks, atrous_rates=atrous_rates)
32+
model = MSC(base=base, scales=scales)
3033

34+
# Load pretrained models
35+
if isinstance(pretrained, str):
3136

32-
if __name__ == "__main__":
33-
import torch.hub
37+
assert pretrained in model_dict, list(model_dict.keys())
38+
expected = model_dict[pretrained][1]
39+
error_message = "Expected: n_classes={}".format(expected)
40+
assert n_classes == expected, error_message
3441

35-
model = torch.hub.load(
36-
"kazuto1011/deeplab-pytorch",
37-
"deeplabv2_resnet101",
38-
n_classes=182,
39-
force_reload=True,
40-
)
42+
model_url = model_url_root + model_dict[pretrained][0]
43+
state_dict = load_state_dict_from_url(model_url)
44+
model.load_state_dict(state_dict)
45+
46+
return model
4147

42-
print(model)

0 commit comments

Comments
 (0)