13
13
# limitations under the License.
14
14
15
15
import argparse
16
- import random
17
16
18
- import paddle
19
- import numpy as np
20
17
from paddleslim .dygraph .dist import Distill
21
18
22
- from paddleseg .cvlibs import manager , Config
23
- from paddleseg .utils import get_sys_env , logger , utils
19
+ from paddleseg .cvlibs import Config , SegBuilder
20
+ from paddleseg .utils import logger , utils
24
21
from distill_utils import distill_train
25
22
from distill_config import prepare_distill_adaptor , prepare_distill_config
26
23
@@ -117,47 +114,29 @@ def prepare_envs(args):
117
114
"""
118
115
Set random seed and the device.
119
116
"""
120
- if args .seed is not None :
121
- paddle .seed (args .seed )
122
- np .random .seed (args .seed )
123
- random .seed (args .seed )
124
117
125
- env_info = get_sys_env ()
126
- info = ['{}: {}' .format (k , v ) for k , v in env_info .items ()]
127
- info = '\n ' .join (['' , format ('Environment Information' , '-^48s' )] + info +
128
- ['-' * 48 ])
129
- logger .info (info )
118
+ utils .set_seed (args .seed )
119
+ utils .show_env_info ()
130
120
131
- place = 'gpu' if env_info ['Paddle compiled with cuda' ] and env_info [
132
- 'GPUs used' ] else 'cpu'
121
+ env_info = utils .get_sys_env ()
122
+ place = 'gpu' if env_info ['GPUs used' ] else 'cpu'
123
+ utils .set_device (place )
133
124
134
- paddle .set_device (place )
135
125
126
+ def main (args ):
127
+
128
+ prepare_envs (args )
136
129
137
- def prepare_config (args ):
138
- """
139
- Create and check the config of student and teacher model.
140
- Note: we only use the dataset generated by the student config.
141
- """
142
130
if args .teather_config is None or args .student_config is None :
143
131
raise RuntimeError ('No configuration file specified.' )
144
-
145
132
t_cfg = Config (args .teather_config )
146
133
s_cfg = Config (
147
134
args .student_config ,
148
135
learning_rate = args .learning_rate ,
149
136
iters = args .iters ,
150
137
batch_size = args .batch_size )
151
-
152
- train_dataset = s_cfg .train_dataset
153
- val_dataset = s_cfg .val_dataset if args .do_eval else None
154
- if train_dataset is None :
155
- raise RuntimeError (
156
- 'The training dataset is not specified in the configuration file.' )
157
- elif len (train_dataset ) == 0 :
158
- raise ValueError (
159
- 'The length of train_dataset is 0. Please check if your dataset is valid'
160
- )
138
+ t_builder = SegBuilder (t_cfg )
139
+ s_builder = SegBuilder (s_cfg )
161
140
162
141
msg = '\n ---------------Teacher Config Information---------------\n '
163
142
msg += str (t_cfg )
@@ -169,21 +148,12 @@ def prepare_config(args):
169
148
msg += '------------------------------------------------'
170
149
logger .info (msg )
171
150
172
- return t_cfg , s_cfg , train_dataset , val_dataset
173
-
174
-
175
- def main (args ):
176
-
177
- prepare_envs (args )
178
-
179
- t_cfg , s_cfg , train_dataset , val_dataset = prepare_config (args )
180
-
181
151
distill_config = prepare_distill_config ()
182
152
183
153
s_adaptor , t_adaptor = prepare_distill_adaptor ()
184
154
185
- t_model = t_cfg .model
186
- s_model = s_cfg .model
155
+ t_model = t_builder .model
156
+ s_model = s_builder .model
187
157
t_model .eval ()
188
158
s_model .train ()
189
159
@@ -192,19 +162,19 @@ def main(args):
192
162
193
163
distill_train (
194
164
distill_model = distill_model ,
195
- train_dataset = train_dataset ,
196
- val_dataset = val_dataset ,
197
- optimizer = s_cfg .optimizer ,
165
+ train_dataset = s_builder . train_dataset ,
166
+ val_dataset = s_builder . val_dataset ,
167
+ optimizer = s_builder .optimizer ,
198
168
save_dir = args .save_dir ,
199
- iters = s_cfg .iters ,
200
- batch_size = s_cfg .batch_size ,
169
+ iters = s_builder .iters ,
170
+ batch_size = s_builder .batch_size ,
201
171
resume_model = args .resume_model ,
202
172
save_interval = args .save_interval ,
203
173
log_iters = args .log_iters ,
204
174
num_workers = args .num_workers ,
205
175
use_vdl = args .use_vdl ,
206
- losses = s_cfg .loss ,
207
- distill_losses = s_cfg .distill_loss ,
176
+ losses = s_builder .loss ,
177
+ distill_losses = s_builder .distill_loss ,
208
178
keep_checkpoint_max = args .keep_checkpoint_max ,
209
179
test_config = s_cfg .test_config , )
210
180
0 commit comments