Skip to content

Add StochasticDepth implementation #4301

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 11 commits into from
Aug 20, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions docs/source/ops.rst
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@ torchvision.ops
.. autofunction:: ps_roi_pool
.. autofunction:: deform_conv2d
.. autofunction:: sigmoid_focal_loss
.. autofunction:: stochastic_depth

.. autoclass:: RoIAlign
.. autoclass:: PSRoIAlign
Expand All @@ -31,3 +32,4 @@ torchvision.ops
.. autoclass:: DeformConv2d
.. autoclass:: MultiScaleRoIAlign
.. autoclass:: FeaturePyramidNetwork
.. autoclass:: StochasticDepth
28 changes: 28 additions & 0 deletions test/test_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -1000,5 +1000,33 @@ def gen_iou_check(box, expected, tolerance=1e-4):
gen_iou_check(box_tensor, expected, tolerance=0.002 if dtype == torch.float16 else 1e-3)


class TestStochasticDepth:
@pytest.mark.parametrize('p', [0.2, 0.5, 0.8])
@pytest.mark.parametrize('mode', ["batch", "row"])
def test_stochastic_depth(self, mode, p):
stats = pytest.importorskip("scipy.stats")
batch_size = 5
x = torch.ones(size=(batch_size, 3, 4, 4))
layer = ops.StochasticDepth(p=p, mode=mode).to(device=x.device, dtype=x.dtype)
layer.__repr__()

trials = 250
num_samples = 0
counts = 0
for _ in range(trials):
out = layer(x)
non_zero_count = out.sum(dim=(1, 2, 3)).nonzero().size(0)
if mode == "batch":
if non_zero_count == 0:
counts += 1
num_samples += 1
elif mode == "row":
counts += batch_size - non_zero_count
num_samples += batch_size

p_value = stats.binom_test(counts, num_samples, p=p)
assert p_value > 0.0001


if __name__ == '__main__':
pytest.main([__file__])
3 changes: 2 additions & 1 deletion torchvision/ops/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
from .poolers import MultiScaleRoIAlign
from .feature_pyramid_network import FeaturePyramidNetwork
from .focal_loss import sigmoid_focal_loss
from .stochastic_depth import stochastic_depth, StochasticDepth

from ._register_onnx_ops import _register_custom_op

Expand All @@ -20,5 +21,5 @@
'box_area', 'box_iou', 'generalized_box_iou', 'roi_align', 'RoIAlign', 'roi_pool',
'RoIPool', 'ps_roi_align', 'PSRoIAlign', 'ps_roi_pool',
'PSRoIPool', 'MultiScaleRoIAlign', 'FeaturePyramidNetwork',
'sigmoid_focal_loss'
'sigmoid_focal_loss', 'stochastic_depth', 'StochasticDepth'
]
56 changes: 56 additions & 0 deletions torchvision/ops/stochastic_depth.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,56 @@
import torch
from torch import nn, Tensor


def stochastic_depth(input: Tensor, p: float, mode: str, training: bool = True) -> Tensor:
"""
Implements the Stochastic Depth from `"Deep Networks with Stochastic Depth"
<https://arxiv.org/abs/1603.09382>`_ used for randomly dropping residual
branches of residual architectures.

Args:
input (Tensor[N, ...]): The input tensor or arbitrary dimensions with the first one
being its batch i.e. a batch with ``N`` rows.
p (float): probability of the input to be zeroed.
mode (str): ``"batch"`` or ``"row"``.
``"batch"`` randomly zeroes the entire input, ``"row"`` zeroes
randomly selected rows from the batch.
training: apply stochastic depth if is ``True``. Default: ``True``

Returns:
Tensor[N, ...]: The randomly zeroed tensor.
"""
if p < 0.0 or p > 1.0:
raise ValueError("drop probability has to be between 0 and 1, but got {}".format(p))
if not training or p == 0.0:
return input

survival_rate = 1.0 - p
if mode not in ["batch", "row"]:
raise ValueError("mode has to be either 'batch' or 'row', but got {}".format(mode))
size = [1] * input.ndim
if mode == "row":
size[0] = input.shape[0]
noise = torch.empty(size, dtype=input.dtype, device=input.device)
noise = noise.bernoulli_(survival_rate).div_(survival_rate)
return input * noise


class StochasticDepth(nn.Module):
"""
See :func:`stochastic_depth`.
"""
def __init__(self, p: float, mode: str) -> None:
super().__init__()
self.p = p
self.mode = mode

def forward(self, input: Tensor) -> Tensor:
return stochastic_depth(input, self.p, self.mode, self.training)

def __repr__(self) -> str:
tmpstr = self.__class__.__name__ + '('
tmpstr += 'p=' + str(self.p)
tmpstr += ', mode=' + str(self.mode)
tmpstr += ')'
return tmpstr