Skip to content

Commit 5806dd7

Browse files
authored
Lower _MINIMUM_COUNT to 1e-7 (#644)
PR curve summaries had used 1.0 as the minimum division value while computing precision and recall. That actually does not work because certain weights can cause TP, FP, TN, and FN values to be below 1.0. This change sets that minimum to 1e-7 and adds a test.
1 parent fa1ee26 commit 5806dd7

File tree

2 files changed

+24
-2
lines changed

2 files changed

+24
-2
lines changed

Diff for: tensorboard/plugins/pr_curve/summary.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -28,8 +28,8 @@
2828
from tensorboard.plugins.pr_curve import metadata
2929

3030
# A value that we use as the minimum value during division of counts to prevent
31-
# division by 0. 1 suffices because counts of course must be whole numbers.
32-
_MINIMUM_COUNT = 1.0
31+
# division by 0. 1.0 does not work: Certain weights could cause counts below 1.
32+
_MINIMUM_COUNT = 1e-7
3333

3434
# The default number of thresholds.
3535
_DEFAULT_NUM_THRESHOLDS = 201

Diff for: tensorboard/plugins/pr_curve/summary_test.py

+22
Original file line numberDiff line numberDiff line change
@@ -239,6 +239,28 @@ def test_exhaustive_random_values(self):
239239
values = tf.make_ndarray(pb.value[0].tensor)
240240
self.verify_float_arrays_are_equal(expected, values)
241241

242+
def test_counts_below_1(self):
243+
"""Tests support for counts below 1.
244+
245+
Certain weights cause TP, FP, TN, FN counts to be below 1.
246+
"""
247+
pb = self.compute_and_check_summary_pb(
248+
name='foo',
249+
labels=np.array([True, False, False, True, True, True]),
250+
predictions=np.float32([0.2, 0.3, 0.4, 0.6, 0.7, 0.8]),
251+
num_thresholds=3,
252+
weights=np.float32([0.0, 0.1, 0.2, 0.1, 0.1, 0.0]))
253+
expected = [
254+
[0.2, 0.2, 0.0],
255+
[0.3, 0.0, 0.0],
256+
[0.0, 0.3, 0.3],
257+
[0.0, 0.0, 0.2],
258+
[0.4, 1.0, 0.0],
259+
[1.0, 1.0, 0.0]
260+
]
261+
values = tf.make_ndarray(pb.value[0].tensor)
262+
self.verify_float_arrays_are_equal(expected, values)
263+
242264
def test_raw_data_op(self):
243265
# We pass raw counts and precision/recall values.
244266
op = summary.raw_data_op(

0 commit comments

Comments
 (0)