@@ -332,6 +332,16 @@ def execute_something():
332
332
333
333
return RemovableEventHandle (event_name , handler , self )
334
334
335
+ @staticmethod
336
+ def _assert_non_filtered_event (event_name : Any ) -> None :
337
+ if (
338
+ isinstance (event_name , CallableEventWithFilter )
339
+ and event_name .filter != CallableEventWithFilter .default_event_filter
340
+ ):
341
+ raise TypeError (
342
+ "Argument event_name should not be a filtered event, " "please use event without any event filtering"
343
+ )
344
+
335
345
def has_event_handler (self , handler : Callable , event_name : Optional [Any ] = None ) -> bool :
336
346
"""Check if the specified event has the specified handler.
337
347
@@ -932,6 +942,53 @@ def _setup_dataloader_iter(self) -> None:
932
942
else :
933
943
self ._dataloader_iter = iter (self .state .dataloader )
934
944
945
+ def _check_and_set_max_epochs (self , max_epochs : Optional [int ] = None ) -> None :
946
+ if max_epochs is not None :
947
+ if max_epochs < 1 :
948
+ raise ValueError ("Argument max_epochs is invalid. Please, set a correct max_epochs positive value" )
949
+ if self .state .max_epochs is not None and max_epochs <= self .state .epoch :
950
+ raise ValueError (
951
+ "Argument max_epochs should be larger than the current epoch "
952
+ f"defined in the state: { max_epochs } vs { self .state .epoch } . "
953
+ "Please, set engine.state.max_epochs = None "
954
+ "before calling engine.run() in order to restart the training from the beginning."
955
+ )
956
+ self .state .max_epochs = max_epochs
957
+
958
+ def _check_and_set_max_iters (self , max_iters : Optional [int ] = None ) -> None :
959
+ if max_iters is not None :
960
+ if max_iters < 1 :
961
+ raise ValueError ("Argument max_iters is invalid. Please, set a correct max_iters positive value" )
962
+ if (self .state .max_iters is not None ) and max_iters <= self .state .iteration :
963
+ raise ValueError (
964
+ "Argument max_iters should be larger than the current iteration "
965
+ f"defined in the state: { max_iters } vs { self .state .iteration } . "
966
+ "Please, set engine.state.max_iters = None "
967
+ "before calling engine.run() in order to restart the training from the beginning."
968
+ )
969
+ self .state .max_iters = max_iters
970
+
971
+ def _check_and_set_epoch_length (self , data : Iterable , epoch_length : Optional [int ] = None ) -> None :
972
+ # Can't we accept a redefinition ?
973
+ if self .state .epoch_length is not None :
974
+ if epoch_length is not None :
975
+ if epoch_length != self .state .epoch_length :
976
+ raise ValueError (
977
+ "Argument epoch_length should be same as in the state, "
978
+ f"but given { epoch_length } vs { self .state .epoch_length } "
979
+ )
980
+ else :
981
+ if epoch_length is None :
982
+ epoch_length = self ._get_data_length (data )
983
+
984
+ if epoch_length is not None and epoch_length < 1 :
985
+ raise ValueError (
986
+ "Argument epoch_length is invalid. Please, either set a correct epoch_length value or "
987
+ "check if input data has non-zero size."
988
+ )
989
+
990
+ self .state .epoch_length = epoch_length
991
+
935
992
def _setup_engine (self ) -> None :
936
993
self ._setup_dataloader_iter ()
937
994
@@ -1291,6 +1348,16 @@ def _run_once_on_dataset_legacy(self) -> float:
1291
1348
1292
1349
return time .time () - start_time
1293
1350
1351
+ def debug (self , enabled : bool = True ) -> None :
1352
+ """Enables/disables engine's logging debug mode"""
1353
+ from ignite .utils import setup_logger
1354
+
1355
+ if enabled :
1356
+ setattr (self , "_stored_logger" , self .logger )
1357
+ self .logger = setup_logger (level = logging .DEBUG )
1358
+ elif hasattr (self , "_stored_logger" ):
1359
+ self .logger = getattr (self , "_stored_logger" )
1360
+
1294
1361
1295
1362
def _get_none_data_iter (size : int ) -> Iterator :
1296
1363
# Sized iterator for data as None
0 commit comments