@@ -247,25 +247,19 @@ def advance(self) -> None:
247
247
248
248
next_learning_team = self .controller .get_learning_team
249
249
250
- # CASE 1: Current learning team is managed by this GhostTrainer.
251
- # If the learning team changes, the following loop over queues will push the
252
- # new policy into the policy queue for the new learning agent if
253
- # that policy is managed by this GhostTrainer. Otherwise, it will save the current snapshot.
254
- # CASE 2: Current learning team is managed by a different GhostTrainer.
255
- # If the learning team changes to a team managed by this GhostTrainer, this loop
256
- # will push the current_snapshot into the correct queue. Otherwise,
257
- # it will continue skipping and swap_snapshot will continue to handle
258
- # pushing fixed snapshots
259
- # Case 3: No team change. The if statement just continues to push the policy
250
+ # Case 1: No team change. The if statement just continues to push the policy
260
251
# into the correct queue (or not if not learning team).
261
252
for brain_name in self ._internal_policy_queues :
262
253
internal_policy_queue = self ._internal_policy_queues [brain_name ]
263
254
try :
264
255
policy = internal_policy_queue .get_nowait ()
265
256
self .current_policy_snapshot [brain_name ] = policy .get_weights ()
266
257
except AgentManagerQueue .Empty :
267
- pass
268
- if next_learning_team in self ._team_to_name_to_policy_queue :
258
+ continue
259
+ if (
260
+ self ._learning_team == next_learning_team
261
+ and next_learning_team in self ._team_to_name_to_policy_queue
262
+ ):
269
263
name_to_policy_queue = self ._team_to_name_to_policy_queue [
270
264
next_learning_team
271
265
]
@@ -277,6 +271,28 @@ def advance(self) -> None:
277
271
policy .load_weights (self .current_policy_snapshot [brain_name ])
278
272
name_to_policy_queue [brain_name ].put (policy )
279
273
274
+ # CASE 2: Current learning team is managed by this GhostTrainer.
275
+ # If the learning team changes, the following loop over queues will push the
276
+ # new policy into the policy queue for the new learning agent if
277
+ # that policy is managed by this GhostTrainer. Otherwise, it will save the current snapshot.
278
+ # CASE 3: Current learning team is managed by a different GhostTrainer.
279
+ # If the learning team changes to a team managed by this GhostTrainer, this loop
280
+ # will push the current_snapshot into the correct queue. Otherwise,
281
+ # it will continue skipping and swap_snapshot will continue to handle
282
+ # pushing fixed snapshots
283
+ if (
284
+ self ._learning_team != next_learning_team
285
+ and next_learning_team in self ._team_to_name_to_policy_queue
286
+ ):
287
+ name_to_policy_queue = self ._team_to_name_to_policy_queue [
288
+ next_learning_team
289
+ ]
290
+ for brain_name in name_to_policy_queue :
291
+ behavior_id = create_name_behavior_id (brain_name , next_learning_team )
292
+ policy = self .get_policy (behavior_id )
293
+ policy .load_weights (self .current_policy_snapshot [brain_name ])
294
+ name_to_policy_queue [brain_name ].put (policy )
295
+
280
296
# Note save and swap should be on different step counters.
281
297
# We don't want to save unless the policy is learning.
282
298
if self .get_step - self .last_save > self .steps_between_save :
0 commit comments