Skip to content

Commit 18b7e41

Browse files
authored
Add pb method for PR curves (#633)
This change adds a `pb` implementation for the PR curves summary, which, like all `pb` implementations, lets users generate summaries without having to use TensorFlow. Also modified the test for PR curve summaries to use small test cases that are easy for a developer to reason through instead of using the demo data. This allows us to use the `compute_and_check_summary_pb` paradigm for the PR curves summary test, just like for other plugins. Updated the summary module to include `pr_curve_pb`. Fixes #445. As part of this change, renamed the `tag` parameter of summary ops to `name` to be consistent with other summaries.
1 parent f92092e commit 18b7e41

File tree

6 files changed

+322
-300
lines changed

6 files changed

+322
-300
lines changed

Diff for: tensorboard/plugins/pr_curve/BUILD

+1-3
Original file line numberDiff line numberDiff line change
@@ -59,6 +59,7 @@ py_library(
5959
visibility = ["//visibility:public"],
6060
deps = [
6161
":metadata",
62+
"//tensorboard:expect_numpy_installed",
6263
"//tensorboard:expect_tensorflow_installed",
6364
],
6465
)
@@ -69,12 +70,9 @@ py_test(
6970
srcs = ["summary_test.py"],
7071
srcs_version = "PY2AND3",
7172
deps = [
72-
":pr_curve_demo",
7373
":summary",
7474
"//tensorboard:expect_numpy_installed",
7575
"//tensorboard:expect_tensorflow_installed",
76-
"//tensorboard/backend:application",
77-
"//tensorboard/backend/event_processing:event_multiplexer",
7876
"//tensorboard/plugins:base_plugin",
7977
"@org_pocoo_werkzeug",
8078
"@org_pythonhosted_six",

Diff for: tensorboard/plugins/pr_curve/pr_curve_demo.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -161,7 +161,7 @@ def start_runs(
161161
weights = tf.cast(consecutive_indices % 2, dtype=tf.float32)
162162

163163
summary.op(
164-
tag=color,
164+
name=color,
165165
labels=labels[:, i],
166166
predictions=predictions[i],
167167
num_thresholds=thresholds,

Diff for: tensorboard/plugins/pr_curve/summary.py

+87-15
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@
2222
from __future__ import division
2323
from __future__ import print_function
2424

25+
import numpy as np
2526
import tensorflow as tf
2627

2728
from tensorboard.plugins.pr_curve import metadata
@@ -30,8 +31,11 @@
3031
# division by 0. 1 suffices because counts of course must be whole numbers.
3132
_MINIMUM_COUNT = 1.0
3233

34+
# The default number of thresholds.
35+
_DEFAULT_NUM_THRESHOLDS = 200
36+
3337
def op(
34-
tag,
38+
name,
3539
labels,
3640
predictions,
3741
num_thresholds=None,
@@ -51,7 +55,7 @@ def op(
5155
used to reweight certain values, or more commonly used for masking values.
5256
5357
Args:
54-
tag: A tag attached to the summary. Used by TensorBoard for organization.
58+
name: A tag attached to the summary. Used by TensorBoard for organization.
5559
labels: The ground truth values. A Tensor of `bool` values with arbitrary
5660
shape.
5761
predictions: A float32 `Tensor` whose values are in the range `[0, 1]`.
@@ -78,14 +82,14 @@ def op(
7882
7983
"""
8084
if num_thresholds is None:
81-
num_thresholds = 200
85+
num_thresholds = _DEFAULT_NUM_THRESHOLDS
8286

8387
if weights is None:
8488
weights = 1.0
8589

8690
dtype = predictions.dtype
8791

88-
with tf.name_scope(tag, values=[labels, predictions, weights]):
92+
with tf.name_scope(name, values=[labels, predictions, weights]):
8993
tf.assert_type(labels, tf.bool)
9094
# We cast to float to ensure we have 0.0 or 1.0.
9195
f_labels = tf.cast(labels, dtype)
@@ -152,7 +156,7 @@ def op(
152156
recall = tp / tf.maximum(_MINIMUM_COUNT, tp + fn)
153157

154158
return _create_tensor_summary(
155-
tag,
159+
name,
156160
tp,
157161
fp,
158162
tn,
@@ -164,8 +168,76 @@ def op(
164168
description,
165169
collections)
166170

171+
def pb(name,
172+
labels,
173+
predictions,
174+
num_thresholds=None,
175+
weights=None,
176+
display_name=None,
177+
description=None):
178+
"""Create a PR curves summary protobuf.
179+
180+
Arguments:
181+
name: A name for the generated node. Will also serve as a series name in
182+
TensorBoard.
183+
labels: The ground truth values. A bool numpy array.
184+
predictions: A float32 numpy array whose values are in the range `[0, 1]`.
185+
Dimensions must match those of `labels`.
186+
num_thresholds: Optional number of thresholds, evenly distributed in
187+
`[0, 1]`, to compute PR metrics for. When provided, should be an int of
188+
value at least 2. Defaults to 200.
189+
weights: Optional float or float32 numpy array. Individual counts are
190+
multiplied by this value. This tensor must be either the same shape as
191+
or broadcastable to the `labels` numpy array.
192+
display_name: Optional name for this summary in TensorBoard, as a `str`.
193+
Defaults to `name`.
194+
description: Optional long-form description for this summary, as a `str`.
195+
Markdown is supported. Defaults to empty.
196+
"""
197+
if num_thresholds is None:
198+
num_thresholds = _DEFAULT_NUM_THRESHOLDS
199+
200+
if weights is None:
201+
weights = 1.0
167202

168-
def streaming_op(tag,
203+
# Compute bins of true positives and false positives.
204+
bucket_indices = np.int32(np.floor(predictions * (num_thresholds - 1)))
205+
float_labels = labels.astype(np.float)
206+
histogram_range = (0, num_thresholds - 1)
207+
tp_buckets, _ = np.histogram(
208+
bucket_indices,
209+
bins=num_thresholds,
210+
range=histogram_range,
211+
weights=float_labels * weights)
212+
fp_buckets, _ = np.histogram(
213+
bucket_indices,
214+
bins=num_thresholds,
215+
range=histogram_range,
216+
weights=(1.0 - float_labels) * weights)
217+
218+
# Obtain the reverse cumulative sum.
219+
tp = np.cumsum(tp_buckets[::-1])[::-1]
220+
fp = np.cumsum(fp_buckets[::-1])[::-1]
221+
tn = fp[0] - fp
222+
fn = tp[0] - tp
223+
precision = tp / np.maximum(_MINIMUM_COUNT, tp + fp)
224+
recall = tp / np.maximum(_MINIMUM_COUNT, tp + fn)
225+
226+
if display_name is None:
227+
display_name = name
228+
summary_metadata = metadata.create_summary_metadata(
229+
display_name=display_name if display_name is not None else name,
230+
description=description or '',
231+
num_thresholds=num_thresholds)
232+
summary = tf.Summary()
233+
data = np.stack((tp, fp, tn, fn, precision, recall))
234+
tensor = tf.make_tensor_proto(data, dtype=tf.float32)
235+
summary.value.add(tag='%s/pr_curves' % name,
236+
metadata=summary_metadata,
237+
tensor=tensor)
238+
return summary
239+
240+
def streaming_op(name,
169241
labels,
170242
predictions,
171243
num_thresholds=200,
@@ -186,7 +258,7 @@ def streaming_op(tag,
186258
updated with the returned update_op.
187259
188260
Args:
189-
tag: A tag attached to the summary. Used by TensorBoard for organization.
261+
name: A tag attached to the summary. Used by TensorBoard for organization.
190262
labels: The ground truth values, a `Tensor` whose dimensions must match
191263
`predictions`. Will be cast to `bool`.
192264
predictions: A floating point `Tensor` of arbitrary shape and whose values
@@ -216,7 +288,7 @@ def streaming_op(tag,
216288
thresholds = [i / float(num_thresholds - 1)
217289
for i in range(num_thresholds)]
218290

219-
with tf.name_scope(tag, values=[labels, predictions, weights]):
291+
with tf.name_scope(name, values=[labels, predictions, weights]):
220292
tp, update_tp = tf.metrics.true_positives_at_thresholds(
221293
labels=labels,
222294
predictions=predictions,
@@ -243,7 +315,7 @@ def compute_summary(tp, fp, tn, fn, collections):
243315
recall = tp / tf.maximum(_MINIMUM_COUNT, tp + fn)
244316

245317
return _create_tensor_summary(
246-
tag,
318+
name,
247319
tp,
248320
fp,
249321
tn,
@@ -263,7 +335,7 @@ def compute_summary(tp, fp, tn, fn, collections):
263335

264336

265337
def raw_data_op(
266-
tag,
338+
name,
267339
true_positive_counts,
268340
false_positive_counts,
269341
true_negative_counts,
@@ -285,7 +357,7 @@ def raw_data_op(
285357
differently but still use the PR curves plugin.
286358
287359
Args:
288-
tag: A tag attached to the summary. Used by TensorBoard for organization.
360+
name: A tag attached to the summary. Used by TensorBoard for organization.
289361
true_positive_counts: A rank-1 tensor of true positive counts. Must contain
290362
`num_thresholds` elements and be castable to float32.
291363
false_positive_counts: A rank-1 tensor of false positive counts. Must
@@ -309,7 +381,7 @@ def raw_data_op(
309381
A summary operation for use in a TensorFlow graph. See docs for the `op`
310382
method for details on the float32 tensor produced by this summary.
311383
"""
312-
with tf.name_scope(tag, values=[
384+
with tf.name_scope(name, values=[
313385
true_positive_counts,
314386
false_positive_counts,
315387
true_negative_counts,
@@ -318,7 +390,7 @@ def raw_data_op(
318390
recall,
319391
]):
320392
return _create_tensor_summary(
321-
tag,
393+
name,
322394
true_positive_counts,
323395
false_positive_counts,
324396
true_negative_counts,
@@ -331,7 +403,7 @@ def raw_data_op(
331403
collections)
332404

333405
def _create_tensor_summary(
334-
tag,
406+
name,
335407
true_positive_counts,
336408
false_positive_counts,
337409
true_negative_counts,
@@ -355,7 +427,7 @@ def _create_tensor_summary(
355427
# Store the number of thresholds within the summary metadata because
356428
# that value is constant for all pr curve summaries with the same tag.
357429
summary_metadata = metadata.create_summary_metadata(
358-
display_name=display_name if display_name is not None else tag,
430+
display_name=display_name if display_name is not None else name,
359431
description=description or '',
360432
num_thresholds=num_thresholds)
361433

0 commit comments

Comments
 (0)