@@ -117,11 +117,11 @@ def pre_step(self, current_action: str) -> None:
117
117
118
118
def reset (self ):
119
119
# handle properly `fast_dev_run`. PyTorch Profiler will fail otherwise.
120
- self ._num_optimizer_step_and_closure = 0
120
+ self ._num_optimizer_step_with_closure = 0
121
121
self ._num_validation_step = 0
122
122
self ._num_test_step = 0
123
123
self ._num_predict_step = 0
124
- self ._optimizer_step_and_closure_reached_end = False
124
+ self ._optimizer_step_with_closure_reached_end = False
125
125
self ._validation_step_reached_end = False
126
126
self ._test_step_reached_end = False
127
127
self ._predict_step_reached_end = False
@@ -132,13 +132,13 @@ def reset(self):
132
132
@property
133
133
def is_training (self ) -> bool :
134
134
return self ._current_action is not None and (
135
- self ._current_action .startswith ("optimizer_step_and_closure_ " ) or self ._current_action == "training_step"
135
+ self ._current_action .startswith ("optimizer_step_with_closure_ " ) or self ._current_action == "training_step"
136
136
)
137
137
138
138
@property
139
139
def num_step (self ) -> int :
140
140
if self .is_training :
141
- return self ._num_optimizer_step_and_closure
141
+ return self ._num_optimizer_step_with_closure
142
142
if self ._current_action == "validation_step" :
143
143
return self ._num_validation_step
144
144
if self ._current_action == "test_step" :
@@ -149,10 +149,10 @@ def num_step(self) -> int:
149
149
150
150
def _step (self ) -> None :
151
151
if self .is_training :
152
- self ._num_optimizer_step_and_closure += 1
152
+ self ._num_optimizer_step_with_closure += 1
153
153
elif self ._current_action == "validation_step" :
154
154
if self ._start_action_name == "on_fit_start" :
155
- if self ._num_optimizer_step_and_closure > 0 :
155
+ if self ._num_optimizer_step_with_closure > 0 :
156
156
self ._num_validation_step += 1
157
157
else :
158
158
self ._num_validation_step += 1
@@ -164,7 +164,7 @@ def _step(self) -> None:
164
164
@property
165
165
def has_finished (self ) -> bool :
166
166
if self .is_training :
167
- return self ._optimizer_step_and_closure_reached_end
167
+ return self ._optimizer_step_with_closure_reached_end
168
168
if self ._current_action == "validation_step" :
169
169
return self ._validation_step_reached_end
170
170
if self ._current_action == "test_step" :
@@ -182,7 +182,7 @@ def __call__(self, num_step: int) -> "ProfilerAction":
182
182
action = self ._schedule (max (self .num_step , 0 ))
183
183
if action == ProfilerAction .RECORD_AND_SAVE :
184
184
if self .is_training :
185
- self ._optimizer_step_and_closure_reached_end = True
185
+ self ._optimizer_step_with_closure_reached_end = True
186
186
elif self ._current_action == "validation_step" :
187
187
self ._validation_step_reached_end = True
188
188
elif self ._current_action == "test_step" :
@@ -202,9 +202,9 @@ class PyTorchProfiler(BaseProfiler):
202
202
"test_step" ,
203
203
"predict_step" ,
204
204
}
205
- RECORD_FUNCTION_PREFIX = "optimizer_step_and_closure_ "
205
+ RECORD_FUNCTION_PREFIX = "optimizer_step_with_closure_ "
206
206
STEP_FUNCTIONS = {"training_step" , "validation_step" , "test_step" , "predict_step" }
207
- STEP_FUNCTION_PREFIX = "optimizer_step_and_closure_ "
207
+ STEP_FUNCTION_PREFIX = "optimizer_step_with_closure_ "
208
208
AVAILABLE_SORT_KEYS = {
209
209
"cpu_time" ,
210
210
"cuda_time" ,
@@ -383,8 +383,8 @@ def start(self, action_name: str) -> None:
383
383
self ._register .__enter__ ()
384
384
385
385
if self ._lightning_module is not None :
386
- # when the model is used in automatic optimization,
387
- # we use `optimizer_step_and_closure` to step the model.
386
+ # when the model is used in automatic optimization, we use `optimizer_step_with_closure` to step the model.
387
+ # this profiler event is generated in the `LightningOptimizer.step` method
388
388
if self ._lightning_module .automatic_optimization and "training_step" in self .STEP_FUNCTIONS :
389
389
self .STEP_FUNCTIONS .remove ("training_step" )
390
390
0 commit comments