File tree Expand file tree Collapse file tree 2 files changed +7
-8
lines changed
ml-agents/mlagents/trainers Expand file tree Collapse file tree 2 files changed +7
-8
lines changed Original file line number Diff line number Diff line change @@ -160,13 +160,10 @@ def get_batch(
160
160
)
161
161
if batch_size * training_length > len (self ):
162
162
padding = np .array (self [- 1 ], dtype = np .float32 ) * self .padding_value
163
- return np .array (
164
- [padding ] * (training_length - leftover ) + self [:], dtype = np .float32
165
- )
163
+ return [padding ] * (training_length - leftover ) + self [:]
164
+
166
165
else :
167
- return np .array (
168
- self [len (self ) - batch_size * training_length :], dtype = np .float32
169
- )
166
+ return self [len (self ) - batch_size * training_length :]
170
167
else :
171
168
# The sequences will have overlapping elements
172
169
if batch_size is None :
@@ -182,7 +179,7 @@ def get_batch(
182
179
tmp_list : List [np .ndarray ] = []
183
180
for end in range (len (self ) - batch_size + 1 , len (self ) + 1 ):
184
181
tmp_list += self [end - training_length : end ]
185
- return np . array ( tmp_list , dtype = np . float32 )
182
+ return tmp_list
186
183
187
184
def reset_field (self ) -> None :
188
185
"""
Original file line number Diff line number Diff line change @@ -222,7 +222,9 @@ def _update_policy(self):
222
222
int (self .hyperparameters .batch_size / self .policy .sequence_length ), 1
223
223
)
224
224
225
- advantages = self .update_buffer [BufferKey .ADVANTAGES ].get_batch ()
225
+ advantages = np .array (
226
+ self .update_buffer [BufferKey .ADVANTAGES ].get_batch (), dtype = np .float32
227
+ )
226
228
self .update_buffer [BufferKey .ADVANTAGES ].set (
227
229
(advantages - advantages .mean ()) / (advantages .std () + 1e-10 )
228
230
)
You can’t perform that action at this time.
0 commit comments