@@ -31,6 +31,7 @@ class PrCurveTest(tf.test.TestCase):
31
31
def setUp (self ):
32
32
super (PrCurveTest , self ).setUp ()
33
33
tf .reset_default_graph ()
34
+ np .random .seed (42 )
34
35
35
36
def pb_via_op (self , summary_op , feed_dict = None ):
36
37
with tf .Session () as sess :
@@ -218,6 +219,25 @@ def test_many_values_with_weights(self):
218
219
values = tf .make_ndarray (pb .value [0 ].tensor )
219
220
self .verify_float_arrays_are_equal (expected , values )
220
221
222
+ def test_exhaustive_random_values (self ):
223
+ # Most other tests check for specific cases.
224
+ data_points = 420
225
+ pb = self .compute_and_check_summary_pb (
226
+ name = 'foo' ,
227
+ labels = np .random .uniform (size = (data_points ,)) > 0.5 ,
228
+ predictions = np .float32 (np .random .uniform (size = (data_points ,))),
229
+ num_thresholds = 5 )
230
+ expected = [
231
+ [218.0 , 162.0 , 111.0 , 55.0 , 0.0 ],
232
+ [202.0 , 148.0 , 98.0 , 51.0 , 0.0 ],
233
+ [0.0 , 54.0 , 104.0 , 151.0 , 202.0 ],
234
+ [0.0 , 56.0 , 107.0 , 163.0 , 218.0 ],
235
+ [0.5190476 , 0.5225806 , 0.5311005 , 0.5188679 , 0.0 ],
236
+ [1.0 , 0.7431192 , 0.5091743 , 0.2522936 , 0.0 ]
237
+ ]
238
+ values = tf .make_ndarray (pb .value [0 ].tensor )
239
+ self .verify_float_arrays_are_equal (expected , values )
240
+
221
241
def test_raw_data_op (self ):
222
242
# We pass raw counts and precision/recall values.
223
243
op = summary .raw_data_op (
0 commit comments