|
22 | 22 | from __future__ import division
|
23 | 23 | from __future__ import print_function
|
24 | 24 |
|
| 25 | +import numpy as np |
25 | 26 | import tensorflow as tf
|
26 | 27 |
|
27 | 28 | from tensorboard.plugins.pr_curve import metadata
|
28 | 29 |
|
29 | 30 | # A value that we use as the minimum value during division of counts to prevent
|
30 |
| -# division by 0. 1 suffices because counts of course must be whole numbers. |
31 |
| -_MINIMUM_COUNT = 1.0 |
| 31 | +# division by 0. |
| 32 | +_MINIMUM_COUNT = 1e-7 |
| 33 | + |
| 34 | +# The default number of thresholds. |
| 35 | +_DEFAULT_NUM_THRESHOLDS = 200 |
32 | 36 |
|
33 | 37 | def op(
|
34 | 38 | tag,
|
@@ -78,7 +82,7 @@ def op(
|
78 | 82 |
|
79 | 83 | """
|
80 | 84 | if num_thresholds is None:
|
81 |
| - num_thresholds = 200 |
| 85 | + num_thresholds = _DEFAULT_NUM_THRESHOLDS |
82 | 86 |
|
83 | 87 | if weights is None:
|
84 | 88 | weights = 1.0
|
@@ -164,6 +168,74 @@ def op(
|
164 | 168 | description,
|
165 | 169 | collections)
|
166 | 170 |
|
| 171 | +def pb(tag, |
| 172 | + labels, |
| 173 | + predictions, |
| 174 | + num_thresholds=None, |
| 175 | + weights=None, |
| 176 | + display_name=None, |
| 177 | + description=None): |
| 178 | + """Creates a PR curves summary protobuf |
| 179 | +
|
| 180 | + Arguments: |
| 181 | + tag: A name for the generated node. Will also serve as a series name in |
| 182 | + TensorBoard. |
| 183 | + labels: The ground truth values. A bool numpy array. |
| 184 | + predictions: A float32 numpy array whose values are in the range `[0, 1]`. |
| 185 | + Dimensions must match those of `labels`. |
| 186 | + num_thresholds: Optional number of thresholds, evenly distributed in |
| 187 | + `[0, 1]`, to compute PR metrics for. Should be `>= 2`. This value should |
| 188 | + be a python int. Defaults to 200. |
| 189 | + weights: Optional python float or float32 numpy array. Individual counts are |
| 190 | + multiplied by this value. This tensor must be either the same shape as |
| 191 | + or broadcastable to the `labels` numpy array. |
| 192 | + display_name: Optional name for this summary in TensorBoard, as a |
| 193 | + constant `str`. Defaults to `name`. |
| 194 | + description: Optional long-form description for this summary, as a |
| 195 | + constant `str`. Markdown is supported. Defaults to empty. |
| 196 | + """ |
| 197 | + if num_thresholds is None: |
| 198 | + num_thresholds = _DEFAULT_NUM_THRESHOLDS |
| 199 | + |
| 200 | + if weights is None: |
| 201 | + weights = 1.0 |
| 202 | + |
| 203 | + # Compute bins of true positives and false positives. |
| 204 | + bucket_indices = np.int32(np.floor(predictions * (num_thresholds - 1))) |
| 205 | + float_labels = labels.astype(np.float) |
| 206 | + histogram_range = (0, num_thresholds - 1) |
| 207 | + tp_buckets, _ = np.histogram( |
| 208 | + bucket_indices, |
| 209 | + bins=num_thresholds, |
| 210 | + range=histogram_range, |
| 211 | + weights=float_labels * weights) |
| 212 | + fp_buckets, _ = np.histogram( |
| 213 | + bucket_indices, |
| 214 | + bins=num_thresholds, |
| 215 | + range=histogram_range, |
| 216 | + weights=(1.0 - float_labels) * weights) |
| 217 | + |
| 218 | + # Obtain the reverse cumulative sum. |
| 219 | + tp = np.cumsum(tp_buckets[::-1])[::-1] |
| 220 | + fp = np.cumsum(fp_buckets[::-1])[::-1] |
| 221 | + tn = fp[0] - fp |
| 222 | + fn = tp[0] - tp |
| 223 | + precision = tp / np.maximum(_MINIMUM_COUNT, tp + fp) |
| 224 | + recall = tp / np.maximum(_MINIMUM_COUNT, tp + fn) |
| 225 | + |
| 226 | + if display_name is None: |
| 227 | + display_name = tag |
| 228 | + summary_metadata = metadata.create_summary_metadata( |
| 229 | + display_name=display_name if display_name is not None else tag, |
| 230 | + description=description or '', |
| 231 | + num_thresholds=num_thresholds) |
| 232 | + summary = tf.Summary() |
| 233 | + data = np.stack((tp, fp, tn, fn, precision, recall)) |
| 234 | + tensor = tf.make_tensor_proto(data, dtype=tf.float32) |
| 235 | + summary.value.add(tag='%s/pr_curves' % tag, |
| 236 | + metadata=summary_metadata, |
| 237 | + tensor=tensor) |
| 238 | + return summary |
167 | 239 |
|
168 | 240 | def streaming_op(tag,
|
169 | 241 | labels,
|
|
0 commit comments