@@ -138,19 +138,19 @@ def test_multi_batch_correctness(
138
138
def test_float_sample_weight (self , values , sample_weight , expected_variance ):
139
139
metric = variance .Variance ()
140
140
metric (values , sample_weight = sample_weight )
141
- self .assertEqual (expected_variance , metric .result ())
141
+ self .assertAllClose (expected_variance , metric .result ())
142
142
143
143
def test_empty_input (self ):
144
144
metric = variance .Variance ()
145
145
values = tf .constant ([0 , 1 , 2 , 3 ])
146
146
metric (values )
147
- self .assertEqual (1.25 , metric .result ())
147
+ self .assertAllClose (1.25 , metric .result ())
148
148
metric (tf .ones (shape = (0 ,)), sample_weight = None )
149
- self .assertEqual (1.25 , metric .result ())
149
+ self .assertAllClose (1.25 , metric .result ())
150
150
151
151
def test_initial_state (self ):
152
152
metric = variance .Variance ()
153
- self .assertEqual (0.0 , metric .result ())
153
+ self .assertAllClose (0.0 , metric .result ())
154
154
155
155
def test_dtype_correctness (self ):
156
156
# 1 << 128 overflows for float32 but fits in float64.
@@ -196,24 +196,26 @@ def test_multiple_result_calls(self):
196
196
values = tf .constant ([1 , 2 , 1 , 4 ])
197
197
metric .update_state (values )
198
198
199
- self .assertEqual (values .numpy ().var (), metric .result ())
200
- self .assertEqual (values .numpy ().var (), metric .result ())
199
+ self .assertAllClose (values .numpy ().var (), metric .result ())
200
+ self .assertAllClose (values .numpy ().var (), metric .result ())
201
201
202
202
metric .update_state (tf .constant ([- 1 , - 2 , 0 ]))
203
203
204
- self .assertEqual (np .array ([1 , 2 , 1 , 4 , - 1 , - 2 , 0 ]).var (), metric .result ())
204
+ self .assertAllClose (
205
+ np .array ([1 , 2 , 1 , 4 , - 1 , - 2 , 0 ]).var (), metric .result ()
206
+ )
205
207
206
208
def test_reset_state (self ):
207
209
metric = variance .Variance ()
208
210
values = tf .constant ([1 , 2 , 1 , 4 ])
209
211
210
212
metric .update_state (values )
211
- self .assertEqual (1.5 , metric .result ())
213
+ self .assertAllClose (1.5 , metric .result ())
212
214
213
215
metric .reset_state ()
214
216
215
217
metric .update_state (values , sample_weight = tf .constant ([1 , 0 , 1 , 0 ]))
216
- self .assertEqual (0.0 , metric .result ())
218
+ self .assertAllClose (0.0 , metric .result ())
217
219
218
220
def test_numpy_correctness (self ):
219
221
metric = variance .Variance ()
0 commit comments