Skip to content

Commit a0ca8d0

Browse files
authored
Refactor early stopping test (#11866)
1 parent 25b5055 commit a0ca8d0

File tree

1 file changed

+45
-84
lines changed

1 file changed

+45
-84
lines changed

tests/callbacks/test_early_stopping.py

Lines changed: 45 additions & 84 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
1414
import logging
15+
import math
1516
import pickle
1617
from typing import List, Optional
1718
from unittest import mock
@@ -264,100 +265,60 @@ def validation_epoch_end(self, outputs):
264265
assert early_stopping.stopped_epoch == expected_stop_epoch
265266

266267

267-
@pytest.mark.parametrize("step_freeze, min_steps, min_epochs", [(5, 1, 1), (5, 1, 3), (3, 15, 1)])
268-
def test_min_steps_override_early_stopping_functionality(tmpdir, step_freeze: int, min_steps: int, min_epochs: int):
269-
"""Excepted Behaviour: IF `min_steps` was set to a higher value than the `trainer.global_step` when
270-
`early_stopping` is being triggered, THEN the trainer should continue until reaching `trainer.global_step` ==
271-
`min_steps`, and stop.
272-
273-
IF `min_epochs` resulted in a higher number of steps than the `trainer.global_step`
274-
when `early_stopping` is being triggered,
275-
THEN the trainer should continue until reaching
276-
`trainer.global_step` == `min_epochs * len(train_dataloader)`, and stop.
277-
This test validate this expected behaviour
278-
279-
IF both `min_epochs` and `min_steps` are provided and higher than the `trainer.global_step`
280-
when `early_stopping` is being triggered,
281-
THEN the highest between `min_epochs * len(train_dataloader)` and `min_steps` would be reached.
282-
283-
Caveat: IF min_steps is divisible by len(train_dataloader), then it will do min_steps + len(train_dataloader)
284-
285-
This test validate those expected behaviours
286-
"""
287-
288-
_logger.disabled = True
289-
290-
original_loss_value = 10
291-
limit_train_batches = 3
292-
patience = 3
293-
294-
class Model(BoringModel):
295-
def __init__(self, step_freeze):
296-
super().__init__()
297-
298-
self._step_freeze = step_freeze
299-
300-
self._loss_value = 10.0
301-
self._eps = 1e-1
302-
self._count_decrease = 0
303-
self._values = []
268+
@pytest.mark.parametrize("limit_train_batches", (3, 5))
269+
@pytest.mark.parametrize(
270+
["min_epochs", "min_steps"],
271+
[
272+
# IF `min_steps` was set to a higher value than the `trainer.global_step` when `early_stopping` is being
273+
# triggered, THEN the trainer should continue until reaching `trainer.global_step == min_steps` and stop
274+
(0, 10),
275+
# IF `min_epochs` resulted in a higher number of steps than the `trainer.global_step` when `early_stopping` is
276+
# being triggered, THEN the trainer should continue until reaching
277+
# `trainer.global_step` == `min_epochs * len(train_dataloader)`
278+
(2, 0),
279+
# IF both `min_epochs` and `min_steps` are provided and higher than the `trainer.global_step` when
280+
# `early_stopping` is being triggered, THEN the highest between `min_epochs * len(train_dataloader)` and
281+
# `min_steps` would be reached
282+
(1, 10),
283+
(3, 10),
284+
],
285+
)
286+
def test_min_epochs_min_steps_global_step(tmpdir, limit_train_batches, min_epochs, min_steps):
287+
if min_steps:
288+
assert limit_train_batches < min_steps
304289

290+
class TestModel(BoringModel):
305291
def training_step(self, batch, batch_idx):
306-
output = self.layer(batch)
307-
loss = self.loss(batch, output)
308-
return {"loss": loss}
309-
310-
def validation_step(self, batch, batch_idx):
311-
return {"test_val_loss": self._loss_value}
292+
self.log("foo", batch_idx)
293+
return super().training_step(batch, batch_idx)
312294

313-
def validation_epoch_end(self, outputs):
314-
_mean = np.mean([x["test_val_loss"] for x in outputs])
315-
if self.trainer.global_step <= self._step_freeze:
316-
self._count_decrease += 1
317-
self._loss_value -= self._eps
318-
self._values.append(_mean)
319-
self.log("test_val_loss", _mean)
320-
321-
model = Model(step_freeze)
322-
model.training_step_end = None
323-
model.test_dataloader = None
324-
early_stop_callback = EarlyStopping(monitor="test_val_loss", patience=patience, verbose=True)
295+
es_callback = EarlyStopping("foo")
325296
trainer = Trainer(
326297
default_root_dir=tmpdir,
327-
callbacks=[early_stop_callback],
298+
callbacks=es_callback,
299+
limit_val_batches=0,
328300
limit_train_batches=limit_train_batches,
329-
limit_val_batches=2,
330-
min_steps=min_steps,
331301
min_epochs=min_epochs,
302+
min_steps=min_steps,
303+
logger=False,
304+
enable_checkpointing=False,
305+
enable_progress_bar=False,
306+
enable_model_summary=False,
332307
)
333-
trainer.fit(model)
334-
335-
# Make sure loss was properly decreased
336-
assert abs(original_loss_value - (model._count_decrease) * model._eps - model._loss_value) < 1e-6
337-
338-
pos_diff = (np.diff(model._values) == 0).nonzero()[0][0]
339-
340-
# Compute when the latest validation epoch end happened
341-
latest_validation_epoch_end = (pos_diff // limit_train_batches) * limit_train_batches
342-
if pos_diff % limit_train_batches == 0:
343-
latest_validation_epoch_end += limit_train_batches
344-
345-
# Compute early stopping latest step
346-
by_early_stopping = latest_validation_epoch_end + (1 + limit_train_batches) * patience
347-
348-
# Compute min_epochs latest step
349-
by_min_epochs = min_epochs * limit_train_batches
308+
model = TestModel()
350309

351-
# Make sure the trainer stops for the max of all minimum requirements
352-
assert trainer.global_step == max(min_steps, by_early_stopping, by_min_epochs), (
353-
trainer.global_step,
354-
max(min_steps, by_early_stopping, by_min_epochs),
355-
step_freeze,
356-
min_steps,
357-
min_epochs,
358-
)
310+
expected_epochs = max(math.ceil(min_steps / limit_train_batches), min_epochs)
311+
# trigger early stopping directly after the first epoch
312+
side_effect = [(True, "")] * expected_epochs
313+
with mock.patch.object(es_callback, "_evaluate_stopping_criteria", side_effect=side_effect):
314+
trainer.fit(model)
359315

360-
_logger.disabled = False
316+
assert trainer.should_stop
317+
# epochs continue until min steps are reached
318+
assert trainer.current_epoch == expected_epochs
319+
# steps continue until min steps are reached AND the epoch is exhausted
320+
# stopping mid-epoch is not supported
321+
assert trainer.global_step == limit_train_batches * expected_epochs
361322

362323

363324
def test_early_stopping_mode_options():

0 commit comments

Comments
 (0)