Skip to content

Commit cacfc8c

Browse files
maryamhonariErvin T.
and
Ervin T.
authored
collecting latest step as a stat (#5264) (#5295)
* collecting latest step as a stat * adding a list of hidden_keys to TB summarywriter to hide unnecessary stats from user * fixing precommit * formating * defined the property types * moving custom defaults to get_default_stats_writers * new test for TensorboardWriter.hidden_keys * improved testing * explicit None evaluation Co-authored-by: Ervin T. <[email protected]> * make hidden_keys optional Co-authored-by: Ervin T. <[email protected]> * adding optional argument * lowering the training threshold to 0.8 on test_var_len_obs_and_goal_poca * Update pytest.yml * Do not merge! droping pytest 3.9 job * -add back pytest -format imports and comments * back to default threshold for test_var_len_obs_and_goal_poca Co-authored-by: mahon94 <[email protected]> Co-authored-by: Ervin T. <[email protected]> Co-authored-by: mahon94 <[email protected]> Co-authored-by: Ervin T. <[email protected]>
1 parent 2d39303 commit cacfc8c

File tree

4 files changed

+39
-3
lines changed

4 files changed

+39
-3
lines changed

ml-agents/mlagents/plugins/stats_writer.py

+1
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,7 @@ def get_default_stats_writers(run_options: RunOptions) -> List[StatsWriter]:
3131
TensorboardWriter(
3232
checkpoint_settings.write_path,
3333
clear_past_data=not checkpoint_settings.resume,
34+
hidden_keys=["Is Training", "Step"],
3435
),
3536
GaugeWriter(),
3637
ConsoleWriter(),

ml-agents/mlagents/trainers/stats.py

+12-3
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
from collections import defaultdict
22
from enum import Enum
3-
from typing import List, Dict, NamedTuple, Any
3+
from typing import List, Dict, NamedTuple, Any, Optional
44
import numpy as np
55
import abc
66
import os
@@ -14,7 +14,6 @@
1414
from torch.utils.tensorboard import SummaryWriter
1515
from mlagents.torch_utils.globals import get_rank
1616

17-
1817
logger = get_logger(__name__)
1918

2019

@@ -212,24 +211,34 @@ def add_property(
212211

213212

214213
class TensorboardWriter(StatsWriter):
215-
def __init__(self, base_dir: str, clear_past_data: bool = False):
214+
def __init__(
215+
self,
216+
base_dir: str,
217+
clear_past_data: bool = False,
218+
hidden_keys: Optional[List[str]] = None,
219+
):
216220
"""
217221
A StatsWriter that writes to a Tensorboard summary.
218222
219223
:param base_dir: The directory within which to place all the summaries. Tensorboard files will be written to a
220224
{base_dir}/{category} directory.
221225
:param clear_past_data: Whether or not to clean up existing Tensorboard files associated with the base_dir and
222226
category.
227+
:param hidden_keys: If provided, Tensorboard Writer won't write statistics identified with these Keys in
228+
Tensorboard summary.
223229
"""
224230
self.summary_writers: Dict[str, SummaryWriter] = {}
225231
self.base_dir: str = base_dir
226232
self._clear_past_data = clear_past_data
233+
self.hidden_keys: List[str] = hidden_keys if hidden_keys is not None else []
227234

228235
def write_stats(
229236
self, category: str, values: Dict[str, StatsSummary], step: int
230237
) -> None:
231238
self._maybe_create_summary_writer(category)
232239
for key, value in values.items():
240+
if key in self.hidden_keys:
241+
continue
233242
self.summary_writers[category].add_scalar(
234243
f"{key}", value.aggregated_value, step
235244
)

ml-agents/mlagents/trainers/tests/test_stats.py

+25
Original file line numberDiff line numberDiff line change
@@ -129,6 +129,31 @@ def test_tensorboard_writer_clear(tmp_path):
129129
assert len(os.listdir(os.path.join(tmp_path, "category1"))) == 1
130130

131131

132+
@mock.patch("mlagents.trainers.stats.SummaryWriter")
133+
def test_tensorboard_writer_hidden_keys(mock_summary):
134+
# Test write_stats
135+
category = "category1"
136+
with tempfile.TemporaryDirectory(prefix="unittest-") as base_dir:
137+
tb_writer = TensorboardWriter(
138+
base_dir, clear_past_data=False, hidden_keys="hiddenKey"
139+
)
140+
statssummary1 = StatsSummary(
141+
full_dist=[1.0], aggregation_method=StatsAggregationMethod.AVERAGE
142+
)
143+
tb_writer.write_stats("category1", {"hiddenKey": statssummary1}, 10)
144+
145+
# Test that the filewriter has been created and the directory has been created.
146+
filewriter_dir = "{basedir}/{category}".format(
147+
basedir=base_dir, category=category
148+
)
149+
assert os.path.exists(filewriter_dir)
150+
mock_summary.assert_called_once_with(filewriter_dir)
151+
152+
# Test that the filewriter was not written to since we used the hidden key.
153+
mock_summary.return_value.add_scalar.assert_not_called()
154+
mock_summary.return_value.flush.assert_not_called()
155+
156+
132157
def test_gauge_stat_writer_sanitize():
133158
assert GaugeWriter.sanitize_string("Policy/Learning Rate") == "Policy.LearningRate"
134159
assert (

ml-agents/mlagents/trainers/trainer/rl_trainer.py

+1
Original file line numberDiff line numberDiff line change
@@ -211,6 +211,7 @@ def _increment_step(self, n_steps: int, name_behavior_id: str) -> None:
211211
p = self.get_policy(name_behavior_id)
212212
if p:
213213
p.increment_step(n_steps)
214+
self.stats_reporter.set_stat("Step", float(self.get_step))
214215

215216
def _get_next_interval_step(self, interval: int) -> int:
216217
"""

0 commit comments

Comments
 (0)