@@ -112,25 +112,46 @@ def test_overfit_batch_limits(tmpdir):
112
112
# ------------------------------------------------------
113
113
for split in (RunningStage .VALIDATING , RunningStage .TESTING ):
114
114
115
+ # ------------------------------------------------------
116
+ # test overfit_batches action
117
+ # ------------------------------------------------------
118
+
115
119
# ------------------------------------------------------
116
120
# test overfit_batches as percent
117
121
# ------------------------------------------------------
118
122
trainer = Trainer (overfit_batches = 0.11 )
119
123
trainer ._data_connector .attach_dataloaders (model )
120
124
loader_num_batches , _ = trainer ._reset_eval_dataloader (split , model = model )
121
- assert loader_num_batches [0 ] == 0
125
+ if split == RunningStage .VALIDATING :
126
+ assert loader_num_batches [0 ] == 0
127
+ else :
128
+ assert loader_num_batches [0 ] == len (test_loader )
122
129
123
130
# ------------------------------------------------------
124
131
# test overfit_batches as int
125
132
# ------------------------------------------------------
126
133
trainer = Trainer (overfit_batches = 1 )
127
134
trainer ._data_connector .attach_dataloaders (model )
128
- loader_num_batches , _ = trainer ._reset_eval_dataloader (split , model = model )
129
- assert loader_num_batches [0 ] == 0
135
+ loader_num_batches , dataloaders = trainer ._reset_eval_dataloader (split , model = model )
136
+ if split == RunningStage .VALIDATING :
137
+ assert loader_num_batches [0 ] == 0
138
+ else :
139
+ assert loader_num_batches [0 ] == len (test_loader )
140
+ # make sure we turned off shuffle for the user
141
+ assert isinstance (dataloaders [0 ].sampler , SequentialSampler )
142
+
143
+ # make sure the loaders are the same
144
+ (xb , yb ) = next (iter (dataloaders [0 ]))
145
+ assert torch .eq (xa , xb ).all ()
146
+ assert torch .eq (ya , yb ).all ()
147
+
130
148
trainer = Trainer (overfit_batches = 5 )
131
149
trainer ._data_connector .attach_dataloaders (model )
132
150
loader_num_batches , _ = trainer ._reset_eval_dataloader (split , model = model )
133
- assert loader_num_batches [0 ] == 0
151
+ if split == RunningStage .VALIDATING :
152
+ assert loader_num_batches [0 ] == 0
153
+ else :
154
+ assert loader_num_batches [0 ] == len (test_loader )
134
155
135
156
# ------------------------------------------------------
136
157
# test limit_xxx_batches as percent AND int
0 commit comments