Skip to content

Commit dab4268

Browse files
authored
[Fix] Adapt Slimming Tools to New APIs (#2916)
* Adapt to new API * Optimize distill_train
1 parent 82a550e commit dab4268

File tree

7 files changed

+78
-167
lines changed

7 files changed

+78
-167
lines changed

deploy/python/infer_onnx_trt.py

+3-2
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,7 @@
2424
import onnx
2525
import onnxruntime
2626

27-
from paddleseg.cvlibs import Config
27+
from paddleseg.cvlibs import Config, SegBuilder
2828
from paddleseg.utils import logger, utils
2929
"""
3030
Export the Paddle model to ONNX, infer the ONNX model by TRT.
@@ -395,7 +395,8 @@ def export_load_infer(args, model=None):
395395
# 1. prepare
396396
if model is None:
397397
cfg = Config(args.config)
398-
model = cfg.model
398+
builder = SegBuilder(cfg)
399+
model = builder.model
399400
if args.model_path is not None:
400401
utils.load_entire_model(model, args.model_path)
401402
logger.info('Loaded trained params of model successfully')

deploy/slim/distill/distill_train.py

+21-51
Original file line numberDiff line numberDiff line change
@@ -13,14 +13,11 @@
1313
# limitations under the License.
1414

1515
import argparse
16-
import random
1716

18-
import paddle
19-
import numpy as np
2017
from paddleslim.dygraph.dist import Distill
2118

22-
from paddleseg.cvlibs import manager, Config
23-
from paddleseg.utils import get_sys_env, logger, utils
19+
from paddleseg.cvlibs import Config, SegBuilder
20+
from paddleseg.utils import logger, utils
2421
from distill_utils import distill_train
2522
from distill_config import prepare_distill_adaptor, prepare_distill_config
2623

@@ -117,47 +114,29 @@ def prepare_envs(args):
117114
"""
118115
Set random seed and the device.
119116
"""
120-
if args.seed is not None:
121-
paddle.seed(args.seed)
122-
np.random.seed(args.seed)
123-
random.seed(args.seed)
124117

125-
env_info = get_sys_env()
126-
info = ['{}: {}'.format(k, v) for k, v in env_info.items()]
127-
info = '\n'.join(['', format('Environment Information', '-^48s')] + info +
128-
['-' * 48])
129-
logger.info(info)
118+
utils.set_seed(args.seed)
119+
utils.show_env_info()
130120

131-
place = 'gpu' if env_info['Paddle compiled with cuda'] and env_info[
132-
'GPUs used'] else 'cpu'
121+
env_info = utils.get_sys_env()
122+
place = 'gpu' if env_info['GPUs used'] else 'cpu'
123+
utils.set_device(place)
133124

134-
paddle.set_device(place)
135125

126+
def main(args):
127+
128+
prepare_envs(args)
136129

137-
def prepare_config(args):
138-
"""
139-
Create and check the config of student and teacher model.
140-
Note: we only use the dataset generated by the student config.
141-
"""
142130
if args.teather_config is None or args.student_config is None:
143131
raise RuntimeError('No configuration file specified.')
144-
145132
t_cfg = Config(args.teather_config)
146133
s_cfg = Config(
147134
args.student_config,
148135
learning_rate=args.learning_rate,
149136
iters=args.iters,
150137
batch_size=args.batch_size)
151-
152-
train_dataset = s_cfg.train_dataset
153-
val_dataset = s_cfg.val_dataset if args.do_eval else None
154-
if train_dataset is None:
155-
raise RuntimeError(
156-
'The training dataset is not specified in the configuration file.')
157-
elif len(train_dataset) == 0:
158-
raise ValueError(
159-
'The length of train_dataset is 0. Please check if your dataset is valid'
160-
)
138+
t_builder = SegBuilder(t_cfg)
139+
s_builder = SegBuilder(s_cfg)
161140

162141
msg = '\n---------------Teacher Config Information---------------\n'
163142
msg += str(t_cfg)
@@ -169,21 +148,12 @@ def prepare_config(args):
169148
msg += '------------------------------------------------'
170149
logger.info(msg)
171150

172-
return t_cfg, s_cfg, train_dataset, val_dataset
173-
174-
175-
def main(args):
176-
177-
prepare_envs(args)
178-
179-
t_cfg, s_cfg, train_dataset, val_dataset = prepare_config(args)
180-
181151
distill_config = prepare_distill_config()
182152

183153
s_adaptor, t_adaptor = prepare_distill_adaptor()
184154

185-
t_model = t_cfg.model
186-
s_model = s_cfg.model
155+
t_model = t_builder.model
156+
s_model = s_builder.model
187157
t_model.eval()
188158
s_model.train()
189159

@@ -192,19 +162,19 @@ def main(args):
192162

193163
distill_train(
194164
distill_model=distill_model,
195-
train_dataset=train_dataset,
196-
val_dataset=val_dataset,
197-
optimizer=s_cfg.optimizer,
165+
train_dataset=s_builder.train_dataset,
166+
val_dataset=s_builder.val_dataset,
167+
optimizer=s_builder.optimizer,
198168
save_dir=args.save_dir,
199-
iters=s_cfg.iters,
200-
batch_size=s_cfg.batch_size,
169+
iters=s_builder.iters,
170+
batch_size=s_builder.batch_size,
201171
resume_model=args.resume_model,
202172
save_interval=args.save_interval,
203173
log_iters=args.log_iters,
204174
num_workers=args.num_workers,
205175
use_vdl=args.use_vdl,
206-
losses=s_cfg.loss,
207-
distill_losses=s_cfg.distill_loss,
176+
losses=s_builder.loss,
177+
distill_losses=s_builder.distill_loss,
208178
keep_checkpoint_max=args.keep_checkpoint_max,
209179
test_config=s_cfg.test_config, )
210180

deploy/slim/prune/prune.py

+18-25
Original file line numberDiff line numberDiff line change
@@ -18,14 +18,14 @@
1818
from functools import partial
1919

2020
import yaml
21-
2221
import paddle
2322
from paddleslim.dygraph import L1NormFilterPruner
2423
from paddleslim.analysis import dygraph_flops
25-
from paddleseg.cvlibs.config import Config
24+
25+
from paddleseg.cvlibs.config import Config, SegBuilder
2626
from paddleseg.core.val import evaluate
2727
from paddleseg.core.train import train
28-
from paddleseg.utils import get_sys_env, logger
28+
from paddleseg.utils import logger, utils
2929

3030

3131
def parse_args():
@@ -87,9 +87,9 @@ def eval_fn(net, eval_dataset, num_workers):
8787
return miou
8888

8989

90-
def export_model(net, cfg, save_dir):
90+
def export_model(net, val_dataset, cfg, save_dir):
9191
net.forward = paddle.jit.to_static(net.forward)
92-
input_shape = [1] + list(cfg.val_dataset[0]['img'].shape)
92+
input_shape = [1] + list(val_dataset[0]['img'].shape)
9393
input_var = paddle.ones(input_shape)
9494
out = net(input_var)
9595

@@ -98,7 +98,7 @@ def export_model(net, cfg, save_dir):
9898

9999
yml_file = os.path.join(save_dir, 'deploy.yaml')
100100
with open(yml_file, 'w') as file:
101-
transforms = cfg.dic['val_dataset']['transforms']
101+
transforms = cfg.val_dataset_cfg['transforms']
102102
data = {
103103
'Deploy': {
104104
'transforms': transforms,
@@ -110,11 +110,10 @@ def export_model(net, cfg, save_dir):
110110

111111

112112
def main(args):
113-
env_info = get_sys_env()
113+
env_info = utils.get_sys_env()
114114

115-
place = 'gpu' if env_info['Paddle compiled with cuda'] and env_info[
116-
'GPUs used'] else 'cpu'
117-
paddle.set_device(place)
115+
place = 'gpu' if env_info['GPUs used'] else 'cpu'
116+
utils.set_device(place)
118117

119118
if not (0.0 < args.pruning_ratio < 1.0):
120119
raise RuntimeError(
@@ -123,24 +122,18 @@ def main(args):
123122
if not os.path.exists(args.save_dir):
124123
os.makedirs(args.save_dir)
125124

125+
os.environ['PADDLESEG_EXPORT_STAGE'] = 'True'
126+
126127
cfg = Config(
127128
args.cfg,
128129
iters=args.retraining_iters,
129130
batch_size=args.batch_size,
130131
learning_rate=args.learning_rate)
132+
builder = SegBuilder(cfg)
131133

132-
train_dataset = cfg.train_dataset
133-
if not train_dataset:
134-
raise RuntimeError(
135-
'The training dataset is not specified in the configuration file.')
136-
137-
val_dataset = cfg.val_dataset
138-
if not val_dataset:
139-
raise RuntimeError(
140-
'The validation dataset is not specified in the c;onfiguration file.'
141-
)
142-
os.environ['PADDLESEG_EXPORT_STAGE'] = 'True'
143-
net = cfg.model
134+
train_dataset = builder.train_dataset
135+
val_dataset = builder.val_dataset
136+
net = builder.model
144137

145138
if args.model_path:
146139
para_state_dict = paddle.load(args.model_path)
@@ -180,17 +173,17 @@ def main(args):
180173
train(
181174
net,
182175
train_dataset,
183-
optimizer=cfg.optimizer,
176+
optimizer=builder.optimizer,
184177
save_dir=args.save_dir,
185178
num_workers=args.num_workers,
186179
iters=cfg.iters,
187180
batch_size=cfg.batch_size,
188-
losses=cfg.loss)
181+
losses=builder.loss)
189182

190183
evaluate(net, val_dataset)
191184

192185
if paddle.distributed.get_rank() == 0:
193-
export_model(net, cfg, args.save_dir)
186+
export_model(net, val_dataset, cfg, args.save_dir)
194187

195188
ckpt = os.path.join(args.save_dir, f'iter_{args.retraining_iters}')
196189
if os.path.exists(ckpt):

deploy/slim/quant/ptq.py

+6-10
Original file line numberDiff line numberDiff line change
@@ -17,10 +17,11 @@
1717

1818
import numpy as np
1919
import shutil
20-
from paddleseg.cvlibs import manager, Config
2120
import paddle
2221
from paddleslim.quant import quant_post_static
2322

23+
from paddleseg.cvlibs import Config, SegBuilder
24+
2425
paddle.enable_static()
2526

2627

@@ -63,16 +64,11 @@ def __reader__():
6364
def main(args):
6465
fp32_model_dir = args.model_dir
6566
quant_output_dir = 'quant_model'
67+
6668
cfg = Config(args.cfg)
67-
val_dataset = cfg.val_dataset
68-
if val_dataset is None:
69-
raise RuntimeError(
70-
'The verification dataset is not specified in the configuration file.'
71-
)
72-
elif len(val_dataset) == 0:
73-
raise ValueError(
74-
'The length of val_dataset is 0. Please check if your dataset is valid'
75-
)
69+
builder = SegBuilder(cfg)
70+
71+
val_dataset = builder.val_dataset
7672

7773
use_gpu = True
7874
place = paddle.CUDAPlace(0) if use_gpu else paddle.CPUPlace()

deploy/slim/quant/qat_export.py

+6-4
Original file line numberDiff line numberDiff line change
@@ -15,11 +15,11 @@
1515
import argparse
1616
import os
1717

18-
import paddle
1918
import yaml
19+
import paddle
2020
from paddleslim import QAT
2121

22-
from paddleseg.cvlibs import Config
22+
from paddleseg.cvlibs import Config, SegBuilder
2323
from paddleseg.utils import logger, utils
2424
from paddleseg.deploy.export import WrappedModel
2525
from qat_config import quant_config
@@ -59,7 +59,9 @@ def parse_args():
5959
def main(args):
6060
os.environ['PADDLESEG_EXPORT_STAGE'] = 'True'
6161
cfg = Config(args.config)
62-
net = cfg.model
62+
builder = SegBuilder(cfg)
63+
64+
net = builder.model
6365

6466
skip_quant(net)
6567
quantizer = QAT(config=quant_config)
@@ -92,7 +94,7 @@ def main(args):
9294

9395
yml_file = os.path.join(args.save_dir, 'deploy.yaml')
9496
with open(yml_file, 'w') as file:
95-
transforms = cfg.val_dataset_config.get('transforms', [{
97+
transforms = cfg.val_dataset_cfg.get('transforms', [{
9698
'type': 'Normalize'
9799
}])
98100
data = {

0 commit comments

Comments
 (0)