@@ -152,13 +152,14 @@ def _mock_list_hyperparameters(
152
152
):
153
153
return self ._hyperparameters
154
154
155
- def _experiment_from_metadata (self ):
155
+ def _experiment_from_metadata (self , * , include_metrics = True ):
156
156
"""Calls the expected operations for generating an Experiment proto."""
157
157
ctxt = backend_context .Context (self ._mock_tb_context )
158
158
request_ctx = context .RequestContext ()
159
159
return ctxt .experiment_from_metadata (
160
160
request_ctx ,
161
161
"123" ,
162
+ include_metrics ,
162
163
ctxt .hparams_metadata (request_ctx , "123" ),
163
164
ctxt .hparams_from_data_provider (request_ctx , "123" ),
164
165
)
@@ -187,7 +188,39 @@ def test_experiment_with_experiment_tag(self):
187
188
}
188
189
self .assertProtoEquals (experiment , self ._experiment_from_metadata ())
189
190
190
- def test_experiment_without_experiment_tag (self ):
191
+ def test_experiment_with_experiment_tag_include_metrics (self ):
192
+ experiment = """
193
+ description: 'Test experiment'
194
+ metric_infos: [
195
+ { name: { tag: 'current_temp' } },
196
+ { name: { tag: 'delta_temp' } }
197
+ ]
198
+ """
199
+ run = "exp"
200
+ tag = metadata .EXPERIMENT_TAG
201
+ t = provider .TensorTimeSeries (
202
+ max_step = 0 ,
203
+ max_wall_time = 0 ,
204
+ plugin_content = self ._serialized_plugin_data (
205
+ DATA_TYPE_EXPERIMENT , experiment
206
+ ),
207
+ description = "" ,
208
+ display_name = "" ,
209
+ )
210
+ self ._mock_tb_context .data_provider .list_tensors .side_effect = None
211
+ self ._mock_tb_context .data_provider .list_tensors .return_value = {
212
+ run : {tag : t }
213
+ }
214
+
215
+ with self .subTest ("False" ):
216
+ response = self ._experiment_from_metadata (include_metrics = False )
217
+ self .assertEmpty (response .metric_infos )
218
+
219
+ with self .subTest ("True" ):
220
+ response = self ._experiment_from_metadata (include_metrics = True )
221
+ self .assertLen (response .metric_infos , 2 )
222
+
223
+ def test_experiment_with_session_tags (self ):
191
224
self .session_1_start_info_ = """
192
225
hparams: [
193
226
{key: 'batch_size' value: {number_value: 100}},
@@ -243,7 +276,7 @@ def test_experiment_without_experiment_tag(self):
243
276
_canonicalize_experiment (actual_exp )
244
277
self .assertProtoEquals (expected_exp , actual_exp )
245
278
246
- def test_experiment_without_experiment_tag_different_hparam_types (self ):
279
+ def test_experiment_with_session_tags_different_hparam_types (self ):
247
280
self .session_1_start_info_ = """
248
281
hparams:[
249
282
{key: 'batch_size' value: {number_value: 100}},
@@ -304,7 +337,7 @@ def test_experiment_without_experiment_tag_different_hparam_types(self):
304
337
_canonicalize_experiment (actual_exp )
305
338
self .assertProtoEquals (expected_exp , actual_exp )
306
339
307
- def test_experiment_with_bool_types (self ):
340
+ def test_experiment_with_session_tags_bool_types (self ):
308
341
self .session_1_start_info_ = """
309
342
hparams:[
310
343
{key: 'batch_size' value: {bool_value: true}}
@@ -344,7 +377,9 @@ def test_experiment_with_bool_types(self):
344
377
_canonicalize_experiment (actual_exp )
345
378
self .assertProtoEquals (expected_exp , actual_exp )
346
379
347
- def test_experiment_with_string_domain_and_invalid_number_values (self ):
380
+ def test_experiment_with_session_tags_string_domain_and_invalid_number_values (
381
+ self ,
382
+ ):
348
383
self .session_1_start_info_ = """
349
384
hparams:[
350
385
{key: 'maybe_invalid' value: {string_value: 'force_to_string_type'}}
@@ -371,8 +406,21 @@ def test_experiment_with_string_domain_and_invalid_number_values(self):
371
406
self .assertLen (actual_exp .hparam_infos , 1 )
372
407
self .assertProtoEquals (expected_hparam_info , actual_exp .hparam_infos [0 ])
373
408
409
+ def test_experiment_with_session_tags_include_metrics (self ):
410
+ self .session_1_start_info_ = """
411
+ hparams: [
412
+ {key: 'batch_size' value: {number_value: 100}}
413
+ ]
414
+ """
415
+ with self .subTest ("False" ):
416
+ response = self ._experiment_from_metadata (include_metrics = False )
417
+ self .assertEmpty (response .metric_infos )
418
+
419
+ with self .subTest ("True" ):
420
+ response = self ._experiment_from_metadata (include_metrics = True )
421
+ self .assertLen (response .metric_infos , 4 )
422
+
374
423
def test_experiment_without_any_hparams (self ):
375
- request_ctx = context .RequestContext ()
376
424
actual_exp = self ._experiment_from_metadata ()
377
425
self .assertIsInstance (actual_exp , api_pb2 .Experiment )
378
426
self .assertProtoEquals ("" , actual_exp )
@@ -789,6 +837,33 @@ def test_experiment_from_data_provider_session_group_without_session_names(
789
837
"""
790
838
self .assertProtoEquals (expected_exp , actual_exp )
791
839
840
+ def test_experiment_from_data_provider_include_metrics (self ):
841
+ self ._mock_tb_context .data_provider .list_tensors .side_effect = None
842
+ self ._hyperparameters = provider .ListHyperparametersResult (
843
+ hyperparameters = [],
844
+ session_groups = [
845
+ provider .HyperparameterSessionGroup (
846
+ root = provider .HyperparameterSessionRun (
847
+ experiment_id = "exp" , run = ""
848
+ ),
849
+ sessions = [
850
+ provider .HyperparameterSessionRun (
851
+ experiment_id = "exp" , run = "session_1"
852
+ ),
853
+ ],
854
+ hyperparameter_values = [],
855
+ ),
856
+ ],
857
+ )
858
+
859
+ with self .subTest ("False" ):
860
+ response = self ._experiment_from_metadata (include_metrics = False )
861
+ self .assertEmpty (response .metric_infos )
862
+
863
+ with self .subTest ("True" ):
864
+ response = self ._experiment_from_metadata (include_metrics = True )
865
+ self .assertLen (response .metric_infos , 4 )
866
+
792
867
def test_experiment_from_data_provider_old_response_type (self ):
793
868
self ._hyperparameters = [
794
869
provider .Hyperparameter (
0 commit comments