Skip to content

Commit 7605a82

Browse files
authored
Merge pull request #455 from kengz/resume
`train` mode with resume; `enjoy` mode refactor
2 parents b608395 + 111af12 commit 7605a82

24 files changed

+162
-203
lines changed

run_lab.py

+37-14
Original file line numberDiff line numberDiff line change
@@ -1,13 +1,12 @@
11
# The SLM Lab entrypoint
2+
from glob import glob
23
from slm_lab import EVAL_MODES, TRAIN_MODES
34
from slm_lab.experiment import search
45
from slm_lab.experiment.control import Session, Trial, Experiment
56
from slm_lab.lib import logger, util
67
from slm_lab.spec import spec_util
78
import os
8-
import pydash as ps
99
import sys
10-
import torch
1110
import torch.multiprocessing as mp
1211

1312

@@ -19,34 +18,58 @@
1918
logger = logger.get_logger(__name__)
2019

2120

21+
def get_spec(spec_file, spec_name, lab_mode, pre_):
22+
'''Get spec using args processed from inputs'''
23+
if lab_mode in TRAIN_MODES:
24+
if pre_ is None: # new train trial
25+
spec = spec_util.get(spec_file, spec_name)
26+
else:
27+
# for resuming with train@{predir}
28+
# e.g. train@latest (fill find the latest predir)
29+
# e.g. train@data/reinforce_cartpole_2020_04_13_232521
30+
predir = pre_
31+
if predir == 'latest':
32+
predir = sorted(glob(f'data/{spec_name}*/'))[-1] # get the latest predir with spec_name
33+
_, _, _, _, experiment_ts = util.prepath_split(predir) # get experiment_ts to resume train spec
34+
logger.info(f'Resolved to train@{predir}')
35+
spec = spec_util.get(spec_file, spec_name, experiment_ts)
36+
elif lab_mode == 'enjoy':
37+
# for enjoy@{session_spec_file}
38+
# e.g. enjoy@data/reinforce_cartpole_2020_04_13_232521/reinforce_cartpole_t0_s0_spec.json
39+
session_spec_file = pre_
40+
assert session_spec_file is not None, 'enjoy mode must specify a `enjoy@{session_spec_file}`'
41+
spec = util.read(f'{session_spec_file}')
42+
else:
43+
raise ValueError(f'Unrecognizable lab_mode not of {TRAIN_MODES} or {EVAL_MODES}')
44+
return spec
45+
46+
2247
def run_spec(spec, lab_mode):
2348
'''Run a spec in lab_mode'''
24-
os.environ['lab_mode'] = lab_mode
49+
os.environ['lab_mode'] = lab_mode # set lab_mode
50+
spec = spec_util.override_spec(spec, lab_mode) # conditionally override spec
2551
if lab_mode in TRAIN_MODES:
2652
spec_util.save(spec) # first save the new spec
27-
if lab_mode == 'dev':
28-
spec = spec_util.override_dev_spec(spec)
2953
if lab_mode == 'search':
3054
spec_util.tick(spec, 'experiment')
3155
Experiment(spec).run()
3256
else:
3357
spec_util.tick(spec, 'trial')
3458
Trial(spec).run()
3559
elif lab_mode in EVAL_MODES:
36-
spec = spec_util.override_enjoy_spec(spec)
3760
Session(spec).run()
3861
else:
3962
raise ValueError(f'Unrecognizable lab_mode not of {TRAIN_MODES} or {EVAL_MODES}')
4063

4164

42-
def read_spec_and_run(spec_file, spec_name, lab_mode):
65+
def get_spec_and_run(spec_file, spec_name, lab_mode):
4366
'''Read a spec and run it in lab mode'''
4467
logger.info(f'Running lab spec_file:{spec_file} spec_name:{spec_name} in mode:{lab_mode}')
45-
if lab_mode in TRAIN_MODES:
46-
spec = spec_util.get(spec_file, spec_name)
47-
else: # eval mode
48-
lab_mode, prename = lab_mode.split('@')
49-
spec = spec_util.get_eval_spec(spec_file, prename)
68+
if '@' in lab_mode: # process lab_mode@{predir/prename}
69+
lab_mode, pre_ = lab_mode.split('@')
70+
else:
71+
pre_ = None
72+
spec = get_spec(spec_file, spec_name, lab_mode, pre_)
5073

5174
if 'spec_params' not in spec:
5275
run_spec(spec, lab_mode)
@@ -62,10 +85,10 @@ def main():
6285
job_file = args[0] if len(args) == 1 else 'job/experiments.json'
6386
for spec_file, spec_and_mode in util.read(job_file).items():
6487
for spec_name, lab_mode in spec_and_mode.items():
65-
read_spec_and_run(spec_file, spec_name, lab_mode)
88+
get_spec_and_run(spec_file, spec_name, lab_mode)
6689
else: # run single spec
6790
assert len(args) == 3, f'To use sys args, specify spec_file, spec_name, lab_mode'
68-
read_spec_and_run(*args)
91+
get_spec_and_run(*args)
6992

7093

7194
if __name__ == '__main__':

slm_lab/agent/__init__.py

+17-9
Original file line numberDiff line numberDiff line change
@@ -47,7 +47,7 @@ def act(self, state):
4747
def update(self, state, action, reward, next_state, done):
4848
'''Update per timestep after env transitions, e.g. memory, algorithm, update agent params, train net'''
4949
self.body.update(state, action, reward, next_state, done)
50-
if util.in_eval_lab_modes(): # eval does not update agent for training
50+
if util.in_eval_lab_mode(): # eval does not update agent for training
5151
return
5252
self.body.memory.update(state, action, reward, next_state, done)
5353
loss = self.algorithm.train()
@@ -59,7 +59,7 @@ def update(self, state, action, reward, next_state, done):
5959
@lab_api
6060
def save(self, ckpt=None):
6161
'''Save agent'''
62-
if util.in_eval_lab_modes(): # eval does not save new models
62+
if util.in_eval_lab_mode(): # eval does not save new models
6363
return
6464
self.algorithm.save(ckpt=ckpt)
6565

@@ -103,8 +103,16 @@ def __init__(self, env, spec, aeb=(0, 0, 0)):
103103
self.train_df = pd.DataFrame(columns=[
104104
'epi', 't', 'wall_t', 'opt_step', 'frame', 'fps', 'total_reward', 'total_reward_ma', 'loss', 'lr',
105105
'explore_var', 'entropy_coef', 'entropy', 'grad_norm'])
106+
107+
# in train@ mode, override from saved train_df if exists
108+
if util.in_train_lab_mode() and self.spec['meta']['resume']:
109+
train_df_filepath = util.get_session_df_path(self.spec, 'train')
110+
if os.path.exists(train_df_filepath):
111+
self.train_df = util.read(train_df_filepath)
112+
self.env.clock.load(self.train_df)
113+
106114
# track eval data within run_eval. the same as train_df except for reward
107-
if ps.get(self.spec, 'meta.rigorous_eval'):
115+
if self.spec['meta']['rigorous_eval']:
108116
self.eval_df = self.train_df.copy()
109117
else:
110118
self.eval_df = self.train_df
@@ -178,6 +186,7 @@ def ckpt(self, env, df_mode):
178186
df = getattr(self, f'{df_mode}_df')
179187
df.loc[len(df)] = row # append efficiently to df
180188
df.iloc[-1]['total_reward_ma'] = total_reward_ma = df[-viz.PLOT_MA_WINDOW:]['total_reward'].mean()
189+
df.drop_duplicates('frame', inplace=True) # remove any duplicates by the same frame
181190
self.total_reward_ma = total_reward_ma
182191

183192
def get_mean_lr(self):
@@ -192,10 +201,9 @@ def get_mean_lr(self):
192201

193202
def get_log_prefix(self):
194203
'''Get the prefix for logging'''
195-
spec = self.agent.spec
196-
spec_name = spec['name']
197-
trial_index = spec['meta']['trial']
198-
session_index = spec['meta']['session']
204+
spec_name = self.spec['name']
205+
trial_index = self.spec['meta']['trial']
206+
session_index = self.spec['meta']['session']
199207
prefix = f'Trial {trial_index} session {session_index} {spec_name}_t{trial_index}_s{session_index}'
200208
return prefix
201209

@@ -232,8 +240,8 @@ def log_tensorboard(self):
232240
self.tb_actions = [] # store actions for tensorboard
233241
logger.info(f'Using TensorBoard logging for dev mode. Run `tensorboard --logdir={log_prepath}` to start TensorBoard.')
234242

235-
trial_index = self.agent.spec['meta']['trial']
236-
session_index = self.agent.spec['meta']['session']
243+
trial_index = self.spec['meta']['trial']
244+
session_index = self.spec['meta']['session']
237245
if session_index != 0: # log only session 0
238246
return
239247
idx_suffix = f'trial{trial_index}_session{session_index}'

slm_lab/agent/algorithm/actor_critic.py

+1-3
Original file line numberDiff line numberDiff line change
@@ -162,7 +162,7 @@ def init_nets(self, global_nets=None):
162162
self.critic_optim = net_util.get_optim(self.critic_net, self.critic_net.optim_spec)
163163
self.critic_lr_scheduler = net_util.get_lr_scheduler(self.critic_optim, self.critic_net.lr_scheduler_spec)
164164
net_util.set_global_nets(self, global_nets)
165-
self.post_init_nets()
165+
self.end_init_nets()
166166

167167
@lab_api
168168
def calc_pdparam(self, x, net=None):
@@ -278,8 +278,6 @@ def calc_val_loss(self, v_preds, v_targets):
278278

279279
def train(self):
280280
'''Train actor critic by computing the loss in batch efficiently'''
281-
if util.in_eval_lab_modes():
282-
return np.nan
283281
clock = self.body.env.clock
284282
if self.to_train == 1:
285283
batch = self.sample()

slm_lab/agent/algorithm/base.py

+10-11
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,7 @@ def __init__(self, agent, global_nets=None):
2222
self.body = self.agent.body
2323
self.init_algorithm_params()
2424
self.init_nets(global_nets)
25-
logger.info(util.self_desc(self))
25+
logger.info(util.self_desc(self, omit=['algorithm_spec', 'name', 'memory_spec', 'net_spec', 'body']))
2626

2727
@abstractmethod
2828
@lab_api
@@ -37,19 +37,20 @@ def init_nets(self, global_nets=None):
3737
raise NotImplementedError
3838

3939
@lab_api
40-
def post_init_nets(self):
41-
'''
42-
Method to conditionally load models.
43-
Call at the end of init_nets() after setting self.net_names
44-
'''
40+
def end_init_nets(self):
41+
'''Checkers and conditional loaders called at the end of init_nets()'''
42+
# check all nets naming
4543
assert hasattr(self, 'net_names')
4644
for net_name in self.net_names:
4745
assert net_name.endswith('net'), f'Naming convention: net_name must end with "net"; got {net_name}'
48-
if util.in_eval_lab_modes():
46+
47+
# load algorithm if is in train@ resume or enjoy mode
48+
lab_mode = util.get_lab_mode()
49+
if self.agent.spec['meta']['resume'] or lab_mode == 'enjoy':
4950
self.load()
50-
logger.info(f'Loaded algorithm models for lab_mode: {util.get_lab_mode()}')
51+
logger.info(f'Loaded algorithm models for lab_mode: {lab_mode}')
5152
else:
52-
logger.info(f'Initialized algorithm models for lab_mode: {util.get_lab_mode()}')
53+
logger.info(f'Initialized algorithm models for lab_mode: {lab_mode}')
5354

5455
@lab_api
5556
def calc_pdparam(self, x, net=None):
@@ -76,8 +77,6 @@ def sample(self):
7677
@lab_api
7778
def train(self):
7879
'''Implement algorithm train, or throw NotImplementedError'''
79-
if util.in_eval_lab_modes():
80-
return np.nan
8180
raise NotImplementedError
8281

8382
@abstractmethod

slm_lab/agent/algorithm/dqn.py

+3-7
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,9 @@
11
from slm_lab.agent import net
2-
from slm_lab.agent.algorithm import policy_util
32
from slm_lab.agent.algorithm.sarsa import SARSA
43
from slm_lab.agent.net import net_util
5-
from slm_lab.lib import logger, math_util, util
4+
from slm_lab.lib import logger, util
65
from slm_lab.lib.decorator import lab_api
76
import numpy as np
8-
import pydash as ps
97
import torch
108

119
logger = logger.get_logger(__name__)
@@ -87,7 +85,7 @@ def init_nets(self, global_nets=None):
8785
self.optim = net_util.get_optim(self.net, self.net.optim_spec)
8886
self.lr_scheduler = net_util.get_lr_scheduler(self.optim, self.net.lr_scheduler_spec)
8987
net_util.set_global_nets(self, global_nets)
90-
self.post_init_nets()
88+
self.end_init_nets()
9189

9290
def calc_q_loss(self, batch):
9391
'''Compute the Q value loss using predicted and target Q values from the appropriate networks'''
@@ -130,8 +128,6 @@ def train(self):
130128
For each of the batches, the target Q values (q_targets) are computed and a single training step is taken k times
131129
Otherwise this function does nothing.
132130
'''
133-
if util.in_eval_lab_modes():
134-
return np.nan
135131
clock = self.body.env.clock
136132
if self.to_train == 1:
137133
total_loss = torch.tensor(0.0)
@@ -187,7 +183,7 @@ def init_nets(self, global_nets=None):
187183
self.optim = net_util.get_optim(self.net, self.net.optim_spec)
188184
self.lr_scheduler = net_util.get_lr_scheduler(self.optim, self.net.lr_scheduler_spec)
189185
net_util.set_global_nets(self, global_nets)
190-
self.post_init_nets()
186+
self.end_init_nets()
191187
self.online_net = self.target_net
192188
self.eval_net = self.target_net
193189

slm_lab/agent/algorithm/policy_util.py

+3-4
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,6 @@
55
from slm_lab.lib import distribution, logger, math_util, util
66
from torch import distributions
77
import numpy as np
8-
import pydash as ps
98
import torch
109

1110
logger = logger.get_logger(__name__)
@@ -61,7 +60,7 @@ def guard_tensor(state, body):
6160
if isinstance(state, LazyFrames):
6261
state = state.__array__() # realize data
6362
state = torch.from_numpy(state.astype(np.float32))
64-
if not body.env.is_venv or util.in_eval_lab_modes():
63+
if not body.env.is_venv:
6564
# singleton state, unsqueeze as minibatch for net input
6665
state = state.unsqueeze(dim=0)
6766
return state
@@ -142,7 +141,7 @@ def default(state, algorithm, body):
142141

143142
def random(state, algorithm, body):
144143
'''Random action using gym.action_space.sample(), with the same format as default()'''
145-
if body.env.is_venv and not util.in_eval_lab_modes():
144+
if body.env.is_venv:
146145
_action = [body.action_space.sample() for _ in range(body.env.num_envs)]
147146
else:
148147
_action = [body.action_space.sample()]
@@ -269,7 +268,7 @@ def __init__(self, var_decay_spec=None):
269268

270269
def update(self, algorithm, clock):
271270
'''Get an updated value for var'''
272-
if (util.in_eval_lab_modes()) or self._updater_name == 'no_decay':
271+
if (util.in_eval_lab_mode()) or self._updater_name == 'no_decay':
273272
return self.end_val
274273
step = clock.get()
275274
val = self._updater(self.start_val, self.end_val, self.start_step, self.end_step, step)

slm_lab/agent/algorithm/ppo.py

-4
Original file line numberDiff line numberDiff line change
@@ -1,13 +1,11 @@
11
from copy import deepcopy
2-
from slm_lab.agent import net
32
from slm_lab.agent.algorithm import policy_util
43
from slm_lab.agent.algorithm.actor_critic import ActorCritic
54
from slm_lab.agent.net import net_util
65
from slm_lab.lib import logger, math_util, util
76
from slm_lab.lib.decorator import lab_api
87
import math
98
import numpy as np
10-
import pydash as ps
119
import torch
1210

1311
logger = logger.get_logger(__name__)
@@ -168,8 +166,6 @@ def calc_policy_loss(self, batch, pdparams, advs):
168166
return policy_loss
169167

170168
def train(self):
171-
if util.in_eval_lab_modes():
172-
return np.nan
173169
clock = self.body.env.clock
174170
if self.to_train == 1:
175171
net_util.copy(self.net, self.old_net) # update old net

slm_lab/agent/algorithm/random.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
# The random agent algorithm
22
# For basic dev purpose
33
from slm_lab.agent.algorithm.base import Algorithm
4-
from slm_lab.lib import logger, util
4+
from slm_lab.lib import logger
55
from slm_lab.lib.decorator import lab_api
66
import numpy as np
77

@@ -29,7 +29,7 @@ def init_nets(self, global_nets=None):
2929
def act(self, state):
3030
'''Random action'''
3131
body = self.body
32-
if body.env.is_venv and not util.in_eval_lab_modes():
32+
if body.env.is_venv:
3333
action = np.array([body.action_space.sample() for _ in range(body.env.num_envs)])
3434
else:
3535
action = body.action_space.sample()

slm_lab/agent/algorithm/reinforce.py

+1-3
Original file line numberDiff line numberDiff line change
@@ -87,7 +87,7 @@ def init_nets(self, global_nets=None):
8787
self.optim = net_util.get_optim(self.net, self.net.optim_spec)
8888
self.lr_scheduler = net_util.get_lr_scheduler(self.optim, self.net.lr_scheduler_spec)
8989
net_util.set_global_nets(self, global_nets)
90-
self.post_init_nets()
90+
self.end_init_nets()
9191

9292
@lab_api
9393
def calc_pdparam(self, x, net=None):
@@ -145,8 +145,6 @@ def calc_policy_loss(self, batch, pdparams, advs):
145145

146146
@lab_api
147147
def train(self):
148-
if util.in_eval_lab_modes():
149-
return np.nan
150148
clock = self.body.env.clock
151149
if self.to_train == 1:
152150
batch = self.sample()

slm_lab/agent/algorithm/sac.py

+1-3
Original file line numberDiff line numberDiff line change
@@ -90,7 +90,7 @@ def init_nets(self, global_nets=None):
9090
self.alpha_optim = net_util.get_optim(self.log_alpha, self.net.optim_spec)
9191
self.alpha_lr_scheduler = net_util.get_lr_scheduler(self.alpha_optim, self.net.lr_scheduler_spec)
9292
net_util.set_global_nets(self, global_nets)
93-
self.post_init_nets()
93+
self.end_init_nets()
9494

9595
@lab_api
9696
def act(self, state):
@@ -187,8 +187,6 @@ def train_alpha(self, alpha_loss):
187187

188188
def train(self):
189189
'''Train actor critic by computing the loss in batch efficiently'''
190-
if util.in_eval_lab_modes():
191-
return np.nan
192190
clock = self.body.env.clock
193191
if self.to_train == 1:
194192
for _ in range(self.training_iter):

0 commit comments

Comments
 (0)