Skip to content

Commit 501f037

Browse files
committed
Introduce a raw_data_pb for PR curve summaries
This method lets users generate PR curve summaries from raw TP, FP, TN, FN, precision, recall values without the use of TensorFlow.
1 parent 5806dd7 commit 501f037

File tree

3 files changed

+110
-29
lines changed

3 files changed

+110
-29
lines changed

Diff for: tensorboard/plugins/pr_curve/summary.py

+68-14
Original file line numberDiff line numberDiff line change
@@ -223,19 +223,16 @@ def pb(name,
223223
precision = tp / np.maximum(_MINIMUM_COUNT, tp + fp)
224224
recall = tp / np.maximum(_MINIMUM_COUNT, tp + fn)
225225

226-
if display_name is None:
227-
display_name = name
228-
summary_metadata = metadata.create_summary_metadata(
229-
display_name=display_name if display_name is not None else name,
230-
description=description or '',
231-
num_thresholds=num_thresholds)
232-
summary = tf.Summary()
233-
data = np.stack((tp, fp, tn, fn, precision, recall))
234-
tensor = tf.make_tensor_proto(data, dtype=tf.float32)
235-
summary.value.add(tag='%s/pr_curves' % name,
236-
metadata=summary_metadata,
237-
tensor=tensor)
238-
return summary
226+
return raw_data_pb(name,
227+
true_positive_counts=tp,
228+
false_positive_counts=fp,
229+
true_negative_counts=tn,
230+
false_negative_counts=fn,
231+
precision=precision,
232+
recall=recall,
233+
num_thresholds=num_thresholds,
234+
display_name=display_name,
235+
description=description)
239236

240237
def streaming_op(name,
241238
labels,
@@ -336,7 +333,6 @@ def compute_summary(tp, fp, tn, fn, collections):
336333

337334
return pr_curve, update_op
338335

339-
340336
def raw_data_op(
341337
name,
342338
true_positive_counts,
@@ -405,6 +401,64 @@ def raw_data_op(
405401
description,
406402
collections)
407403

404+
def raw_data_pb(
405+
name,
406+
true_positive_counts,
407+
false_positive_counts,
408+
true_negative_counts,
409+
false_negative_counts,
410+
precision,
411+
recall,
412+
num_thresholds=None,
413+
display_name=None,
414+
description=None):
415+
"""Create a PR curves summary protobuf from raw data values.
416+
417+
Args:
418+
name: A tag attached to the summary. Used by TensorBoard for organization.
419+
true_positive_counts: A rank-1 numpy array of true positive counts. Must
420+
contain `num_thresholds` elements and be castable to float32.
421+
false_positive_counts: A rank-1 numpy array of false positive counts. Must
422+
contain `num_thresholds` elements and be castable to float32.
423+
true_negative_counts: A rank-1 numpy array of true negative counts. Must
424+
contain `num_thresholds` elements and be castable to float32.
425+
false_negative_counts: A rank-1 numpy array of false negative counts. Must
426+
contain `num_thresholds` elements and be castable to float32.
427+
precision: A rank-1 numpy array of precision values. Must contain
428+
`num_thresholds` elements and be castable to float32.
429+
recall: A rank-1 numpy array of recall values. Must contain `num_thresholds`
430+
elements and be castable to float32.
431+
num_thresholds: Number of thresholds, evenly distributed in `[0, 1]`, to
432+
compute PR metrics for. Should be an int `>= 2`.
433+
display_name: Optional name for this summary in TensorBoard, as a `str`.
434+
Defaults to `name`.
435+
description: Optional long-form description for this summary, as a `str`.
436+
Markdown is supported. Defaults to empty.
437+
438+
Returns:
439+
A summary operation for use in a TensorFlow graph. See docs for the `op`
440+
method for details on the float32 tensor produced by this summary.
441+
"""
442+
if display_name is None:
443+
display_name = name
444+
summary_metadata = metadata.create_summary_metadata(
445+
display_name=display_name if display_name is not None else name,
446+
description=description or '',
447+
num_thresholds=num_thresholds)
448+
summary = tf.Summary()
449+
data = np.stack(
450+
(true_positive_counts,
451+
false_positive_counts,
452+
true_negative_counts,
453+
false_negative_counts,
454+
precision,
455+
recall))
456+
tensor = tf.make_tensor_proto(np.float32(data), dtype=tf.float32)
457+
summary.value.add(tag='%s/pr_curves' % name,
458+
metadata=summary_metadata,
459+
tensor=tensor)
460+
return summary
461+
408462
def _create_tensor_summary(
409463
name,
410464
true_positive_counts,

Diff for: tensorboard/plugins/pr_curve/summary_test.py

+40-14
Original file line numberDiff line numberDiff line change
@@ -261,21 +261,47 @@ def test_counts_below_1(self):
261261
values = tf.make_ndarray(pb.value[0].tensor)
262262
self.verify_float_arrays_are_equal(expected, values)
263263

264-
def test_raw_data_op(self):
265-
# We pass raw counts and precision/recall values.
264+
def test_raw_data(self):
265+
# We pass these raw counts and precision/recall values.
266+
name = 'foo'
267+
true_positive_counts = [75, 64, 21, 5, 0]
268+
false_positive_counts = [150, 105, 18, 0, 0]
269+
true_negative_counts = [0, 45, 132, 150, 150]
270+
false_negative_counts = [0, 11, 54, 70, 75]
271+
precision = [0.3333333, 0.3786982, 0.5384616, 1.0, 0.0]
272+
recall = [1.0, 0.8533334, 0.28, 0.0666667, 0.0]
273+
num_thresholds = 5
274+
display_name = 'some_raw_values'
275+
description = 'We passed raw values into a summary op.'
276+
266277
op = summary.raw_data_op(
267-
name='foo',
268-
true_positive_counts=tf.constant([75, 64, 21, 5, 0]),
269-
false_positive_counts=tf.constant([150, 105, 18, 0, 0]),
270-
true_negative_counts=tf.constant([0, 45, 132, 150, 150]),
271-
false_negative_counts=tf.constant([0, 11, 54, 70, 75]),
272-
precision=tf.constant(
273-
[0.3333333, 0.3786982, 0.5384616, 1.0, 0.0]),
274-
recall=tf.constant([1.0, 0.8533334, 0.28, 0.0666667, 0.0]),
275-
num_thresholds=5,
276-
display_name='some_raw_values',
277-
description='We passed raw values into a summary op.')
278-
pb = self.pb_via_op(op)
278+
name=name,
279+
true_positive_counts=tf.constant(true_positive_counts),
280+
false_positive_counts=tf.constant(false_positive_counts),
281+
true_negative_counts=tf.constant(true_negative_counts),
282+
false_negative_counts=tf.constant(false_negative_counts),
283+
precision=tf.constant(precision),
284+
recall=tf.constant(recall),
285+
num_thresholds=num_thresholds,
286+
display_name=display_name,
287+
description=description)
288+
pb_via_op = self.normalize_summary_pb(self.pb_via_op(op))
289+
290+
# Call the corresponding method that is decoupled from TensorFlow.
291+
pb = self.normalize_summary_pb(summary.raw_data_pb(
292+
name=name,
293+
true_positive_counts=true_positive_counts,
294+
false_positive_counts=false_positive_counts,
295+
true_negative_counts=true_negative_counts,
296+
false_negative_counts=false_negative_counts,
297+
precision=precision,
298+
recall=recall,
299+
num_thresholds=num_thresholds,
300+
display_name=display_name,
301+
description=description))
302+
303+
# The 2 methods above should write summaries with the same data.
304+
self.assertProtoEquals(pb, pb_via_op)
279305

280306
# Test the metadata.
281307
summary_metadata = pb.value[0].metadata

Diff for: tensorboard/summary.py

+2-1
Original file line numberDiff line numberDiff line change
@@ -41,7 +41,8 @@
4141
pr_curve = _pr_curve_summary.op
4242
pr_curve_pb = _pr_curve_summary.pb
4343
pr_curve_streaming_op = _pr_curve_summary.streaming_op
44-
pr_curve_raw_data = _pr_curve_summary.raw_data_op
44+
pr_curve_raw_data_op = _pr_curve_summary.raw_data_op
45+
pr_curve_raw_data_pb = _pr_curve_summary.raw_data_pb
4546

4647
scalar = _scalar_summary.op
4748
scalar_pb = _scalar_summary.pb

0 commit comments

Comments
 (0)