Skip to content

Commit d697fcc

Browse files
andrewcohGitHub Enterprise
authored and
GitHub Enterprise
committed
Add shared critic configurability for PPO (#45)
1 parent d6666d5 commit d697fcc

File tree

6 files changed

+11
-4
lines changed

6 files changed

+11
-4
lines changed

docs/Training-Configuration-File.md

+1
Original file line numberDiff line numberDiff line change
@@ -63,6 +63,7 @@ the `trainer` setting above).
6363
| `hyperparameters -> epsilon_schedule` | (default = `learning_rate_schedule `) Determines how epsilon changes over time (PPO only). <br><br>`linear` decays epsilon linearly, reaching 0 at max_steps, while `constant` keeps the epsilon constant for the entire training run. If not explicitly set, the default epsilon schedule will be set to `hyperparameters -> learning_rate_schedule`.
6464
| `hyperparameters -> lambd` | (default = `0.95`) Regularization parameter (lambda) used when calculating the Generalized Advantage Estimate ([GAE](https://arxiv.org/abs/1506.02438)). This can be thought of as how much the agent relies on its current value estimate when calculating an updated value estimate. Low values correspond to relying more on the current value estimate (which can be high bias), and high values correspond to relying more on the actual rewards received in the environment (which can be high variance). The parameter provides a trade-off between the two, and the right value can lead to a more stable training process. <br><br>Typical range: `0.9` - `0.95` |
6565
| `hyperparameters -> num_epoch` | (default = `3`) Number of passes to make through the experience buffer when performing gradient descent optimization.The larger the batch_size, the larger it is acceptable to make this. Decreasing this will ensure more stable updates, at the cost of slower learning. <br><br>Typical range: `3` - `10` |
66+
| `hyperparameters -> shared_critic` | (default = `False`) Whether or not the policy and value function networks share a backbone. It may be useful to use a shared backbone when learning from image observations.
6667

6768
### SAC-specific Configurations
6869

docs/Training-ML-Agents.md

+1
Original file line numberDiff line numberDiff line change
@@ -284,6 +284,7 @@ behaviors:
284284
epsilon_schedule: linear
285285
lambd: 0.95
286286
num_epoch: 3
287+
shared_critic: False
287288
288289
# Configuration of the neural network (common to PPO/SAC)
289290
network_settings:

ml-agents/mlagents/trainers/ppo/optimizer_torch.py

+2-1
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,7 @@ def __init__(self, policy: TorchPolicy, trainer_settings: TrainerSettings):
2929
reward_signal_configs = trainer_settings.reward_signals
3030
reward_signal_names = [key.value for key, _ in reward_signal_configs.items()]
3131

32+
params = list(self.policy.actor.parameters())
3233
if policy.shared_critic:
3334
self._critic = policy.actor
3435
else:
@@ -38,8 +39,8 @@ def __init__(self, policy: TorchPolicy, trainer_settings: TrainerSettings):
3839
network_settings=trainer_settings.network_settings,
3940
)
4041
self._critic.to(default_device())
42+
params += list(self._critic.parameters())
4143

42-
params = list(self.policy.actor.parameters()) + list(self._critic.parameters())
4344
self.hyperparameters: PPOSettings = cast(
4445
PPOSettings, trainer_settings.hyperparameters
4546
)

ml-agents/mlagents/trainers/ppo/trainer.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -229,7 +229,7 @@ def create_torch_policy(
229229
behavior_spec,
230230
self.trainer_settings,
231231
condition_sigma_on_obs=False, # Faster training for PPO
232-
separate_critic=True, # Match network architecture with TF
232+
separate_critic=not self.hyperparameters.shared_critic, # Only PPO currently allows shared critic
233233
)
234234
return policy
235235

ml-agents/mlagents/trainers/settings.py

+1
Original file line numberDiff line numberDiff line change
@@ -180,6 +180,7 @@ class PPOSettings(HyperparamSettings):
180180
epsilon: float = 0.2
181181
lambd: float = 0.95
182182
num_epoch: int = 3
183+
shared_critic: bool = False
183184
learning_rate_schedule: ScheduleType = ScheduleType.LINEAR
184185
beta_schedule: ScheduleType = ScheduleType.LINEAR
185186
epsilon_schedule: ScheduleType = ScheduleType.LINEAR

ml-agents/mlagents/trainers/tests/torch/test_simple_rl.py

+5-2
Original file line numberDiff line numberDiff line change
@@ -151,7 +151,8 @@ def test_2d_ppo(action_sizes):
151151

152152
@pytest.mark.parametrize("action_sizes", [(0, 1), (1, 0)])
153153
@pytest.mark.parametrize("num_visual", [1, 2])
154-
def test_visual_ppo(num_visual, action_sizes):
154+
@pytest.mark.parametrize("shared_critic", [True, False])
155+
def test_visual_ppo(shared_critic, num_visual, action_sizes):
155156
env = SimpleEnvironment(
156157
[BRAIN_NAME],
157158
action_sizes=action_sizes,
@@ -160,7 +161,9 @@ def test_visual_ppo(num_visual, action_sizes):
160161
step_size=0.2,
161162
)
162163
new_hyperparams = attr.evolve(
163-
PPO_TORCH_CONFIG.hyperparameters, learning_rate=3.0e-4
164+
PPO_TORCH_CONFIG.hyperparameters,
165+
learning_rate=3.0e-4,
166+
shared_critic=shared_critic,
164167
)
165168
config = attr.evolve(PPO_TORCH_CONFIG, hyperparameters=new_hyperparams)
166169
check_environment_trains(env, {BRAIN_NAME: config})

0 commit comments

Comments
 (0)