diff --git a/tensorboard/plugins/pr_curve/summary.py b/tensorboard/plugins/pr_curve/summary.py index 099096e167..4a9625b80c 100644 --- a/tensorboard/plugins/pr_curve/summary.py +++ b/tensorboard/plugins/pr_curve/summary.py @@ -28,8 +28,8 @@ from tensorboard.plugins.pr_curve import metadata # A value that we use as the minimum value during division of counts to prevent -# division by 0. 1 suffices because counts of course must be whole numbers. -_MINIMUM_COUNT = 1.0 +# division by 0. 1.0 does not work: Certain weights could cause counts below 1. +_MINIMUM_COUNT = 1e-7 # The default number of thresholds. _DEFAULT_NUM_THRESHOLDS = 200 diff --git a/tensorboard/plugins/pr_curve/summary_test.py b/tensorboard/plugins/pr_curve/summary_test.py index 02ac78cad5..3afea2b71a 100644 --- a/tensorboard/plugins/pr_curve/summary_test.py +++ b/tensorboard/plugins/pr_curve/summary_test.py @@ -239,6 +239,28 @@ def test_exhaustive_random_values(self): values = tf.make_ndarray(pb.value[0].tensor) self.verify_float_arrays_are_equal(expected, values) + def test_counts_below_1(self): + """Tests support for counts below 1. + + Certain weights cause TP, FP, TN, FN counts to be below 1. + """ + pb = self.compute_and_check_summary_pb( + name='foo', + labels=np.array([True, False, False, True, True, True]), + predictions=np.float32([0.2, 0.3, 0.4, 0.6, 0.7, 0.8]), + num_thresholds=3, + weights=np.float32([0.0, 0.1, 0.2, 0.1, 0.1, 0.0])) + expected = [ + [0.2, 0.2, 0.0], + [0.3, 0.0, 0.0], + [0.0, 0.3, 0.3], + [0.0, 0.0, 0.2], + [0.4, 1.0, 0.0], + [1.0, 1.0, 0.0] + ] + values = tf.make_ndarray(pb.value[0].tensor) + self.verify_float_arrays_are_equal(expected, values) + def test_raw_data_op(self): # We pass raw counts and precision/recall values. op = summary.raw_data_op(