diff --git a/docs/source/ops.rst b/docs/source/ops.rst index cdebe9721c3..ecef74dd8a6 100644 --- a/docs/source/ops.rst +++ b/docs/source/ops.rst @@ -23,6 +23,7 @@ torchvision.ops .. autofunction:: ps_roi_pool .. autofunction:: deform_conv2d .. autofunction:: sigmoid_focal_loss +.. autofunction:: stochastic_depth .. autoclass:: RoIAlign .. autoclass:: PSRoIAlign @@ -31,3 +32,4 @@ torchvision.ops .. autoclass:: DeformConv2d .. autoclass:: MultiScaleRoIAlign .. autoclass:: FeaturePyramidNetwork +.. autoclass:: StochasticDepth diff --git a/test/test_ops.py b/test/test_ops.py index 5c2fc882902..c64ba1fd0bb 100644 --- a/test/test_ops.py +++ b/test/test_ops.py @@ -1000,5 +1000,33 @@ def gen_iou_check(box, expected, tolerance=1e-4): gen_iou_check(box_tensor, expected, tolerance=0.002 if dtype == torch.float16 else 1e-3) +class TestStochasticDepth: + @pytest.mark.parametrize('p', [0.2, 0.5, 0.8]) + @pytest.mark.parametrize('mode', ["batch", "row"]) + def test_stochastic_depth(self, mode, p): + stats = pytest.importorskip("scipy.stats") + batch_size = 5 + x = torch.ones(size=(batch_size, 3, 4, 4)) + layer = ops.StochasticDepth(p=p, mode=mode).to(device=x.device, dtype=x.dtype) + layer.__repr__() + + trials = 250 + num_samples = 0 + counts = 0 + for _ in range(trials): + out = layer(x) + non_zero_count = out.sum(dim=(1, 2, 3)).nonzero().size(0) + if mode == "batch": + if non_zero_count == 0: + counts += 1 + num_samples += 1 + elif mode == "row": + counts += batch_size - non_zero_count + num_samples += batch_size + + p_value = stats.binom_test(counts, num_samples, p=p) + assert p_value > 0.0001 + + if __name__ == '__main__': pytest.main([__file__]) diff --git a/torchvision/ops/__init__.py b/torchvision/ops/__init__.py index 0ec189dbc2a..606c27abcbe 100644 --- a/torchvision/ops/__init__.py +++ b/torchvision/ops/__init__.py @@ -8,6 +8,7 @@ from .poolers import MultiScaleRoIAlign from .feature_pyramid_network import FeaturePyramidNetwork from .focal_loss import sigmoid_focal_loss +from .stochastic_depth import stochastic_depth, StochasticDepth from ._register_onnx_ops import _register_custom_op @@ -20,5 +21,5 @@ 'box_area', 'box_iou', 'generalized_box_iou', 'roi_align', 'RoIAlign', 'roi_pool', 'RoIPool', 'ps_roi_align', 'PSRoIAlign', 'ps_roi_pool', 'PSRoIPool', 'MultiScaleRoIAlign', 'FeaturePyramidNetwork', - 'sigmoid_focal_loss' + 'sigmoid_focal_loss', 'stochastic_depth', 'StochasticDepth' ] diff --git a/torchvision/ops/stochastic_depth.py b/torchvision/ops/stochastic_depth.py new file mode 100644 index 00000000000..f3338242a76 --- /dev/null +++ b/torchvision/ops/stochastic_depth.py @@ -0,0 +1,56 @@ +import torch +from torch import nn, Tensor + + +def stochastic_depth(input: Tensor, p: float, mode: str, training: bool = True) -> Tensor: + """ + Implements the Stochastic Depth from `"Deep Networks with Stochastic Depth" + `_ used for randomly dropping residual + branches of residual architectures. + + Args: + input (Tensor[N, ...]): The input tensor or arbitrary dimensions with the first one + being its batch i.e. a batch with ``N`` rows. + p (float): probability of the input to be zeroed. + mode (str): ``"batch"`` or ``"row"``. + ``"batch"`` randomly zeroes the entire input, ``"row"`` zeroes + randomly selected rows from the batch. + training: apply stochastic depth if is ``True``. Default: ``True`` + + Returns: + Tensor[N, ...]: The randomly zeroed tensor. + """ + if p < 0.0 or p > 1.0: + raise ValueError("drop probability has to be between 0 and 1, but got {}".format(p)) + if not training or p == 0.0: + return input + + survival_rate = 1.0 - p + if mode not in ["batch", "row"]: + raise ValueError("mode has to be either 'batch' or 'row', but got {}".format(mode)) + size = [1] * input.ndim + if mode == "row": + size[0] = input.shape[0] + noise = torch.empty(size, dtype=input.dtype, device=input.device) + noise = noise.bernoulli_(survival_rate).div_(survival_rate) + return input * noise + + +class StochasticDepth(nn.Module): + """ + See :func:`stochastic_depth`. + """ + def __init__(self, p: float, mode: str) -> None: + super().__init__() + self.p = p + self.mode = mode + + def forward(self, input: Tensor) -> Tensor: + return stochastic_depth(input, self.p, self.mode, self.training) + + def __repr__(self) -> str: + tmpstr = self.__class__.__name__ + '(' + tmpstr += 'p=' + str(self.p) + tmpstr += ', mode=' + str(self.mode) + tmpstr += ')' + return tmpstr