|
7 | 7 |
|
8 | 8 | from __future__ import print_function
|
9 | 9 |
|
| 10 | +from torch.hub import load_state_dict_from_url |
10 | 11 |
|
11 |
| -def deeplabv2_resnet101(pretrained=False, **kwargs): |
12 |
| - """ |
13 |
| - DeepLab v2 model with ResNet-101 backbone |
14 |
| - n_classes (int): the number of classes |
15 |
| - """ |
| 12 | +model_url_root = "https://github.com/kazuto1011/deeplab-pytorch/releases/download/v1.0/" |
| 13 | +model_dict = { |
| 14 | + "cocostuff10k": ("deeplabv2_resnet101_msc-cocostuff10k-20000.pth", 182), |
| 15 | + "cocostuff164k": ("deeplabv2_resnet101_msc-cocostuff164k-100000.pth", 182), |
| 16 | + "voc12": ("deeplabv2_resnet101_msc-vocaug-20000.pth", 21), |
| 17 | +} |
16 | 18 |
|
17 |
| - if pretrained: |
18 |
| - raise NotImplementedError( |
19 |
| - "Please download from " |
20 |
| - "https://github.com/kazuto1011/deeplab-pytorch/tree/master#performance" |
21 |
| - ) |
| 19 | + |
| 20 | +def deeplabv2_resnet101(pretrained=None, n_classes=182, scales=None): |
22 | 21 |
|
23 | 22 | from libs.models.deeplabv2 import DeepLabV2
|
24 | 23 | from libs.models.msc import MSC
|
25 | 24 |
|
26 |
| - base = DeepLabV2(n_blocks=[3, 4, 23, 3], atrous_rates=[6, 12, 18, 24], **kwargs) |
27 |
| - model = MSC(base=base, scales=[0.5, 0.75]) |
| 25 | + # Model parameters |
| 26 | + n_blocks = [3, 4, 23, 3] |
| 27 | + atrous_rates = [6, 12, 18, 24] |
| 28 | + if scales is None: |
| 29 | + scales = [0.5, 0.75] |
28 | 30 |
|
29 |
| - return model |
| 31 | + base = DeepLabV2(n_classes=n_classes, n_blocks=n_blocks, atrous_rates=atrous_rates) |
| 32 | + model = MSC(base=base, scales=scales) |
30 | 33 |
|
| 34 | + # Load pretrained models |
| 35 | + if isinstance(pretrained, str): |
31 | 36 |
|
32 |
| -if __name__ == "__main__": |
33 |
| - import torch.hub |
| 37 | + assert pretrained in model_dict, list(model_dict.keys()) |
| 38 | + expected = model_dict[pretrained][1] |
| 39 | + error_message = "Expected: n_classes={}".format(expected) |
| 40 | + assert n_classes == expected, error_message |
34 | 41 |
|
35 |
| - model = torch.hub.load( |
36 |
| - "kazuto1011/deeplab-pytorch", |
37 |
| - "deeplabv2_resnet101", |
38 |
| - n_classes=182, |
39 |
| - force_reload=True, |
40 |
| - ) |
| 42 | + model_url = model_url_root + model_dict[pretrained][0] |
| 43 | + state_dict = load_state_dict_from_url(model_url) |
| 44 | + model.load_state_dict(state_dict) |
| 45 | + |
| 46 | + return model |
41 | 47 |
|
42 |
| - print(model) |
|
0 commit comments