diff --git a/pytorch_lightning/loops/epoch/prediction_epoch_loop.py b/pytorch_lightning/loops/epoch/prediction_epoch_loop.py index 3fb49e7d4b1a0..14ec50c489a86 100644 --- a/pytorch_lightning/loops/epoch/prediction_epoch_loop.py +++ b/pytorch_lightning/loops/epoch/prediction_epoch_loop.py @@ -68,8 +68,9 @@ def on_run_start( # type: ignore[override] void(dataloader_iter, dataloader_idx) self._dl_max_batches = dl_max_batches self._num_dataloaders = num_dataloaders - self._seen_batch_indices = self._get_batch_indices(dataloader_idx) self.return_predictions = return_predictions + # this call requires that `self.return_predictions` is set + self._seen_batch_indices = self._get_batch_indices(dataloader_idx) def advance( # type: ignore[override] self,