Skip to content
This repository was archived by the owner on Apr 17, 2023. It is now read-only.

Commit 6458939

Browse files
Songki ChoiharimkangJihwanEom
authored
Add EfficientNetV2 (#9)
* Apply MPA public repo changes to inner source (#8) * Move tasks & model templates to OTE repo * Update external/training_extension to MPA PR * Remove external/mmdetection * Update init_venv.sh * Update README.md * Enable non-incremental learning * TODO: Remove OTE SDK/Task dependency * Update mpa's minor issue and add model (#10) * Add det/seg OTELoggerHook & OTEProgressHook * Add effnetv2 template & seg export config update * Add ote hooks * Add metrics score to best_acc * Fix some pretrained var conf * Separate mpa changes and apis task changes * Update SamClassifier & seg weight mixing * Minor fix * Update segmentation configuration * Update sam_classifier.py * Merge remote-tracking branch 'origin/ote' into ote-public Co-authored-by: Harim Kang <[email protected]> Co-authored-by: Jihwan Eom <[email protected]>
1 parent ebee675 commit 6458939

26 files changed

+229
-82
lines changed

README.md

+4-1
Original file line numberDiff line numberDiff line change
@@ -44,7 +44,10 @@ related NN *model* architecures.
4444
### TL Tasks
4545
* Classification
4646
* Detection
47-
* Semantic segmentation
47+
* Segmentation
48+
* (WIP) Instance segmentation
49+
* Semantic segmentation
50+
* (TBD) Panoptic segmentation
4851
> Train / Infer / Evaluate / Export operations are supported for each task
4952
5053
### TL Methods

init_venv.sh

+4-4
Original file line numberDiff line numberDiff line change
@@ -67,7 +67,7 @@ if [ -e "$CUDA_HOME" ]; then
6767
fi
6868
fi
6969

70-
# install PyTorch and MMCV.
70+
# Install PyTorch and MMCV.
7171
export TORCH_VERSION=1.8.2
7272
export TORCHVISION_VERSION=0.9.2
7373
export MMCV_VERSION=1.3.14
@@ -109,10 +109,10 @@ pip install torch==${TORCH_VERSION} torchvision==${TORCHVISION_VERSION} -f https
109109

110110
# Install mmcv
111111
pip install --no-cache-dir mmcv-full==${MMCV_VERSION} -c ${CONSTRAINTS_FILE} || exit 1
112+
sed -i "s/force=False/force=True/g" ${venv_dir}/lib/python${PYTHON_VERSION}/site-packages/mmcv/utils/registry.py # Patch: remedy for MMCV registry collision from mmdet/mmseg
112113

113-
# Install other requirements.
114114
# Install mmpycocotools from source to make sure it is compatible with installed numpy version.
115-
pip install --no-cache-dir --no-binary=mmpycocotools mmpycocotools || exit 1
115+
pip install --no-cache-dir --no-binary=mmpycocotools mmpycocotools -c ${CONSTRAINTS_FILE} || exit 1
116116
cat requirements.txt | xargs -n 1 -L 1 pip install --no-cache || exit 1
117117

118118
# Install external modules
@@ -134,7 +134,7 @@ else
134134
fi
135135

136136
# Install MPA
137-
pip install -e . || exit 1
137+
pip install -e . -c ${CONSTRAINTS_FILE} || exit 1
138138
MPA_DIR=`realpath .`
139139
echo "export MPA_DIR=${MPA_DIR}" >> ${venv_dir}/bin/activate
140140

Original file line numberDiff line numberDiff line change
@@ -0,0 +1,13 @@
1+
# model settings
2+
model = dict(
3+
type='ImageClassifier',
4+
backbone=dict(
5+
type='OTEEfficientNetV2',
6+
version='s_21k'),
7+
neck=dict(type='GlobalAveragePooling'),
8+
head=dict(
9+
type='LinearClsHead',
10+
num_classes=1000,
11+
in_channels=1280,
12+
loss=dict(type='CrossEntropyLoss', loss_weight=1.0),
13+
))
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,7 @@
1+
_base_: _base_/ote_efficientnet_v2.py
2+
3+
model:
4+
type: SAMImageClassifier
5+
task: classification
6+
backbone:
7+
version: s_21k

mpa/cls/exporter.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -78,7 +78,7 @@ def run(self, model_cfg, model_ckpt, data_cfg, **kwargs):
7878
input_names=['data'],
7979
output_names=['logits', 'features', 'vector'],
8080
dynamic_axes={},
81-
opset_version=9,
81+
opset_version=11,
8282
operator_export_type=torch.onnx.OperatorExportTypes.ONNX
8383
)
8484

mpa/cls/stage.py

+3-9
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,6 @@
33
import numpy as np
44

55
from mmcv import ConfigDict
6-
from mmcv import Config
76
from mmcv import build_from_cfg
87

98
from mpa.stage import Stage
@@ -24,12 +23,7 @@ def configure(self, model_cfg, model_ckpt, data_cfg, training=True, **kwargs):
2423
logger.info(f'configure: training={training}')
2524

2625
# Recipe + model
27-
cfg = Config(
28-
self.cfg._cfg_dict,
29-
self.cfg.text,
30-
self.cfg.filename
31-
)
32-
26+
cfg = self.cfg
3327
if model_cfg:
3428
if hasattr(cfg, 'model'):
3529
cfg.merge_from_dict(model_cfg._cfg_dict)
@@ -54,9 +48,9 @@ def configure(self, model_cfg, model_ckpt, data_cfg, training=True, **kwargs):
5448
cfg.model.backbone.model_path = ir_path
5549

5650
pretrained = kwargs.get('pretrained', None)
57-
if pretrained:
51+
if pretrained and isinstance(pretrained, str):
5852
logger.info(f'Overriding cfg.load_from -> {pretrained}')
59-
cfg.load_from = pretrained # Overriding by stage input
53+
cfg.load_from = pretrained
6054

6155
# Data
6256
if data_cfg:

mpa/det/inferrer.py

+12
Original file line numberDiff line numberDiff line change
@@ -127,8 +127,20 @@ def infer(self, cfg):
127127
# Inference
128128
model = MMDataParallel(model, device_ids=[0])
129129
detections = single_gpu_test(model, data_loader)
130+
131+
eval_cfg = cfg.evaluation.copy()
132+
eval_cfg.pop('interval', None)
133+
eval_cfg.pop('save_best', None)
134+
135+
metric = dataset.evaluate(
136+
detections,
137+
logger='silent',
138+
**eval_cfg
139+
)[cfg.evaluation.metric]
140+
130141
outputs = dict(
131142
classes=target_classes,
132143
detections=detections,
144+
metric=metric,
133145
)
134146
return outputs

mpa/det/stage.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -37,7 +37,7 @@ def configure(self, model_cfg, model_ckpt, data_cfg, training=True, **kwargs):
3737
if model_ckpt:
3838
cfg.load_from = self.get_model_ckpt(model_ckpt)
3939
pretrained = kwargs.get('pretrained', None)
40-
if pretrained:
40+
if pretrained and isinstance(pretrained, str):
4141
logger.info(f'Overriding cfg.load_from -> {pretrained}')
4242
cfg.load_from = pretrained # Overriding by stage input
4343

mpa/det/trainer.py

+3-3
Original file line numberDiff line numberDiff line change
@@ -90,7 +90,7 @@ def run(self, model_cfg, model_ckpt, data_cfg, **kwargs):
9090
# Metadata
9191
meta = dict()
9292
meta['env_info'] = env_info
93-
meta['config'] = cfg.pretty_text
93+
# meta['config'] = cfg.pretty_text
9494
meta['seed'] = cfg.seed
9595
meta['exp_name'] = cfg.work_dir
9696
if cfg.checkpoint_config is not None:
@@ -108,8 +108,8 @@ def run(self, model_cfg, model_ckpt, data_cfg, **kwargs):
108108
cfg.optimizer.lr = new_lr
109109

110110
# Save config
111-
cfg.dump(osp.join(cfg.work_dir, 'config.py'))
112-
logger.info(f'Config:\n{cfg.pretty_text}')
111+
# cfg.dump(osp.join(cfg.work_dir, 'config.py'))
112+
# logger.info(f'Config:\n{cfg.pretty_text}')
113113

114114
if distributed:
115115
os.environ['MASTER_ADDR'] = cfg.dist_params.get('master_addr', 'localhost')

mpa/modules/datasets/mpa_det_dataset.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
from mmdet.datasets.builder import DATASETS
2-
from detection_tasks.extension.datasets import OTEDataset
2+
from detection_tasks.extension.datasets import OTEDataset # TODO: Remove OTE SDK/Task dependency
33
from mpa.utils.logger import get_logger
44

55
logger = get_logger()

mpa/modules/datasets/mpa_seg_incr_dataset.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
from mmseg.datasets.builder import DATASETS
2-
from segmentation_tasks.extension.datasets import OTEDataset
2+
from segmentation_tasks.extension.datasets import OTEDataset # TODO: Remove OTE SDK/Task dependency
33
from mpa.utils.logger import get_logger
44

55
logger = get_logger()

mpa/modules/datasets/pipelines/mpa_cls_pipeline.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
from mmcls.datasets import PIPELINES
2-
from detection_tasks.extension.utils import LoadImageFromOTEDataset
2+
from detection_tasks.extension.utils import LoadImageFromOTEDataset # TODO: Remove OTE SDK/Task dependency
33

44

55
@PIPELINES.register_module()

mpa/modules/hooks/early_stopping_hook.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
from mmcv.runner.hooks import HOOKS
2-
from detection_tasks.extension.utils.hooks import EarlyStoppingHook
2+
from detection_tasks.extension.utils.hooks import EarlyStoppingHook # TODO: Remove OTE SDK/Task dependency
33

44

55
@HOOKS.register_module()
+2-1
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
# flake8: noqa
22
from . import wideresnet
33
from . import mobilenetv3
4-
from . import efficientnet
4+
from . import efficientnet
5+
from . import efficientnetv2
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,99 @@
1+
"""
2+
EfficientNet for ImageNet-1K, implemented in PyTorch.
3+
Original papers:
4+
- 'EfficientNet: Rethinking Model Scaling for Convolutional Neural Networks,' https://arxiv.org/abs/1905.11946,
5+
- 'Adversarial Examples Improve Image Recognition,' https://arxiv.org/abs/1911.09665.
6+
"""
7+
8+
import os
9+
10+
import timm
11+
import torch.nn as nn
12+
from mmcls.models.builder import BACKBONES
13+
from mmcv.runner import load_checkpoint
14+
from mpa.utils.logger import get_logger
15+
16+
logger = get_logger()
17+
18+
pretrained_root = "https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-effv2-weights/"
19+
pretrained_urls = {
20+
"efficientnetv2_s_21k": pretrained_root + "tf_efficientnetv2_s_21k-6337ad01.pth",
21+
"efficientnetv2_s_1k": pretrained_root + "tf_efficientnetv2_s_21ft1k-d7dafa41.pth",
22+
}
23+
24+
NAME_DICT = {
25+
'mobilenetv3_large_21k': 'mobilenetv3_large_100_miil_in21k',
26+
'mobilenetv3_large_1k': 'mobilenetv3_large_100_miil',
27+
'tresnet': 'tresnet_m',
28+
'efficientnetv2_s_21k': 'tf_efficientnetv2_s_in21k',
29+
'efficientnetv2_s_1k': 'tf_efficientnetv2_s_in21ft1k',
30+
'efficientnetv2_m_21k': 'tf_efficientnetv2_m_in21k',
31+
'efficientnetv2_m_1k': 'tf_efficientnetv2_m_in21ft1k',
32+
'efficientnetv2_b0': 'tf_efficientnetv2_b0',
33+
}
34+
35+
36+
class TimmModelsWrapper(nn.Module):
37+
def __init__(self,
38+
model_name,
39+
pretrained=True,
40+
pooling_type='avg',
41+
**kwargs):
42+
super().__init__(**kwargs)
43+
self.model_name = model_name
44+
self.pretrained = pretrained
45+
self.is_mobilenet = True if model_name in [
46+
"mobilenetv3_large_100_miil_in21k", "mobilenetv3_large_100_miil"
47+
] else False
48+
self.model = timm.create_model(NAME_DICT[self.model_name],
49+
pretrained=pretrained,
50+
num_classes=1000)
51+
self.model.classifier = None # Detach classifier. Only use 'backbone' part in mpa.
52+
self.num_head_features = self.model.num_features
53+
self.num_features = (self.model.conv_head.in_channels if self.is_mobilenet
54+
else self.model.num_features)
55+
self.pooling_type = pooling_type
56+
57+
def forward(self, x, return_featuremaps=True, **kwargs):
58+
y = self.extract_features(x)
59+
if return_featuremaps:
60+
return y
61+
62+
def extract_features(self, x):
63+
if self.is_mobilenet:
64+
x = self.model.conv_stem(x)
65+
x = self.model.bn1(x)
66+
x = self.model.act1(x)
67+
y = self.model.blocks(x)
68+
return y
69+
return self.model.forward_features(x)
70+
71+
def get_config_optim(self, lrs):
72+
parameters = [
73+
{'params': self.model.named_parameters()},
74+
]
75+
if isinstance(lrs, list):
76+
assert len(lrs) == len(parameters)
77+
for lr, param_dict in zip(lrs, parameters):
78+
param_dict['lr'] = lr
79+
else:
80+
assert isinstance(lrs, float)
81+
for param_dict in parameters:
82+
param_dict['lr'] = lrs
83+
84+
return parameters
85+
86+
87+
@BACKBONES.register_module()
88+
class OTEEfficientNetV2(TimmModelsWrapper):
89+
def __init__(self, version="s_21k", **kwargs):
90+
self.model_name = "efficientnetv2_" + version
91+
super().__init__(model_name=self.model_name, **kwargs)
92+
93+
def init_weights(self, pretrained=None):
94+
if isinstance(pretrained, str) and os.path.exists(pretrained):
95+
load_checkpoint(self, pretrained)
96+
logger.info(f"init weight - {pretrained}")
97+
elif pretrained is not None:
98+
load_checkpoint(self, pretrained_urls[self.model_name])
99+
logger.info(f"init weight - {pretrained_urls[self.model_name]}")

0 commit comments

Comments
 (0)