Skip to content

Commit c56c617

Browse files
authored
Add additional logic to avoid load being called on every advance (#4934)
1 parent aeedd0b commit c56c617

File tree

3 files changed

+30
-20
lines changed

3 files changed

+30
-20
lines changed

com.unity.ml-agents/CHANGELOG.md

+1
Original file line numberDiff line numberDiff line change
@@ -71,6 +71,7 @@ and this project adheres to
7171
while waiting for a connection, and raises a better error message if it crashes. (#4880)
7272
- Passing a `-logfile` option in the `--env-args` option to `mlagents-learn` is
7373
no longer overwritten. (#4880)
74+
- The `load_weights` function was being called unnecessarily often in the Ghost Trainer leading to training slowdowns. (#4934)
7475

7576

7677
## [1.7.2-preview] - 2020-12-22

ml-agents/mlagents/trainers/ghost/trainer.py

+28-12
Original file line numberDiff line numberDiff line change
@@ -247,25 +247,19 @@ def advance(self) -> None:
247247

248248
next_learning_team = self.controller.get_learning_team
249249

250-
# CASE 1: Current learning team is managed by this GhostTrainer.
251-
# If the learning team changes, the following loop over queues will push the
252-
# new policy into the policy queue for the new learning agent if
253-
# that policy is managed by this GhostTrainer. Otherwise, it will save the current snapshot.
254-
# CASE 2: Current learning team is managed by a different GhostTrainer.
255-
# If the learning team changes to a team managed by this GhostTrainer, this loop
256-
# will push the current_snapshot into the correct queue. Otherwise,
257-
# it will continue skipping and swap_snapshot will continue to handle
258-
# pushing fixed snapshots
259-
# Case 3: No team change. The if statement just continues to push the policy
250+
# Case 1: No team change. The if statement just continues to push the policy
260251
# into the correct queue (or not if not learning team).
261252
for brain_name in self._internal_policy_queues:
262253
internal_policy_queue = self._internal_policy_queues[brain_name]
263254
try:
264255
policy = internal_policy_queue.get_nowait()
265256
self.current_policy_snapshot[brain_name] = policy.get_weights()
266257
except AgentManagerQueue.Empty:
267-
pass
268-
if next_learning_team in self._team_to_name_to_policy_queue:
258+
continue
259+
if (
260+
self._learning_team == next_learning_team
261+
and next_learning_team in self._team_to_name_to_policy_queue
262+
):
269263
name_to_policy_queue = self._team_to_name_to_policy_queue[
270264
next_learning_team
271265
]
@@ -277,6 +271,28 @@ def advance(self) -> None:
277271
policy.load_weights(self.current_policy_snapshot[brain_name])
278272
name_to_policy_queue[brain_name].put(policy)
279273

274+
# CASE 2: Current learning team is managed by this GhostTrainer.
275+
# If the learning team changes, the following loop over queues will push the
276+
# new policy into the policy queue for the new learning agent if
277+
# that policy is managed by this GhostTrainer. Otherwise, it will save the current snapshot.
278+
# CASE 3: Current learning team is managed by a different GhostTrainer.
279+
# If the learning team changes to a team managed by this GhostTrainer, this loop
280+
# will push the current_snapshot into the correct queue. Otherwise,
281+
# it will continue skipping and swap_snapshot will continue to handle
282+
# pushing fixed snapshots
283+
if (
284+
self._learning_team != next_learning_team
285+
and next_learning_team in self._team_to_name_to_policy_queue
286+
):
287+
name_to_policy_queue = self._team_to_name_to_policy_queue[
288+
next_learning_team
289+
]
290+
for brain_name in name_to_policy_queue:
291+
behavior_id = create_name_behavior_id(brain_name, next_learning_team)
292+
policy = self.get_policy(behavior_id)
293+
policy.load_weights(self.current_policy_snapshot[brain_name])
294+
name_to_policy_queue[brain_name].put(policy)
295+
280296
# Note save and swap should be on different step counters.
281297
# We don't want to save unless the policy is learning.
282298
if self.get_step - self.last_save > self.steps_between_save:

ml-agents/mlagents/trainers/tests/torch/test_ghost.py

+1-8
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,7 @@ def dummy_config():
2323
VECTOR_ACTION_SPACE = 1
2424
VECTOR_OBS_SPACE = 8
2525
DISCRETE_ACTION_SPACE = [3, 3, 3, 2]
26-
BUFFER_INIT_SAMPLES = 513
26+
BUFFER_INIT_SAMPLES = 10241
2727
NUM_AGENTS = 12
2828

2929

@@ -193,13 +193,6 @@ def test_publish_queue(dummy_config):
193193
# clear
194194
policy_queue1.get_nowait()
195195

196-
mock_specs = mb.setup_test_behavior_specs(
197-
False,
198-
False,
199-
vector_action_space=VECTOR_ACTION_SPACE,
200-
vector_obs_space=VECTOR_OBS_SPACE,
201-
)
202-
203196
buffer = mb.simulate_rollout(BUFFER_INIT_SAMPLES, mock_specs)
204197
# Mock out reward signal eval
205198
copy_buffer_fields(

0 commit comments

Comments
 (0)