Skip to content

Commit b840408

Browse files
committed
Add pb method for PR curves
This change adds a `pb` implementation for the PR curves summary, which, like all `pb` implementations, lets users generate summaries without having to use TensorFlow. Also modified the test for PR curve summaries to use small test cases that are easy for a developer to reason through instead of using the demo data. This allows us to use the compute_and_check_summary_pb paradigm for the PR curves summary test, just like for other plugins. Updated the summary module to include `pr_curve_pb`. Fixes #445.
1 parent 2615072 commit b840408

File tree

5 files changed

+275
-272
lines changed

5 files changed

+275
-272
lines changed

Diff for: tensorboard/plugins/pr_curve/BUILD

+1-3
Original file line numberDiff line numberDiff line change
@@ -59,6 +59,7 @@ py_library(
5959
visibility = ["//visibility:public"],
6060
deps = [
6161
":metadata",
62+
"//tensorboard:expect_numpy_installed",
6263
"//tensorboard:expect_tensorflow_installed",
6364
],
6465
)
@@ -69,12 +70,9 @@ py_test(
6970
srcs = ["summary_test.py"],
7071
srcs_version = "PY2AND3",
7172
deps = [
72-
":pr_curve_demo",
7373
":summary",
7474
"//tensorboard:expect_numpy_installed",
7575
"//tensorboard:expect_tensorflow_installed",
76-
"//tensorboard/backend:application",
77-
"//tensorboard/backend/event_processing:event_multiplexer",
7876
"//tensorboard/plugins:base_plugin",
7977
"@org_pocoo_werkzeug",
8078
"@org_pythonhosted_six",

Diff for: tensorboard/plugins/pr_curve/summary.py

+75-3
Original file line numberDiff line numberDiff line change
@@ -22,13 +22,17 @@
2222
from __future__ import division
2323
from __future__ import print_function
2424

25+
import numpy as np
2526
import tensorflow as tf
2627

2728
from tensorboard.plugins.pr_curve import metadata
2829

2930
# A value that we use as the minimum value during division of counts to prevent
30-
# division by 0. 1 suffices because counts of course must be whole numbers.
31-
_MINIMUM_COUNT = 1.0
31+
# division by 0.
32+
_MINIMUM_COUNT = 1e-7
33+
34+
# The default number of thresholds.
35+
_DEFAULT_NUM_THRESHOLDS = 200
3236

3337
def op(
3438
tag,
@@ -78,7 +82,7 @@ def op(
7882
7983
"""
8084
if num_thresholds is None:
81-
num_thresholds = 200
85+
num_thresholds = _DEFAULT_NUM_THRESHOLDS
8286

8387
if weights is None:
8488
weights = 1.0
@@ -164,6 +168,74 @@ def op(
164168
description,
165169
collections)
166170

171+
def pb(tag,
172+
labels,
173+
predictions,
174+
num_thresholds=None,
175+
weights=None,
176+
display_name=None,
177+
description=None):
178+
"""Creates a PR curves summary protobuf
179+
180+
Arguments:
181+
tag: A name for the generated node. Will also serve as a series name in
182+
TensorBoard.
183+
labels: The ground truth values. A bool numpy array.
184+
predictions: A float32 numpy array whose values are in the range `[0, 1]`.
185+
Dimensions must match those of `labels`.
186+
num_thresholds: Optional number of thresholds, evenly distributed in
187+
`[0, 1]`, to compute PR metrics for. Should be `>= 2`. This value should
188+
be a python int. Defaults to 200.
189+
weights: Optional python float or float32 numpy array. Individual counts are
190+
multiplied by this value. This tensor must be either the same shape as
191+
or broadcastable to the `labels` numpy array.
192+
display_name: Optional name for this summary in TensorBoard, as a
193+
constant `str`. Defaults to `name`.
194+
description: Optional long-form description for this summary, as a
195+
constant `str`. Markdown is supported. Defaults to empty.
196+
"""
197+
if num_thresholds is None:
198+
num_thresholds = _DEFAULT_NUM_THRESHOLDS
199+
200+
if weights is None:
201+
weights = 1.0
202+
203+
# Compute bins of true positives and false positives.
204+
bucket_indices = np.int32(np.floor(predictions * (num_thresholds - 1)))
205+
float_labels = labels.astype(np.float)
206+
histogram_range = (0, num_thresholds - 1)
207+
tp_buckets, _ = np.histogram(
208+
bucket_indices,
209+
bins=num_thresholds,
210+
range=histogram_range,
211+
weights=float_labels * weights)
212+
fp_buckets, _ = np.histogram(
213+
bucket_indices,
214+
bins=num_thresholds,
215+
range=histogram_range,
216+
weights=(1.0 - float_labels) * weights)
217+
218+
# Obtain the reverse cumulative sum.
219+
tp = np.cumsum(tp_buckets[::-1])[::-1]
220+
fp = np.cumsum(fp_buckets[::-1])[::-1]
221+
tn = fp[0] - fp
222+
fn = tp[0] - tp
223+
precision = tp / np.maximum(_MINIMUM_COUNT, tp + fp)
224+
recall = tp / np.maximum(_MINIMUM_COUNT, tp + fn)
225+
226+
if display_name is None:
227+
display_name = tag
228+
summary_metadata = metadata.create_summary_metadata(
229+
display_name=display_name if display_name is not None else tag,
230+
description=description or '',
231+
num_thresholds=num_thresholds)
232+
summary = tf.Summary()
233+
data = np.stack((tp, fp, tn, fn, precision, recall))
234+
tensor = tf.make_tensor_proto(data, dtype=tf.float32)
235+
summary.value.add(tag='%s/pr_curves' % tag,
236+
metadata=summary_metadata,
237+
tensor=tensor)
238+
return summary
167239

168240
def streaming_op(tag,
169241
labels,

0 commit comments

Comments
 (0)