Skip to content

Commit 2c03d2b

Browse files
author
Ervin Teng
committed
Buffer fixes
1 parent fce4ad3 commit 2c03d2b

File tree

2 files changed

+7
-8
lines changed

2 files changed

+7
-8
lines changed

ml-agents/mlagents/trainers/buffer.py

Lines changed: 4 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -160,13 +160,10 @@ def get_batch(
160160
)
161161
if batch_size * training_length > len(self):
162162
padding = np.array(self[-1], dtype=np.float32) * self.padding_value
163-
return np.array(
164-
[padding] * (training_length - leftover) + self[:], dtype=np.float32
165-
)
163+
return [padding] * (training_length - leftover) + self[:]
164+
166165
else:
167-
return np.array(
168-
self[len(self) - batch_size * training_length :], dtype=np.float32
169-
)
166+
return self[len(self) - batch_size * training_length :]
170167
else:
171168
# The sequences will have overlapping elements
172169
if batch_size is None:
@@ -182,7 +179,7 @@ def get_batch(
182179
tmp_list: List[np.ndarray] = []
183180
for end in range(len(self) - batch_size + 1, len(self) + 1):
184181
tmp_list += self[end - training_length : end]
185-
return np.array(tmp_list, dtype=np.float32)
182+
return tmp_list
186183

187184
def reset_field(self) -> None:
188185
"""

ml-agents/mlagents/trainers/coma/trainer.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -222,7 +222,9 @@ def _update_policy(self):
222222
int(self.hyperparameters.batch_size / self.policy.sequence_length), 1
223223
)
224224

225-
advantages = self.update_buffer[BufferKey.ADVANTAGES].get_batch()
225+
advantages = np.array(
226+
self.update_buffer[BufferKey.ADVANTAGES].get_batch(), dtype=np.float32
227+
)
226228
self.update_buffer[BufferKey.ADVANTAGES].set(
227229
(advantages - advantages.mean()) / (advantages.std() + 1e-10)
228230
)

0 commit comments

Comments
 (0)