Skip to content

Commit fe28b14

Browse files
datumboxfmassa
authored andcommitted
[fbsync] Add StochasticDepth implementation (#4301)
Summary: * Adding operator. * Adding tests * switching order of `p` and `mode`. * Remove seed setting. * Replace stats import with pytest.importorskip. * Fix doc * Apply suggestions from code review * Fixing indentation. * Adding operator in the documentation. * Fixing lint Reviewed By: fmassa Differential Revision: D30525891 fbshipit-source-id: da8300c8428efa8f74d79ae06c19ea2e040c88c9 Co-authored-by: Francisco Massa <[email protected]> Co-authored-by: Francisco Massa <[email protected]>
1 parent 5b524cd commit fe28b14

File tree

4 files changed

+88
-1
lines changed

4 files changed

+88
-1
lines changed

docs/source/ops.rst

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,7 @@ torchvision.ops
2323
.. autofunction:: ps_roi_pool
2424
.. autofunction:: deform_conv2d
2525
.. autofunction:: sigmoid_focal_loss
26+
.. autofunction:: stochastic_depth
2627

2728
.. autoclass:: RoIAlign
2829
.. autoclass:: PSRoIAlign
@@ -31,3 +32,4 @@ torchvision.ops
3132
.. autoclass:: DeformConv2d
3233
.. autoclass:: MultiScaleRoIAlign
3334
.. autoclass:: FeaturePyramidNetwork
35+
.. autoclass:: StochasticDepth

test/test_ops.py

Lines changed: 28 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1000,5 +1000,33 @@ def gen_iou_check(box, expected, tolerance=1e-4):
10001000
gen_iou_check(box_tensor, expected, tolerance=0.002 if dtype == torch.float16 else 1e-3)
10011001

10021002

1003+
class TestStochasticDepth:
1004+
@pytest.mark.parametrize('p', [0.2, 0.5, 0.8])
1005+
@pytest.mark.parametrize('mode', ["batch", "row"])
1006+
def test_stochastic_depth(self, mode, p):
1007+
stats = pytest.importorskip("scipy.stats")
1008+
batch_size = 5
1009+
x = torch.ones(size=(batch_size, 3, 4, 4))
1010+
layer = ops.StochasticDepth(p=p, mode=mode).to(device=x.device, dtype=x.dtype)
1011+
layer.__repr__()
1012+
1013+
trials = 250
1014+
num_samples = 0
1015+
counts = 0
1016+
for _ in range(trials):
1017+
out = layer(x)
1018+
non_zero_count = out.sum(dim=(1, 2, 3)).nonzero().size(0)
1019+
if mode == "batch":
1020+
if non_zero_count == 0:
1021+
counts += 1
1022+
num_samples += 1
1023+
elif mode == "row":
1024+
counts += batch_size - non_zero_count
1025+
num_samples += batch_size
1026+
1027+
p_value = stats.binom_test(counts, num_samples, p=p)
1028+
assert p_value > 0.0001
1029+
1030+
10031031
if __name__ == '__main__':
10041032
pytest.main([__file__])

torchvision/ops/__init__.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88
from .poolers import MultiScaleRoIAlign
99
from .feature_pyramid_network import FeaturePyramidNetwork
1010
from .focal_loss import sigmoid_focal_loss
11+
from .stochastic_depth import stochastic_depth, StochasticDepth
1112

1213
from ._register_onnx_ops import _register_custom_op
1314

@@ -20,5 +21,5 @@
2021
'box_area', 'box_iou', 'generalized_box_iou', 'roi_align', 'RoIAlign', 'roi_pool',
2122
'RoIPool', 'ps_roi_align', 'PSRoIAlign', 'ps_roi_pool',
2223
'PSRoIPool', 'MultiScaleRoIAlign', 'FeaturePyramidNetwork',
23-
'sigmoid_focal_loss'
24+
'sigmoid_focal_loss', 'stochastic_depth', 'StochasticDepth'
2425
]

torchvision/ops/stochastic_depth.py

Lines changed: 56 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,56 @@
1+
import torch
2+
from torch import nn, Tensor
3+
4+
5+
def stochastic_depth(input: Tensor, p: float, mode: str, training: bool = True) -> Tensor:
6+
"""
7+
Implements the Stochastic Depth from `"Deep Networks with Stochastic Depth"
8+
<https://arxiv.org/abs/1603.09382>`_ used for randomly dropping residual
9+
branches of residual architectures.
10+
11+
Args:
12+
input (Tensor[N, ...]): The input tensor or arbitrary dimensions with the first one
13+
being its batch i.e. a batch with ``N`` rows.
14+
p (float): probability of the input to be zeroed.
15+
mode (str): ``"batch"`` or ``"row"``.
16+
``"batch"`` randomly zeroes the entire input, ``"row"`` zeroes
17+
randomly selected rows from the batch.
18+
training: apply stochastic depth if is ``True``. Default: ``True``
19+
20+
Returns:
21+
Tensor[N, ...]: The randomly zeroed tensor.
22+
"""
23+
if p < 0.0 or p > 1.0:
24+
raise ValueError("drop probability has to be between 0 and 1, but got {}".format(p))
25+
if not training or p == 0.0:
26+
return input
27+
28+
survival_rate = 1.0 - p
29+
if mode not in ["batch", "row"]:
30+
raise ValueError("mode has to be either 'batch' or 'row', but got {}".format(mode))
31+
size = [1] * input.ndim
32+
if mode == "row":
33+
size[0] = input.shape[0]
34+
noise = torch.empty(size, dtype=input.dtype, device=input.device)
35+
noise = noise.bernoulli_(survival_rate).div_(survival_rate)
36+
return input * noise
37+
38+
39+
class StochasticDepth(nn.Module):
40+
"""
41+
See :func:`stochastic_depth`.
42+
"""
43+
def __init__(self, p: float, mode: str) -> None:
44+
super().__init__()
45+
self.p = p
46+
self.mode = mode
47+
48+
def forward(self, input: Tensor) -> Tensor:
49+
return stochastic_depth(input, self.p, self.mode, self.training)
50+
51+
def __repr__(self) -> str:
52+
tmpstr = self.__class__.__name__ + '('
53+
tmpstr += 'p=' + str(self.p)
54+
tmpstr += ', mode=' + str(self.mode)
55+
tmpstr += ')'
56+
return tmpstr

0 commit comments

Comments
 (0)