@@ -67,7 +67,7 @@ def __init__(
67
67
self .hyperparameters : SACSettings = cast (
68
68
SACSettings , trainer_settings .hyperparameters
69
69
)
70
- self .step = 0
70
+ self ._step = 0
71
71
72
72
# Don't divide by zero
73
73
self .update_steps = 1
@@ -188,7 +188,7 @@ def _is_ready_update(self) -> bool:
188
188
"""
189
189
return (
190
190
self .update_buffer .num_experiences >= self .hyperparameters .batch_size
191
- and self .step >= self .hyperparameters .buffer_init_steps
191
+ and self ._step >= self .hyperparameters .buffer_init_steps
192
192
)
193
193
194
194
@timed
@@ -251,9 +251,9 @@ def _update_sac_policy(self) -> bool:
251
251
252
252
batch_update_stats : Dict [str , list ] = defaultdict (list )
253
253
while (
254
- self .step - self .hyperparameters .buffer_init_steps
254
+ self ._step - self .hyperparameters .buffer_init_steps
255
255
) / self .update_steps > self .steps_per_update :
256
- logger .debug (f"Updating SAC policy at step { self .step } " )
256
+ logger .debug (f"Updating SAC policy at step { self ._step } " )
257
257
buffer = self .update_buffer
258
258
if self .update_buffer .num_experiences >= self .hyperparameters .batch_size :
259
259
sampled_minibatch = buffer .sample_mini_batch (
@@ -305,12 +305,12 @@ def _update_reward_signals(self) -> None:
305
305
)
306
306
batch_update_stats : Dict [str , list ] = defaultdict (list )
307
307
while (
308
- self .step - self .hyperparameters .buffer_init_steps
308
+ self ._step - self .hyperparameters .buffer_init_steps
309
309
) / self .reward_signal_update_steps > self .reward_signal_steps_per_update :
310
310
# Get minibatches for reward signal update if needed
311
311
reward_signal_minibatches = {}
312
312
for name in self .optimizer .reward_signals .keys ():
313
- logger .debug (f"Updating { name } at step { self .step } " )
313
+ logger .debug (f"Updating { name } at step { self ._step } " )
314
314
if name != "extrinsic" :
315
315
reward_signal_minibatches [name ] = buffer .sample_mini_batch (
316
316
self .hyperparameters .batch_size ,
@@ -355,11 +355,11 @@ def add_policy(
355
355
self .model_saver .initialize_or_load ()
356
356
357
357
# Needed to resume loads properly
358
- self .step = policy .get_current_step ()
358
+ self ._step = policy .get_current_step ()
359
359
# Assume steps were updated at the correct ratio before
360
- self .update_steps = int (max (1 , self .step / self .steps_per_update ))
360
+ self .update_steps = int (max (1 , self ._step / self .steps_per_update ))
361
361
self .reward_signal_update_steps = int (
362
- max (1 , self .step / self .reward_signal_steps_per_update )
362
+ max (1 , self ._step / self .reward_signal_steps_per_update )
363
363
)
364
364
365
365
def get_policy (self , name_behavior_id : str ) -> Policy :
0 commit comments