|
12 | 12 | # See the License for the specific language governing permissions and
|
13 | 13 | # limitations under the License.
|
14 | 14 | import logging
|
| 15 | +import math |
15 | 16 | import pickle
|
16 | 17 | from typing import List, Optional
|
17 | 18 | from unittest import mock
|
@@ -264,100 +265,60 @@ def validation_epoch_end(self, outputs):
|
264 | 265 | assert early_stopping.stopped_epoch == expected_stop_epoch
|
265 | 266 |
|
266 | 267 |
|
267 |
| -@pytest.mark.parametrize("step_freeze, min_steps, min_epochs", [(5, 1, 1), (5, 1, 3), (3, 15, 1)]) |
268 |
| -def test_min_steps_override_early_stopping_functionality(tmpdir, step_freeze: int, min_steps: int, min_epochs: int): |
269 |
| - """Excepted Behaviour: IF `min_steps` was set to a higher value than the `trainer.global_step` when |
270 |
| - `early_stopping` is being triggered, THEN the trainer should continue until reaching `trainer.global_step` == |
271 |
| - `min_steps`, and stop. |
272 |
| -
|
273 |
| - IF `min_epochs` resulted in a higher number of steps than the `trainer.global_step` |
274 |
| - when `early_stopping` is being triggered, |
275 |
| - THEN the trainer should continue until reaching |
276 |
| - `trainer.global_step` == `min_epochs * len(train_dataloader)`, and stop. |
277 |
| - This test validate this expected behaviour |
278 |
| -
|
279 |
| - IF both `min_epochs` and `min_steps` are provided and higher than the `trainer.global_step` |
280 |
| - when `early_stopping` is being triggered, |
281 |
| - THEN the highest between `min_epochs * len(train_dataloader)` and `min_steps` would be reached. |
282 |
| -
|
283 |
| - Caveat: IF min_steps is divisible by len(train_dataloader), then it will do min_steps + len(train_dataloader) |
284 |
| -
|
285 |
| - This test validate those expected behaviours |
286 |
| - """ |
287 |
| - |
288 |
| - _logger.disabled = True |
289 |
| - |
290 |
| - original_loss_value = 10 |
291 |
| - limit_train_batches = 3 |
292 |
| - patience = 3 |
293 |
| - |
294 |
| - class Model(BoringModel): |
295 |
| - def __init__(self, step_freeze): |
296 |
| - super().__init__() |
297 |
| - |
298 |
| - self._step_freeze = step_freeze |
299 |
| - |
300 |
| - self._loss_value = 10.0 |
301 |
| - self._eps = 1e-1 |
302 |
| - self._count_decrease = 0 |
303 |
| - self._values = [] |
| 268 | +@pytest.mark.parametrize("limit_train_batches", (3, 5)) |
| 269 | +@pytest.mark.parametrize( |
| 270 | + ["min_epochs", "min_steps"], |
| 271 | + [ |
| 272 | + # IF `min_steps` was set to a higher value than the `trainer.global_step` when `early_stopping` is being |
| 273 | + # triggered, THEN the trainer should continue until reaching `trainer.global_step == min_steps` and stop |
| 274 | + (0, 10), |
| 275 | + # IF `min_epochs` resulted in a higher number of steps than the `trainer.global_step` when `early_stopping` is |
| 276 | + # being triggered, THEN the trainer should continue until reaching |
| 277 | + # `trainer.global_step` == `min_epochs * len(train_dataloader)` |
| 278 | + (2, 0), |
| 279 | + # IF both `min_epochs` and `min_steps` are provided and higher than the `trainer.global_step` when |
| 280 | + # `early_stopping` is being triggered, THEN the highest between `min_epochs * len(train_dataloader)` and |
| 281 | + # `min_steps` would be reached |
| 282 | + (1, 10), |
| 283 | + (3, 10), |
| 284 | + ], |
| 285 | +) |
| 286 | +def test_min_epochs_min_steps_global_step(tmpdir, limit_train_batches, min_epochs, min_steps): |
| 287 | + if min_steps: |
| 288 | + assert limit_train_batches < min_steps |
304 | 289 |
|
| 290 | + class TestModel(BoringModel): |
305 | 291 | def training_step(self, batch, batch_idx):
|
306 |
| - output = self.layer(batch) |
307 |
| - loss = self.loss(batch, output) |
308 |
| - return {"loss": loss} |
309 |
| - |
310 |
| - def validation_step(self, batch, batch_idx): |
311 |
| - return {"test_val_loss": self._loss_value} |
| 292 | + self.log("foo", batch_idx) |
| 293 | + return super().training_step(batch, batch_idx) |
312 | 294 |
|
313 |
| - def validation_epoch_end(self, outputs): |
314 |
| - _mean = np.mean([x["test_val_loss"] for x in outputs]) |
315 |
| - if self.trainer.global_step <= self._step_freeze: |
316 |
| - self._count_decrease += 1 |
317 |
| - self._loss_value -= self._eps |
318 |
| - self._values.append(_mean) |
319 |
| - self.log("test_val_loss", _mean) |
320 |
| - |
321 |
| - model = Model(step_freeze) |
322 |
| - model.training_step_end = None |
323 |
| - model.test_dataloader = None |
324 |
| - early_stop_callback = EarlyStopping(monitor="test_val_loss", patience=patience, verbose=True) |
| 295 | + es_callback = EarlyStopping("foo") |
325 | 296 | trainer = Trainer(
|
326 | 297 | default_root_dir=tmpdir,
|
327 |
| - callbacks=[early_stop_callback], |
| 298 | + callbacks=es_callback, |
| 299 | + limit_val_batches=0, |
328 | 300 | limit_train_batches=limit_train_batches,
|
329 |
| - limit_val_batches=2, |
330 |
| - min_steps=min_steps, |
331 | 301 | min_epochs=min_epochs,
|
| 302 | + min_steps=min_steps, |
| 303 | + logger=False, |
| 304 | + enable_checkpointing=False, |
| 305 | + enable_progress_bar=False, |
| 306 | + enable_model_summary=False, |
332 | 307 | )
|
333 |
| - trainer.fit(model) |
334 |
| - |
335 |
| - # Make sure loss was properly decreased |
336 |
| - assert abs(original_loss_value - (model._count_decrease) * model._eps - model._loss_value) < 1e-6 |
337 |
| - |
338 |
| - pos_diff = (np.diff(model._values) == 0).nonzero()[0][0] |
339 |
| - |
340 |
| - # Compute when the latest validation epoch end happened |
341 |
| - latest_validation_epoch_end = (pos_diff // limit_train_batches) * limit_train_batches |
342 |
| - if pos_diff % limit_train_batches == 0: |
343 |
| - latest_validation_epoch_end += limit_train_batches |
344 |
| - |
345 |
| - # Compute early stopping latest step |
346 |
| - by_early_stopping = latest_validation_epoch_end + (1 + limit_train_batches) * patience |
347 |
| - |
348 |
| - # Compute min_epochs latest step |
349 |
| - by_min_epochs = min_epochs * limit_train_batches |
| 308 | + model = TestModel() |
350 | 309 |
|
351 |
| - # Make sure the trainer stops for the max of all minimum requirements |
352 |
| - assert trainer.global_step == max(min_steps, by_early_stopping, by_min_epochs), ( |
353 |
| - trainer.global_step, |
354 |
| - max(min_steps, by_early_stopping, by_min_epochs), |
355 |
| - step_freeze, |
356 |
| - min_steps, |
357 |
| - min_epochs, |
358 |
| - ) |
| 310 | + expected_epochs = max(math.ceil(min_steps / limit_train_batches), min_epochs) |
| 311 | + # trigger early stopping directly after the first epoch |
| 312 | + side_effect = [(True, "")] * expected_epochs |
| 313 | + with mock.patch.object(es_callback, "_evaluate_stopping_criteria", side_effect=side_effect): |
| 314 | + trainer.fit(model) |
359 | 315 |
|
360 |
| - _logger.disabled = False |
| 316 | + assert trainer.should_stop |
| 317 | + # epochs continue until min steps are reached |
| 318 | + assert trainer.current_epoch == expected_epochs |
| 319 | + # steps continue until min steps are reached AND the epoch is exhausted |
| 320 | + # stopping mid-epoch is not supported |
| 321 | + assert trainer.global_step == limit_train_batches * expected_epochs |
361 | 322 |
|
362 | 323 |
|
363 | 324 | def test_early_stopping_mode_options():
|
|
0 commit comments