diff --git a/torchvision/models/vision_transformer.py b/torchvision/models/vision_transformer.py index 063d51749b4..e67c2a67acd 100644 --- a/torchvision/models/vision_transformer.py +++ b/torchvision/models/vision_transformer.py @@ -40,6 +40,8 @@ class ConvStemConfig(NamedTuple): class MLPBlock(MLP): """Transformer MLP block.""" + _version = 2 + def __init__(self, in_dim: int, mlp_dim: int, dropout: float): super().__init__(in_dim, [mlp_dim, in_dim], activation_layer=nn.GELU, inplace=None, dropout=dropout)