-
Notifications
You must be signed in to change notification settings - Fork 3.5k
Follow-up changes to #10575 #10957
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Follow-up changes to #10575 #10957
Changes from all commits
d37e8da
ab7a76e
ac07666
b34c77a
4f91270
c3513a0
d8a4158
9559167
96417c0
30244f4
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -144,12 +144,10 @@ def closure(self, *args: Any, **kwargs: Any) -> ClosureResult: | |
) | ||
|
||
if self._zero_grad_fn is not None: | ||
with self._profiler.profile("zero_grad"): | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. If we get rid of They are removed from here because each function will already profile them There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I think we could remove it. PyTorch Profiler would track those operations already at the PyTorch level. |
||
self._zero_grad_fn() | ||
self._zero_grad_fn() | ||
|
||
if self._backward_fn is not None and step_output.closure_loss is not None: | ||
with self._profiler.profile("backward"): | ||
self._backward_fn(step_output.closure_loss) | ||
self._backward_fn(step_output.closure_loss) | ||
|
||
return step_output | ||
|
||
|
@@ -320,7 +318,7 @@ def _make_backward_fn(self, optimizer: Optimizer, opt_idx: int) -> Optional[Call | |
return None | ||
|
||
def backward_fn(loss: Tensor) -> None: | ||
self.trainer.training_type_plugin.backward(loss, optimizer, opt_idx) | ||
self.trainer._call_ttp_hook("backward", loss, optimizer, opt_idx) | ||
|
||
# check if model weights are nan | ||
if self.trainer._terminate_on_nan: | ||
|
@@ -362,16 +360,15 @@ def _optimizer_step( | |
train_step_and_backward_closure: the closure function performing the train step and computing the | ||
gradients. By default called by the optimizer (if possible) | ||
""" | ||
lightning_module = self.trainer.lightning_module | ||
|
||
is_lbfgs = isinstance(optimizer, torch.optim.LBFGS) | ||
# wraps into LightningOptimizer only for running step | ||
optimizer = LightningOptimizer._to_lightning_optimizer(optimizer, self.trainer, opt_idx) | ||
|
||
self.optim_progress.optimizer.step.increment_ready() | ||
|
||
# model hook | ||
lightning_module.optimizer_step( | ||
self.trainer._call_lightning_module_hook( | ||
"optimizer_step", | ||
self.trainer.current_epoch, | ||
batch_idx, | ||
optimizer, | ||
|
@@ -403,7 +400,7 @@ def _optimizer_zero_grad(self, batch_idx: int, optimizer: torch.optim.Optimizer, | |
optimizer: the current optimizer | ||
opt_idx: the index of the current optimizer | ||
""" | ||
self.trainer.training_type_plugin.optimizer_zero_grad(self.trainer.current_epoch, batch_idx, optimizer, opt_idx) | ||
self.trainer._call_ttp_hook("optimizer_zero_grad", self.trainer.current_epoch, batch_idx, optimizer, opt_idx) | ||
self.optim_progress.optimizer.zero_grad.increment_completed() | ||
|
||
def _training_step(self, split_batch: Any, batch_idx: int, opt_idx: int) -> ClosureResult: | ||
|
@@ -427,7 +424,7 @@ def _training_step(self, split_batch: Any, batch_idx: int, opt_idx: int) -> Clos | |
) | ||
|
||
# manually capture logged metrics | ||
training_step_output = self.trainer._call_accelerator_hook("training_step", *step_kwargs.values()) | ||
training_step_output = self.trainer._call_ttp_hook("training_step", *step_kwargs.values()) | ||
self.trainer.training_type_plugin.post_training_step() | ||
|
||
del step_kwargs | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -164,8 +164,7 @@ def optimizer_step( | |
|
||
def optimizer_zero_grad(self, current_epoch: int, batch_idx: int, optimizer: Optimizer, opt_idx: int) -> None: | ||
"""Zeros all model parameter's gradients.""" | ||
model_ref = self.lightning_module | ||
model_ref.optimizer_zero_grad(current_epoch, batch_idx, optimizer, opt_idx) | ||
self.lightning_module.optimizer_zero_grad(current_epoch, batch_idx, optimizer, opt_idx) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. why not There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Because it's called already with No TTP subclasses this hook, so technically the optimizer_loop could call |
||
|
||
def _setup_model_and_optimizers(self, model: Module, optimizers: List[Optimizer]) -> Tuple[Module, List[Optimizer]]: | ||
"""Setup a model and multiple optimizers together. | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -1452,15 +1452,15 @@ def _call_lightning_module_hook( | |
*args: Any, | ||
pl_module: Optional["pl.LightningModule"] = None, | ||
**kwargs: Any, | ||
): | ||
) -> Any: | ||
pl_module = pl_module or self.lightning_module | ||
|
||
if pl_module is None: | ||
raise TypeError("No Lightning Module is available to call hooks on") | ||
|
||
fn = getattr(pl_module, hook_name) | ||
if not callable(fn): | ||
return None | ||
return | ||
|
||
prev_fx_name = pl_module._current_fx_name | ||
pl_module._current_fx_name = hook_name | ||
|
@@ -1479,16 +1479,15 @@ def _call_callback_hooks( | |
hook_name: str, | ||
*args: Any, | ||
**kwargs: Any, | ||
) -> Optional[Any]: | ||
output = None | ||
) -> None: | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. why are we returning None for callback hooks? I thought part of the motivation for #8506 was to allow returning something for callback hooks? i.e. 4) in the description There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. It was described as a limitation, but it's not something we want to support with the refactor. However, the refactor allows for adding it easily when we want to. |
||
if hook_name in ("on_init_start", "on_init_end"): | ||
# these `Callback` hooks are the only ones that do not take a lightning module. | ||
# we also don't profile bc profiler hasn't been set yet | ||
for callback in self.callbacks: | ||
fn = getattr(callback, hook_name) | ||
if callable(fn): | ||
output = fn(self, *args, **kwargs) | ||
return output | ||
fn(self, *args, **kwargs) | ||
return | ||
|
||
pl_module = self.lightning_module | ||
if pl_module: | ||
|
@@ -1500,34 +1499,39 @@ def _call_callback_hooks( | |
fn = getattr(self, hook_name) | ||
if callable(fn): | ||
with self.profiler.profile(hook_name): | ||
output = fn(*args, **kwargs) | ||
fn(*args, **kwargs) | ||
else: | ||
for callback in self.callbacks: | ||
fn = getattr(callback, hook_name) | ||
if callable(fn): | ||
with self.profiler.profile(hook_name): | ||
output = fn(self, self.lightning_module, *args, **kwargs) | ||
fn(self, self.lightning_module, *args, **kwargs) | ||
|
||
if pl_module: | ||
# restore current_fx when nested context | ||
pl_module._current_fx_name = prev_fx_name | ||
|
||
return output | ||
|
||
# TODO: rename to _call_strategy_hook and eventually no longer need this | ||
def _call_ttp_hook( | ||
self, | ||
hook_name: str, | ||
*args: Any, | ||
**kwargs: Any, | ||
): | ||
) -> Any: | ||
pl_module = self.lightning_module | ||
prev_fx_name = pl_module._current_fx_name | ||
pl_module._current_fx_name = hook_name | ||
|
||
fn = getattr(self.training_type_plugin, hook_name) | ||
if not callable(fn): | ||
return None | ||
return | ||
|
||
with self.profiler.profile(hook_name): | ||
output = fn(*args, **kwargs) | ||
|
||
# restore current_fx when nested context | ||
pl_module._current_fx_name = prev_fx_name | ||
|
||
return output | ||
|
||
# TODO: eventually no longer need this | ||
|
@@ -1536,15 +1540,21 @@ def _call_accelerator_hook( | |
hook_name: str, | ||
*args: Any, | ||
**kwargs: Any, | ||
) -> Optional[Any]: | ||
self.lightning_module._current_fx_name = hook_name | ||
fn = getattr(self.training_type_plugin, hook_name) | ||
) -> Any: | ||
pl_module = self.lightning_module | ||
daniellepintz marked this conversation as resolved.
Show resolved
Hide resolved
|
||
prev_fx_name = pl_module._current_fx_name | ||
pl_module._current_fx_name = hook_name | ||
|
||
fn = getattr(self.accelerator, hook_name) | ||
if not callable(fn): | ||
return None | ||
return | ||
|
||
with self.profiler.profile(hook_name): | ||
output = fn(*args, **kwargs) | ||
|
||
# restore current_fx when nested context | ||
pl_module._current_fx_name = prev_fx_name | ||
|
||
return output | ||
|
||
@staticmethod | ||
|
Uh oh!
There was an error while loading. Please reload this page.