Skip to content

Commit 8ae0742

Browse files
committed
add checker for test
1 parent ad167fe commit 8ae0742

File tree

1 file changed

+25
-4
lines changed

1 file changed

+25
-4
lines changed

tests/trainer/test_trainer_tricks.py

Lines changed: 25 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -112,25 +112,46 @@ def test_overfit_batch_limits(tmpdir):
112112
# ------------------------------------------------------
113113
for split in (RunningStage.VALIDATING, RunningStage.TESTING):
114114

115+
# ------------------------------------------------------
116+
# test overfit_batches action
117+
# ------------------------------------------------------
118+
115119
# ------------------------------------------------------
116120
# test overfit_batches as percent
117121
# ------------------------------------------------------
118122
trainer = Trainer(overfit_batches=0.11)
119123
trainer._data_connector.attach_dataloaders(model)
120124
loader_num_batches, _ = trainer._reset_eval_dataloader(split, model=model)
121-
assert loader_num_batches[0] == 0
125+
if split == RunningStage.VALIDATING:
126+
assert loader_num_batches[0] == 0
127+
else:
128+
assert loader_num_batches[0] == len(test_loader)
122129

123130
# ------------------------------------------------------
124131
# test overfit_batches as int
125132
# ------------------------------------------------------
126133
trainer = Trainer(overfit_batches=1)
127134
trainer._data_connector.attach_dataloaders(model)
128-
loader_num_batches, _ = trainer._reset_eval_dataloader(split, model=model)
129-
assert loader_num_batches[0] == 0
135+
loader_num_batches, dataloaders = trainer._reset_eval_dataloader(split, model=model)
136+
if split == RunningStage.VALIDATING:
137+
assert loader_num_batches[0] == 0
138+
else:
139+
assert loader_num_batches[0] == len(test_loader)
140+
# make sure we turned off shuffle for the user
141+
assert isinstance(dataloaders[0].sampler, SequentialSampler)
142+
143+
# make sure the loaders are the same
144+
(xb, yb) = next(iter(dataloaders[0]))
145+
assert torch.eq(xa, xb).all()
146+
assert torch.eq(ya, yb).all()
147+
130148
trainer = Trainer(overfit_batches=5)
131149
trainer._data_connector.attach_dataloaders(model)
132150
loader_num_batches, _ = trainer._reset_eval_dataloader(split, model=model)
133-
assert loader_num_batches[0] == 0
151+
if split == RunningStage.VALIDATING:
152+
assert loader_num_batches[0] == 0
153+
else:
154+
assert loader_num_batches[0] == len(test_loader)
134155

135156
# ------------------------------------------------------
136157
# test limit_xxx_batches as percent AND int

0 commit comments

Comments
 (0)