Skip to content

Commit b77cf5a

Browse files
authored
Add DeepLabV3 implementation (#149)
* Add DeepLabV3 implemetation * Add DeepLabV3 to README * Add DeepLabV3 docstring
1 parent b74d91f commit b77cf5a

File tree

7 files changed

+231
-13
lines changed

7 files changed

+231
-13
lines changed

README.md

+1
Original file line numberDiff line numberDiff line change
@@ -68,6 +68,7 @@ preprocess_input = get_preprocessing_fn('resnet18', pretrained='imagenet')
6868
- [FPN](http://presentations.cocodataset.org/COCO17-Stuff-FAIR.pdf)
6969
- [PSPNet](https://arxiv.org/abs/1612.01105)
7070
- [PAN](https://arxiv.org/abs/1805.10180)
71+
- [DeepLabV3](https://arxiv.org/abs/1706.05587)
7172

7273
#### Encoders <a name="encoders"></a>
7374

segmentation_models_pytorch/__init__.py

+1
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22
from .linknet import Linknet
33
from .fpn import FPN
44
from .pspnet import PSPNet
5+
from .deeplabv3 import DeepLabV3
56
from .pan import PAN
67

78
from . import encoders
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
from .model import DeepLabV3
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,107 @@
1+
"""
2+
BSD 3-Clause License
3+
4+
Copyright (c) Soumith Chintala 2016,
5+
All rights reserved.
6+
7+
Redistribution and use in source and binary forms, with or without
8+
modification, are permitted provided that the following conditions are met:
9+
10+
* Redistributions of source code must retain the above copyright notice, this
11+
list of conditions and the following disclaimer.
12+
13+
* Redistributions in binary form must reproduce the above copyright notice,
14+
this list of conditions and the following disclaimer in the documentation
15+
and/or other materials provided with the distribution.
16+
17+
* Neither the name of the copyright holder nor the names of its
18+
contributors may be used to endorse or promote products derived from
19+
this software without specific prior written permission.
20+
21+
THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
22+
AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
23+
IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
24+
DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
25+
FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
26+
DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
27+
SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
28+
CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
29+
OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
30+
OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
31+
"""
32+
33+
import torch
34+
from torch import nn
35+
from torch.nn import functional as F
36+
37+
__all__ = ["DeepLabV3Decoder"]
38+
39+
40+
class DeepLabV3Decoder(nn.Sequential):
41+
def __init__(self, in_channels, out_channels=256, atrous_rates=(12, 24, 36)):
42+
super().__init__(
43+
ASPP(in_channels, out_channels, atrous_rates),
44+
nn.Conv2d(out_channels, out_channels, 3, padding=1, bias=False),
45+
nn.BatchNorm2d(out_channels),
46+
nn.ReLU(),
47+
)
48+
self.out_channels = out_channels
49+
50+
def forward(self, *features):
51+
return super().forward(features[-1])
52+
53+
54+
class ASPPConv(nn.Sequential):
55+
def __init__(self, in_channels, out_channels, dilation):
56+
modules = [
57+
nn.Conv2d(in_channels, out_channels, 3, padding=dilation, dilation=dilation, bias=False),
58+
nn.BatchNorm2d(out_channels),
59+
nn.ReLU()
60+
]
61+
super(ASPPConv, self).__init__(*modules)
62+
63+
64+
class ASPPPooling(nn.Sequential):
65+
def __init__(self, in_channels, out_channels):
66+
super(ASPPPooling, self).__init__(
67+
nn.AdaptiveAvgPool2d(1),
68+
nn.Conv2d(in_channels, out_channels, 1, bias=False),
69+
nn.BatchNorm2d(out_channels),
70+
nn.ReLU())
71+
72+
def forward(self, x):
73+
size = x.shape[-2:]
74+
for mod in self:
75+
x = mod(x)
76+
return F.interpolate(x, size=size, mode='bilinear', align_corners=False)
77+
78+
79+
class ASPP(nn.Module):
80+
def __init__(self, in_channels, out_channels, atrous_rates):
81+
super(ASPP, self).__init__()
82+
modules = []
83+
modules.append(nn.Sequential(
84+
nn.Conv2d(in_channels, out_channels, 1, bias=False),
85+
nn.BatchNorm2d(out_channels),
86+
nn.ReLU()))
87+
88+
rate1, rate2, rate3 = tuple(atrous_rates)
89+
modules.append(ASPPConv(in_channels, out_channels, rate1))
90+
modules.append(ASPPConv(in_channels, out_channels, rate2))
91+
modules.append(ASPPConv(in_channels, out_channels, rate3))
92+
modules.append(ASPPPooling(in_channels, out_channels))
93+
94+
self.convs = nn.ModuleList(modules)
95+
96+
self.project = nn.Sequential(
97+
nn.Conv2d(5 * out_channels, out_channels, 1, bias=False),
98+
nn.BatchNorm2d(out_channels),
99+
nn.ReLU(),
100+
nn.Dropout(0.5))
101+
102+
def forward(self, x):
103+
res = []
104+
for conv in self.convs:
105+
res.append(conv(x))
106+
res = torch.cat(res, dim=1)
107+
return self.project(res)
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,81 @@
1+
import torch.nn as nn
2+
3+
from typing import Optional
4+
from .decoder import DeepLabV3Decoder
5+
from ..base import SegmentationModel, SegmentationHead, ClassificationHead
6+
from ..encoders import get_encoder
7+
8+
9+
class DeepLabV3(SegmentationModel):
10+
"""DeepLabV3_ implemetation from "Rethinking Atrous Convolution for Semantic Image Segmentation"
11+
Args:
12+
encoder_name: name of classification model (without last dense layers) used as feature
13+
extractor to build segmentation model.
14+
encoder_depth: number of stages used in decoder, larger depth - more features are generated.
15+
e.g. for depth=3 encoder will generate list of features with following spatial shapes
16+
[(H,W), (H/2, W/2), (H/4, W/4), (H/8, W/8)], so in general the deepest feature will have
17+
spatial resolution (H/(2^depth), W/(2^depth)]
18+
encoder_weights: one of ``None`` (random initialization), ``imagenet`` (pre-training on ImageNet).
19+
decoder_channels: a number of convolution filters in ASPP module (default 256).
20+
in_channels: number of input channels for model, default is 3.
21+
classes: a number of classes for output (output shape - ``(batch, classes, h, w)``).
22+
activation (str, callable): activation function used in ``.predict(x)`` method for inference.
23+
One of [``sigmoid``, ``softmax2d``, callable, None]
24+
upsampling: optional, final upsampling factor
25+
(default is 8 to preserve input -> output spatial shape identity)
26+
aux_params: if specified model will have additional classification auxiliary output
27+
build on top of encoder, supported params:
28+
- classes (int): number of classes
29+
- pooling (str): one of 'max', 'avg'. Default is 'avg'.
30+
- dropout (float): dropout factor in [0, 1)
31+
- activation (str): activation function to apply "sigmoid"/"softmax" (could be None to return logits)
32+
Returns:
33+
``torch.nn.Module``: **DeepLabV3**
34+
.. _DeeplabV3:
35+
https://arxiv.org/abs/1706.05587
36+
"""
37+
38+
def __init__(
39+
self,
40+
encoder_name: str = "resnet34",
41+
encoder_depth: int = 5,
42+
encoder_weights: Optional[str] = "imagenet",
43+
decoder_channels: int = 256,
44+
in_channels: int = 3,
45+
classes: int = 1,
46+
activation: Optional[str] = None,
47+
upsampling: int = 8,
48+
aux_params: Optional[dict] = None,
49+
):
50+
super().__init__()
51+
52+
self.encoder = get_encoder(
53+
encoder_name,
54+
in_channels=in_channels,
55+
depth=encoder_depth,
56+
weights=encoder_weights,
57+
)
58+
self.encoder.make_dilated(
59+
stage_list=[4, 5],
60+
dilation_list=[2, 4]
61+
)
62+
63+
self.decoder = DeepLabV3Decoder(
64+
in_channels=self.encoder.out_channels[-1],
65+
out_channels=decoder_channels,
66+
)
67+
68+
self.segmentation_head = SegmentationHead(
69+
in_channels=self.decoder.out_channels,
70+
out_channels=classes,
71+
activation=activation,
72+
kernel_size=1,
73+
upsampling=upsampling,
74+
)
75+
76+
if aux_params is not None:
77+
self.classification_head = ClassificationHead(
78+
in_channels=self.encoder.out_channels[-1], **aux_params
79+
)
80+
else:
81+
self.classification_head = None

segmentation_models_pytorch/encoders/_base.py

+2-1
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
import torch
22
import torch.nn as nn
33
from typing import List
4+
from collections import OrderedDict
45

56
from . import _utils as utils
67

@@ -12,7 +13,7 @@ class EncoderMixin:
1213
"""
1314

1415
@property
15-
def out_channels(self) -> List:
16+
def out_channels(self):
1617
"""Return channels dimensions for each tensor of forward output of encoder"""
1718
return self._out_channels[: self._depth + 1]
1819

tests/test_models.py

+38-12
Original file line numberDiff line numberDiff line change
@@ -27,18 +27,32 @@ def get_encoders():
2727

2828
ENCODERS = get_encoders()
2929
DEFAULT_ENCODER = "resnet18"
30-
DEFAULT_SAMPLE = torch.ones([1, 3, 64, 64])
31-
DEFAULT_PAN_SAMPLE = torch.ones([2, 3, 256, 256])
3230

3331

34-
def _test_forward(model):
32+
def get_sample(model_class):
33+
if model_class in [smp.Unet, smp.Linknet, smp.FPN, smp.PSPNet]:
34+
sample = torch.ones([1, 3, 64, 64])
35+
elif model_class == smp.PAN:
36+
sample = torch.ones([2, 3, 256, 256])
37+
elif model_class == smp.DeepLabV3:
38+
sample = torch.ones([2, 3, 128, 128])
39+
else:
40+
raise ValueError("Not supported model class {}".format(model_class))
41+
return sample
42+
43+
44+
def _test_forward(model, sample, test_shape=False):
3545
with torch.no_grad():
36-
model(DEFAULT_SAMPLE)
46+
out = model(sample)
47+
if test_shape:
48+
assert out.shape[2:] == sample.shape[2:]
3749

3850

39-
def _test_forward_backward(model, sample):
51+
def _test_forward_backward(model, sample, test_shape=False):
4052
out = model(sample)
4153
out.mean().backward()
54+
if test_shape:
55+
assert out.shape[2:] == sample.shape[2:]
4256

4357

4458
@pytest.mark.parametrize("encoder_name", ENCODERS)
@@ -50,12 +64,22 @@ def test_forward(model_class, encoder_name, encoder_depth, **kwargs):
5064
model = model_class(
5165
encoder_name, encoder_depth=encoder_depth, encoder_weights=None, **kwargs
5266
)
53-
_test_forward(model)
67+
sample = get_sample(model_class)
5468

69+
if encoder_depth == 5 and model_class != smp.PSPNet:
70+
test_shape = True
71+
else:
72+
test_shape = False
5573

56-
@pytest.mark.parametrize("model_class", [smp.PAN, smp.FPN, smp.PSPNet, smp.Linknet, smp.Unet])
74+
_test_forward(model, sample, test_shape)
75+
76+
77+
@pytest.mark.parametrize(
78+
"model_class",
79+
[smp.PAN, smp.FPN, smp.PSPNet, smp.Linknet, smp.Unet, smp.DeepLabV3]
80+
)
5781
def test_forward_backward(model_class):
58-
sample = DEFAULT_PAN_SAMPLE if model_class is smp.PAN else DEFAULT_SAMPLE
82+
sample = get_sample(model_class)
5983
model = model_class(DEFAULT_ENCODER, encoder_weights=None)
6084
_test_forward_backward(model, sample)
6185

@@ -65,8 +89,8 @@ def test_aux_output(model_class):
6589
model = model_class(
6690
DEFAULT_ENCODER, encoder_weights=None, aux_params=dict(classes=2)
6791
)
68-
sample = DEFAULT_PAN_SAMPLE if model_class is smp.PAN else DEFAULT_SAMPLE
69-
label_size = (2, 2) if model_class is smp.PAN else (1, 2)
92+
sample = get_sample(model_class)
93+
label_size = (sample.shape[0], 2)
7094
mask, label = model(sample)
7195
assert label.size() == label_size
7296

@@ -76,7 +100,8 @@ def test_aux_output(model_class):
76100
def test_upsample(model_class, upsampling):
77101
default_upsampling = 4 if model_class is smp.FPN else 8
78102
model = model_class(DEFAULT_ENCODER, encoder_weights=None, upsampling=upsampling)
79-
mask = model(DEFAULT_SAMPLE)
103+
sample = get_sample(model_class)
104+
mask = model(sample)
80105
assert mask.size()[-1] / 64 == upsampling / default_upsampling
81106

82107

@@ -106,7 +131,8 @@ def test_dilation(encoder_name):
106131

107132
encoder.eval()
108133
with torch.no_grad():
109-
output = encoder(DEFAULT_SAMPLE)
134+
sample = torch.ones([1, 3, 64, 64])
135+
output = encoder(sample)
110136

111137
shapes = [out.shape[-1] for out in output]
112138
assert shapes == [64, 32, 16, 8, 4, 4] # last downsampling replaced with dilation

0 commit comments

Comments
 (0)