diff --git a/tensorboard/plugins/pr_curve/BUILD b/tensorboard/plugins/pr_curve/BUILD index 0953f9acfc..3d2e920b81 100644 --- a/tensorboard/plugins/pr_curve/BUILD +++ b/tensorboard/plugins/pr_curve/BUILD @@ -59,6 +59,7 @@ py_library( visibility = ["//visibility:public"], deps = [ ":metadata", + "//tensorboard:expect_numpy_installed", "//tensorboard:expect_tensorflow_installed", ], ) @@ -69,12 +70,9 @@ py_test( srcs = ["summary_test.py"], srcs_version = "PY2AND3", deps = [ - ":pr_curve_demo", ":summary", "//tensorboard:expect_numpy_installed", "//tensorboard:expect_tensorflow_installed", - "//tensorboard/backend:application", - "//tensorboard/backend/event_processing:event_multiplexer", "//tensorboard/plugins:base_plugin", "@org_pocoo_werkzeug", "@org_pythonhosted_six", diff --git a/tensorboard/plugins/pr_curve/pr_curve_demo.py b/tensorboard/plugins/pr_curve/pr_curve_demo.py index 5f9794fc24..7f8a8dc8fa 100644 --- a/tensorboard/plugins/pr_curve/pr_curve_demo.py +++ b/tensorboard/plugins/pr_curve/pr_curve_demo.py @@ -161,7 +161,7 @@ def start_runs( weights = tf.cast(consecutive_indices % 2, dtype=tf.float32) summary.op( - tag=color, + name=color, labels=labels[:, i], predictions=predictions[i], num_thresholds=thresholds, diff --git a/tensorboard/plugins/pr_curve/summary.py b/tensorboard/plugins/pr_curve/summary.py index b9beec1414..099096e167 100644 --- a/tensorboard/plugins/pr_curve/summary.py +++ b/tensorboard/plugins/pr_curve/summary.py @@ -22,6 +22,7 @@ from __future__ import division from __future__ import print_function +import numpy as np import tensorflow as tf from tensorboard.plugins.pr_curve import metadata @@ -30,8 +31,11 @@ # division by 0. 1 suffices because counts of course must be whole numbers. _MINIMUM_COUNT = 1.0 +# The default number of thresholds. +_DEFAULT_NUM_THRESHOLDS = 200 + def op( - tag, + name, labels, predictions, num_thresholds=None, @@ -51,7 +55,7 @@ def op( used to reweight certain values, or more commonly used for masking values. Args: - tag: A tag attached to the summary. Used by TensorBoard for organization. + name: A tag attached to the summary. Used by TensorBoard for organization. labels: The ground truth values. A Tensor of `bool` values with arbitrary shape. predictions: A float32 `Tensor` whose values are in the range `[0, 1]`. @@ -78,14 +82,14 @@ def op( """ if num_thresholds is None: - num_thresholds = 200 + num_thresholds = _DEFAULT_NUM_THRESHOLDS if weights is None: weights = 1.0 dtype = predictions.dtype - with tf.name_scope(tag, values=[labels, predictions, weights]): + with tf.name_scope(name, values=[labels, predictions, weights]): tf.assert_type(labels, tf.bool) # We cast to float to ensure we have 0.0 or 1.0. f_labels = tf.cast(labels, dtype) @@ -152,7 +156,7 @@ def op( recall = tp / tf.maximum(_MINIMUM_COUNT, tp + fn) return _create_tensor_summary( - tag, + name, tp, fp, tn, @@ -164,8 +168,76 @@ def op( description, collections) +def pb(name, + labels, + predictions, + num_thresholds=None, + weights=None, + display_name=None, + description=None): + """Create a PR curves summary protobuf. + + Arguments: + name: A name for the generated node. Will also serve as a series name in + TensorBoard. + labels: The ground truth values. A bool numpy array. + predictions: A float32 numpy array whose values are in the range `[0, 1]`. + Dimensions must match those of `labels`. + num_thresholds: Optional number of thresholds, evenly distributed in + `[0, 1]`, to compute PR metrics for. When provided, should be an int of + value at least 2. Defaults to 200. + weights: Optional float or float32 numpy array. Individual counts are + multiplied by this value. This tensor must be either the same shape as + or broadcastable to the `labels` numpy array. + display_name: Optional name for this summary in TensorBoard, as a `str`. + Defaults to `name`. + description: Optional long-form description for this summary, as a `str`. + Markdown is supported. Defaults to empty. + """ + if num_thresholds is None: + num_thresholds = _DEFAULT_NUM_THRESHOLDS + + if weights is None: + weights = 1.0 -def streaming_op(tag, + # Compute bins of true positives and false positives. + bucket_indices = np.int32(np.floor(predictions * (num_thresholds - 1))) + float_labels = labels.astype(np.float) + histogram_range = (0, num_thresholds - 1) + tp_buckets, _ = np.histogram( + bucket_indices, + bins=num_thresholds, + range=histogram_range, + weights=float_labels * weights) + fp_buckets, _ = np.histogram( + bucket_indices, + bins=num_thresholds, + range=histogram_range, + weights=(1.0 - float_labels) * weights) + + # Obtain the reverse cumulative sum. + tp = np.cumsum(tp_buckets[::-1])[::-1] + fp = np.cumsum(fp_buckets[::-1])[::-1] + tn = fp[0] - fp + fn = tp[0] - tp + precision = tp / np.maximum(_MINIMUM_COUNT, tp + fp) + recall = tp / np.maximum(_MINIMUM_COUNT, tp + fn) + + if display_name is None: + display_name = name + summary_metadata = metadata.create_summary_metadata( + display_name=display_name if display_name is not None else name, + description=description or '', + num_thresholds=num_thresholds) + summary = tf.Summary() + data = np.stack((tp, fp, tn, fn, precision, recall)) + tensor = tf.make_tensor_proto(data, dtype=tf.float32) + summary.value.add(tag='%s/pr_curves' % name, + metadata=summary_metadata, + tensor=tensor) + return summary + +def streaming_op(name, labels, predictions, num_thresholds=200, @@ -186,7 +258,7 @@ def streaming_op(tag, updated with the returned update_op. Args: - tag: A tag attached to the summary. Used by TensorBoard for organization. + name: A tag attached to the summary. Used by TensorBoard for organization. labels: The ground truth values, a `Tensor` whose dimensions must match `predictions`. Will be cast to `bool`. predictions: A floating point `Tensor` of arbitrary shape and whose values @@ -216,7 +288,7 @@ def streaming_op(tag, thresholds = [i / float(num_thresholds - 1) for i in range(num_thresholds)] - with tf.name_scope(tag, values=[labels, predictions, weights]): + with tf.name_scope(name, values=[labels, predictions, weights]): tp, update_tp = tf.metrics.true_positives_at_thresholds( labels=labels, predictions=predictions, @@ -243,7 +315,7 @@ def compute_summary(tp, fp, tn, fn, collections): recall = tp / tf.maximum(_MINIMUM_COUNT, tp + fn) return _create_tensor_summary( - tag, + name, tp, fp, tn, @@ -263,7 +335,7 @@ def compute_summary(tp, fp, tn, fn, collections): def raw_data_op( - tag, + name, true_positive_counts, false_positive_counts, true_negative_counts, @@ -285,7 +357,7 @@ def raw_data_op( differently but still use the PR curves plugin. Args: - tag: A tag attached to the summary. Used by TensorBoard for organization. + name: A tag attached to the summary. Used by TensorBoard for organization. true_positive_counts: A rank-1 tensor of true positive counts. Must contain `num_thresholds` elements and be castable to float32. false_positive_counts: A rank-1 tensor of false positive counts. Must @@ -309,7 +381,7 @@ def raw_data_op( A summary operation for use in a TensorFlow graph. See docs for the `op` method for details on the float32 tensor produced by this summary. """ - with tf.name_scope(tag, values=[ + with tf.name_scope(name, values=[ true_positive_counts, false_positive_counts, true_negative_counts, @@ -318,7 +390,7 @@ def raw_data_op( recall, ]): return _create_tensor_summary( - tag, + name, true_positive_counts, false_positive_counts, true_negative_counts, @@ -331,7 +403,7 @@ def raw_data_op( collections) def _create_tensor_summary( - tag, + name, true_positive_counts, false_positive_counts, true_negative_counts, @@ -355,7 +427,7 @@ def _create_tensor_summary( # Store the number of thresholds within the summary metadata because # that value is constant for all pr curve summaries with the same tag. summary_metadata = metadata.create_summary_metadata( - display_name=display_name if display_name is not None else tag, + display_name=display_name if display_name is not None else name, description=description or '', num_thresholds=num_thresholds) diff --git a/tensorboard/plugins/pr_curve/summary_test.py b/tensorboard/plugins/pr_curve/summary_test.py index 8fb3c0279a..02ac78cad5 100644 --- a/tensorboard/plugins/pr_curve/summary_test.py +++ b/tensorboard/plugins/pr_curve/summary_test.py @@ -22,302 +22,262 @@ import numpy as np import tensorflow as tf -from tensorboard.backend.event_processing import plugin_event_multiplexer as event_multiplexer # pylint: disable=line-too-long from tensorboard.plugins.pr_curve import metadata from tensorboard.plugins.pr_curve import summary -from tensorboard.plugins.pr_curve import pr_curve_demo class PrCurveTest(tf.test.TestCase): def setUp(self): super(PrCurveTest, self).setUp() - self.logdir = self.get_temp_dir() tf.reset_default_graph() + np.random.seed(42) - def generateDemoData(self): - """Generates test data using the plugin demo.""" - pr_curve_demo.run_all( - logdir=self.logdir, - steps=3, - thresholds=5, - verbose=False) - - def createMultiplexer(self): - """Creates a multiplexer for reading data within the logdir.""" - multiplexer = event_multiplexer.EventMultiplexer() - multiplexer.AddRunsFromDirectory(self.logdir) - multiplexer.Reload() - return multiplexer - - def validateTensorEvent(self, expected_step, expected_value, tensor_event): - """Checks that the values stored within a tensor are correct. - - Args: - expected_step: The expected step. - tensor_event: A TensorEvent named tuple. - expected_value: A nested python list of expected float32 values. + def pb_via_op(self, summary_op, feed_dict=None): + with tf.Session() as sess: + actual_pbtxt = sess.run(summary_op, feed_dict=feed_dict or {}) + actual_proto = tf.Summary() + actual_proto.ParseFromString(actual_pbtxt) + return actual_proto + + def normalize_summary_pb(self, pb): + """Pass `pb`'s `TensorProto` through a marshalling roundtrip. + `TensorProto`s can be equal in value even if they are not identical + in representation, because data can be stored in either the + `tensor_content` field or the `${dtype}_value` field. This + normalization ensures a canonical form, and should be used before + comparing two `Summary`s for equality. + """ + result = tf.Summary() + result.MergeFrom(pb) + for value in result.value: + if value.HasField('tensor'): + new_tensor = tf.make_tensor_proto(tf.make_ndarray(value.tensor)) + value.ClearField('tensor') + value.tensor.MergeFrom(new_tensor) + return result + + def compute_and_check_summary_pb(self, + name, + labels, + predictions, + num_thresholds, + weights=None, + display_name=None, + description=None, + feed_dict=None): + """Use both `op` and `pb` to get a summary, asserting equality. + Returns: + a `Summary` protocol buffer """ - self.assertEqual(expected_step, tensor_event.step) - tensor_nd_array = tf.make_ndarray(tensor_event.tensor_proto) + labels_tensor = tf.constant(labels) + predictions_tensor = tf.constant(predictions) + weights_tensor = None if weights is None else tf.constant(weights) + op = summary.op( + name=name, + labels=labels_tensor, + predictions=predictions_tensor, + num_thresholds=num_thresholds, + weights=weights_tensor, + display_name=display_name, + description=description) + pb = self.normalize_summary_pb(summary.pb( + name=name, + labels=labels, + predictions=predictions, + num_thresholds=num_thresholds, + weights=weights, + display_name=display_name, + description=description)) + pb_via_op = self.normalize_summary_pb( + self.pb_via_op(op, feed_dict=feed_dict)) + self.assertProtoEquals(pb, pb_via_op) + return pb + + def verify_float_arrays_are_equal(self, expected, actual): # We use an absolute error instead of a relative one because the expected # values are small. The default relative error (trol) of 1e-7 yields many # undesired test failures. np.testing.assert_allclose( - expected_value, tensor_nd_array, rtol=0, atol=1e-7) - - def testWeight1(self): - self.generateDemoData() - multiplexer = self.createMultiplexer() - - # Verify that the metadata was correctly written. - accumulator = multiplexer.GetAccumulator('colors') - tag_content_dict = accumulator.PluginTagToContent('pr_curves') - - # Test the summary contents. - expected_tags = ['red/pr_curves', 'green/pr_curves', 'blue/pr_curves'] - self.assertItemsEqual(expected_tags, list(tag_content_dict.keys())) - - for tag in expected_tags: - # Parse the data within the JSON string and set the proto's fields. - plugin_data = metadata.parse_plugin_metadata(tag_content_dict[tag]) - self.assertEqual(5, plugin_data.num_thresholds) - - # Test the summary contents. - tensor_events = accumulator.Tensors(tag) - self.assertEqual(3, len(tensor_events)) - - # Test the output for the red classifier. The red classifier has the - # narrowest standard deviation. - tensor_events = accumulator.Tensors('red/pr_curves') - self.validateTensorEvent(0, [ - [100.0, 45.0, 11.0, 2.0, 0.0], # True positives. - [350.0, 50.0, 11.0, 2.0, 0.0], # False positives. - [0.0, 300.0, 339.0, 348.0, 350.0], # True negatives. - [0.0, 55.0, 89.0, 98.0, 100.0], # False negatives. - [0.2222222, 0.4736842, 0.5, 0.5, 0.0], # Precision. - [1.0, 0.45, 0.11, 0.02, 0.0], # Recall. - ], tensor_events[0]) - self.validateTensorEvent(1, [ - [100.0, 41.0, 11.0, 1.0, 0.0], # True positives. - [350.0, 48.0, 7.0, 1.0, 0.0], # False positives. - [0.0, 302.0, 343.0, 349.0, 350.0], # True negatives. - [0.0, 59.0, 89.0, 99.0, 100.0], # False negatives. - [0.2222222, 0.4606742, 0.6111111, 0.5, 0.0], # Precision. - [1.0, 0.41, 0.11, 0.01, 0.0], # Recall. - ], tensor_events[1]) - self.validateTensorEvent(2, [ - [100.0, 39.0, 11.0, 2.0, 0.0], # True positives. - [350.0, 54.0, 13.0, 1.0, 0.0], # False positives. - [0.0, 296.0, 337.0, 349.0, 350.0], # True negatives. - [0.0, 61.0, 89.0, 98.0, 100.0], # False negatives. - [0.2222222, 0.4193548, 0.4583333, 0.6666667, 0.0], # Precision. - [1.0, 0.39, 0.11, 0.02, 0.0], # Recall. - ], tensor_events[2]) - - # Test the output for the green classifier. - tensor_events = accumulator.Tensors('green/pr_curves') - self.validateTensorEvent(0, [ - [200.0, 125.0, 48.0, 7.0, 0.0], # True positives. - [250.0, 100.0, 13.0, 2.0, 0.0], # False positives. - [0.0, 150.0, 237.0, 248.0, 250.0], # True negatives. - [0.0, 75.0, 152.0, 193.0, 200.0], # False negatives. - [0.4444444, 0.5555556, 0.7868853, 0.7777778, 0.0], # Precision. - [1.0, 0.625, 0.24, 0.035, 0.0], # Recall. - ], tensor_events[0]) - self.validateTensorEvent(1, [ - [200.0, 123.0, 36.0, 7.0, 0.0], # True positives. - [250.0, 91.0, 18.0, 2.0, 0.0], # False positives. - [0.0, 159.0, 232.0, 248.0, 250.0], # True negatives. - [0.0, 77.0, 164.0, 193.0, 200.0], # False negatives. - [0.4444444, 0.5747663, 0.6666667, 0.7777778, 0.0], # Precision. - [1.0, 0.615, 0.18, 0.035, 0.0], # Recall. - ], tensor_events[1]) - self.validateTensorEvent(2, [ - [200.0, 116.0, 40.0, 5.0, 0.0], # True positives. - [250.0, 87.0, 18.0, 1.0, 0.0], # False positives. - [0.0, 163.0, 232.0, 249.0, 250.0], # True negatives. - [0.0, 84.0, 160.0, 195.0, 200.0], # False negatives. - [0.4444444, 0.5714286, 0.6896552, 0.8333333, 0.0], # Precision. - [1.0, 0.58, 0.2, 0.025, 0.0], # Recall. - ], tensor_events[2]) - - # Test the output for the blue classifier. The normal distribution that is - # the blue classifier has the widest standard deviation. - tensor_events = accumulator.Tensors('blue/pr_curves') - self.validateTensorEvent(0, [ - [150.0, 126.0, 45.0, 6.0, 0.0], # True positives. - [300.0, 201.0, 38.0, 2.0, 0.0], # False positives. - [0.0, 99.0, 262.0, 298.0, 300.0], # True negatives. - [0.0, 24.0, 105.0, 144.0, 150.0], # False negatives. - [0.3333333, 0.3853211, 0.5421687, 0.75, 0.0], # Precision. - [1.0, 0.84, 0.3, 0.04, 0.0], # Recall. - ], tensor_events[0]) - self.validateTensorEvent(1, [ - [150.0, 128.0, 45.0, 4.0, 0.0], # True positives. - [300.0, 204.0, 39.0, 6.0, 0.0], # False positives. - [0.0, 96.0, 261.0, 294.0, 300.0], # True negatives. - [0.0, 22.0, 105.0, 146.0, 150.0], # False negatives. - [0.3333333, 0.3855422, 0.5357143, 0.4, 0.0], # Precision. - [1.0, 0.8533334, 0.3, 0.0266667, 0.0], # Recall. - ], tensor_events[1]) - self.validateTensorEvent(2, [ - [150.0, 120.0, 39.0, 4.0, 0.0], # True positives. - [300.0, 185.0, 38.0, 2.0, 0.0], # False positives. - [0.0, 115.0, 262.0, 298.0, 300.0], # True negatives. - [0.0, 30.0, 111.0, 146.0, 150.0], # False negatives. - [0.3333333, 0.3934426, 0.5064935, 0.6666667, 0.0], # Precision. - [1.0, 0.8, 0.26, 0.0266667, 0.0], # Recall. - ], tensor_events[2]) - - def testExplicitWeights(self): - self.generateDemoData() - multiplexer = self.createMultiplexer() - - # Verify that the metadata was correctly written. - accumulator = multiplexer.GetAccumulator('mask_every_other_prediction') - tag_content_dict = accumulator.PluginTagToContent('pr_curves') - - # Test the summary contents. - expected_tags = ['red/pr_curves', 'green/pr_curves', 'blue/pr_curves'] - self.assertItemsEqual(expected_tags, list(tag_content_dict.keys())) - - for tag in expected_tags: - # Parse the data within the JSON string and set the proto's fields. - plugin_data = metadata.parse_plugin_metadata(tag_content_dict[tag]) - self.assertEqual(5, plugin_data.num_thresholds) - - # Test the summary contents. - tensor_events = accumulator.Tensors(tag) - self.assertEqual(3, len(tensor_events)) - - # Test the output for the red classifier. The red classifier has the - # narrowest standard deviation. - tensor_events = accumulator.Tensors('red/pr_curves') - self.validateTensorEvent(0, [ - [50.0, 22.0, 4.0, 0.0, 0.0], # True positives. - [175.0, 22.0, 6.0, 1.0, 0.0], # False positives. - [0.0, 153.0, 169.0, 174.0, 175.0], # True negatives. - [0.0, 28.0, 46.0, 50.0, 50.0], # False negatives. - [0.2222222, 0.5, 0.4, 0.0, 0.0], # Precision. - [1.0, 0.44, 0.08, 0.0, 0.0], # Recall. - ], tensor_events[0]) - self.validateTensorEvent(1, [ - [50.0, 17.0, 5.0, 1.0, 0.0], # True positives. - [175.0, 28.0, 1.0, 0.0, 0.0], # False positives. - [0.0, 147.0, 174.0, 175.0, 175.0], # True negatives. - [0.0, 33.0, 45.0, 49.0, 50.0], # False negatives. - [0.2222222, 0.3777778, 0.8333333, 1.0, 0.0], # Precision. - [1.0, 0.34, 0.1, 0.02, 0.0], # Recall. - ], tensor_events[1]) - self.validateTensorEvent(2, [ - [50.0, 18.0, 6.0, 1.0, 0.0], # True positives. - [175.0, 27.0, 6.0, 0.0, 0.0], # False positives. - [0.0, 148.0, 169.0, 175.0, 175.0], # True negatives. - [0.0, 32.0, 44.0, 49.0, 50.0], # False negatives. - [0.2222222, 0.4, 0.5, 1.0, 0.0], # Precision. - [1.0, 0.36, 0.12, 0.02, 0.0], # Recall. - ], tensor_events[2]) - - # Test the output for the green classifier. - tensor_events = accumulator.Tensors('green/pr_curves') - self.validateTensorEvent(0, [ - [100.0, 71.0, 24.0, 2.0, 0.0], # True positives. - [125.0, 51.0, 5.0, 2.0, 0.0], # False positives. - [0.0, 74.0, 120.0, 123.0, 125.0], # True negatives. - [0.0, 29.0, 76.0, 98.0, 100.0], # False negatives. - [0.4444444, 0.5819672, 0.8275862, 0.5, 0.0], # Precision. - [1.0, 0.71, 0.24, 0.02, 0.0], # Recall. - ], tensor_events[0]) - self.validateTensorEvent(1, [ - [100.0, 63.0, 20.0, 5.0, 0.0], # True positives. - [125.0, 42.0, 7.0, 1.0, 0.0], # False positives. - [0.0, 83.0, 118.0, 124.0, 125.0], # True negatives. - [0.0, 37.0, 80.0, 95.0, 100.0], # False negatives. - [0.4444444, 0.6, 0.7407407, 0.8333333, 0.0], # Precision. - [1.0, 0.63, 0.2, 0.05, 0.0], # Recall. - ], tensor_events[1]) - self.validateTensorEvent(2, [ - [100.0, 58.0, 19.0, 2.0, 0.0], # True positives. - [125.0, 40.0, 7.0, 0.0, 0.0], # False positives. - [0.0, 85.0, 118.0, 125.0, 125.0], # True negatives. - [0.0, 42.0, 81.0, 98.0, 100.0], # False negatives. - [0.4444444, 0.5918368, 0.7307692, 1.0, 0.0], # Precision. - [1.0, 0.58, 0.19, 0.02, 0.0], # Recall. - ], tensor_events[2]) - - # Test the output for the blue classifier. The normal distribution that is - # the blue classifier has the widest standard deviation. - tensor_events = accumulator.Tensors('blue/pr_curves') - self.validateTensorEvent(0, [ - [75.0, 64.0, 21.0, 5.0, 0.0], # True positives. - [150.0, 105.0, 18.0, 0.0, 0.0], # False positives. - [0.0, 45.0, 132.0, 150.0, 150.0], # True negatives. - [0.0, 11.0, 54.0, 70.0, 75.0], # False negatives. - [0.3333333, 0.3786982, 0.5384616, 1.0, 0.0], # Precision. - [1.0, 0.8533334, 0.28, 0.0666667, 0.0], # Recall. - ], tensor_events[0]) - self.validateTensorEvent(1, [ - [75.0, 62.0, 21.0, 1.0, 0.0], # True positives. - [150.0, 99.0, 21.0, 3.0, 0.0], # False positives. - [0.0, 51.0, 129.0, 147.0, 150.0], # True negatives. - [0.0, 13.0, 54.0, 74.0, 75.0], # False negatives. - [0.3333333, 0.3850932, 0.5, 0.25, 0.0], # Precision. - [1.0, 0.8266667, 0.28, 0.0133333, 0.0], # Recall. - ], tensor_events[1]) - self.validateTensorEvent(2, [ - [75.0, 61.0, 16.0, 2.0, 0.0], # True positives. - [150.0, 92.0, 20.0, 1.0, 0.0], # False positives. - [0.0, 58.0, 130.0, 149.0, 150.0], # True negatives. - [0.0, 14.0, 59.0, 73.0, 75.0], # False negatives. - [0.3333333, 0.3986928, 0.4444444, 0.6666667, 0.0], # Precision. - [1.0, 0.8133333, 0.2133333, 0.0266667, 0.0], # Recall. - ], tensor_events[2]) - - def testRawDataOp(self): - with tf.summary.FileWriter(self.logdir) as writer, tf.Session() as sess: - # We pass raw counts and precision/recall values. - writer.add_summary(sess.run(summary.raw_data_op( - tag='foo', - true_positive_counts=tf.constant([75, 64, 21, 5, 0]), - false_positive_counts=tf.constant([150, 105, 18, 0, 0]), - true_negative_counts=tf.constant([0, 45, 132, 150, 150]), - false_negative_counts=tf.constant([0, 11, 54, 70, 75]), - precision=tf.constant( - [0.3333333, 0.3786982, 0.5384616, 1.0, 0.0]), - recall=tf.constant([1.0, 0.8533334, 0.28, 0.0666667, 0.0]), - num_thresholds=5, - display_name='some_raw_values', - description='We passed raw values into a summary op.'))) - - multiplexer = self.createMultiplexer() - accumulator = multiplexer.GetAccumulator('.') - tag_content_dict = accumulator.PluginTagToContent('pr_curves') - self.assertItemsEqual(['foo/pr_curves'], list(tag_content_dict.keys())) + expected, actual, rtol=0, atol=1e-7) + + def test_metadata(self): + pb = self.compute_and_check_summary_pb( + name='foo', + labels=np.array([True]), + predictions=np.float32([0.42]), + num_thresholds=3) + summary_metadata = pb.value[0].metadata + plugin_data = summary_metadata.plugin_data + self.assertEqual('foo', summary_metadata.display_name) + self.assertEqual('', summary_metadata.summary_description) + self.assertEqual(metadata.PLUGIN_NAME, plugin_data.plugin_name) + plugin_data = metadata.parse_plugin_metadata( + summary_metadata.plugin_data.content) + self.assertEqual(3, plugin_data.num_thresholds) + + def test_all_true_positives(self): + pb = self.compute_and_check_summary_pb( + name='foo', + labels=np.array([True]), + predictions=np.float32([1]), + num_thresholds=3) + expected = [ + [1.0, 1.0, 1.0], + [0.0, 0.0, 0.0], + [0.0, 0.0, 0.0], + [0.0, 0.0, 0.0], + [1.0, 1.0, 1.0], + [1.0, 1.0, 1.0], + ] + values = tf.make_ndarray(pb.value[0].tensor) + self.verify_float_arrays_are_equal(expected, values) + + def test_all_true_negatives(self): + pb = self.compute_and_check_summary_pb( + name='foo', + labels=np.array([False]), + predictions=np.float32([0]), + num_thresholds=3) + expected = [ + [0.0, 0.0, 0.0], + [1.0, 0.0, 0.0], + [0.0, 1.0, 1.0], + [0.0, 0.0, 0.0], + [0.0, 0.0, 0.0], + [0.0, 0.0, 0.0], + ] + values = tf.make_ndarray(pb.value[0].tensor) + self.verify_float_arrays_are_equal(expected, values) + + def test_all_false_positives(self): + pb = self.compute_and_check_summary_pb( + name='foo', + labels=np.array([False]), + predictions=np.float32([1]), + num_thresholds=3) + expected = [ + [0.0, 0.0, 0.0], + [1.0, 1.0, 1.0], + [0.0, 0.0, 0.0], + [0.0, 0.0, 0.0], + [0.0, 0.0, 0.0], + [0.0, 0.0, 0.0], + ] + values = tf.make_ndarray(pb.value[0].tensor) + self.verify_float_arrays_are_equal(expected, values) + + def test_all_false_negatives(self): + pb = self.compute_and_check_summary_pb( + name='foo', + labels=np.array([True]), + predictions=np.float32([0]), + num_thresholds=3) + expected = [ + [1.0, 0.0, 0.0], + [0.0, 0.0, 0.0], + [0.0, 0.0, 0.0], + [0.0, 1.0, 1.0], + [1.0, 0.0, 0.0], + [1.0, 0.0, 0.0], + ] + values = tf.make_ndarray(pb.value[0].tensor) + self.verify_float_arrays_are_equal(expected, values) + + def test_many_values(self): + pb = self.compute_and_check_summary_pb( + name='foo', + labels=np.array([True, False, False, True, True, True]), + predictions=np.float32([0.2, 0.3, 0.4, 0.6, 0.7, 0.8]), + num_thresholds=3) + expected = [ + [4.0, 3.0, 0.0], + [2.0, 0.0, 0.0], + [0.0, 2.0, 2.0], + [0.0, 1.0, 4.0], + [2.0 / 3.0, 1.0, 0.0], + [1.0, 0.75, 0.0], + ] + values = tf.make_ndarray(pb.value[0].tensor) + self.verify_float_arrays_are_equal(expected, values) + + def test_many_values_with_weights(self): + pb = self.compute_and_check_summary_pb( + name='foo', + labels=np.array([True, False, False, True, True, True]), + predictions=np.float32([0.2, 0.3, 0.4, 0.6, 0.7, 0.8]), + num_thresholds=3, + weights=np.float32([0.0, 0.5, 2.0, 0.0, 0.5, 1.0])) + expected = [ + [1.5, 1.5, 0.0], + [2.5, 0.0, 0.0], + [0.0, 2.5, 2.5], + [0.0, 0.0, 1.5], + [0.375, 1.0, 0.0], + [1.0, 1.0, 0.0] + ] + values = tf.make_ndarray(pb.value[0].tensor) + self.verify_float_arrays_are_equal(expected, values) + + def test_exhaustive_random_values(self): + # Most other tests use small and crafted predictions and labels. + # This test exhaustively generates many data points. + data_points = 420 + pb = self.compute_and_check_summary_pb( + name='foo', + labels=np.random.uniform(size=(data_points,)) > 0.5, + predictions=np.float32(np.random.uniform(size=(data_points,))), + num_thresholds=5) + expected = [ + [218.0, 162.0, 111.0, 55.0, 0.0], + [202.0, 148.0, 98.0, 51.0, 0.0], + [0.0, 54.0, 104.0, 151.0, 202.0], + [0.0, 56.0, 107.0, 163.0, 218.0], + [0.5190476, 0.5225806, 0.5311005, 0.5188679, 0.0], + [1.0, 0.7431192, 0.5091743, 0.2522936, 0.0] + ] + values = tf.make_ndarray(pb.value[0].tensor) + self.verify_float_arrays_are_equal(expected, values) + + def test_raw_data_op(self): + # We pass raw counts and precision/recall values. + op = summary.raw_data_op( + name='foo', + true_positive_counts=tf.constant([75, 64, 21, 5, 0]), + false_positive_counts=tf.constant([150, 105, 18, 0, 0]), + true_negative_counts=tf.constant([0, 45, 132, 150, 150]), + false_negative_counts=tf.constant([0, 11, 54, 70, 75]), + precision=tf.constant( + [0.3333333, 0.3786982, 0.5384616, 1.0, 0.0]), + recall=tf.constant([1.0, 0.8533334, 0.28, 0.0666667, 0.0]), + num_thresholds=5, + display_name='some_raw_values', + description='We passed raw values into a summary op.') + pb = self.pb_via_op(op) # Test the metadata. - summary_metadata = multiplexer.SummaryMetadata('.', 'foo/pr_curves') + summary_metadata = pb.value[0].metadata self.assertEqual('some_raw_values', summary_metadata.display_name) self.assertEqual( 'We passed raw values into a summary op.', summary_metadata.summary_description) + self.assertEqual( + metadata.PLUGIN_NAME, summary_metadata.plugin_data.plugin_name) - # Test the stored plugin data. plugin_data = metadata.parse_plugin_metadata( - tag_content_dict['foo/pr_curves']) + summary_metadata.plugin_data.content) self.assertEqual(5, plugin_data.num_thresholds) # Test the summary contents. - tensor_events = accumulator.Tensors('foo/pr_curves') - self.assertEqual(1, len(tensor_events)) - self.validateTensorEvent(0, [ + values = tf.make_ndarray(pb.value[0].tensor) + self.verify_float_arrays_are_equal([ [75.0, 64.0, 21.0, 5.0, 0.0], # True positives. [150.0, 105.0, 18.0, 0.0, 0.0], # False positives. [0.0, 45.0, 132.0, 150.0, 150.0], # True negatives. [0.0, 11.0, 54.0, 70.0, 75.0], # False negatives. [0.3333333, 0.3786982, 0.5384616, 1.0, 0.0], # Precision. [1.0, 0.8533334, 0.28, 0.0666667, 0.0], # Recall. - ], tensor_events[0]) + ], values) class StreamingOpTest(tf.test.TestCase): @@ -339,15 +299,15 @@ def tensor_via_op(self, summary_op): actual_proto.ParseFromString(actual_pbtxt) return actual_proto - def testMatchesOp(self): + def test_matches_op(self): predictions = tf.constant([0.2, 0.4, 0.5, 0.6, 0.8], dtype=tf.float32) labels = tf.constant([False, True, True, False, True], dtype=tf.bool) - pr_curve, update_op = summary.streaming_op(tag='pr_curve', + pr_curve, update_op = summary.streaming_op(name='pr_curve', predictions=predictions, labels=labels, num_thresholds=10) - expected_pr_curve = summary.op(tag='pr_curve', + expected_pr_curve = summary.op(name='pr_curve', predictions=predictions, labels=labels, num_thresholds=10) @@ -365,17 +325,17 @@ def testMatchesOp(self): self.assertProtoEquals(expected_proto, proto) - def testMatchesOpWithUpdates(self): + def test_matches_op_with_updates(self): predictions = tf.constant([0.2, 0.4, 0.5, 0.6, 0.8], dtype=tf.float32) labels = tf.constant([False, True, True, False, True], dtype=tf.bool) - pr_curve, update_op = summary.streaming_op(tag='pr_curve', + pr_curve, update_op = summary.streaming_op(name='pr_curve', predictions=predictions, labels=labels, num_thresholds=10) complete_predictions = tf.tile(predictions, [3]) complete_labels = tf.tile(labels, [3]) - expected_pr_curve = summary.op(tag='pr_curve', + expected_pr_curve = summary.op(name='pr_curve', predictions=complete_predictions, labels=complete_labels, num_thresholds=10) diff --git a/tensorboard/summary.py b/tensorboard/summary.py index a20600d85c..96da031f00 100644 --- a/tensorboard/summary.py +++ b/tensorboard/summary.py @@ -39,6 +39,7 @@ image_pb = _image_summary.pb pr_curve = _pr_curve_summary.op +pr_curve_pb = _pr_curve_summary.pb pr_curve_streaming_op = _pr_curve_summary.streaming_op pr_curve_raw_data = _pr_curve_summary.raw_data_op diff --git a/tensorboard/summary_test.py b/tensorboard/summary_test.py index c042e4b0d4..f97bf20213 100644 --- a/tensorboard/summary_test.py +++ b/tensorboard/summary_test.py @@ -37,14 +37,6 @@ 'text', ]) -# The subset of `STANDARD_PLUGINS` for which we do not currently have -# functions to generate a summary protobuf outside of a TensorFlow -# graph. This set should ideally be empty; any entries here should be -# considered temporary. -PLUGINS_WITHOUT_PB_FUNCTIONS = frozenset([ - 'pr_curve', # TODO(@chihuahua, #445): Fix this. -]) - class SummaryExportsTest(tf.test.TestCase): @@ -54,9 +46,8 @@ def test_each_plugin_has_an_export(self): def test_plugins_export_pb_functions(self): for plugin in STANDARD_PLUGINS: - if plugin not in PLUGINS_WITHOUT_PB_FUNCTIONS: - self.assertIsInstance( - getattr(summary, '%s_pb' % plugin), collections.Callable) + self.assertIsInstance( + getattr(summary, '%s_pb' % plugin), collections.Callable) def test_all_exports_correspond_to_plugins(self): exports = [name for name in dir(summary) if not name.startswith('_')]