Skip to content

Commit dad4b05

Browse files
authored
[feat] Adding UPerNet (#926)
* add UPerNet * update paper link * update tests add UPerNet for test_models * update readme and doc * rename variable * fix format and lint checks * update UPerNet decoder Resize all FPN output features to 1/4 of the original resolution. * update UPerNet decoder 1. Use `SegmentationHead` for upsampling, set `upsampling=4` 2. Remove the additional variable `out_channels` from `UPerNetDecoder` 3. Fix `SegmentationHead` kernel size to 1
1 parent d989faa commit dad4b05

File tree

7 files changed

+249
-2
lines changed

7 files changed

+249
-2
lines changed

README.md

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,7 @@ Segmentation based on [PyTorch](https://pytorch.org/).**
1919
The main features of this library are:
2020

2121
- High-level API (just two lines to create a neural network)
22-
- 9 models architectures for binary and multi class segmentation (including legendary Unet)
22+
- 10 models architectures for binary and multi class segmentation (including legendary Unet)
2323
- 124 available encoders (and 500+ encoders from [timm](https://github.com/rwightman/pytorch-image-models))
2424
- All encoders have pre-trained weights for faster and better convergence
2525
- Popular metrics and losses for training routines
@@ -94,6 +94,7 @@ Congratulations! You are done! Now you can train your model with your favorite f
9494
- PAN [[paper](https://arxiv.org/abs/1805.10180)] [[docs](https://smp.readthedocs.io/en/latest/models.html#pan)]
9595
- DeepLabV3 [[paper](https://arxiv.org/abs/1706.05587)] [[docs](https://smp.readthedocs.io/en/latest/models.html#deeplabv3)]
9696
- DeepLabV3+ [[paper](https://arxiv.org/abs/1802.02611)] [[docs](https://smp.readthedocs.io/en/latest/models.html#id9)]
97+
- UPerNet [[paper](https://arxiv.org/abs/1807.10221)] [[docs](https://smp.readthedocs.io/en/latest/models.html#upernet)]
9798

9899
#### Encoders <a name="encoders"></a>
99100

docs/models.rst

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -66,3 +66,10 @@ MAnet
6666
PAN
6767
~~~
6868
.. autoclass:: segmentation_models_pytorch.PAN
69+
70+
71+
.. _upernet:
72+
73+
UPerNet
74+
~~~
75+
.. autoclass:: segmentation_models_pytorch.UPerNet

segmentation_models_pytorch/__init__.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@
1414
from .decoders.pspnet import PSPNet
1515
from .decoders.deeplabv3 import DeepLabV3, DeepLabV3Plus
1616
from .decoders.pan import PAN
17+
from .decoders.upernet import UPerNet
1718
from .base.hub_mixin import from_pretrained
1819

1920
from .__version__ import __version__
@@ -48,6 +49,7 @@ def create_model(
4849
DeepLabV3,
4950
DeepLabV3Plus,
5051
PAN,
52+
UPerNet,
5153
]
5254
archs_dict = {a.__name__.lower(): a for a in archs}
5355
try:
@@ -82,6 +84,7 @@ def create_model(
8284
"DeepLabV3",
8385
"DeepLabV3Plus",
8486
"PAN",
87+
"UPerNet",
8588
"from_pretrained",
8689
"create_model",
8790
"__version__",
Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,3 @@
1+
from .model import UPerNet
2+
3+
__all__ = ["UPerNet"]
Lines changed: 139 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,139 @@
1+
import torch
2+
import torch.nn as nn
3+
import torch.nn.functional as F
4+
5+
from segmentation_models_pytorch.base import modules as md
6+
7+
8+
class PSPModule(nn.Module):
9+
def __init__(
10+
self,
11+
in_channels,
12+
out_channels,
13+
sizes=(1, 2, 3, 6),
14+
use_batchnorm=True,
15+
):
16+
super().__init__()
17+
self.blocks = nn.ModuleList(
18+
[
19+
nn.Sequential(
20+
nn.AdaptiveAvgPool2d(size),
21+
md.Conv2dReLU(
22+
in_channels,
23+
in_channels // len(sizes),
24+
kernel_size=1,
25+
use_batchnorm=use_batchnorm,
26+
),
27+
)
28+
for size in sizes
29+
]
30+
)
31+
self.out_conv = md.Conv2dReLU(
32+
in_channels=in_channels * 2,
33+
out_channels=out_channels,
34+
kernel_size=1,
35+
use_batchnorm=True,
36+
)
37+
38+
def forward(self, x):
39+
_, _, height, weight = x.shape
40+
out = [x] + [
41+
F.interpolate(
42+
block(x), size=(height, weight), mode="bilinear", align_corners=False
43+
)
44+
for block in self.blocks
45+
]
46+
out = self.out_conv(torch.cat(out, dim=1))
47+
return out
48+
49+
50+
class FPNBlock(nn.Module):
51+
def __init__(self, skip_channels, pyramid_channels, use_bathcnorm=True):
52+
super().__init__()
53+
self.skip_conv = (
54+
md.Conv2dReLU(
55+
skip_channels,
56+
pyramid_channels,
57+
kernel_size=1,
58+
use_batchnorm=use_bathcnorm,
59+
)
60+
if skip_channels != 0
61+
else nn.Identity()
62+
)
63+
64+
def forward(self, x, skip):
65+
_, channels, height, weight = skip.shape
66+
x = F.interpolate(
67+
x, size=(height, weight), mode="bilinear", align_corners=False
68+
)
69+
if channels != 0:
70+
skip = self.skip_conv(skip)
71+
x = x + skip
72+
return x
73+
74+
75+
class UPerNetDecoder(nn.Module):
76+
def __init__(
77+
self,
78+
encoder_channels,
79+
encoder_depth=5,
80+
pyramid_channels=256,
81+
segmentation_channels=64,
82+
):
83+
super().__init__()
84+
85+
if encoder_depth < 3:
86+
raise ValueError(
87+
"Encoder depth for UPerNet decoder cannot be less than 3, got {}.".format(
88+
encoder_depth
89+
)
90+
)
91+
92+
encoder_channels = encoder_channels[::-1]
93+
94+
# PSP Module
95+
self.psp = PSPModule(
96+
in_channels=encoder_channels[0],
97+
out_channels=pyramid_channels,
98+
sizes=(1, 2, 3, 6),
99+
use_batchnorm=True,
100+
)
101+
102+
# FPN Module
103+
self.fpn_stages = nn.ModuleList(
104+
[FPNBlock(ch, pyramid_channels) for ch in encoder_channels[1:]]
105+
)
106+
107+
self.fpn_bottleneck = md.Conv2dReLU(
108+
in_channels=(len(encoder_channels) - 1) * pyramid_channels,
109+
out_channels=segmentation_channels,
110+
kernel_size=3,
111+
padding=1,
112+
use_batchnorm=True,
113+
)
114+
115+
def forward(self, *features):
116+
output_size = features[0].shape[2:]
117+
target_size = [size // 4 for size in output_size]
118+
119+
features = features[1:] # remove first skip with same spatial resolution
120+
features = features[::-1] # reverse channels to start from head of encoder
121+
122+
psp_out = self.psp(features[0])
123+
124+
fpn_features = [psp_out]
125+
for feature, stage in zip(features[1:], self.fpn_stages):
126+
fpn_feature = stage(fpn_features[-1], feature)
127+
fpn_features.append(fpn_feature)
128+
129+
# Resize all FPN features to 1/4 of the original resolution.
130+
resized_fpn_features = []
131+
for feature in fpn_features:
132+
resized_feature = F.interpolate(
133+
feature, size=target_size, mode="bilinear", align_corners=False
134+
)
135+
resized_fpn_features.append(resized_feature)
136+
137+
output = self.fpn_bottleneck(torch.cat(resized_fpn_features, dim=1))
138+
139+
return output
Lines changed: 92 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,92 @@
1+
from typing import Optional, Union
2+
3+
from segmentation_models_pytorch.encoders import get_encoder
4+
from segmentation_models_pytorch.base import (
5+
SegmentationModel,
6+
SegmentationHead,
7+
ClassificationHead,
8+
)
9+
from .decoder import UPerNetDecoder
10+
11+
12+
class UPerNet(SegmentationModel):
13+
"""UPerNet is a unified perceptual parsing network for image segmentation.
14+
15+
Args:
16+
encoder_name: Name of the classification model that will be used as an encoder (a.k.a backbone)
17+
to extract features of different spatial resolution
18+
encoder_depth: A number of stages used in encoder in range [3, 5]. Each stage generate features
19+
two times smaller in spatial dimensions than previous one (e.g. for depth 0 we will have features
20+
with shapes [(N, C, H, W),], for depth 1 - [(N, C, H, W), (N, C, H // 2, W // 2)] and so on).
21+
Default is 5
22+
encoder_weights: One of **None** (random initialization), **"imagenet"** (pre-training on ImageNet) and
23+
other pretrained weights (see table with available weights for each encoder_name)
24+
decoder_pyramid_channels: A number of convolution filters in Feature Pyramid, default is 256
25+
decoder_segmentation_channels: A number of convolution filters in segmentation blocks, default is 64
26+
in_channels: A number of input channels for the model, default is 3 (RGB images)
27+
classes: A number of classes for output mask (or you can think as a number of channels of output mask)
28+
activation: An activation function to apply after the final convolution layer.
29+
Available options are **"sigmoid"**, **"softmax"**, **"logsoftmax"**, **"tanh"**, **"identity"**,
30+
**callable** and **None**.
31+
Default is **None**
32+
aux_params: Dictionary with parameters of the auxiliary output (classification head). Auxiliary output is build
33+
on top of encoder if **aux_params** is not **None** (default). Supported params:
34+
- classes (int): A number of classes
35+
- pooling (str): One of "max", "avg". Default is "avg"
36+
- dropout (float): Dropout factor in [0, 1)
37+
- activation (str): An activation function to apply "sigmoid"/"softmax"
38+
(could be **None** to return logits)
39+
40+
Returns:
41+
``torch.nn.Module``: **UPerNet**
42+
43+
.. _UPerNet:
44+
https://arxiv.org/abs/1807.10221
45+
46+
"""
47+
48+
def __init__(
49+
self,
50+
encoder_name: str = "resnet34",
51+
encoder_depth: int = 5,
52+
encoder_weights: Optional[str] = "imagenet",
53+
decoder_pyramid_channels: int = 256,
54+
decoder_segmentation_channels: int = 64,
55+
in_channels: int = 3,
56+
classes: int = 1,
57+
activation: Optional[Union[str, callable]] = None,
58+
aux_params: Optional[dict] = None,
59+
):
60+
super().__init__()
61+
62+
self.encoder = get_encoder(
63+
encoder_name,
64+
in_channels=in_channels,
65+
depth=encoder_depth,
66+
weights=encoder_weights,
67+
)
68+
69+
self.decoder = UPerNetDecoder(
70+
encoder_channels=self.encoder.out_channels,
71+
encoder_depth=encoder_depth,
72+
pyramid_channels=decoder_pyramid_channels,
73+
segmentation_channels=decoder_segmentation_channels,
74+
)
75+
76+
self.segmentation_head = SegmentationHead(
77+
in_channels=decoder_segmentation_channels,
78+
out_channels=classes,
79+
activation=activation,
80+
kernel_size=1,
81+
upsampling=4,
82+
)
83+
84+
if aux_params is not None:
85+
self.classification_head = ClassificationHead(
86+
in_channels=self.encoder.out_channels[-1], **aux_params
87+
)
88+
else:
89+
self.classification_head = None
90+
91+
self.name = "upernet-{}".format(encoder_name)
92+
self.initialize()

tests/test_models.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,7 @@ def get_sample(model_class):
2929
smp.PSPNet,
3030
smp.UnetPlusPlus,
3131
smp.MAnet,
32+
smp.UPerNet,
3233
]:
3334
sample = torch.ones([1, 3, 64, 64])
3435
elif model_class == smp.PAN:
@@ -57,7 +58,8 @@ def _test_forward_backward(model, sample, test_shape=False):
5758
@pytest.mark.parametrize("encoder_name", ENCODERS)
5859
@pytest.mark.parametrize("encoder_depth", [3, 5])
5960
@pytest.mark.parametrize(
60-
"model_class", [smp.FPN, smp.PSPNet, smp.Linknet, smp.Unet, smp.UnetPlusPlus]
61+
"model_class",
62+
[smp.FPN, smp.PSPNet, smp.Linknet, smp.Unet, smp.UnetPlusPlus, smp.UPerNet],
6163
)
6264
def test_forward(model_class, encoder_name, encoder_depth, **kwargs):
6365
if (

0 commit comments

Comments
 (0)