Skip to content

Commit 2a91acc

Browse files
authored
Hparams: Support excluding metric information in HTTP requests. (#6556)
There are some clients of the Hparams HTTP API that do not require the metric information. This includes the metric_infos usually returned in the /experiments request and the metric_values usually returned in the /session_groups request. Since these can be expensive to calculate, we want the option to not calculate and return them in the response. Add option `include_metrics` to both GetExperimentRequest and ListSessionGroupsRequest. If unspecified we treat `include_metrics` as True, for backward compatibility. Honor the `include_metrics` property in all three major cases: When experiment metadata is defined by Experiment tags, by Session tags, or by the DataProvider.
1 parent 91a637e commit 2a91acc

7 files changed

+274
-41
lines changed

tensorboard/plugins/hparams/api.proto

+8-2
Original file line numberDiff line numberDiff line change
@@ -254,17 +254,20 @@ enum Status {
254254

255255
// Parameters for a GetExperiment API call.
256256
// Each experiment is scoped by a unique global id.
257-
// NEXT_TAG: 2
257+
// NEXT_TAG: 3
258258
message GetExperimentRequest {
259259
// REQUIRED
260260
string experiment_name = 1;
261+
262+
// Whether to fetch metrics and include them in the results. Defaults to true.
263+
optional bool include_metrics = 2;
261264
}
262265

263266
// Parameters for a ListSessionGroups API call.
264267
// Computes a list of the current session groups allowing for filtering and
265268
// sorting by metrics and hyperparameter values. Returns a "slice" of
266269
// that list specified by start_index and slice_size.
267-
// NEXT_TAG: 8
270+
// NEXT_TAG: 9
268271
message ListSessionGroupsRequest {
269272
string experiment_name = 6;
270273

@@ -314,6 +317,9 @@ message ListSessionGroupsRequest {
314317
// sorted and filtered by the parameters above (if start_index > total_size
315318
// no session groups are returned).
316319
int32 slice_size = 5;
320+
321+
// Whether to fetch metrics and include them in the results. Defaults to true.
322+
optional bool include_metrics = 8;
317323
}
318324

319325
// Defines parmeters for a ListSessionGroupsRequest for a specific column.

tensorboard/plugins/hparams/backend_context.py

+34-10
Original file line numberDiff line numberDiff line change
@@ -57,6 +57,7 @@ def experiment_from_metadata(
5757
self,
5858
ctx,
5959
experiment_id,
60+
include_metrics,
6061
hparams_run_to_tag_to_content,
6162
data_provider_hparams,
6263
):
@@ -76,6 +77,8 @@ def experiment_from_metadata(
7677
7778
Args:
7879
experiment_id: String, from `plugin_util.experiment_id`.
80+
include_metrics: Whether to determine metrics_infos and include them
81+
in the result.
7982
hparams_run_to_tag_to_content: The output from an hparams_metadata()
8083
call. A dict `d` such that `d[run][tag]` is a `bytes` value with the
8184
summary metadata content for the keyed time series.
@@ -87,19 +90,21 @@ def experiment_from_metadata(
8790
The experiment proto. If no data is found for an experiment proto to
8891
be built, returns an entirely empty experiment.
8992
"""
90-
experiment = self._find_experiment_tag(hparams_run_to_tag_to_content)
93+
experiment = self._find_experiment_tag(
94+
hparams_run_to_tag_to_content, include_metrics
95+
)
9196
if experiment:
9297
return experiment
9398

9499
experiment_from_runs = self._compute_experiment_from_runs(
95-
ctx, experiment_id, hparams_run_to_tag_to_content
100+
ctx, experiment_id, include_metrics, hparams_run_to_tag_to_content
96101
)
97102
if experiment_from_runs:
98103
return experiment_from_runs
99104

100105
experiment_from_data_provider_hparams = (
101106
self._experiment_from_data_provider_hparams(
102-
ctx, experiment_id, data_provider_hparams
107+
ctx, experiment_id, include_metrics, data_provider_hparams
103108
)
104109
)
105110
return (
@@ -202,7 +207,9 @@ def session_groups_from_data_provider(
202207
ctx, experiment_ids=[experiment_id], filters=filters, sort=sort
203208
)
204209

205-
def _find_experiment_tag(self, hparams_run_to_tag_to_content):
210+
def _find_experiment_tag(
211+
self, hparams_run_to_tag_to_content, include_metrics
212+
):
206213
"""Finds the experiment associcated with the metadata.EXPERIMENT_TAG
207214
tag.
208215
@@ -214,23 +221,34 @@ def _find_experiment_tag(self, hparams_run_to_tag_to_content):
214221
for tags in hparams_run_to_tag_to_content.values():
215222
maybe_content = tags.get(metadata.EXPERIMENT_TAG)
216223
if maybe_content is not None:
217-
return metadata.parse_experiment_plugin_data(maybe_content)
224+
experiment = metadata.parse_experiment_plugin_data(
225+
maybe_content
226+
)
227+
if not include_metrics:
228+
# metric_infos haven't technically been "calculated" in this
229+
# case. They have been read directly from the Experiment
230+
# proto.
231+
# Delete them from the result so that they are not returned
232+
# to the client.
233+
experiment.ClearField("metric_infos")
234+
return experiment
218235
return None
219236

220237
def _compute_experiment_from_runs(
221-
self, ctx, experiment_id, hparams_run_to_tag_to_content
238+
self, ctx, experiment_id, include_metrics, hparams_run_to_tag_to_content
222239
):
223240
"""Computes a minimal Experiment protocol buffer by scanning the runs.
224241
225242
Returns None if there are no hparam infos logged.
226243
"""
227244
hparam_infos = self._compute_hparam_infos(hparams_run_to_tag_to_content)
228-
if hparam_infos:
229-
metric_infos = self._compute_metric_infos_from_runs(
245+
metric_infos = (
246+
self._compute_metric_infos_from_runs(
230247
ctx, experiment_id, hparams_run_to_tag_to_content
231248
)
232-
else:
233-
metric_infos = []
249+
if hparam_infos and include_metrics
250+
else []
251+
)
234252
if not hparam_infos and not metric_infos:
235253
return None
236254

@@ -320,11 +338,15 @@ def _experiment_from_data_provider_hparams(
320338
self,
321339
ctx,
322340
experiment_id,
341+
include_metrics,
323342
data_provider_hparams,
324343
):
325344
"""Returns an experiment protobuffer based on data provider hparams.
326345
327346
Args:
347+
experiment_id: String, from `plugin_util.experiment_id`.
348+
include_metrics: Whether to determine metrics_infos and include them
349+
in the result.
328350
data_provider_hparams: The ouput from an hparams_from_data_provider()
329351
call, corresponding to DataProvider.list_hyperparameters().
330352
A provider.ListHyperparametersResult.
@@ -352,6 +374,8 @@ def _experiment_from_data_provider_hparams(
352374
self.compute_metric_infos_from_data_provider_session_groups(
353375
ctx, experiment_id, session_groups
354376
)
377+
if include_metrics
378+
else []
355379
)
356380
return api_pb2.Experiment(
357381
hparam_infos=hparam_infos, metric_infos=metric_infos

tensorboard/plugins/hparams/backend_context_test.py

+81-6
Original file line numberDiff line numberDiff line change
@@ -152,13 +152,14 @@ def _mock_list_hyperparameters(
152152
):
153153
return self._hyperparameters
154154

155-
def _experiment_from_metadata(self):
155+
def _experiment_from_metadata(self, *, include_metrics=True):
156156
"""Calls the expected operations for generating an Experiment proto."""
157157
ctxt = backend_context.Context(self._mock_tb_context)
158158
request_ctx = context.RequestContext()
159159
return ctxt.experiment_from_metadata(
160160
request_ctx,
161161
"123",
162+
include_metrics,
162163
ctxt.hparams_metadata(request_ctx, "123"),
163164
ctxt.hparams_from_data_provider(request_ctx, "123"),
164165
)
@@ -187,7 +188,39 @@ def test_experiment_with_experiment_tag(self):
187188
}
188189
self.assertProtoEquals(experiment, self._experiment_from_metadata())
189190

190-
def test_experiment_without_experiment_tag(self):
191+
def test_experiment_with_experiment_tag_include_metrics(self):
192+
experiment = """
193+
description: 'Test experiment'
194+
metric_infos: [
195+
{ name: { tag: 'current_temp' } },
196+
{ name: { tag: 'delta_temp' } }
197+
]
198+
"""
199+
run = "exp"
200+
tag = metadata.EXPERIMENT_TAG
201+
t = provider.TensorTimeSeries(
202+
max_step=0,
203+
max_wall_time=0,
204+
plugin_content=self._serialized_plugin_data(
205+
DATA_TYPE_EXPERIMENT, experiment
206+
),
207+
description="",
208+
display_name="",
209+
)
210+
self._mock_tb_context.data_provider.list_tensors.side_effect = None
211+
self._mock_tb_context.data_provider.list_tensors.return_value = {
212+
run: {tag: t}
213+
}
214+
215+
with self.subTest("False"):
216+
response = self._experiment_from_metadata(include_metrics=False)
217+
self.assertEmpty(response.metric_infos)
218+
219+
with self.subTest("True"):
220+
response = self._experiment_from_metadata(include_metrics=True)
221+
self.assertLen(response.metric_infos, 2)
222+
223+
def test_experiment_with_session_tags(self):
191224
self.session_1_start_info_ = """
192225
hparams: [
193226
{key: 'batch_size' value: {number_value: 100}},
@@ -243,7 +276,7 @@ def test_experiment_without_experiment_tag(self):
243276
_canonicalize_experiment(actual_exp)
244277
self.assertProtoEquals(expected_exp, actual_exp)
245278

246-
def test_experiment_without_experiment_tag_different_hparam_types(self):
279+
def test_experiment_with_session_tags_different_hparam_types(self):
247280
self.session_1_start_info_ = """
248281
hparams:[
249282
{key: 'batch_size' value: {number_value: 100}},
@@ -304,7 +337,7 @@ def test_experiment_without_experiment_tag_different_hparam_types(self):
304337
_canonicalize_experiment(actual_exp)
305338
self.assertProtoEquals(expected_exp, actual_exp)
306339

307-
def test_experiment_with_bool_types(self):
340+
def test_experiment_with_session_tags_bool_types(self):
308341
self.session_1_start_info_ = """
309342
hparams:[
310343
{key: 'batch_size' value: {bool_value: true}}
@@ -344,7 +377,9 @@ def test_experiment_with_bool_types(self):
344377
_canonicalize_experiment(actual_exp)
345378
self.assertProtoEquals(expected_exp, actual_exp)
346379

347-
def test_experiment_with_string_domain_and_invalid_number_values(self):
380+
def test_experiment_with_session_tags_string_domain_and_invalid_number_values(
381+
self,
382+
):
348383
self.session_1_start_info_ = """
349384
hparams:[
350385
{key: 'maybe_invalid' value: {string_value: 'force_to_string_type'}}
@@ -371,8 +406,21 @@ def test_experiment_with_string_domain_and_invalid_number_values(self):
371406
self.assertLen(actual_exp.hparam_infos, 1)
372407
self.assertProtoEquals(expected_hparam_info, actual_exp.hparam_infos[0])
373408

409+
def test_experiment_with_session_tags_include_metrics(self):
410+
self.session_1_start_info_ = """
411+
hparams: [
412+
{key: 'batch_size' value: {number_value: 100}}
413+
]
414+
"""
415+
with self.subTest("False"):
416+
response = self._experiment_from_metadata(include_metrics=False)
417+
self.assertEmpty(response.metric_infos)
418+
419+
with self.subTest("True"):
420+
response = self._experiment_from_metadata(include_metrics=True)
421+
self.assertLen(response.metric_infos, 4)
422+
374423
def test_experiment_without_any_hparams(self):
375-
request_ctx = context.RequestContext()
376424
actual_exp = self._experiment_from_metadata()
377425
self.assertIsInstance(actual_exp, api_pb2.Experiment)
378426
self.assertProtoEquals("", actual_exp)
@@ -789,6 +837,33 @@ def test_experiment_from_data_provider_session_group_without_session_names(
789837
"""
790838
self.assertProtoEquals(expected_exp, actual_exp)
791839

840+
def test_experiment_from_data_provider_include_metrics(self):
841+
self._mock_tb_context.data_provider.list_tensors.side_effect = None
842+
self._hyperparameters = provider.ListHyperparametersResult(
843+
hyperparameters=[],
844+
session_groups=[
845+
provider.HyperparameterSessionGroup(
846+
root=provider.HyperparameterSessionRun(
847+
experiment_id="exp", run=""
848+
),
849+
sessions=[
850+
provider.HyperparameterSessionRun(
851+
experiment_id="exp", run="session_1"
852+
),
853+
],
854+
hyperparameter_values=[],
855+
),
856+
],
857+
)
858+
859+
with self.subTest("False"):
860+
response = self._experiment_from_metadata(include_metrics=False)
861+
self.assertEmpty(response.metric_infos)
862+
863+
with self.subTest("True"):
864+
response = self._experiment_from_metadata(include_metrics=True)
865+
self.assertLen(response.metric_infos, 4)
866+
792867
def test_experiment_from_data_provider_old_response_type(self):
793868
self._hyperparameters = [
794869
provider.Hyperparameter(

tensorboard/plugins/hparams/get_experiment.py

+14-5
Original file line numberDiff line numberDiff line change
@@ -18,32 +18,41 @@
1818
class Handler:
1919
"""Handles a GetExperiment request."""
2020

21-
def __init__(self, request_context, backend_context, experiment_id):
21+
def __init__(
22+
self, request_context, backend_context, experiment_id, request
23+
):
2224
"""Constructor.
2325
2426
Args:
2527
request_context: A tensorboard.context.RequestContext.
2628
backend_context: A backend_context.Context instance.
2729
experiment_id: A string, as from `plugin_util.experiment_id`.
30+
request: A api_pb2.GetExperimentRequest instance.
2831
"""
2932
self._request_context = request_context
3033
self._backend_context = backend_context
3134
self._experiment_id = experiment_id
35+
self._include_metrics = (
36+
# Metrics are included by default if include_metrics is not
37+
# specified in the request.
38+
not request.HasField("include_metrics")
39+
or request.include_metrics
40+
)
3241

3342
def run(self):
3443
"""Handles the request specified on construction.
3544
3645
Returns:
3746
An Experiment object.
3847
"""
39-
experiment_id = self._experiment_id
4048
return self._backend_context.experiment_from_metadata(
4149
self._request_context,
42-
experiment_id,
50+
self._experiment_id,
51+
self._include_metrics,
4352
self._backend_context.hparams_metadata(
44-
self._request_context, experiment_id
53+
self._request_context, self._experiment_id
4554
),
4655
self._backend_context.hparams_from_data_provider(
47-
self._request_context, experiment_id
56+
self._request_context, self._experiment_id
4857
),
4958
)

tensorboard/plugins/hparams/hparams_plugin.py

+4-5
Original file line numberDiff line numberDiff line change
@@ -113,15 +113,14 @@ def get_experiment_route(self, request):
113113
ctx = plugin_util.context(request.environ)
114114
experiment_id = plugin_util.experiment_id(request.environ)
115115
try:
116-
# This backend currently ignores the request parameters, but (for a POST)
117-
# we must advance the input stream to skip them -- otherwise the next HTTP
118-
# request will be parsed incorrectly.
119-
_ = _parse_request_argument(request, api_pb2.GetExperimentRequest)
116+
request_proto = _parse_request_argument(
117+
request, api_pb2.GetExperimentRequest
118+
)
120119
return http_util.Respond(
121120
request,
122121
json_format.MessageToJson(
123122
get_experiment.Handler(
124-
ctx, self._context, experiment_id
123+
ctx, self._context, experiment_id, request_proto
125124
).run(),
126125
including_default_value_fields=True,
127126
),

0 commit comments

Comments
 (0)