Skip to content

Commit 2e7f5a2

Browse files
committed
Add test_exhaustive_random_values
1 parent fef4b13 commit 2e7f5a2

File tree

2 files changed

+27
-7
lines changed

2 files changed

+27
-7
lines changed

Diff for: tensorboard/plugins/pr_curve/summary.py

+7-7
Original file line numberDiff line numberDiff line change
@@ -188,15 +188,15 @@ def pb(name,
188188
predictions: A float32 numpy array whose values are in the range `[0, 1]`.
189189
Dimensions must match those of `labels`.
190190
num_thresholds: Optional number of thresholds, evenly distributed in
191-
`[0, 1]`, to compute PR metrics for. Should be `>= 2`. This value should
192-
be a python int. Defaults to 200.
193-
weights: Optional python float or float32 numpy array. Individual counts are
191+
`[0, 1]`, to compute PR metrics for. When provided, should be an int of
192+
value at least 2. Defaults to 200.
193+
weights: Optional float or float32 numpy array. Individual counts are
194194
multiplied by this value. This tensor must be either the same shape as
195195
or broadcastable to the `labels` numpy array.
196-
display_name: Optional name for this summary in TensorBoard, as a
197-
constant `str`. Defaults to `name`.
198-
description: Optional long-form description for this summary, as a
199-
constant `str`. Markdown is supported. Defaults to empty.
196+
display_name: Optional name for this summary in TensorBoard, as a `str`.
197+
Defaults to `name`.
198+
description: Optional long-form description for this summary, as a `str`.
199+
Markdown is supported. Defaults to empty.
200200
"""
201201
if num_thresholds is None:
202202
num_thresholds = _DEFAULT_NUM_THRESHOLDS

Diff for: tensorboard/plugins/pr_curve/summary_test.py

+20
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,7 @@ class PrCurveTest(tf.test.TestCase):
3131
def setUp(self):
3232
super(PrCurveTest, self).setUp()
3333
tf.reset_default_graph()
34+
np.random.seed(42)
3435

3536
def pb_via_op(self, summary_op, feed_dict=None):
3637
with tf.Session() as sess:
@@ -218,6 +219,25 @@ def test_many_values_with_weights(self):
218219
values = tf.make_ndarray(pb.value[0].tensor)
219220
self.verify_float_arrays_are_equal(expected, values)
220221

222+
def test_exhaustive_random_values(self):
223+
# Most other tests check for specific cases.
224+
data_points = 420
225+
pb = self.compute_and_check_summary_pb(
226+
name='foo',
227+
labels=np.random.uniform(size=(data_points,)) > 0.5,
228+
predictions=np.float32(np.random.uniform(size=(data_points,))),
229+
num_thresholds=5)
230+
expected = [
231+
[218.0, 162.0, 111.0, 55.0, 0.0],
232+
[202.0, 148.0, 98.0, 51.0, 0.0],
233+
[0.0, 54.0, 104.0, 151.0, 202.0],
234+
[0.0, 56.0, 107.0, 163.0, 218.0],
235+
[0.5190476, 0.5225806, 0.5311005, 0.5188679, 0.0],
236+
[1.0, 0.7431192, 0.5091743, 0.2522936, 0.0]
237+
]
238+
values = tf.make_ndarray(pb.value[0].tensor)
239+
self.verify_float_arrays_are_equal(expected, values)
240+
221241
def test_raw_data_op(self):
222242
# We pass raw counts and precision/recall values.
223243
op = summary.raw_data_op(

0 commit comments

Comments
 (0)