|
| 1 | +""" |
| 2 | + EfficientNet for ImageNet-1K, implemented in PyTorch. |
| 3 | + Original papers: |
| 4 | + - 'EfficientNet: Rethinking Model Scaling for Convolutional Neural Networks,' https://arxiv.org/abs/1905.11946, |
| 5 | + - 'Adversarial Examples Improve Image Recognition,' https://arxiv.org/abs/1911.09665. |
| 6 | +""" |
| 7 | + |
| 8 | +import os |
| 9 | + |
| 10 | +import timm |
| 11 | +import torch.nn as nn |
| 12 | +from mmcls.models.builder import BACKBONES |
| 13 | +from mmcv.runner import load_checkpoint |
| 14 | +from mpa.utils.logger import get_logger |
| 15 | + |
| 16 | +logger = get_logger() |
| 17 | + |
| 18 | +pretrained_root = "https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-effv2-weights/" |
| 19 | +pretrained_urls = { |
| 20 | + "efficientnetv2_s_21k": pretrained_root + "tf_efficientnetv2_s_21k-6337ad01.pth", |
| 21 | + "efficientnetv2_s_1k": pretrained_root + "tf_efficientnetv2_s_21ft1k-d7dafa41.pth", |
| 22 | +} |
| 23 | + |
| 24 | +NAME_DICT = { |
| 25 | + 'mobilenetv3_large_21k': 'mobilenetv3_large_100_miil_in21k', |
| 26 | + 'mobilenetv3_large_1k': 'mobilenetv3_large_100_miil', |
| 27 | + 'tresnet': 'tresnet_m', |
| 28 | + 'efficientnetv2_s_21k': 'tf_efficientnetv2_s_in21k', |
| 29 | + 'efficientnetv2_s_1k': 'tf_efficientnetv2_s_in21ft1k', |
| 30 | + 'efficientnetv2_m_21k': 'tf_efficientnetv2_m_in21k', |
| 31 | + 'efficientnetv2_m_1k': 'tf_efficientnetv2_m_in21ft1k', |
| 32 | + 'efficientnetv2_b0': 'tf_efficientnetv2_b0', |
| 33 | + } |
| 34 | + |
| 35 | + |
| 36 | +class TimmModelsWrapper(nn.Module): |
| 37 | + def __init__(self, |
| 38 | + model_name, |
| 39 | + pretrained=True, |
| 40 | + pooling_type='avg', |
| 41 | + **kwargs): |
| 42 | + super().__init__(**kwargs) |
| 43 | + self.model_name = model_name |
| 44 | + self.pretrained = pretrained |
| 45 | + self.is_mobilenet = True if model_name in [ |
| 46 | + "mobilenetv3_large_100_miil_in21k", "mobilenetv3_large_100_miil" |
| 47 | + ] else False |
| 48 | + self.model = timm.create_model(NAME_DICT[self.model_name], |
| 49 | + pretrained=pretrained, |
| 50 | + num_classes=1000) |
| 51 | + self.model.classifier = None # Detach classifier. Only use 'backbone' part in mpa. |
| 52 | + self.num_head_features = self.model.num_features |
| 53 | + self.num_features = (self.model.conv_head.in_channels if self.is_mobilenet |
| 54 | + else self.model.num_features) |
| 55 | + self.pooling_type = pooling_type |
| 56 | + |
| 57 | + def forward(self, x, return_featuremaps=True, **kwargs): |
| 58 | + y = self.extract_features(x) |
| 59 | + if return_featuremaps: |
| 60 | + return y |
| 61 | + |
| 62 | + def extract_features(self, x): |
| 63 | + if self.is_mobilenet: |
| 64 | + x = self.model.conv_stem(x) |
| 65 | + x = self.model.bn1(x) |
| 66 | + x = self.model.act1(x) |
| 67 | + y = self.model.blocks(x) |
| 68 | + return y |
| 69 | + return self.model.forward_features(x) |
| 70 | + |
| 71 | + def get_config_optim(self, lrs): |
| 72 | + parameters = [ |
| 73 | + {'params': self.model.named_parameters()}, |
| 74 | + ] |
| 75 | + if isinstance(lrs, list): |
| 76 | + assert len(lrs) == len(parameters) |
| 77 | + for lr, param_dict in zip(lrs, parameters): |
| 78 | + param_dict['lr'] = lr |
| 79 | + else: |
| 80 | + assert isinstance(lrs, float) |
| 81 | + for param_dict in parameters: |
| 82 | + param_dict['lr'] = lrs |
| 83 | + |
| 84 | + return parameters |
| 85 | + |
| 86 | + |
| 87 | +@BACKBONES.register_module() |
| 88 | +class OTEEfficientNetV2(TimmModelsWrapper): |
| 89 | + def __init__(self, version="s_21k", **kwargs): |
| 90 | + self.model_name = "efficientnetv2_" + version |
| 91 | + super().__init__(model_name=self.model_name, **kwargs) |
| 92 | + |
| 93 | + def init_weights(self, pretrained=None): |
| 94 | + if isinstance(pretrained, str) and os.path.exists(pretrained): |
| 95 | + load_checkpoint(self, pretrained) |
| 96 | + logger.info(f"init weight - {pretrained}") |
| 97 | + elif pretrained is not None: |
| 98 | + load_checkpoint(self, pretrained_urls[self.model_name]) |
| 99 | + logger.info(f"init weight - {pretrained_urls[self.model_name]}") |
0 commit comments