Skip to content

Commit 952b8f0

Browse files
Change variance tests to use assertAllClose.
PiperOrigin-RevId: 644161903
1 parent d2427a5 commit 952b8f0

File tree

1 file changed

+11
-9
lines changed

1 file changed

+11
-9
lines changed

official/recommendation/uplift/metrics/variance_test.py

+11-9
Original file line numberDiff line numberDiff line change
@@ -138,19 +138,19 @@ def test_multi_batch_correctness(
138138
def test_float_sample_weight(self, values, sample_weight, expected_variance):
139139
metric = variance.Variance()
140140
metric(values, sample_weight=sample_weight)
141-
self.assertEqual(expected_variance, metric.result())
141+
self.assertAllClose(expected_variance, metric.result())
142142

143143
def test_empty_input(self):
144144
metric = variance.Variance()
145145
values = tf.constant([0, 1, 2, 3])
146146
metric(values)
147-
self.assertEqual(1.25, metric.result())
147+
self.assertAllClose(1.25, metric.result())
148148
metric(tf.ones(shape=(0,)), sample_weight=None)
149-
self.assertEqual(1.25, metric.result())
149+
self.assertAllClose(1.25, metric.result())
150150

151151
def test_initial_state(self):
152152
metric = variance.Variance()
153-
self.assertEqual(0.0, metric.result())
153+
self.assertAllClose(0.0, metric.result())
154154

155155
def test_dtype_correctness(self):
156156
# 1 << 128 overflows for float32 but fits in float64.
@@ -196,24 +196,26 @@ def test_multiple_result_calls(self):
196196
values = tf.constant([1, 2, 1, 4])
197197
metric.update_state(values)
198198

199-
self.assertEqual(values.numpy().var(), metric.result())
200-
self.assertEqual(values.numpy().var(), metric.result())
199+
self.assertAllClose(values.numpy().var(), metric.result())
200+
self.assertAllClose(values.numpy().var(), metric.result())
201201

202202
metric.update_state(tf.constant([-1, -2, 0]))
203203

204-
self.assertEqual(np.array([1, 2, 1, 4, -1, -2, 0]).var(), metric.result())
204+
self.assertAllClose(
205+
np.array([1, 2, 1, 4, -1, -2, 0]).var(), metric.result()
206+
)
205207

206208
def test_reset_state(self):
207209
metric = variance.Variance()
208210
values = tf.constant([1, 2, 1, 4])
209211

210212
metric.update_state(values)
211-
self.assertEqual(1.5, metric.result())
213+
self.assertAllClose(1.5, metric.result())
212214

213215
metric.reset_state()
214216

215217
metric.update_state(values, sample_weight=tf.constant([1, 0, 1, 0]))
216-
self.assertEqual(0.0, metric.result())
218+
self.assertAllClose(0.0, metric.result())
217219

218220
def test_numpy_correctness(self):
219221
metric = variance.Variance()

0 commit comments

Comments
 (0)