-
Notifications
You must be signed in to change notification settings - Fork 1.7k
Add pb method for PR curves #633
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
9bf6095
b840408
0c4057a
10793bf
22f4e29
fef4b13
2e7f5a2
aaa3212
7c65864
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -22,6 +22,7 @@ | |
from __future__ import division | ||
from __future__ import print_function | ||
|
||
import numpy as np | ||
import tensorflow as tf | ||
|
||
from tensorboard.plugins.pr_curve import metadata | ||
|
@@ -30,8 +31,11 @@ | |
# division by 0. 1 suffices because counts of course must be whole numbers. | ||
_MINIMUM_COUNT = 1.0 | ||
|
||
# The default number of thresholds. | ||
_DEFAULT_NUM_THRESHOLDS = 200 | ||
|
||
def op( | ||
tag, | ||
name, | ||
labels, | ||
predictions, | ||
num_thresholds=None, | ||
|
@@ -51,7 +55,7 @@ def op( | |
used to reweight certain values, or more commonly used for masking values. | ||
|
||
Args: | ||
tag: A tag attached to the summary. Used by TensorBoard for organization. | ||
name: A tag attached to the summary. Used by TensorBoard for organization. | ||
labels: The ground truth values. A Tensor of `bool` values with arbitrary | ||
shape. | ||
predictions: A float32 `Tensor` whose values are in the range `[0, 1]`. | ||
|
@@ -78,14 +82,14 @@ def op( | |
|
||
""" | ||
if num_thresholds is None: | ||
num_thresholds = 200 | ||
num_thresholds = _DEFAULT_NUM_THRESHOLDS | ||
|
||
if weights is None: | ||
weights = 1.0 | ||
|
||
dtype = predictions.dtype | ||
|
||
with tf.name_scope(tag, values=[labels, predictions, weights]): | ||
with tf.name_scope(name, values=[labels, predictions, weights]): | ||
tf.assert_type(labels, tf.bool) | ||
# We cast to float to ensure we have 0.0 or 1.0. | ||
f_labels = tf.cast(labels, dtype) | ||
|
@@ -152,7 +156,7 @@ def op( | |
recall = tp / tf.maximum(_MINIMUM_COUNT, tp + fn) | ||
|
||
return _create_tensor_summary( | ||
tag, | ||
name, | ||
tp, | ||
fp, | ||
tn, | ||
|
@@ -164,8 +168,76 @@ def op( | |
description, | ||
collections) | ||
|
||
def pb(name, | ||
labels, | ||
predictions, | ||
num_thresholds=None, | ||
weights=None, | ||
display_name=None, | ||
description=None): | ||
"""Create a PR curves summary protobuf. | ||
|
||
Arguments: | ||
name: A name for the generated node. Will also serve as a series name in | ||
TensorBoard. | ||
labels: The ground truth values. A bool numpy array. | ||
predictions: A float32 numpy array whose values are in the range `[0, 1]`. | ||
Dimensions must match those of `labels`. | ||
num_thresholds: Optional number of thresholds, evenly distributed in | ||
`[0, 1]`, to compute PR metrics for. When provided, should be an int of | ||
value at least 2. Defaults to 200. | ||
weights: Optional float or float32 numpy array. Individual counts are | ||
multiplied by this value. This tensor must be either the same shape as | ||
or broadcastable to the `labels` numpy array. | ||
display_name: Optional name for this summary in TensorBoard, as a `str`. | ||
Defaults to `name`. | ||
description: Optional long-form description for this summary, as a `str`. | ||
Markdown is supported. Defaults to empty. | ||
""" | ||
if num_thresholds is None: | ||
num_thresholds = _DEFAULT_NUM_THRESHOLDS | ||
|
||
if weights is None: | ||
weights = 1.0 | ||
|
||
def streaming_op(tag, | ||
# Compute bins of true positives and false positives. | ||
bucket_indices = np.int32(np.floor(predictions * (num_thresholds - 1))) | ||
float_labels = labels.astype(np.float) | ||
histogram_range = (0, num_thresholds - 1) | ||
tp_buckets, _ = np.histogram( | ||
bucket_indices, | ||
bins=num_thresholds, | ||
range=histogram_range, | ||
weights=float_labels * weights) | ||
fp_buckets, _ = np.histogram( | ||
bucket_indices, | ||
bins=num_thresholds, | ||
range=histogram_range, | ||
weights=(1.0 - float_labels) * weights) | ||
|
||
# Obtain the reverse cumulative sum. | ||
tp = np.cumsum(tp_buckets[::-1])[::-1] | ||
fp = np.cumsum(fp_buckets[::-1])[::-1] | ||
tn = fp[0] - fp | ||
fn = tp[0] - tp | ||
precision = tp / np.maximum(_MINIMUM_COUNT, tp + fp) | ||
recall = tp / np.maximum(_MINIMUM_COUNT, tp + fn) | ||
|
||
if display_name is None: | ||
display_name = name | ||
summary_metadata = metadata.create_summary_metadata( | ||
display_name=display_name if display_name is not None else name, | ||
description=description or '', | ||
num_thresholds=num_thresholds) | ||
summary = tf.Summary() | ||
data = np.stack((tp, fp, tn, fn, precision, recall)) | ||
tensor = tf.make_tensor_proto(data, dtype=tf.float32) | ||
summary.value.add(tag='%s/pr_curves' % name, | ||
metadata=summary_metadata, | ||
tensor=tensor) | ||
return summary | ||
|
||
def streaming_op(name, | ||
labels, | ||
predictions, | ||
num_thresholds=200, | ||
|
@@ -186,7 +258,7 @@ def streaming_op(tag, | |
updated with the returned update_op. | ||
|
||
Args: | ||
tag: A tag attached to the summary. Used by TensorBoard for organization. | ||
name: A tag attached to the summary. Used by TensorBoard for organization. | ||
labels: The ground truth values, a `Tensor` whose dimensions must match | ||
`predictions`. Will be cast to `bool`. | ||
predictions: A floating point `Tensor` of arbitrary shape and whose values | ||
|
@@ -216,7 +288,7 @@ def streaming_op(tag, | |
thresholds = [i / float(num_thresholds - 1) | ||
for i in range(num_thresholds)] | ||
|
||
with tf.name_scope(tag, values=[labels, predictions, weights]): | ||
with tf.name_scope(name, values=[labels, predictions, weights]): | ||
tp, update_tp = tf.metrics.true_positives_at_thresholds( | ||
labels=labels, | ||
predictions=predictions, | ||
|
@@ -243,7 +315,7 @@ def compute_summary(tp, fp, tn, fn, collections): | |
recall = tp / tf.maximum(_MINIMUM_COUNT, tp + fn) | ||
|
||
return _create_tensor_summary( | ||
tag, | ||
name, | ||
tp, | ||
fp, | ||
tn, | ||
|
@@ -263,7 +335,7 @@ def compute_summary(tp, fp, tn, fn, collections): | |
|
||
|
||
def raw_data_op( | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I think that we probably want to provide a There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. SG. Yeah - maybe seperate PR? Just because this one's getting big, albeit that is related. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Sure, fine with me. |
||
tag, | ||
name, | ||
true_positive_counts, | ||
false_positive_counts, | ||
true_negative_counts, | ||
|
@@ -285,7 +357,7 @@ def raw_data_op( | |
differently but still use the PR curves plugin. | ||
|
||
Args: | ||
tag: A tag attached to the summary. Used by TensorBoard for organization. | ||
name: A tag attached to the summary. Used by TensorBoard for organization. | ||
true_positive_counts: A rank-1 tensor of true positive counts. Must contain | ||
`num_thresholds` elements and be castable to float32. | ||
false_positive_counts: A rank-1 tensor of false positive counts. Must | ||
|
@@ -309,7 +381,7 @@ def raw_data_op( | |
A summary operation for use in a TensorFlow graph. See docs for the `op` | ||
method for details on the float32 tensor produced by this summary. | ||
""" | ||
with tf.name_scope(tag, values=[ | ||
with tf.name_scope(name, values=[ | ||
true_positive_counts, | ||
false_positive_counts, | ||
true_negative_counts, | ||
|
@@ -318,7 +390,7 @@ def raw_data_op( | |
recall, | ||
]): | ||
return _create_tensor_summary( | ||
tag, | ||
name, | ||
true_positive_counts, | ||
false_positive_counts, | ||
true_negative_counts, | ||
|
@@ -331,7 +403,7 @@ def raw_data_op( | |
collections) | ||
|
||
def _create_tensor_summary( | ||
tag, | ||
name, | ||
true_positive_counts, | ||
false_positive_counts, | ||
true_negative_counts, | ||
|
@@ -355,7 +427,7 @@ def _create_tensor_summary( | |
# Store the number of thresholds within the summary metadata because | ||
# that value is constant for all pr curve summaries with the same tag. | ||
summary_metadata = metadata.create_summary_metadata( | ||
display_name=display_name if display_name is not None else tag, | ||
display_name=display_name if display_name is not None else name, | ||
description=description or '', | ||
num_thresholds=num_thresholds) | ||
|
||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Mentions of "python ints" and "constant
str
s" in thepb
function seem potentially confusing; this is not in a TensorFlow context, so the distinction that you intend doesn't exist, and instead folks might wonder why they can't pass a string from a variable or something (they of course can). If you look at another summary'spb
function, you'll see that the documentation is changed appropriately.Suggested changes:
num_thresholds
: Optional […] metrics for. When provided, should be anint
of value at least2
. Defaults to200
.weights
: Optionalfloat
or float32 numpy array. […] This value must be […].display_name
: […] as astr
. […]description
: […] as astr
. […]There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Done. Indeed, that seems much clearer, and using
python
could be confusing here becausepb
methods inherently don't rely on TensorFlow.