diff --git a/torchvision/models/efficientnet.py b/torchvision/models/efficientnet.py index 06b2a301b6d..bad5b57b25b 100644 --- a/torchvision/models/efficientnet.py +++ b/torchvision/models/efficientnet.py @@ -32,17 +32,25 @@ class SqueezeExcitation(nn.Module): - def __init__(self, input_channels: int, squeeze_channels: int): + def __init__( + self, + input_channels: int, + squeeze_channels: int, + activation: Callable[..., nn.Module] = nn.ReLU, + scale_activation: Callable[..., nn.Module] = nn.Sigmoid, + ) -> None: super().__init__() self.fc1 = nn.Conv2d(input_channels, squeeze_channels, 1) self.fc2 = nn.Conv2d(squeeze_channels, input_channels, 1) + self.activation = activation() + self.scale_activation = scale_activation() def _scale(self, input: Tensor) -> Tensor: scale = F.adaptive_avg_pool2d(input, 1) scale = self.fc1(scale) - scale = F.silu(scale, inplace=True) + scale = self.activation(scale) scale = self.fc2(scale) - return scale.sigmoid() + return self.scale_activation(scale) def forward(self, input: Tensor) -> Tensor: scale = self._scale(input) @@ -108,7 +116,7 @@ def __init__(self, cnf: MBConvConfig, stochastic_depth_prob: float, norm_layer: # squeeze and excitation squeeze_channels = max(1, cnf.input_channels // 4) - layers.append(se_layer(expanded_channels, squeeze_channels)) + layers.append(se_layer(expanded_channels, squeeze_channels, activation=partial(nn.SiLU, inplace=True))) # project layers.append(ConvBNActivation(expanded_channels, cnf.out_channels, kernel_size=1, norm_layer=norm_layer,