Skip to content

Follow-up changes to #10575 #10957

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 10 commits into from
Dec 7, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 2 additions & 3 deletions pytorch_lightning/loops/dataloader/evaluation_loop.py
Original file line number Diff line number Diff line change
Expand Up @@ -186,11 +186,10 @@ def _on_evaluation_model_eval(self) -> None:

def _on_evaluation_model_train(self) -> None:
"""Sets model to train mode."""
model_ref = self.trainer.lightning_module
if self.trainer.testing:
model_ref.on_test_model_train()
self.trainer._call_lightning_module_hook("on_test_model_train")
else:
model_ref.on_validation_model_train()
self.trainer._call_lightning_module_hook("on_validation_model_train")

def _on_evaluation_end(self, *args: Any, **kwargs: Any) -> None:
"""Runs ``on_{validation/test}_end`` hook."""
Expand Down
4 changes: 2 additions & 2 deletions pytorch_lightning/loops/epoch/evaluation_epoch_loop.py
Original file line number Diff line number Diff line change
Expand Up @@ -218,9 +218,9 @@ def _evaluation_step(self, **kwargs: Any) -> Optional[STEP_OUTPUT]:
the outputs of the step
"""
if self.trainer.testing:
output = self.trainer._call_accelerator_hook("test_step", *kwargs.values())
output = self.trainer._call_ttp_hook("test_step", *kwargs.values())
else:
output = self.trainer._call_accelerator_hook("validation_step", *kwargs.values())
output = self.trainer._call_ttp_hook("validation_step", *kwargs.values())

return output

Expand Down
2 changes: 1 addition & 1 deletion pytorch_lightning/loops/epoch/prediction_epoch_loop.py
Original file line number Diff line number Diff line change
Expand Up @@ -130,7 +130,7 @@ def _predict_step(self, batch: Any, batch_idx: int, dataloader_idx: int) -> None

self.batch_progress.increment_started()

predictions = self.trainer._call_accelerator_hook("predict_step", *step_kwargs.values())
predictions = self.trainer._call_ttp_hook("predict_step", *step_kwargs.values())

self.batch_progress.increment_processed()

Expand Down
5 changes: 2 additions & 3 deletions pytorch_lightning/loops/epoch/training_epoch_loop.py
Original file line number Diff line number Diff line change
Expand Up @@ -180,11 +180,10 @@ def advance(self, data_fetcher: AbstractDataFetcher) -> None: # type: ignore[ov

# hook
self.trainer._call_callback_hooks("on_train_batch_start", batch, batch_idx, **extra_kwargs)
model_response = self.trainer._call_lightning_module_hook(
response = self.trainer._call_lightning_module_hook(
"on_train_batch_start", batch, batch_idx, **extra_kwargs
)
ttp_response = self.trainer._call_ttp_hook("on_train_batch_start", batch, batch_idx, **extra_kwargs)
response = ttp_response if model_response is None else model_response
self.trainer._call_ttp_hook("on_train_batch_start", batch, batch_idx, **extra_kwargs)
if response == -1:
self.batch_progress.increment_processed()
raise StopIteration
Expand Down
2 changes: 1 addition & 1 deletion pytorch_lightning/loops/optimization/manual_loop.py
Original file line number Diff line number Diff line change
Expand Up @@ -102,7 +102,7 @@ def advance(self, batch: Any, batch_idx: int) -> None: # type: ignore[override]
)

# manually capture logged metrics
training_step_output = self.trainer._call_accelerator_hook("training_step", *step_kwargs.values())
training_step_output = self.trainer._call_ttp_hook("training_step", *step_kwargs.values())
self.trainer.training_type_plugin.post_training_step()

del step_kwargs
Expand Down
17 changes: 7 additions & 10 deletions pytorch_lightning/loops/optimization/optimizer_loop.py
Original file line number Diff line number Diff line change
Expand Up @@ -144,12 +144,10 @@ def closure(self, *args: Any, **kwargs: Any) -> ClosureResult:
)

if self._zero_grad_fn is not None:
with self._profiler.profile("zero_grad"):
Copy link
Contributor Author

@carmocca carmocca Dec 6, 2021

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If we get rid of profile("training_step_and_backward") then we could remove the profiler reference from this class. Just pointing it out for a possible follow-up.

They are removed from here because each function will already profile them

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think we could remove it. PyTorch Profiler would track those operations already at the PyTorch level.

self._zero_grad_fn()
self._zero_grad_fn()

if self._backward_fn is not None and step_output.closure_loss is not None:
with self._profiler.profile("backward"):
self._backward_fn(step_output.closure_loss)
self._backward_fn(step_output.closure_loss)

return step_output

Expand Down Expand Up @@ -320,7 +318,7 @@ def _make_backward_fn(self, optimizer: Optimizer, opt_idx: int) -> Optional[Call
return None

def backward_fn(loss: Tensor) -> None:
self.trainer.training_type_plugin.backward(loss, optimizer, opt_idx)
self.trainer._call_ttp_hook("backward", loss, optimizer, opt_idx)

# check if model weights are nan
if self.trainer._terminate_on_nan:
Expand Down Expand Up @@ -362,16 +360,15 @@ def _optimizer_step(
train_step_and_backward_closure: the closure function performing the train step and computing the
gradients. By default called by the optimizer (if possible)
"""
lightning_module = self.trainer.lightning_module

is_lbfgs = isinstance(optimizer, torch.optim.LBFGS)
# wraps into LightningOptimizer only for running step
optimizer = LightningOptimizer._to_lightning_optimizer(optimizer, self.trainer, opt_idx)

self.optim_progress.optimizer.step.increment_ready()

# model hook
lightning_module.optimizer_step(
self.trainer._call_lightning_module_hook(
"optimizer_step",
self.trainer.current_epoch,
batch_idx,
optimizer,
Expand Down Expand Up @@ -403,7 +400,7 @@ def _optimizer_zero_grad(self, batch_idx: int, optimizer: torch.optim.Optimizer,
optimizer: the current optimizer
opt_idx: the index of the current optimizer
"""
self.trainer.training_type_plugin.optimizer_zero_grad(self.trainer.current_epoch, batch_idx, optimizer, opt_idx)
self.trainer._call_ttp_hook("optimizer_zero_grad", self.trainer.current_epoch, batch_idx, optimizer, opt_idx)
self.optim_progress.optimizer.zero_grad.increment_completed()

def _training_step(self, split_batch: Any, batch_idx: int, opt_idx: int) -> ClosureResult:
Expand All @@ -427,7 +424,7 @@ def _training_step(self, split_batch: Any, batch_idx: int, opt_idx: int) -> Clos
)

# manually capture logged metrics
training_step_output = self.trainer._call_accelerator_hook("training_step", *step_kwargs.values())
training_step_output = self.trainer._call_ttp_hook("training_step", *step_kwargs.values())
self.trainer.training_type_plugin.post_training_step()

del step_kwargs
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -164,8 +164,7 @@ def optimizer_step(

def optimizer_zero_grad(self, current_epoch: int, batch_idx: int, optimizer: Optimizer, opt_idx: int) -> None:
"""Zeros all model parameter's gradients."""
model_ref = self.lightning_module
model_ref.optimizer_zero_grad(current_epoch, batch_idx, optimizer, opt_idx)
self.lightning_module.optimizer_zero_grad(current_epoch, batch_idx, optimizer, opt_idx)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

why not _call_LM_hook?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Because it's called already with _call_ttp_hook.

https://github.com/PyTorchLightning/pytorch-lightning/blob/eec2bae6c984895d1d4df4e75cf628c58128928a/pytorch_lightning/loops/optimization/optimizer_loop.py#L403

No TTP subclasses this hook, so technically the optimizer_loop could call _call_LM_hook directly and this could be removed


def _setup_model_and_optimizers(self, model: Module, optimizers: List[Optimizer]) -> Tuple[Module, List[Optimizer]]:
"""Setup a model and multiple optimizers together.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -31,15 +31,24 @@ class _LogOptions(TypedDict):
"on_before_backward": _LogOptions(
allowed_on_step=(False, True), allowed_on_epoch=(False, True), default_on_step=True, default_on_epoch=False
),
"backward": _LogOptions(
allowed_on_step=(False, True), allowed_on_epoch=(False, True), default_on_step=True, default_on_epoch=False
),
"on_after_backward": _LogOptions(
allowed_on_step=(False, True), allowed_on_epoch=(False, True), default_on_step=True, default_on_epoch=False
),
"on_before_optimizer_step": _LogOptions(
allowed_on_step=(False, True), allowed_on_epoch=(False, True), default_on_step=True, default_on_epoch=False
),
"optimizer_step": _LogOptions(
allowed_on_step=(False, True), allowed_on_epoch=(False, True), default_on_step=True, default_on_epoch=False
),
"on_before_zero_grad": _LogOptions(
allowed_on_step=(False, True), allowed_on_epoch=(False, True), default_on_step=True, default_on_epoch=False
),
"optimizer_zero_grad": _LogOptions(
allowed_on_step=(False, True), allowed_on_epoch=(False, True), default_on_step=True, default_on_epoch=False
),
"on_init_start": None,
"on_init_end": None,
"on_fit_start": None,
Expand Down Expand Up @@ -160,6 +169,8 @@ class _LogOptions(TypedDict):
"configure_callbacks": None,
"on_validation_model_eval": None,
"on_test_model_eval": None,
"on_validation_model_train": None,
"on_test_model_train": None,
}

@classmethod
Expand Down
42 changes: 26 additions & 16 deletions pytorch_lightning/trainer/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -1452,15 +1452,15 @@ def _call_lightning_module_hook(
*args: Any,
pl_module: Optional["pl.LightningModule"] = None,
**kwargs: Any,
):
) -> Any:
pl_module = pl_module or self.lightning_module

if pl_module is None:
raise TypeError("No Lightning Module is available to call hooks on")

fn = getattr(pl_module, hook_name)
if not callable(fn):
return None
return

prev_fx_name = pl_module._current_fx_name
pl_module._current_fx_name = hook_name
Expand All @@ -1479,16 +1479,15 @@ def _call_callback_hooks(
hook_name: str,
*args: Any,
**kwargs: Any,
) -> Optional[Any]:
output = None
) -> None:
Copy link
Contributor

@daniellepintz daniellepintz Dec 9, 2021

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

why are we returning None for callback hooks? I thought part of the motivation for #8506 was to allow returning something for callback hooks? i.e. 4) in the description

Copy link
Contributor Author

@carmocca carmocca Dec 14, 2021

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It was described as a limitation, but it's not something we want to support with the refactor. However, the refactor allows for adding it easily when we want to.

if hook_name in ("on_init_start", "on_init_end"):
# these `Callback` hooks are the only ones that do not take a lightning module.
# we also don't profile bc profiler hasn't been set yet
for callback in self.callbacks:
fn = getattr(callback, hook_name)
if callable(fn):
output = fn(self, *args, **kwargs)
return output
fn(self, *args, **kwargs)
return

pl_module = self.lightning_module
if pl_module:
Expand All @@ -1500,34 +1499,39 @@ def _call_callback_hooks(
fn = getattr(self, hook_name)
if callable(fn):
with self.profiler.profile(hook_name):
output = fn(*args, **kwargs)
fn(*args, **kwargs)
else:
for callback in self.callbacks:
fn = getattr(callback, hook_name)
if callable(fn):
with self.profiler.profile(hook_name):
output = fn(self, self.lightning_module, *args, **kwargs)
fn(self, self.lightning_module, *args, **kwargs)

if pl_module:
# restore current_fx when nested context
pl_module._current_fx_name = prev_fx_name

return output

# TODO: rename to _call_strategy_hook and eventually no longer need this
def _call_ttp_hook(
self,
hook_name: str,
*args: Any,
**kwargs: Any,
):
) -> Any:
pl_module = self.lightning_module
prev_fx_name = pl_module._current_fx_name
pl_module._current_fx_name = hook_name

fn = getattr(self.training_type_plugin, hook_name)
if not callable(fn):
return None
return

with self.profiler.profile(hook_name):
output = fn(*args, **kwargs)

# restore current_fx when nested context
pl_module._current_fx_name = prev_fx_name

return output

# TODO: eventually no longer need this
Expand All @@ -1536,15 +1540,21 @@ def _call_accelerator_hook(
hook_name: str,
*args: Any,
**kwargs: Any,
) -> Optional[Any]:
self.lightning_module._current_fx_name = hook_name
fn = getattr(self.training_type_plugin, hook_name)
) -> Any:
pl_module = self.lightning_module
prev_fx_name = pl_module._current_fx_name
pl_module._current_fx_name = hook_name

fn = getattr(self.accelerator, hook_name)
if not callable(fn):
return None
return

with self.profiler.profile(hook_name):
output = fn(*args, **kwargs)

# restore current_fx when nested context
pl_module._current_fx_name = prev_fx_name

return output

@staticmethod
Expand Down
17 changes: 9 additions & 8 deletions tests/trainer/logging_/test_logger_connector.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,9 +31,9 @@


def test_fx_validator(tmpdir):
funcs_name = sorted(get_members(Callback))
funcs_name = get_members(Callback)

callbacks_func = [
callbacks_func = {
"on_before_backward",
"on_after_backward",
"on_before_optimizer_step",
Expand Down Expand Up @@ -82,9 +82,9 @@ def test_fx_validator(tmpdir):
"on_predict_start",
"setup",
"teardown",
]
}

not_supported = [
not_supported = {
"on_before_accelerator_backend_setup",
"on_fit_end",
"on_fit_start",
Expand All @@ -110,11 +110,10 @@ def test_fx_validator(tmpdir):
"on_validation_end",
"setup",
"teardown",
]
}

assert funcs_name == sorted(
callbacks_func
), "Detected new callback function. Need to add its logging permission to FxValidator and update this test"
# Detected new callback function. Need to add its logging permission to FxValidator and update this test
assert funcs_name == callbacks_func

validator = _FxValidator()

Expand Down Expand Up @@ -233,6 +232,7 @@ def test_fx_validator_integration(tmpdir):
"prepare_data": "You can't",
"configure_callbacks": "You can't",
"on_validation_model_eval": "You can't",
"on_validation_model_train": "You can't",
"summarize": "not managed by the `Trainer",
}
model = HookedModel(not_supported)
Expand Down Expand Up @@ -260,6 +260,7 @@ def test_fx_validator_integration(tmpdir):
"on_test_dataloader": "You can't",
"test_dataloader": "You can't",
"on_test_model_eval": "You can't",
"on_test_model_train": "You can't",
"on_test_end": "You can't",
}
)
Expand Down
3 changes: 3 additions & 0 deletions tests/trainer/logging_/test_loop_logging.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,9 +50,12 @@ def _make_assertion(model, hooks, result_mock, on_step, on_epoch, extra_kwargs):
trainer.state.stage = RunningStage.TRAINING
hooks = [
"on_before_backward",
"backward",
"on_after_backward",
"on_before_optimizer_step",
"optimizer_step",
"on_before_zero_grad",
"optimizer_zero_grad",
"training_step",
"training_step_end",
"on_batch_start",
Expand Down