Skip to content

Commit 9d0422e

Browse files
committed
modify lstm policy
1 parent 58b416e commit 9d0422e

8 files changed

+47
-41
lines changed

MUJOCO_LOG.TXT

+9
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,9 @@
1+
Wed Jan 8 11:15:37 2020
2+
ERROR: Expired activation key
3+
4+
Wed Jan 8 11:16:43 2020
5+
ERROR: Expired activation key
6+
7+
Wed Jan 8 11:20:15 2020
8+
ERROR: Expired activation key
9+

POMDP/sac_v2_lstm.py

+8-8
Original file line numberDiff line numberDiff line change
@@ -242,16 +242,16 @@ def plot(rewards):
242242
next_state, reward, done, _ = env.step(action)
243243
# env.render()
244244

245-
if step>0:
245+
if step == 0:
246246
ini_hidden_in = hidden_in
247247
ini_hidden_out = hidden_out
248-
episode_state.append(state)
249-
episode_action.append(action)
250-
episode_last_action.append(last_action)
251-
episode_reward.append(reward)
252-
episode_next_state.append(next_state)
253-
episode_done.append(done)
254-
248+
episode_state.append(state)
249+
episode_action.append(action)
250+
episode_last_action.append(last_action)
251+
episode_reward.append(reward)
252+
episode_next_state.append(next_state)
253+
episode_done.append(done)
254+
255255
state = next_state
256256
last_action = action
257257
frame_idx += 1

POMDP/td3_lstm.py

+7-7
Original file line numberDiff line numberDiff line change
@@ -239,15 +239,15 @@ def plot(rewards):
239239
next_state, reward, done, _ = env.step(action)
240240
# env.render()
241241

242-
if step>0:
242+
if step == 0:
243243
ini_hidden_in = hidden_in
244244
ini_hidden_out = hidden_out
245-
episode_state.append(state)
246-
episode_action.append(action)
247-
episode_last_action.append(last_action)
248-
episode_reward.append(reward)
249-
episode_next_state.append(next_state)
250-
episode_done.append(done)
245+
episode_state.append(state)
246+
episode_action.append(action)
247+
episode_last_action.append(last_action)
248+
episode_reward.append(reward)
249+
episode_next_state.append(next_state)
250+
episode_done.append(done)
251251

252252
state = next_state
253253
last_action = action

ppo_continuous.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -55,7 +55,7 @@
5555

5656
##################### hyper parameters ####################
5757

58-
ENV_NAME = 'Pendulum-v0' # environment name HalfCheetah-v2 Pendulum-v0
58+
ENV_NAME = 'HalfCheetah-v2' # environment name HalfCheetah-v2 Pendulum-v0
5959
RANDOMSEED = 2 # random seed
6060

6161
EP_MAX = 1000 # total number of episodes for training

ppo_continuous_multiprocess.py

-2
Original file line numberDiff line numberDiff line change
@@ -455,8 +455,6 @@ def main():
455455

456456
ppo.save_model(MODEL_PATH)
457457

458-
459-
460458
if args.test:
461459
ppo.load_model(MODEL_PATH)
462460
while True:

sac_v2_gru.py

+7-7
Original file line numberDiff line numberDiff line change
@@ -262,15 +262,15 @@ def plot(rewards):
262262
next_state, reward, done, _ = env.step(action)
263263
env.render()
264264

265-
if step>0:
265+
if step == 0:
266266
ini_hidden_in = hidden_in
267267
ini_hidden_out = hidden_out
268-
episode_state.append(state)
269-
episode_action.append(action)
270-
episode_last_action.append(last_action)
271-
episode_reward.append(reward)
272-
episode_next_state.append(next_state)
273-
episode_done.append(done)
268+
episode_state.append(state)
269+
episode_action.append(action)
270+
episode_last_action.append(last_action)
271+
episode_reward.append(reward)
272+
episode_next_state.append(next_state)
273+
episode_done.append(done)
274274

275275
state = next_state
276276
last_action = action

sac_v2_lstm.py

+8-8
Original file line numberDiff line numberDiff line change
@@ -259,17 +259,17 @@ def plot(rewards):
259259
next_state, reward, done, _ = env.step(action, SPARSE_REWARD, SCREEN_SHOT)
260260
else:
261261
next_state, reward, done, _ = env.step(action)
262-
env.render()
262+
# env.render()
263263

264-
if step>0:
264+
if step == 0:
265265
ini_hidden_in = hidden_in
266266
ini_hidden_out = hidden_out
267-
episode_state.append(state)
268-
episode_action.append(action)
269-
episode_last_action.append(last_action)
270-
episode_reward.append(reward)
271-
episode_next_state.append(next_state)
272-
episode_done.append(done)
267+
episode_state.append(state)
268+
episode_action.append(action)
269+
episode_last_action.append(last_action)
270+
episode_reward.append(reward)
271+
episode_next_state.append(next_state)
272+
episode_done.append(done)
273273

274274
state = next_state
275275
last_action = action

td3_lstm.py

+7-8
Original file line numberDiff line numberDiff line change
@@ -248,16 +248,15 @@ def plot(rewards):
248248
else:
249249
next_state, reward, done, _ = env.step(action)
250250
# env.render()
251-
252-
if step>0:
251+
if step == 0:
253252
ini_hidden_in = hidden_in
254253
ini_hidden_out = hidden_out
255-
episode_state.append(state)
256-
episode_action.append(action)
257-
episode_last_action.append(last_action)
258-
episode_reward.append(reward)
259-
episode_next_state.append(next_state)
260-
episode_done.append(done)
254+
episode_state.append(state)
255+
episode_action.append(action)
256+
episode_last_action.append(last_action)
257+
episode_reward.append(reward)
258+
episode_next_state.append(next_state)
259+
episode_done.append(done)
261260

262261
state = next_state
263262
last_action = action

0 commit comments

Comments
 (0)