@@ -1118,52 +1118,6 @@ def test_dataloaders_load_only_once(tmpdir):
1118
1118
1119
1119
assert tracker .mock_calls == [call .val_dataloader (), call .train_dataloader ()]
1120
1120
1121
-
1122
- def test_dataloaders_load_only_once_val_interval (tmpdir ):
1123
- model = BoringModel ()
1124
-
1125
- # logger file to get meta
1126
- trainer = Trainer (
1127
- default_root_dir = tmpdir ,
1128
- limit_train_batches = 10 ,
1129
- limit_val_batches = 10 ,
1130
- val_check_interval = 0.3 ,
1131
- reload_dataloaders_every_n_epochs = 1 ,
1132
- max_epochs = 3 ,
1133
- )
1134
-
1135
- tracker = Mock ()
1136
- model .train_dataloader = Mock (wraps = model .train_dataloader )
1137
- model .val_dataloader = Mock (wraps = model .val_dataloader )
1138
- model .test_dataloader = Mock (wraps = model .test_dataloader )
1139
-
1140
- tracker .attach_mock (model .train_dataloader , "train_dataloader" )
1141
- tracker .attach_mock (model .val_dataloader , "val_dataloader" )
1142
- tracker .attach_mock (model .test_dataloader , "test_dataloader" )
1143
-
1144
- trainer .fit (model )
1145
- trainer .test (model )
1146
-
1147
- # verify the sequence
1148
- expected_sequence = [
1149
- call .val_dataloader (),
1150
- call .train_dataloader (),
1151
- call .val_dataloader (),
1152
- call .val_dataloader (),
1153
- call .val_dataloader (),
1154
- call .train_dataloader (),
1155
- call .val_dataloader (),
1156
- call .val_dataloader (),
1157
- call .val_dataloader (),
1158
- call .train_dataloader (),
1159
- call .val_dataloader (),
1160
- call .val_dataloader (),
1161
- call .val_dataloader (),
1162
- call .test_dataloader (),
1163
- ]
1164
- assert tracker .mock_calls == expected_sequence
1165
-
1166
-
1167
1121
def test_dataloaders_load_only_once_no_sanity_check (tmpdir ):
1168
1122
model = BoringModel ()
1169
1123
@@ -1190,14 +1144,25 @@ def test_dataloaders_load_only_once_no_sanity_check(tmpdir):
1190
1144
1191
1145
@pytest .mark .parametrize ("n" , [1 , 2 ])
1192
1146
def test_dataloaders_load_every_n_epochs (tmpdir , n ):
1193
- model = BoringModel ()
1147
+ train_reload_epochs , val_reload_epochs = [], []
1148
+
1149
+ class TestModel (BoringModel ):
1150
+ def train_dataloader (self ):
1151
+ train_reload_epochs .append (self .current_epoch )
1152
+ return super ().train_dataloader ()
1153
+
1154
+ def val_dataloader (self ):
1155
+ val_reload_epochs .append (self .current_epoch )
1156
+ return super ().val_dataloader ()
1157
+
1158
+ model = TestModel ()
1194
1159
1195
1160
trainer = Trainer (
1196
1161
default_root_dir = tmpdir ,
1197
1162
limit_train_batches = 0.3 ,
1198
1163
limit_val_batches = 0.3 ,
1199
1164
reload_dataloaders_every_n_epochs = n ,
1200
- max_epochs = 3 ,
1165
+ max_epochs = 5 ,
1201
1166
)
1202
1167
1203
1168
tracker = Mock ()
@@ -1212,15 +1177,24 @@ def test_dataloaders_load_every_n_epochs(tmpdir, n):
1212
1177
trainer .fit (model )
1213
1178
trainer .test (model )
1214
1179
1215
- # verify the sequence
1216
- expected_sequence = [call .val_dataloader ()]
1180
+ # Verify the sequence
1181
+ expected_sequence = [call .val_dataloader (), call . train_dataloader ()] # Sanity check first
1217
1182
if n == 1 :
1218
- expected_sequence += [call .train_dataloader (), call .val_dataloader ()] * 3
1183
+ expected_sequence += [call .train_dataloader (), call .val_dataloader ()] * 4
1219
1184
elif n == 2 :
1220
1185
expected_sequence += [call .train_dataloader (), call .val_dataloader ()] * 2
1221
1186
expected_sequence += [call .test_dataloader ()]
1187
+
1222
1188
assert tracker .mock_calls == expected_sequence
1223
1189
1190
+ # Verify epoch of reloads
1191
+ if n == 1 :
1192
+ assert train_reload_epochs == [0 , 1 , 2 , 3 , 4 ]
1193
+ assert val_reload_epochs == [0 , 1 , 2 , 3 , 4 ]
1194
+ elif n == 2 :
1195
+ assert train_reload_epochs == [0 , 2 , 4 ]
1196
+ assert val_reload_epochs == [0 , 2 , 4 ]
1197
+
1224
1198
1225
1199
@pytest .mark .parametrize ("n" , [3 , 5 ])
1226
1200
def test_dataloaders_load_every_n_epochs_infrequent_val (tmpdir , n ):
@@ -1263,6 +1237,7 @@ def val_dataloader(self):
1263
1237
1264
1238
model .test_dataloader .assert_called_once ()
1265
1239
1240
+
1266
1241
def test_dataloaders_load_every_n_epochs_frequent_val (tmpdir ):
1267
1242
train_reload_epochs , val_reload_epochs , val_check_epochs = [], [], []
1268
1243
@@ -1303,6 +1278,7 @@ def validation_epoch_end(self, outputs):
1303
1278
# Verify validation happens 3 times per epoch + 1 for sanity check
1304
1279
assert val_check_epochs == [0 , 0 , 0 , 0 , 1 , 1 , 1 , 2 , 2 , 2 ]
1305
1280
1281
+
1306
1282
@pytest .mark .parametrize ("n" , ["test" , - 1 ])
1307
1283
def test_dataloaders_load_every_n_epochs_exception (tmpdir , n ):
1308
1284
@@ -1347,15 +1323,6 @@ def validation_step(self, batch, batch_idx):
1347
1323
expected_calls = [
1348
1324
call .train_dataloader (),
1349
1325
call .val_dataloader (),
1350
- # This has subsequent calls to val_dataloader
1351
- # because the training loop runs the evaluation loop,
1352
- # which reloads the val dataloader again.
1353
- # We cannot yet rely on trainer.current_epoch=0 to skip reloading
1354
- # the val dataloader on the first epoch because this only tracks the training epoch
1355
- # meaning multiple passes through the validation data within a single training epoch
1356
- # would not have the dataloader reloaded.
1357
- # This breaks the assumption behind reload_dataloaders_every_n_epochs=1
1358
- call .val_dataloader (),
1359
1326
call .train_dataloader (),
1360
1327
call .val_dataloader (),
1361
1328
call .train_dataloader (),
0 commit comments