Callback for logging forward, backward and update time #19928
Unanswered
MattMcPartlon
asked this question in
Lightning Trainer API: Trainer, LightningModule, LightningDataModule
Replies: 1 comment
-
Thanks for the implementation. I changed it a bit and here is my implementation: class LogPerformanceCallback(Callback):
def __init__(self):
super().__init__()
def on_train_batch_start(self, trainer: Trainer, pl_module: LightningModule, batch, batch_idx):
self.batch_start=time.perf_counter()
def on_train_batch_end(
self,
trainer: Trainer,
pl_module: LightningModule,
outputs: STEP_OUTPUT,
batch: Any,
batch_idx: int,
):
batch_time = time.perf_counter() - self.batch_start
pl_module.log(
"train/forward_time_seconds",
batch_time,
on_step=True,
on_epoch=False,
rank_zero_only=True # no need to log for all devices
)
# Also log the number of tokens processed per second which contribute to the loss
num_loss_tokens = (batch["labels"] != NO_LOSS_INDEX).sum().item()
pl_module.log(
"train/tps_per_device",
num_loss_tokens / batch_time,
prog_bar=True, # show in progress bar
on_step=True,
on_epoch=False,
rank_zero_only=True # no need to log for all devices
)
def on_before_backward(self, trainer, pl_module, loss):
self.backward_start = time.perf_counter()
def on_after_backward(self, trainer, pl_module):
backward_time = time.perf_counter() - self.backward_start
pl_module.log(
"train/backward_time_seconds",
backward_time,
on_step=True,
on_epoch=False,
rank_zero_only=True # no need to log for all devices
)
def on_validation_batch_start(self, trainer: Trainer, pl_module: LightningModule, batch, batch_idx, dataloader_idx=0):
self.val_batch_start = time.perf_counter()
def on_validation_batch_end(
self,
trainer: Trainer,
pl_module: LightningModule,
outputs: STEP_OUTPUT,
batch: Any,
batch_idx: int,
dataloader_idx=0
):
batch_time = time.perf_counter() - self.val_batch_start
pl_module.log(
"validation/forward_time_seconds",
batch_time,
on_step=True,
on_epoch=False,
rank_zero_only=True # no need to log for all devices
)
# Also log the number of tokens processed per second which contribute to the loss
num_loss_tokens = (batch["labels"] != NO_LOSS_INDEX).sum().item()
pl_module.log(
"validation/tps_per_device",
num_loss_tokens / batch_time,
on_step=True,
on_epoch=False,
rank_zero_only=True # no need to log for all devices
)
def on_before_optimizer_step(self, trainer: Trainer, pl_module: LightningModule, optimizer: Any) -> None:
self.step_start = time.perf_counter()
def on_before_zero_grad(self, trainer: Trainer, pl_module: LightningModule, optimizer: Any) -> None:
# This will get called at the beginning of training to clear any gradients from tuning etc.
# In those cases the step_start is not set so we do nothing.
if not hasattr(self, "step_start"):
return
step_end = time.perf_counter() - self.step_start
pl_module.log(
"train/step_time_seconds",
step_end,
on_step=True,
on_epoch=False,
rank_zero_only=True # no need to log for all devices
) I have not tested it a lot but it seems to work for me. There might be a bug w.r.t. |
Beta Was this translation helpful? Give feedback.
0 replies
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
-
I'm trying to track the performance of forward/backward/update time with a Callback. My current implementation is showing strange behavior.
It seems that the callback order is (at least functionally) different when using gradient accumulation !=1. This is expected, but it's unclear how to handle both cases with a single callback.
My ask
I'd really appreciate help coming up with an almost-correct implementation for tracking (1) forward pass time, (2) backward pass time, (3) total time for an update (forward + backward + optimizer step) which might depend on gradient accumulation and (4) amount of time spent waiting on the dataloader to generate the next batch. Alternatively, for (3) I'm happy to track only optimizer.step time since this should tell me how long it's taking for devices to sync and gradients to update. I'm open to tracking related metrics or other metrics entirely as long as they're correlated with model throughput/performance.
In addition, I'm wondering how these metrics should be logged. i.e. should I set sync_dist=False since I only care about logging these metrics for training. Should I remove the rank_zero_only decorators? Any input is greatly appreciated.
Thank you!
NOTE: I already know that my implementation is not correct :).
How I'm currently implementing this
last updates per second: This is measured as one divided by the time between consecutive calls to
on_train_batch_end
.average updates per second: This is measured as number of calls to
on_train_batch_end
in the current epoch divided by the elapsed time sinceon_train_epoch_start
forward time: difference in time between
on_train_batch_start
andon_before_backwards.
backward time: difference in time between
on_before_backwards
andon_after_backwards
between step time: difference in time between
on_train_batch_end
andon_train_batch_start
(meant to capture time spent waiting on dataloader to generate next example). I realize there is other overhead geting tracked here but I couldn't figure out a better way.This is what the metrics look like in WandB
Note: both runs use gradient accumulation with value of 4.
Here is the implementation
Beta Was this translation helpful? Give feedback.
All reactions