Skip to content

Commit e15ed90

Browse files
committed
modify ppo std
1 parent 613186b commit e15ed90

File tree

2 files changed

+60
-20
lines changed

2 files changed

+60
-20
lines changed

ppo_continuous.py

+30-10
Original file line numberDiff line numberDiff line change
@@ -55,7 +55,7 @@
5555

5656
##################### hyper parameters ####################
5757

58-
ENV_NAME = 'HalfCheetah-v2' # environment name HalfCheetah-v2 Pendulum-v0
58+
ENV_NAME = 'Pendulum-v0' # environment name HalfCheetah-v2 Pendulum-v0
5959
RANDOMSEED = 2 # random seed
6060

6161
EP_MAX = 1000 # total number of episodes for training
@@ -74,6 +74,20 @@
7474

7575
############################### PPO ####################################
7676

77+
class AddBias(nn.Module):
78+
def __init__(self, bias):
79+
super(AddBias, self).__init__()
80+
self._bias = nn.Parameter(bias.unsqueeze(1))
81+
82+
def forward(self, x):
83+
if x.dim() == 2:
84+
bias = self._bias.t().view(1, -1)
85+
else:
86+
bias = self._bias.t().view(1, -1, 1, 1)
87+
88+
return x + bias
89+
90+
7791
class ValueNetwork(nn.Module):
7892
def __init__(self, state_dim, hidden_dim, init_w=3e-3):
7993
super(ValueNetwork, self).__init__()
@@ -106,12 +120,10 @@ def __init__(self, num_inputs, num_actions, hidden_dim, action_range=1., init_w=
106120
# self.linear4 = nn.Linear(hidden_dim, hidden_dim)
107121

108122
self.mean_linear = nn.Linear(hidden_dim, num_actions)
109-
# self.mean_linear.weight.data.uniform_(-init_w, init_w)
110-
# self.mean_linear.bias.data.uniform_(-init_w, init_w)
111-
112-
self.log_std_linear = nn.Linear(hidden_dim, num_actions)
113-
# self.log_std_linear.weight.data.uniform_(-init_w, init_w)
114-
# self.log_std_linear.bias.data.uniform_(-init_w, init_w)
123+
# implementation 1
124+
# self.log_std_linear = nn.Linear(hidden_dim, num_actions)
125+
# # implementation 2: not dependent on latent features, reference:https://github.com/ikostrikov/pytorch-a2c-ppo-acktr-gail/blob/master/a2c_ppo_acktr/distributions.py
126+
self.log_std = AddBias(torch.zeros(num_actions))
115127

116128
self.num_actions = num_actions
117129
self.action_range = action_range
@@ -123,9 +135,17 @@ def forward(self, state):
123135
# x = F.relu(self.linear4(x))
124136

125137
mean = self.action_range * F.tanh(self.mean_linear(x))
126-
log_std = self.log_std_linear(x)
127-
log_std = torch.clamp(log_std, self.log_std_min, self.log_std_max)
128-
138+
139+
# implementation 1
140+
# log_std = self.log_std_linear(x)
141+
# log_std = torch.clamp(log_std, self.log_std_min, self.log_std_max)
142+
143+
# implementation 2
144+
zeros = torch.zeros(mean.size())
145+
if state.is_cuda:
146+
zeros = zeros.cuda()
147+
log_std = self.log_std(zeros)
148+
129149
return mean, log_std
130150

131151
def get_action(self, state, deterministic=False):

ppo_continuous_multiprocess.py

+30-10
Original file line numberDiff line numberDiff line change
@@ -64,14 +64,29 @@
6464
EPS = 1e-8 # numerical residual
6565
MODEL_PATH = 'model/ppo_multi'
6666
NUM_WORKERS=2 # or: mp.cpu_count()
67-
ACTION_RANGE = 2. # if unnormalized, normalized action range should be 1.
67+
ACTION_RANGE = 1. # if unnormalized, normalized action range should be 1.
6868
METHOD = [
6969
dict(name='kl_pen', kl_target=0.01, lam=0.5), # KL penalty
7070
dict(name='clip', epsilon=0.2), # Clipped surrogate objective, find this is better
7171
][0] # choose the method for optimization
7272

7373
############################### PPO ####################################
7474

75+
76+
class AddBias(nn.Module):
77+
def __init__(self, bias):
78+
super(AddBias, self).__init__()
79+
self._bias = nn.Parameter(bias.unsqueeze(1))
80+
81+
def forward(self, x):
82+
if x.dim() == 2:
83+
bias = self._bias.t().view(1, -1)
84+
else:
85+
bias = self._bias.t().view(1, -1, 1, 1)
86+
87+
return x + bias
88+
89+
7590
class ValueNetwork(nn.Module):
7691
def __init__(self, state_dim, hidden_dim, init_w=3e-3):
7792
super(ValueNetwork, self).__init__()
@@ -104,12 +119,10 @@ def __init__(self, num_inputs, num_actions, hidden_dim, action_range=1., init_w=
104119
# self.linear4 = nn.Linear(hidden_dim, hidden_dim)
105120

106121
self.mean_linear = nn.Linear(hidden_dim, num_actions)
107-
self.mean_linear.weight.data.uniform_(-init_w, init_w)
108-
self.mean_linear.bias.data.uniform_(-init_w, init_w)
109-
110-
self.log_std_linear = nn.Linear(hidden_dim, num_actions)
111-
self.log_std_linear.weight.data.uniform_(-init_w, init_w)
112-
self.log_std_linear.bias.data.uniform_(-init_w, init_w)
122+
# implementation 1
123+
# self.log_std_linear = nn.Linear(hidden_dim, num_actions)
124+
# # implementation 2: not dependent on latent features, reference:https://github.com/ikostrikov/pytorch-a2c-ppo-acktr-gail/blob/master/a2c_ppo_acktr/distributions.py
125+
self.log_std = AddBias(torch.zeros(num_actions))
113126

114127
self.num_actions = num_actions
115128
self.action_range = action_range
@@ -122,8 +135,15 @@ def forward(self, state):
122135
# x = F.relu(self.linear4(x))
123136

124137
mean = self.action_range * F.tanh(self.mean_linear(x))
125-
log_std = self.log_std_linear(x)
126-
log_std = torch.clamp(log_std, self.log_std_min, self.log_std_max)
138+
# implementation 1
139+
# log_std = self.log_std_linear(x)
140+
# log_std = torch.clamp(log_std, self.log_std_min, self.log_std_max)
141+
142+
# implementation 2
143+
zeros = torch.zeros(mean.size())
144+
if state.is_cuda:
145+
zeros = zeros.cuda()
146+
log_std = self.log_std(zeros)
127147

128148
return mean, log_std
129149

@@ -396,7 +416,7 @@ def main():
396416
np.random.seed(RANDOMSEED)
397417
torch.manual_seed(RANDOMSEED)
398418

399-
env = gym.make(ENV_NAME).unwrapped
419+
env = NormalizedActions(gym.make(ENV_NAME).unwrapped)
400420
state_dim = env.observation_space.shape[0]
401421
action_dim = env.action_space.shape[0]
402422

0 commit comments

Comments
 (0)