Skip to content

Commit 11e188a

Browse files
authored
Hparams: Generate metric values for data provider-based session groups. (#6543)
Generate metric values for hparams plugin `/session_groups` requests when the session groups are generated from DataProvider.read_hyperparameters(). We need to reuse the logic introduced in #6539 to generate metric_infos for each session group and also query for scalar values. We reuse the existing logic to join the two collections of data into metric_values for the `/session_groups` request. We also continue the work begun in #6541 to improve how we generate sessions - in this case also handling cases where experiment_id is not specified for the session. This became urgently necessary to address in order to get new tests in list_session_groups_test.py to work with existing test data.
1 parent d4df603 commit 11e188a

File tree

4 files changed

+324
-18
lines changed

4 files changed

+324
-18
lines changed

tensorboard/plugins/hparams/backend_context.py

+17-1
Original file line numberDiff line numberDiff line change
@@ -403,7 +403,7 @@ def compute_metric_infos_from_data_provider_session_groups(
403403
self, ctx, experiment_id, session_groups
404404
):
405405
session_runs = set(
406-
f"{s.experiment_id}/{s.run}" if s.run else s.experiment_id
406+
generate_data_provider_session_name(experiment_id, s)
407407
for sg in session_groups
408408
for s in sg.sessions
409409
)
@@ -460,6 +460,22 @@ def _compute_metric_names(self, ctx, experiment_id, session_runs):
460460
return metric_names_list
461461

462462

463+
def generate_data_provider_session_name(experiment_id, session):
464+
"""Generates a name from a HyperparameterSesssionRun.
465+
466+
If the HyperparameterSessionRun contains no experiment or run information
467+
then the name is set to the original experiment_id.
468+
"""
469+
if not session.experiment_id and not session.run:
470+
return experiment_id
471+
elif not session.experiment_id:
472+
return session.run
473+
elif not session.run:
474+
return session.experiment_id
475+
else:
476+
return f"{session.experiment_id}/{session.run}"
477+
478+
463479
def _find_longest_parent_path(path_set, path):
464480
"""Finds the longest "parent-path" of 'path' in 'path_set'.
465481

tensorboard/plugins/hparams/backend_context_test.py

+37
Original file line numberDiff line numberDiff line change
@@ -652,6 +652,43 @@ def test_experiment_from_data_provider_session_group_without_run_name(self):
652652
"""
653653
self.assertProtoEquals(expected_exp, actual_exp)
654654

655+
def test_experiment_from_data_provider_session_group_without_experiment_name(
656+
self,
657+
):
658+
self._mock_tb_context.data_provider.list_tensors.side_effect = None
659+
self._hyperparameters = provider.ListHyperparametersResult(
660+
hyperparameters=[],
661+
session_groups=[
662+
provider.HyperparameterSessionGroup(
663+
root=provider.HyperparameterSessionRun(
664+
experiment_id="", run="exp/session_1"
665+
),
666+
sessions=[
667+
provider.HyperparameterSessionRun(
668+
experiment_id="", run="exp/session_1"
669+
),
670+
],
671+
hyperparameter_values=[],
672+
),
673+
],
674+
)
675+
actual_exp = self._experiment_from_metadata()
676+
expected_exp = """
677+
metric_infos: {
678+
name: {group: '', tag: 'accuracy'}
679+
}
680+
metric_infos: {
681+
name: {group: '', tag: 'loss'}
682+
}
683+
metric_infos: {
684+
name: {group: 'eval', tag: 'loss'}
685+
}
686+
metric_infos: {
687+
name: {group: 'train', tag: 'loss'}
688+
}
689+
"""
690+
self.assertProtoEquals(expected_exp, actual_exp)
691+
655692
def test_experiment_from_data_provider_old_response_type(self):
656693
self._hyperparameters = [
657694
provider.Hyperparameter(

tensorboard/plugins/hparams/list_session_groups.py

+53-17
Original file line numberDiff line numberDiff line change
@@ -25,10 +25,12 @@
2525

2626
from tensorboard.data import provider
2727
from tensorboard.plugins.hparams import api_pb2
28+
from tensorboard.plugins.hparams import backend_context as backend_context_lib
2829
from tensorboard.plugins.hparams import error
2930
from tensorboard.plugins.hparams import json_format_compat
3031
from tensorboard.plugins.hparams import metadata
3132
from tensorboard.plugins.hparams import metrics
33+
from tensorboard.plugins.hparams import plugin_data_pb2
3234

3335

3436
class Handler:
@@ -93,13 +95,15 @@ def _session_groups_from_tags(self):
9395
hparams_run_to_tag_to_content,
9496
# Don't pass any information from the DataProvider since we are only
9597
# examining session groups based on tag metadata
96-
[],
98+
provider.ListHyperparametersResult(
99+
hyperparameters=[], session_groups=[]
100+
),
97101
)
98102
extractors = _create_extractors(self._request.col_params)
99103
filters = _create_filters(self._request.col_params, extractors)
100104

101105
session_groups = self._build_session_groups(
102-
hparams_run_to_tag_to_content, experiment
106+
hparams_run_to_tag_to_content, experiment.metric_infos
103107
)
104108
session_groups = self._filter(session_groups, filters)
105109
self._sort(session_groups, extractors)
@@ -116,16 +120,37 @@ def _session_groups_from_data_provider(self):
116120
sort,
117121
)
118122

123+
metric_infos = self._backend_context.compute_metric_infos_from_data_provider_session_groups(
124+
self._request_context, self._experiment_id, response
125+
)
126+
127+
all_metric_evals = self._backend_context.read_last_scalars(
128+
self._request_context,
129+
self._experiment_id,
130+
run_tag_filter=None,
131+
)
132+
119133
session_groups = []
120134
for provider_group in response:
121-
sessions = [
122-
api_pb2.Session(name=f"{s.experiment_id}/{s.run}")
123-
for s in provider_group.sessions
124-
]
125-
name = (
126-
f"{provider_group.root.experiment_id}/{provider_group.root.run}"
127-
if provider_group.root.run
128-
else provider_group.root.experiment_id
135+
sessions = []
136+
for session in provider_group.sessions:
137+
session_name = (
138+
backend_context_lib.generate_data_provider_session_name(
139+
self._experiment_id, session
140+
)
141+
)
142+
sessions.append(
143+
self._build_session(
144+
metric_infos,
145+
session_name,
146+
plugin_data_pb2.SessionStartInfo(),
147+
plugin_data_pb2.SessionEndInfo(),
148+
all_metric_evals,
149+
)
150+
)
151+
152+
name = backend_context_lib.generate_data_provider_session_name(
153+
self._experiment_id, provider_group.root
129154
)
130155
session_group = api_pb2.SessionGroup(
131156
name=name,
@@ -154,9 +179,16 @@ def _session_groups_from_data_provider(self):
154179

155180
session_groups.append(session_group)
156181

182+
# Compute the session group's aggregated metrics for each group.
183+
for group in session_groups:
184+
if group.sessions:
185+
self._aggregate_metrics(group)
186+
157187
return session_groups
158188

159-
def _build_session_groups(self, hparams_run_to_tag_to_content, experiment):
189+
def _build_session_groups(
190+
self, hparams_run_to_tag_to_content, metric_infos
191+
):
160192
"""Returns a list of SessionGroups protobuffers from the summary
161193
data."""
162194

@@ -178,7 +210,7 @@ def _build_session_groups(self, hparams_run_to_tag_to_content, experiment):
178210
metric_runs = set()
179211
metric_tags = set()
180212
for session_name in session_names:
181-
for metric in experiment.metric_infos:
213+
for metric in metric_infos:
182214
metric_name = metric.name
183215
(run, tag) = metrics.run_tag_from_session_and_metric(
184216
session_name, metric_name
@@ -207,7 +239,11 @@ def _build_session_groups(self, hparams_run_to_tag_to_content, experiment):
207239
tag_to_content[metadata.SESSION_END_INFO_TAG]
208240
)
209241
session = self._build_session(
210-
experiment, session_name, start_info, end_info, all_metric_evals
242+
metric_infos,
243+
session_name,
244+
start_info,
245+
end_info,
246+
all_metric_evals,
211247
)
212248
if session.status in self._request.allowed_statuses:
213249
self._add_session(session, start_info, groups_by_name)
@@ -263,7 +299,7 @@ def _add_session(self, session, start_info, groups_by_name):
263299
groups_by_name[group_name] = group
264300

265301
def _build_session(
266-
self, experiment, name, start_info, end_info, all_metric_evals
302+
self, metric_infos, name, start_info, end_info, all_metric_evals
267303
):
268304
"""Builds a session object."""
269305

@@ -273,7 +309,7 @@ def _build_session(
273309
start_time_secs=start_info.start_time_secs,
274310
model_uri=start_info.model_uri,
275311
metric_values=self._build_session_metric_values(
276-
experiment, name, all_metric_evals
312+
metric_infos, name, all_metric_evals
277313
),
278314
monitor_url=start_info.monitor_url,
279315
)
@@ -283,13 +319,13 @@ def _build_session(
283319
return result
284320

285321
def _build_session_metric_values(
286-
self, experiment, session_name, all_metric_evals
322+
self, metric_infos, session_name, all_metric_evals
287323
):
288324
"""Builds the session metric values."""
289325

290326
# result is a list of api_pb2.MetricValue instances.
291327
result = []
292-
for metric_info in experiment.metric_infos:
328+
for metric_info in metric_infos:
293329
metric_name = metric_info.name
294330
(run, tag) = metrics.run_tag_from_session_and_metric(
295331
session_name, metric_name

0 commit comments

Comments
 (0)