25
25
26
26
from tensorboard .data import provider
27
27
from tensorboard .plugins .hparams import api_pb2
28
+ from tensorboard .plugins .hparams import backend_context as backend_context_lib
28
29
from tensorboard .plugins .hparams import error
29
30
from tensorboard .plugins .hparams import json_format_compat
30
31
from tensorboard .plugins .hparams import metadata
31
32
from tensorboard .plugins .hparams import metrics
33
+ from tensorboard .plugins .hparams import plugin_data_pb2
32
34
33
35
34
36
class Handler :
@@ -93,13 +95,15 @@ def _session_groups_from_tags(self):
93
95
hparams_run_to_tag_to_content ,
94
96
# Don't pass any information from the DataProvider since we are only
95
97
# examining session groups based on tag metadata
96
- [],
98
+ provider .ListHyperparametersResult (
99
+ hyperparameters = [], session_groups = []
100
+ ),
97
101
)
98
102
extractors = _create_extractors (self ._request .col_params )
99
103
filters = _create_filters (self ._request .col_params , extractors )
100
104
101
105
session_groups = self ._build_session_groups (
102
- hparams_run_to_tag_to_content , experiment
106
+ hparams_run_to_tag_to_content , experiment . metric_infos
103
107
)
104
108
session_groups = self ._filter (session_groups , filters )
105
109
self ._sort (session_groups , extractors )
@@ -116,16 +120,37 @@ def _session_groups_from_data_provider(self):
116
120
sort ,
117
121
)
118
122
123
+ metric_infos = self ._backend_context .compute_metric_infos_from_data_provider_session_groups (
124
+ self ._request_context , self ._experiment_id , response
125
+ )
126
+
127
+ all_metric_evals = self ._backend_context .read_last_scalars (
128
+ self ._request_context ,
129
+ self ._experiment_id ,
130
+ run_tag_filter = None ,
131
+ )
132
+
119
133
session_groups = []
120
134
for provider_group in response :
121
- sessions = [
122
- api_pb2 .Session (name = f"{ s .experiment_id } /{ s .run } " )
123
- for s in provider_group .sessions
124
- ]
125
- name = (
126
- f"{ provider_group .root .experiment_id } /{ provider_group .root .run } "
127
- if provider_group .root .run
128
- else provider_group .root .experiment_id
135
+ sessions = []
136
+ for session in provider_group .sessions :
137
+ session_name = (
138
+ backend_context_lib .generate_data_provider_session_name (
139
+ self ._experiment_id , session
140
+ )
141
+ )
142
+ sessions .append (
143
+ self ._build_session (
144
+ metric_infos ,
145
+ session_name ,
146
+ plugin_data_pb2 .SessionStartInfo (),
147
+ plugin_data_pb2 .SessionEndInfo (),
148
+ all_metric_evals ,
149
+ )
150
+ )
151
+
152
+ name = backend_context_lib .generate_data_provider_session_name (
153
+ self ._experiment_id , provider_group .root
129
154
)
130
155
session_group = api_pb2 .SessionGroup (
131
156
name = name ,
@@ -154,9 +179,16 @@ def _session_groups_from_data_provider(self):
154
179
155
180
session_groups .append (session_group )
156
181
182
+ # Compute the session group's aggregated metrics for each group.
183
+ for group in session_groups :
184
+ if group .sessions :
185
+ self ._aggregate_metrics (group )
186
+
157
187
return session_groups
158
188
159
- def _build_session_groups (self , hparams_run_to_tag_to_content , experiment ):
189
+ def _build_session_groups (
190
+ self , hparams_run_to_tag_to_content , metric_infos
191
+ ):
160
192
"""Returns a list of SessionGroups protobuffers from the summary
161
193
data."""
162
194
@@ -178,7 +210,7 @@ def _build_session_groups(self, hparams_run_to_tag_to_content, experiment):
178
210
metric_runs = set ()
179
211
metric_tags = set ()
180
212
for session_name in session_names :
181
- for metric in experiment . metric_infos :
213
+ for metric in metric_infos :
182
214
metric_name = metric .name
183
215
(run , tag ) = metrics .run_tag_from_session_and_metric (
184
216
session_name , metric_name
@@ -207,7 +239,11 @@ def _build_session_groups(self, hparams_run_to_tag_to_content, experiment):
207
239
tag_to_content [metadata .SESSION_END_INFO_TAG ]
208
240
)
209
241
session = self ._build_session (
210
- experiment , session_name , start_info , end_info , all_metric_evals
242
+ metric_infos ,
243
+ session_name ,
244
+ start_info ,
245
+ end_info ,
246
+ all_metric_evals ,
211
247
)
212
248
if session .status in self ._request .allowed_statuses :
213
249
self ._add_session (session , start_info , groups_by_name )
@@ -263,7 +299,7 @@ def _add_session(self, session, start_info, groups_by_name):
263
299
groups_by_name [group_name ] = group
264
300
265
301
def _build_session (
266
- self , experiment , name , start_info , end_info , all_metric_evals
302
+ self , metric_infos , name , start_info , end_info , all_metric_evals
267
303
):
268
304
"""Builds a session object."""
269
305
@@ -273,7 +309,7 @@ def _build_session(
273
309
start_time_secs = start_info .start_time_secs ,
274
310
model_uri = start_info .model_uri ,
275
311
metric_values = self ._build_session_metric_values (
276
- experiment , name , all_metric_evals
312
+ metric_infos , name , all_metric_evals
277
313
),
278
314
monitor_url = start_info .monitor_url ,
279
315
)
@@ -283,13 +319,13 @@ def _build_session(
283
319
return result
284
320
285
321
def _build_session_metric_values (
286
- self , experiment , session_name , all_metric_evals
322
+ self , metric_infos , session_name , all_metric_evals
287
323
):
288
324
"""Builds the session metric values."""
289
325
290
326
# result is a list of api_pb2.MetricValue instances.
291
327
result = []
292
- for metric_info in experiment . metric_infos :
328
+ for metric_info in metric_infos :
293
329
metric_name = metric_info .name
294
330
(run , tag ) = metrics .run_tag_from_session_and_metric (
295
331
session_name , metric_name
0 commit comments