22
22
from __future__ import division
23
23
from __future__ import print_function
24
24
25
+ import numpy as np
25
26
import tensorflow as tf
26
27
27
28
from tensorboard .plugins .pr_curve import metadata
30
31
# division by 0. 1 suffices because counts of course must be whole numbers.
31
32
_MINIMUM_COUNT = 1.0
32
33
34
+ # The default number of thresholds.
35
+ _DEFAULT_NUM_THRESHOLDS = 200
36
+
33
37
def op (
34
- tag ,
38
+ name ,
35
39
labels ,
36
40
predictions ,
37
41
num_thresholds = None ,
@@ -51,7 +55,7 @@ def op(
51
55
used to reweight certain values, or more commonly used for masking values.
52
56
53
57
Args:
54
- tag : A tag attached to the summary. Used by TensorBoard for organization.
58
+ name : A tag attached to the summary. Used by TensorBoard for organization.
55
59
labels: The ground truth values. A Tensor of `bool` values with arbitrary
56
60
shape.
57
61
predictions: A float32 `Tensor` whose values are in the range `[0, 1]`.
@@ -78,14 +82,14 @@ def op(
78
82
79
83
"""
80
84
if num_thresholds is None :
81
- num_thresholds = 200
85
+ num_thresholds = _DEFAULT_NUM_THRESHOLDS
82
86
83
87
if weights is None :
84
88
weights = 1.0
85
89
86
90
dtype = predictions .dtype
87
91
88
- with tf .name_scope (tag , values = [labels , predictions , weights ]):
92
+ with tf .name_scope (name , values = [labels , predictions , weights ]):
89
93
tf .assert_type (labels , tf .bool )
90
94
# We cast to float to ensure we have 0.0 or 1.0.
91
95
f_labels = tf .cast (labels , dtype )
@@ -152,7 +156,7 @@ def op(
152
156
recall = tp / tf .maximum (_MINIMUM_COUNT , tp + fn )
153
157
154
158
return _create_tensor_summary (
155
- tag ,
159
+ name ,
156
160
tp ,
157
161
fp ,
158
162
tn ,
@@ -164,8 +168,76 @@ def op(
164
168
description ,
165
169
collections )
166
170
171
+ def pb (name ,
172
+ labels ,
173
+ predictions ,
174
+ num_thresholds = None ,
175
+ weights = None ,
176
+ display_name = None ,
177
+ description = None ):
178
+ """Create a PR curves summary protobuf.
179
+
180
+ Arguments:
181
+ name: A name for the generated node. Will also serve as a series name in
182
+ TensorBoard.
183
+ labels: The ground truth values. A bool numpy array.
184
+ predictions: A float32 numpy array whose values are in the range `[0, 1]`.
185
+ Dimensions must match those of `labels`.
186
+ num_thresholds: Optional number of thresholds, evenly distributed in
187
+ `[0, 1]`, to compute PR metrics for. When provided, should be an int of
188
+ value at least 2. Defaults to 200.
189
+ weights: Optional float or float32 numpy array. Individual counts are
190
+ multiplied by this value. This tensor must be either the same shape as
191
+ or broadcastable to the `labels` numpy array.
192
+ display_name: Optional name for this summary in TensorBoard, as a `str`.
193
+ Defaults to `name`.
194
+ description: Optional long-form description for this summary, as a `str`.
195
+ Markdown is supported. Defaults to empty.
196
+ """
197
+ if num_thresholds is None :
198
+ num_thresholds = _DEFAULT_NUM_THRESHOLDS
199
+
200
+ if weights is None :
201
+ weights = 1.0
167
202
168
- def streaming_op (tag ,
203
+ # Compute bins of true positives and false positives.
204
+ bucket_indices = np .int32 (np .floor (predictions * (num_thresholds - 1 )))
205
+ float_labels = labels .astype (np .float )
206
+ histogram_range = (0 , num_thresholds - 1 )
207
+ tp_buckets , _ = np .histogram (
208
+ bucket_indices ,
209
+ bins = num_thresholds ,
210
+ range = histogram_range ,
211
+ weights = float_labels * weights )
212
+ fp_buckets , _ = np .histogram (
213
+ bucket_indices ,
214
+ bins = num_thresholds ,
215
+ range = histogram_range ,
216
+ weights = (1.0 - float_labels ) * weights )
217
+
218
+ # Obtain the reverse cumulative sum.
219
+ tp = np .cumsum (tp_buckets [::- 1 ])[::- 1 ]
220
+ fp = np .cumsum (fp_buckets [::- 1 ])[::- 1 ]
221
+ tn = fp [0 ] - fp
222
+ fn = tp [0 ] - tp
223
+ precision = tp / np .maximum (_MINIMUM_COUNT , tp + fp )
224
+ recall = tp / np .maximum (_MINIMUM_COUNT , tp + fn )
225
+
226
+ if display_name is None :
227
+ display_name = name
228
+ summary_metadata = metadata .create_summary_metadata (
229
+ display_name = display_name if display_name is not None else name ,
230
+ description = description or '' ,
231
+ num_thresholds = num_thresholds )
232
+ summary = tf .Summary ()
233
+ data = np .stack ((tp , fp , tn , fn , precision , recall ))
234
+ tensor = tf .make_tensor_proto (data , dtype = tf .float32 )
235
+ summary .value .add (tag = '%s/pr_curves' % name ,
236
+ metadata = summary_metadata ,
237
+ tensor = tensor )
238
+ return summary
239
+
240
+ def streaming_op (name ,
169
241
labels ,
170
242
predictions ,
171
243
num_thresholds = 200 ,
@@ -186,7 +258,7 @@ def streaming_op(tag,
186
258
updated with the returned update_op.
187
259
188
260
Args:
189
- tag : A tag attached to the summary. Used by TensorBoard for organization.
261
+ name : A tag attached to the summary. Used by TensorBoard for organization.
190
262
labels: The ground truth values, a `Tensor` whose dimensions must match
191
263
`predictions`. Will be cast to `bool`.
192
264
predictions: A floating point `Tensor` of arbitrary shape and whose values
@@ -216,7 +288,7 @@ def streaming_op(tag,
216
288
thresholds = [i / float (num_thresholds - 1 )
217
289
for i in range (num_thresholds )]
218
290
219
- with tf .name_scope (tag , values = [labels , predictions , weights ]):
291
+ with tf .name_scope (name , values = [labels , predictions , weights ]):
220
292
tp , update_tp = tf .metrics .true_positives_at_thresholds (
221
293
labels = labels ,
222
294
predictions = predictions ,
@@ -243,7 +315,7 @@ def compute_summary(tp, fp, tn, fn, collections):
243
315
recall = tp / tf .maximum (_MINIMUM_COUNT , tp + fn )
244
316
245
317
return _create_tensor_summary (
246
- tag ,
318
+ name ,
247
319
tp ,
248
320
fp ,
249
321
tn ,
@@ -263,7 +335,7 @@ def compute_summary(tp, fp, tn, fn, collections):
263
335
264
336
265
337
def raw_data_op (
266
- tag ,
338
+ name ,
267
339
true_positive_counts ,
268
340
false_positive_counts ,
269
341
true_negative_counts ,
@@ -285,7 +357,7 @@ def raw_data_op(
285
357
differently but still use the PR curves plugin.
286
358
287
359
Args:
288
- tag : A tag attached to the summary. Used by TensorBoard for organization.
360
+ name : A tag attached to the summary. Used by TensorBoard for organization.
289
361
true_positive_counts: A rank-1 tensor of true positive counts. Must contain
290
362
`num_thresholds` elements and be castable to float32.
291
363
false_positive_counts: A rank-1 tensor of false positive counts. Must
@@ -309,7 +381,7 @@ def raw_data_op(
309
381
A summary operation for use in a TensorFlow graph. See docs for the `op`
310
382
method for details on the float32 tensor produced by this summary.
311
383
"""
312
- with tf .name_scope (tag , values = [
384
+ with tf .name_scope (name , values = [
313
385
true_positive_counts ,
314
386
false_positive_counts ,
315
387
true_negative_counts ,
@@ -318,7 +390,7 @@ def raw_data_op(
318
390
recall ,
319
391
]):
320
392
return _create_tensor_summary (
321
- tag ,
393
+ name ,
322
394
true_positive_counts ,
323
395
false_positive_counts ,
324
396
true_negative_counts ,
@@ -331,7 +403,7 @@ def raw_data_op(
331
403
collections )
332
404
333
405
def _create_tensor_summary (
334
- tag ,
406
+ name ,
335
407
true_positive_counts ,
336
408
false_positive_counts ,
337
409
true_negative_counts ,
@@ -355,7 +427,7 @@ def _create_tensor_summary(
355
427
# Store the number of thresholds within the summary metadata because
356
428
# that value is constant for all pr curve summaries with the same tag.
357
429
summary_metadata = metadata .create_summary_metadata (
358
- display_name = display_name if display_name is not None else tag ,
430
+ display_name = display_name if display_name is not None else name ,
359
431
description = description or '' ,
360
432
num_thresholds = num_thresholds )
361
433
0 commit comments