@@ -1122,17 +1122,12 @@ def test_dataloaders_load_only_once(tmpdir):
1122
1122
assert tracker .mock_calls == [call .val_dataloader (), call .train_dataloader ()]
1123
1123
1124
1124
1125
- def test_dataloaders_load_only_once_val_interval (tmpdir ):
1125
+ def test_dataloaders_load_only_once_no_sanity_check (tmpdir ):
1126
1126
model = BoringModel ()
1127
1127
1128
1128
# logger file to get meta
1129
1129
trainer = Trainer (
1130
- default_root_dir = tmpdir ,
1131
- limit_train_batches = 10 ,
1132
- limit_val_batches = 10 ,
1133
- val_check_interval = 0.3 ,
1134
- reload_dataloaders_every_n_epochs = True ,
1135
- max_epochs = 3 ,
1130
+ default_root_dir = tmpdir , limit_train_batches = 0.3 , limit_val_batches = 0.3 , num_sanity_val_steps = 0 , max_epochs = 3
1136
1131
)
1137
1132
1138
1133
tracker = Mock ()
@@ -1145,34 +1140,33 @@ def test_dataloaders_load_only_once_val_interval(tmpdir):
1145
1140
tracker .attach_mock (model .test_dataloader , "test_dataloader" )
1146
1141
1147
1142
trainer .fit (model )
1148
- trainer .test (model )
1149
1143
1150
1144
# verify the sequence
1151
- expected_sequence = [
1152
- call .val_dataloader (),
1153
- call .train_dataloader (),
1154
- call .val_dataloader (),
1155
- call .val_dataloader (),
1156
- call .val_dataloader (),
1157
- call .train_dataloader (),
1158
- call .val_dataloader (),
1159
- call .val_dataloader (),
1160
- call .val_dataloader (),
1161
- call .train_dataloader (),
1162
- call .val_dataloader (),
1163
- call .val_dataloader (),
1164
- call .val_dataloader (),
1165
- call .test_dataloader (),
1166
- ]
1145
+ expected_sequence = [call .train_dataloader (), call .val_dataloader ()]
1167
1146
assert tracker .mock_calls == expected_sequence
1168
1147
1169
1148
1170
- def test_dataloaders_load_only_once_no_sanity_check (tmpdir ):
1171
- model = BoringModel ()
1149
+ @pytest .mark .parametrize ("n" , [1 , 2 ])
1150
+ def test_dataloaders_load_every_n_epochs (tmpdir , n ):
1151
+ train_reload_epochs , val_reload_epochs = [], []
1152
+
1153
+ class TestModel (BoringModel ):
1154
+ def train_dataloader (self ):
1155
+ train_reload_epochs .append (self .current_epoch )
1156
+ return super ().train_dataloader ()
1157
+
1158
+ def val_dataloader (self ):
1159
+ val_reload_epochs .append (self .current_epoch )
1160
+ return super ().val_dataloader ()
1161
+
1162
+ model = TestModel ()
1172
1163
1173
- # logger file to get meta
1174
1164
trainer = Trainer (
1175
- default_root_dir = tmpdir , limit_train_batches = 0.3 , limit_val_batches = 0.3 , num_sanity_val_steps = 0 , max_epochs = 3
1165
+ default_root_dir = tmpdir ,
1166
+ limit_train_batches = 0.3 ,
1167
+ limit_val_batches = 0.3 ,
1168
+ reload_dataloaders_every_n_epochs = n ,
1169
+ max_epochs = 5 ,
1176
1170
)
1177
1171
1178
1172
tracker = Mock ()
@@ -1185,44 +1179,113 @@ def test_dataloaders_load_only_once_no_sanity_check(tmpdir):
1185
1179
tracker .attach_mock (model .test_dataloader , "test_dataloader" )
1186
1180
1187
1181
trainer .fit (model )
1182
+ trainer .test (model )
1183
+
1184
+ # Verify the sequence
1185
+ expected_sequence = [call .val_dataloader (), call .train_dataloader ()] # Sanity check first
1186
+ if n == 1 :
1187
+ expected_sequence += [call .train_dataloader (), call .val_dataloader ()] * 4
1188
+ elif n == 2 :
1189
+ expected_sequence += [call .train_dataloader (), call .val_dataloader ()] * 2
1190
+ expected_sequence += [call .test_dataloader ()]
1188
1191
1189
- # verify the sequence
1190
- expected_sequence = [call .train_dataloader (), call .val_dataloader ()]
1191
1192
assert tracker .mock_calls == expected_sequence
1192
1193
1194
+ # Verify epoch of reloads
1195
+ if n == 1 :
1196
+ assert train_reload_epochs == [0 , 1 , 2 , 3 , 4 ]
1197
+ assert val_reload_epochs == [0 , 1 , 2 , 3 , 4 ]
1198
+ elif n == 2 :
1199
+ assert train_reload_epochs == [0 , 2 , 4 ]
1200
+ assert val_reload_epochs == [0 , 2 , 4 ]
1193
1201
1194
- @pytest .mark .parametrize ("n" , [1 , 2 ])
1195
- def test_dataloaders_load_every_n_epochs (tmpdir , n ):
1196
- model = BoringModel ()
1202
+
1203
+ @pytest .mark .parametrize (
1204
+ "n, train_reload_epochs_expect, val_reload_epochs_expect" ,
1205
+ [
1206
+ # Sanity check at epoch 0 creates a validation dataloader, but validation is
1207
+ # checked (and in this case reloaded) every n epochs starting from epoch n-1
1208
+ (3 , [0 , 2 , 4 , 6 , 8 ], [0 , 2 , 5 , 8 ]),
1209
+ (5 , [0 , 2 , 4 , 6 , 8 ], [0 , 4 , 9 ]),
1210
+ ],
1211
+ )
1212
+ def test_dataloaders_load_every_n_epochs_infrequent_val (
1213
+ tmpdir , n , train_reload_epochs_expect , val_reload_epochs_expect
1214
+ ):
1215
+ """Test dataloader reload behavior when infrequently checking validation set (via check_val_every_n_epoch)"""
1216
+ train_reload_epochs , val_reload_epochs = [], []
1217
+
1218
+ class TestModel (BoringModel ):
1219
+ def train_dataloader (self ):
1220
+ train_reload_epochs .append (self .current_epoch )
1221
+ return super ().train_dataloader ()
1222
+
1223
+ def val_dataloader (self ):
1224
+ val_reload_epochs .append (self .current_epoch )
1225
+ return super ().val_dataloader ()
1226
+
1227
+ model = TestModel ()
1197
1228
1198
1229
trainer = Trainer (
1199
1230
default_root_dir = tmpdir ,
1200
1231
limit_train_batches = 0.3 ,
1201
1232
limit_val_batches = 0.3 ,
1202
- reload_dataloaders_every_n_epochs = n ,
1233
+ check_val_every_n_epoch = n ,
1234
+ reload_dataloaders_every_n_epochs = 2 ,
1235
+ max_epochs = 10 ,
1236
+ )
1237
+ model .test_dataloader = Mock (wraps = model .test_dataloader )
1238
+
1239
+ trainer .fit (model )
1240
+ trainer .test (model )
1241
+
1242
+ # Verify epoch of reloads
1243
+ assert train_reload_epochs == train_reload_epochs_expect
1244
+ assert val_reload_epochs == val_reload_epochs_expect
1245
+
1246
+ model .test_dataloader .assert_called_once ()
1247
+
1248
+
1249
+ def test_dataloaders_load_every_n_epochs_frequent_val (tmpdir ):
1250
+ """Test dataloader reload behavior when frequently checking validation set (via val_check_interval)"""
1251
+ train_reload_epochs , val_reload_epochs , val_check_epochs = [], [], []
1252
+
1253
+ class TestModel (BoringModel ):
1254
+ def train_dataloader (self ):
1255
+ train_reload_epochs .append (self .current_epoch )
1256
+ return super ().train_dataloader ()
1257
+
1258
+ def val_dataloader (self ):
1259
+ val_reload_epochs .append (self .current_epoch )
1260
+ return super ().val_dataloader ()
1261
+
1262
+ def validation_epoch_end (self , outputs ):
1263
+ val_check_epochs .append (self .current_epoch )
1264
+ return super ().validation_epoch_end (outputs )
1265
+
1266
+ model = TestModel ()
1267
+
1268
+ trainer = Trainer (
1269
+ default_root_dir = tmpdir ,
1270
+ limit_train_batches = 0.3 ,
1271
+ limit_val_batches = 0.3 ,
1272
+ val_check_interval = 0.3 ,
1273
+ reload_dataloaders_every_n_epochs = 1 ,
1203
1274
max_epochs = 3 ,
1204
1275
)
1205
1276
1206
- tracker = Mock ()
1207
- model .train_dataloader = Mock (wraps = model .train_dataloader )
1208
- model .val_dataloader = Mock (wraps = model .val_dataloader )
1209
1277
model .test_dataloader = Mock (wraps = model .test_dataloader )
1210
1278
1211
- tracker .attach_mock (model .train_dataloader , "train_dataloader" )
1212
- tracker .attach_mock (model .val_dataloader , "val_dataloader" )
1213
- tracker .attach_mock (model .test_dataloader , "test_dataloader" )
1214
-
1215
1279
trainer .fit (model )
1216
1280
trainer .test (model )
1217
1281
1218
- # verify the sequence
1219
- expected_sequence = [call .val_dataloader ()]
1220
- if n == 1 :
1221
- expected_sequence += [call .train_dataloader (), call .val_dataloader ()] * 3
1222
- elif n == 2 :
1223
- expected_sequence += [call .train_dataloader (), call .val_dataloader ()] * 2
1224
- expected_sequence += [call .test_dataloader ()]
1225
- assert tracker .mock_calls == expected_sequence
1282
+ # Verify epoch of reloads
1283
+ assert train_reload_epochs == [0 , 1 , 2 ]
1284
+ assert val_reload_epochs == [0 , 1 , 2 ]
1285
+ model .test_dataloader .assert_called_once ()
1286
+
1287
+ # Verify validation happens 3 times per epoch + 1 for sanity check
1288
+ assert val_check_epochs == [0 , 0 , 0 , 0 , 1 , 1 , 1 , 2 , 2 , 2 ]
1226
1289
1227
1290
1228
1291
@pytest .mark .parametrize ("n" , ["test" , - 1 ])
@@ -1269,15 +1332,6 @@ def validation_step(self, batch, batch_idx):
1269
1332
expected_calls = [
1270
1333
call .train_dataloader (),
1271
1334
call .val_dataloader (),
1272
- # This has subsequent calls to val_dataloader
1273
- # because the training loop runs the evaluation loop,
1274
- # which reloads the val dataloader again.
1275
- # We cannot yet rely on trainer.current_epoch=0 to skip reloading
1276
- # the val dataloader on the first epoch because this only tracks the training epoch
1277
- # meaning multiple passes through the validation data within a single training epoch
1278
- # would not have the dataloader reloaded.
1279
- # This breaks the assumption behind reload_dataloaders_every_epoch=True
1280
- call .val_dataloader (),
1281
1335
call .train_dataloader (),
1282
1336
call .val_dataloader (),
1283
1337
call .train_dataloader (),
0 commit comments