Skip to content

Commit 135b844

Browse files
committed
initial engine update
1 parent af4f548 commit 135b844

File tree

1 file changed

+67
-0
lines changed

1 file changed

+67
-0
lines changed

ignite/engine/engine.py

Lines changed: 67 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -332,6 +332,16 @@ def execute_something():
332332

333333
return RemovableEventHandle(event_name, handler, self)
334334

335+
@staticmethod
336+
def _assert_non_filtered_event(event_name: Any) -> None:
337+
if (
338+
isinstance(event_name, CallableEventWithFilter)
339+
and event_name.filter != CallableEventWithFilter.default_event_filter
340+
):
341+
raise TypeError(
342+
"Argument event_name should not be a filtered event, " "please use event without any event filtering"
343+
)
344+
335345
def has_event_handler(self, handler: Callable, event_name: Optional[Any] = None) -> bool:
336346
"""Check if the specified event has the specified handler.
337347
@@ -932,6 +942,53 @@ def _setup_dataloader_iter(self) -> None:
932942
else:
933943
self._dataloader_iter = iter(self.state.dataloader)
934944

945+
def _check_and_set_max_epochs(self, max_epochs: Optional[int] = None) -> None:
946+
if max_epochs is not None:
947+
if max_epochs < 1:
948+
raise ValueError("Argument max_epochs is invalid. Please, set a correct max_epochs positive value")
949+
if self.state.max_epochs is not None and max_epochs <= self.state.epoch:
950+
raise ValueError(
951+
"Argument max_epochs should be larger than the current epoch "
952+
f"defined in the state: {max_epochs} vs {self.state.epoch}. "
953+
"Please, set engine.state.max_epochs = None "
954+
"before calling engine.run() in order to restart the training from the beginning."
955+
)
956+
self.state.max_epochs = max_epochs
957+
958+
def _check_and_set_max_iters(self, max_iters: Optional[int] = None) -> None:
959+
if max_iters is not None:
960+
if max_iters < 1:
961+
raise ValueError("Argument max_iters is invalid. Please, set a correct max_iters positive value")
962+
if (self.state.max_iters is not None) and max_iters <= self.state.iteration:
963+
raise ValueError(
964+
"Argument max_iters should be larger than the current iteration "
965+
f"defined in the state: {max_iters} vs {self.state.iteration}. "
966+
"Please, set engine.state.max_iters = None "
967+
"before calling engine.run() in order to restart the training from the beginning."
968+
)
969+
self.state.max_iters = max_iters
970+
971+
def _check_and_set_epoch_length(self, data: Iterable, epoch_length: Optional[int] = None) -> None:
972+
# Can't we accept a redefinition ?
973+
if self.state.epoch_length is not None:
974+
if epoch_length is not None:
975+
if epoch_length != self.state.epoch_length:
976+
raise ValueError(
977+
"Argument epoch_length should be same as in the state, "
978+
f"but given {epoch_length} vs {self.state.epoch_length}"
979+
)
980+
else:
981+
if epoch_length is None:
982+
epoch_length = self._get_data_length(data)
983+
984+
if epoch_length is not None and epoch_length < 1:
985+
raise ValueError(
986+
"Argument epoch_length is invalid. Please, either set a correct epoch_length value or "
987+
"check if input data has non-zero size."
988+
)
989+
990+
self.state.epoch_length = epoch_length
991+
935992
def _setup_engine(self) -> None:
936993
self._setup_dataloader_iter()
937994

@@ -1291,6 +1348,16 @@ def _run_once_on_dataset_legacy(self) -> float:
12911348

12921349
return time.time() - start_time
12931350

1351+
def debug(self, enabled: bool = True) -> None:
1352+
"""Enables/disables engine's logging debug mode"""
1353+
from ignite.utils import setup_logger
1354+
1355+
if enabled:
1356+
setattr(self, "_stored_logger", self.logger)
1357+
self.logger = setup_logger(level=logging.DEBUG)
1358+
elif hasattr(self, "_stored_logger"):
1359+
self.logger = getattr(self, "_stored_logger")
1360+
12941361

12951362
def _get_none_data_iter(size: int) -> Iterator:
12961363
# Sized iterator for data as None

0 commit comments

Comments
 (0)