@@ -44,6 +44,7 @@ def __init__(self) -> None:
44
44
self ._outputs : _OUTPUTS_TYPE = []
45
45
self ._warning_cache : WarningCache = WarningCache ()
46
46
self ._remaining_splits : Optional [List [Any ]] = None
47
+ self ._exit_signal : int = 0
47
48
48
49
@property
49
50
def done (self ) -> bool :
@@ -58,35 +59,6 @@ def connect(
58
59
if manual_loop is not None :
59
60
self .manual_loop = manual_loop
60
61
61
- def run (self , batch : Any , batch_idx : int ) -> AttributeDict :
62
- """Runs all the data splits and the ``on_batch_start`` and ``on_train_batch_start`` hooks.
63
-
64
- Args:
65
- batch: the current batch to run the train step on
66
- batch_idx: the index of the current batch
67
- """
68
- if batch is None :
69
- self ._warning_cache .warn ("train_dataloader yielded None. If this was on purpose, ignore this warning..." )
70
- return AttributeDict (signal = 0 , outputs = [])
71
-
72
- # hook
73
- self .trainer .logger_connector .on_batch_start ()
74
- response = self .trainer .call_hook ("on_batch_start" )
75
- if response == - 1 :
76
- return AttributeDict (signal = - 1 )
77
-
78
- # hook
79
- response = self .trainer .call_hook ("on_train_batch_start" , batch , batch_idx , 0 )
80
- if response == - 1 :
81
- return AttributeDict (signal = - 1 )
82
-
83
- self .trainer .fit_loop .epoch_loop .batch_progress .increment_started ()
84
-
85
- super ().run (batch , batch_idx )
86
-
87
- output , self ._outputs = AttributeDict (signal = 0 , outputs = self ._outputs ), None # free memory
88
- return output
89
-
90
62
def reset (self ) -> None :
91
63
"""Resets the loop state."""
92
64
self ._outputs = []
@@ -108,13 +80,31 @@ def advance(self, batch, batch_idx):
108
80
batch: the current batch to run the training on (this is not the split!)
109
81
batch_idx: the index of the current batch
110
82
"""
111
- void (batch )
83
+ if batch is None :
84
+ self ._warning_cache .warn ("train_dataloader yielded None. If this was on purpose, ignore this warning..." )
85
+ raise StopIteration
86
+
112
87
split_idx , split_batch = self ._remaining_splits .pop (0 )
113
88
self .split_idx = split_idx
114
89
115
90
# let logger connector extract current batch size
116
91
self .trainer .logger_connector .on_train_split_start (batch_idx , split_idx , split_batch )
117
92
93
+ # hook
94
+ self .trainer .logger_connector .on_batch_start ()
95
+ response = self .trainer .call_hook ("on_batch_start" )
96
+ if response == - 1 :
97
+ self ._exit_signal = - 1
98
+ raise StopIteration
99
+
100
+ # hook
101
+ response = self .trainer .call_hook ("on_train_batch_start" , batch , batch_idx , 0 )
102
+ if response == - 1 :
103
+ self ._exit_signal = - 1
104
+ raise StopIteration
105
+
106
+ self .trainer .fit_loop .epoch_loop .batch_progress .increment_started ()
107
+
118
108
# choose which loop will run the optimization
119
109
if self .trainer .lightning_module .automatic_optimization :
120
110
optimizers = _get_active_optimizers (self .trainer .optimizers , self .trainer .optimizer_frequencies , batch_idx )
@@ -131,6 +121,9 @@ def on_run_end(self) -> None:
131
121
self .optimizer_loop ._hiddens = None
132
122
# this is not necessary as the manual loop runs for only 1 iteration, but just in case
133
123
self .manual_loop ._hiddens = None
124
+ output , self ._outputs = AttributeDict (signal = self ._exit_signal , outputs = self ._outputs ), None # free memory
125
+ self ._exit_signal = 0
126
+ return output
134
127
135
128
def teardown (self ) -> None :
136
129
# release memory
0 commit comments