-
Notifications
You must be signed in to change notification settings - Fork 3.5k
/
Copy pathmodule.py
1622 lines (1321 loc) · 68.9 KB
/
module.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
743
744
745
746
747
748
749
750
751
752
753
754
755
756
757
758
759
760
761
762
763
764
765
766
767
768
769
770
771
772
773
774
775
776
777
778
779
780
781
782
783
784
785
786
787
788
789
790
791
792
793
794
795
796
797
798
799
800
801
802
803
804
805
806
807
808
809
810
811
812
813
814
815
816
817
818
819
820
821
822
823
824
825
826
827
828
829
830
831
832
833
834
835
836
837
838
839
840
841
842
843
844
845
846
847
848
849
850
851
852
853
854
855
856
857
858
859
860
861
862
863
864
865
866
867
868
869
870
871
872
873
874
875
876
877
878
879
880
881
882
883
884
885
886
887
888
889
890
891
892
893
894
895
896
897
898
899
900
901
902
903
904
905
906
907
908
909
910
911
912
913
914
915
916
917
918
919
920
921
922
923
924
925
926
927
928
929
930
931
932
933
934
935
936
937
938
939
940
941
942
943
944
945
946
947
948
949
950
951
952
953
954
955
956
957
958
959
960
961
962
963
964
965
966
967
968
969
970
971
972
973
974
975
976
977
978
979
980
981
982
983
984
985
986
987
988
989
990
991
992
993
994
995
996
997
998
999
1000
# Copyright The Lightning AI team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""The LightningModule - an nn.Module with many additional features."""
import logging
import numbers
import weakref
from collections.abc import Generator, Mapping, Sequence
from contextlib import contextmanager
from io import BytesIO
from pathlib import Path
from typing import (
IO,
TYPE_CHECKING,
Any,
Callable,
Literal,
Optional,
Union,
cast,
overload,
)
import torch
from lightning_utilities.core.apply_func import apply_to_collection
from lightning_utilities.core.imports import RequirementCache
from torch import ScriptModule, Tensor
from torch.nn import Module
from torch.optim.optimizer import Optimizer
from torchmetrics import Metric, MetricCollection
from typing_extensions import Self, override
import lightning.fabric as lf
import lightning.pytorch as pl
from lightning.fabric.loggers import Logger as FabricLogger
from lightning.fabric.utilities.apply_func import convert_to_tensors
from lightning.fabric.utilities.cloud_io import get_filesystem
from lightning.fabric.utilities.device_dtype_mixin import _DeviceDtypeModuleMixin
from lightning.fabric.utilities.types import _MAP_LOCATION_TYPE, _PATH
from lightning.fabric.wrappers import _FabricOptimizer
from lightning.pytorch.callbacks.callback import Callback
from lightning.pytorch.core.hooks import CheckpointHooks, DataHooks, ModelHooks
from lightning.pytorch.core.mixins import HyperparametersMixin
from lightning.pytorch.core.optimizer import LightningOptimizer
from lightning.pytorch.core.saving import _load_from_checkpoint
from lightning.pytorch.loggers import Logger
from lightning.pytorch.trainer import call
from lightning.pytorch.trainer.connectors.logger_connector.fx_validator import _FxValidator
from lightning.pytorch.trainer.connectors.logger_connector.result import _get_default_dtype
from lightning.pytorch.utilities import GradClipAlgorithmType
from lightning.pytorch.utilities.exceptions import MisconfigurationException
from lightning.pytorch.utilities.imports import _TORCHMETRICS_GREATER_EQUAL_0_9_1
from lightning.pytorch.utilities.model_helpers import _restricted_classmethod
from lightning.pytorch.utilities.rank_zero import WarningCache, rank_zero_warn
from lightning.pytorch.utilities.signature_utils import is_param_in_hook_signature
from lightning.pytorch.utilities.types import (
_METRIC,
STEP_OUTPUT,
LRSchedulerPLType,
LRSchedulerTypeUnion,
OptimizerLRScheduler,
)
if TYPE_CHECKING:
from torch.distributed.device_mesh import DeviceMesh
_ONNX_AVAILABLE = RequirementCache("onnx")
warning_cache = WarningCache()
log = logging.getLogger(__name__)
MODULE_OPTIMIZERS = Union[
Optimizer, LightningOptimizer, _FabricOptimizer, list[Optimizer], list[LightningOptimizer], list[_FabricOptimizer]
]
class LightningModule(
_DeviceDtypeModuleMixin,
HyperparametersMixin,
ModelHooks,
DataHooks,
CheckpointHooks,
Module,
):
# Below is for property support of JIT
# since none of these are important when using JIT, we are going to ignore them.
__jit_unused_properties__: list[str] = (
[
"example_input_array",
"on_gpu",
"current_epoch",
"global_step",
"global_rank",
"local_rank",
"logger",
"loggers",
"automatic_optimization",
"trainer",
"fabric",
"strict_loading",
"device_mesh",
]
+ _DeviceDtypeModuleMixin.__jit_unused_properties__
+ HyperparametersMixin.__jit_unused_properties__
)
_jit_is_scripting = False
CHECKPOINT_HYPER_PARAMS_KEY = "hyper_parameters"
CHECKPOINT_HYPER_PARAMS_NAME = "hparams_name"
CHECKPOINT_HYPER_PARAMS_TYPE = "hparams_type"
def __init__(self, *args: Any, **kwargs: Any) -> None:
super().__init__(*args, **kwargs)
# pointer to the trainer object
self._trainer: Optional[pl.Trainer] = None
# attributes that can be set by user
self._example_input_array: Optional[Union[Tensor, tuple, dict]] = None
self._automatic_optimization: bool = True
self._strict_loading: Optional[bool] = None
# attributes used internally
self._current_fx_name: Optional[str] = None
self._param_requires_grad_state: dict[str, bool] = {}
self._metric_attributes: Optional[dict[int, str]] = None
self._compiler_ctx: Optional[dict[str, Any]] = None
# attributes only used when using fabric
self._fabric: Optional[lf.Fabric] = None
self._fabric_optimizers: list[_FabricOptimizer] = []
# access to device mesh in `conigure_model()` hook
self._device_mesh: Optional[DeviceMesh] = None
@overload
def optimizers(
self, use_pl_optimizer: Literal[True] = True
) -> Union[LightningOptimizer, list[LightningOptimizer]]: ...
@overload
def optimizers(self, use_pl_optimizer: Literal[False]) -> Union[Optimizer, list[Optimizer]]: ...
@overload
def optimizers(self, use_pl_optimizer: bool) -> MODULE_OPTIMIZERS: ...
def optimizers(self, use_pl_optimizer: bool = True) -> MODULE_OPTIMIZERS:
"""Returns the optimizer(s) that are being used during training. Useful for manual optimization.
Args:
use_pl_optimizer: If ``True``, will wrap the optimizer(s) in a
:class:`~lightning.pytorch.core.optimizer.LightningOptimizer` for automatic handling of precision,
profiling, and counting of step calls for proper logging and checkpointing. It specifically wraps the
``step`` method and custom optimizers that don't have this method are not supported.
Returns:
A single optimizer, or a list of optimizers in case multiple ones are present.
"""
if self._fabric:
opts: MODULE_OPTIMIZERS = self._fabric_optimizers
elif use_pl_optimizer:
opts = self.trainer.strategy._lightning_optimizers
else:
opts = self.trainer.optimizers
# single optimizer
if (
isinstance(opts, list)
and len(opts) == 1
and isinstance(opts[0], (Optimizer, LightningOptimizer, _FabricOptimizer))
):
return opts[0]
# multiple opts
return opts
def lr_schedulers(self) -> Union[None, list[LRSchedulerPLType], LRSchedulerPLType]:
"""Returns the learning rate scheduler(s) that are being used during training. Useful for manual optimization.
Returns:
A single scheduler, or a list of schedulers in case multiple ones are present, or ``None`` if no
schedulers were returned in :meth:`~lightning.pytorch.core.LightningModule.configure_optimizers`.
"""
if not self.trainer.lr_scheduler_configs:
return None
# ignore other keys "interval", "frequency", etc.
lr_schedulers: list[LRSchedulerPLType] = [config.scheduler for config in self.trainer.lr_scheduler_configs]
# single scheduler
if len(lr_schedulers) == 1:
return lr_schedulers[0]
# multiple schedulers
return lr_schedulers
@property
def trainer(self) -> "pl.Trainer":
if self._fabric is not None:
return _TrainerFabricShim(fabric=self._fabric) # type: ignore[return-value]
if not self._jit_is_scripting and self._trainer is None:
raise RuntimeError(f"{self.__class__.__qualname__} is not attached to a `Trainer`.")
return self._trainer # type: ignore[return-value]
@trainer.setter
def trainer(self, trainer: Optional["pl.Trainer"]) -> None:
for v in self.children():
if isinstance(v, LightningModule):
v.trainer = trainer # type: ignore[assignment]
self._trainer = trainer
@property
def fabric(self) -> Optional["lf.Fabric"]:
return self._fabric
@fabric.setter
def fabric(self, fabric: Optional["lf.Fabric"]) -> None:
for v in self.children():
if isinstance(v, LightningModule):
v.fabric = fabric
if fabric is not None and not isinstance(fabric, weakref.ProxyTypes):
fabric = weakref.proxy(fabric)
self._fabric = fabric
@property
def example_input_array(self) -> Optional[Union[Tensor, tuple, dict]]:
"""The example input array is a specification of what the module can consume in the :meth:`forward` method. The
return type is interpreted as follows:
- Single tensor: It is assumed the model takes a single argument, i.e.,
``model.forward(model.example_input_array)``
- Tuple: The input array should be interpreted as a sequence of positional arguments, i.e.,
``model.forward(*model.example_input_array)``
- Dict: The input array represents named keyword arguments, i.e.,
``model.forward(**model.example_input_array)``
"""
return self._example_input_array
@example_input_array.setter
def example_input_array(self, example: Optional[Union[Tensor, tuple, dict]]) -> None:
self._example_input_array = example
@property
def current_epoch(self) -> int:
"""The current epoch in the ``Trainer``, or 0 if not attached."""
return self.trainer.current_epoch if self._trainer else 0
@property
def global_step(self) -> int:
"""Total training batches seen across all epochs.
If no Trainer is attached, this propery is 0.
"""
return self.trainer.global_step if self._trainer else 0
@property
def global_rank(self) -> int:
"""The index of the current process across all nodes and devices."""
return self.trainer.global_rank if self._trainer else 0
@property
def local_rank(self) -> int:
"""The index of the current process within a single node."""
return self.trainer.local_rank if self._trainer else 0
@property
def on_gpu(self) -> bool:
"""Returns ``True`` if this model is currently located on a GPU.
Useful to set flags around the LightningModule for different CPU vs GPU behavior.
"""
return self.device.type == "cuda"
@property
def automatic_optimization(self) -> bool:
"""If set to ``False`` you are responsible for calling ``.backward()``, ``.step()``, ``.zero_grad()``."""
return self._automatic_optimization
@automatic_optimization.setter
def automatic_optimization(self, automatic_optimization: bool) -> None:
self._automatic_optimization = automatic_optimization
@property
def strict_loading(self) -> bool:
"""Determines how Lightning loads this model using `.load_state_dict(..., strict=model.strict_loading)`."""
# We use None as the default internally to determine whether the user has set a value
return self._strict_loading in (None, True)
@strict_loading.setter
def strict_loading(self, strict_loading: bool) -> None:
self._strict_loading = strict_loading
@property
def logger(self) -> Optional[Union[Logger, FabricLogger]]:
"""Reference to the logger object in the Trainer."""
if self._fabric is not None:
return self._fabric.logger
return self._trainer.logger if self._trainer is not None else None
@property
def loggers(self) -> Union[list[Logger], list[FabricLogger]]:
"""Reference to the list of loggers in the Trainer."""
if self._fabric is not None:
return self._fabric.loggers
if self._trainer is not None:
return self._trainer.loggers
return []
@property
def device_mesh(self) -> Optional["DeviceMesh"]:
"""Strategies like ``ModelParallelStrategy`` will create a device mesh that can be accessed in the
:meth:`~lightning.pytorch.core.hooks.ModelHooks.configure_model` hook to parallelize the LightningModule."""
return self._device_mesh
def _call_batch_hook(self, hook_name: str, *args: Any) -> Any:
trainer = self._trainer
if trainer:
datahook_selector = trainer._data_connector._datahook_selector
assert datahook_selector is not None
obj = datahook_selector.get_instance(hook_name)
if isinstance(obj, self.__class__):
trainer_method = call._call_lightning_module_hook
else:
trainer_method = call._call_lightning_datamodule_hook
return trainer_method(trainer, hook_name, *args)
hook = getattr(self, hook_name)
return hook(*args)
def _on_before_batch_transfer(self, batch: Any, dataloader_idx: int = 0) -> Any:
return self._call_batch_hook("on_before_batch_transfer", batch, dataloader_idx)
def _apply_batch_transfer_handler(
self, batch: Any, device: Optional[torch.device] = None, dataloader_idx: int = 0
) -> Any:
device = device or self.device
batch = self._call_batch_hook("transfer_batch_to_device", batch, device, dataloader_idx)
batch = self._call_batch_hook("on_after_batch_transfer", batch, dataloader_idx)
return batch
def print(self, *args: Any, **kwargs: Any) -> None:
r"""Prints only from process 0. Use this in any distributed mode to log only once.
Args:
*args: The thing to print. The same as for Python's built-in print function.
**kwargs: The same as for Python's built-in print function.
Example::
def forward(self, x):
self.print(x, 'in forward')
"""
if self.trainer.is_global_zero:
progress_bar = self.trainer.progress_bar_callback
if progress_bar is not None and progress_bar.is_enabled:
progress_bar.print(*args, **kwargs)
else:
print(*args, **kwargs)
def log(
self,
name: str,
value: _METRIC,
prog_bar: bool = False,
logger: Optional[bool] = None,
on_step: Optional[bool] = None,
on_epoch: Optional[bool] = None,
reduce_fx: Union[str, Callable] = "mean",
enable_graph: bool = False,
sync_dist: bool = False,
sync_dist_group: Optional[Any] = None,
add_dataloader_idx: bool = True,
batch_size: Optional[int] = None,
metric_attribute: Optional[str] = None,
rank_zero_only: bool = False,
) -> None:
"""Log a key, value pair.
Example::
self.log('train_loss', loss)
The default behavior per hook is documented here: :ref:`extensions/logging:Automatic Logging`.
Args:
name: key to log. Must be identical across all processes if using DDP or any other distributed strategy.
value: value to log. Can be a ``float``, ``Tensor``, or a ``Metric``.
prog_bar: if ``True`` logs to the progress bar.
logger: if ``True`` logs to the logger.
on_step: if ``True`` logs at this step. The default value is determined by the hook.
See :ref:`extensions/logging:Automatic Logging` for details.
on_epoch: if ``True`` logs epoch accumulated metrics. The default value is determined by the hook.
See :ref:`extensions/logging:Automatic Logging` for details.
reduce_fx: reduction function over step values for end of epoch. :meth:`torch.mean` by default.
enable_graph: if ``True``, will not auto detach the graph.
sync_dist: if ``True``, reduces the metric across devices. Use with care as this may lead to a significant
communication overhead.
sync_dist_group: the DDP group to sync across.
add_dataloader_idx: if ``True``, appends the index of the current dataloader to
the name (when using multiple dataloaders). If False, user needs to give unique names for
each dataloader to not mix the values.
batch_size: Current batch_size. This will be directly inferred from the loaded batch,
but for some data structures you might need to explicitly provide it.
metric_attribute: To restore the metric state, Lightning requires the reference of the
:class:`torchmetrics.Metric` in your model. This is found automatically if it is a model attribute.
rank_zero_only: Tells Lightning if you are calling ``self.log`` from every process (default) or only from
rank 0. If ``True``, you won't be able to use this metric as a monitor in callbacks
(e.g., early stopping). Warning: Improper use can lead to deadlocks! See
:ref:`Advanced Logging <visualize/logging_advanced:rank_zero_only>` for more details.
"""
if self._fabric is not None:
self._log_dict_through_fabric(dictionary={name: value}, logger=logger)
return
# check for invalid values
apply_to_collection(value, dict, self.__check_not_nested, name)
apply_to_collection(
value, object, self.__check_allowed, name, value, wrong_dtype=(numbers.Number, Metric, Tensor)
)
trainer = self._trainer
if trainer is None:
# not an error to support testing the `*_step` methods without a `Trainer` reference
rank_zero_warn(
"You are trying to `self.log()` but the `self.trainer` reference is not registered on the model yet."
" This is most likely because the model hasn't been passed to the `Trainer`"
)
return
if trainer.barebones:
rank_zero_warn(
"You are trying to `self.log()` but `Trainer(barebones=True)` is configured."
" Logging can impact raw speed so it is disabled under this setting."
)
return
results = trainer._results
if results is None:
raise MisconfigurationException(
"You are trying to `self.log()` but the loop's result collection is not registered"
" yet. This is most likely because you are trying to log in a `predict` hook,"
" but it doesn't support logging"
)
if self._current_fx_name is None:
raise MisconfigurationException(
"You are trying to `self.log()` but it is not managed by the `Trainer` control flow"
)
on_step, on_epoch = _FxValidator.check_logging_and_get_default_levels(
self._current_fx_name, on_step=on_step, on_epoch=on_epoch
)
# make sure user doesn't introduce logic for multi-dataloaders
if "/dataloader_idx_" in name:
raise MisconfigurationException(
f"You called `self.log` with the key `{name}`"
" but it should not contain information about `dataloader_idx`"
)
value = apply_to_collection(value, (Tensor, numbers.Number), self.__to_tensor, name)
if trainer._logger_connector.should_reset_tensors(self._current_fx_name):
# if we started a new epoch (running its first batch) the hook name has changed
# reset any tensors for the new hook name
results.reset(metrics=False, fx=self._current_fx_name)
if metric_attribute is None and isinstance(value, Metric):
if self._metric_attributes is None:
# compute once
self._metric_attributes = {
id(module): name for name, module in self.named_modules() if isinstance(module, Metric)
}
if not self._metric_attributes:
raise MisconfigurationException(
"Could not find the `LightningModule` attribute for the `torchmetrics.Metric` logged."
" You can fix this by setting an attribute for the metric in your `LightningModule`."
)
# try to find the passed metric in the LightningModule
metric_attribute = self._metric_attributes.get(id(value), None)
if metric_attribute is None:
raise MisconfigurationException(
"Could not find the `LightningModule` attribute for the `torchmetrics.Metric` logged."
f" You can fix this by calling `self.log({name}, ..., metric_attribute=name)` where `name` is one"
f" of {list(self._metric_attributes.values())}"
)
if (
trainer.training
and is_param_in_hook_signature(self.training_step, "dataloader_iter", explicit=True)
and batch_size is None
):
raise MisconfigurationException(
"With `def training_step(self, dataloader_iter)`, `self.log(..., batch_size=...)` should be provided."
)
if logger and trainer.logger is None:
rank_zero_warn(
f"You called `self.log({name!r}, ..., logger=True)` but have no logger configured. You can enable one"
" by doing `Trainer(logger=ALogger(...))`"
)
if logger is None:
# we could set false here if there's no configured logger, however, we still need to compute the "logged"
# metrics anyway because that's what the evaluation loops use as return value
logger = True
results.log(
self._current_fx_name,
name,
value,
prog_bar=prog_bar,
logger=logger,
on_step=on_step,
on_epoch=on_epoch,
reduce_fx=reduce_fx,
enable_graph=enable_graph,
add_dataloader_idx=add_dataloader_idx,
batch_size=batch_size,
sync_dist=sync_dist and trainer._accelerator_connector.is_distributed,
sync_dist_fn=trainer.strategy.reduce,
sync_dist_group=sync_dist_group,
metric_attribute=metric_attribute,
rank_zero_only=rank_zero_only,
)
trainer._logger_connector._current_fx = self._current_fx_name
def log_dict(
self,
dictionary: Union[Mapping[str, _METRIC], MetricCollection],
prog_bar: bool = False,
logger: Optional[bool] = None,
on_step: Optional[bool] = None,
on_epoch: Optional[bool] = None,
reduce_fx: Union[str, Callable] = "mean",
enable_graph: bool = False,
sync_dist: bool = False,
sync_dist_group: Optional[Any] = None,
add_dataloader_idx: bool = True,
batch_size: Optional[int] = None,
rank_zero_only: bool = False,
) -> None:
"""Log a dictionary of values at once.
Example::
values = {'loss': loss, 'acc': acc, ..., 'metric_n': metric_n}
self.log_dict(values)
Args:
dictionary: key value pairs.
Keys must be identical across all processes if using DDP or any other distributed strategy.
The values can be a ``float``, ``Tensor``, ``Metric``, or ``MetricCollection``.
prog_bar: if ``True`` logs to the progress base.
logger: if ``True`` logs to the logger.
on_step: if ``True`` logs at this step.
``None`` auto-logs for training_step but not validation/test_step.
The default value is determined by the hook.
See :ref:`extensions/logging:Automatic Logging` for details.
on_epoch: if ``True`` logs epoch accumulated metrics.
``None`` auto-logs for val/test step but not ``training_step``.
The default value is determined by the hook.
See :ref:`extensions/logging:Automatic Logging` for details.
reduce_fx: reduction function over step values for end of epoch. :meth:`torch.mean` by default.
enable_graph: if ``True``, will not auto-detach the graph
sync_dist: if ``True``, reduces the metric across GPUs/TPUs. Use with care as this may lead to a significant
communication overhead.
sync_dist_group: the ddp group to sync across.
add_dataloader_idx: if ``True``, appends the index of the current dataloader to
the name (when using multiple). If ``False``, user needs to give unique names for
each dataloader to not mix values.
batch_size: Current batch size. This will be directly inferred from the loaded batch,
but some data structures might need to explicitly provide it.
rank_zero_only: Tells Lightning if you are calling ``self.log`` from every process (default) or only from
rank 0. If ``True``, you won't be able to use this metric as a monitor in callbacks
(e.g., early stopping). Warning: Improper use can lead to deadlocks! See
:ref:`Advanced Logging <visualize/logging_advanced:rank_zero_only>` for more details.
"""
if self._fabric is not None:
return self._log_dict_through_fabric(dictionary=dictionary, logger=logger)
kwargs: dict[str, bool] = {}
if isinstance(dictionary, MetricCollection):
kwargs["keep_base"] = False
if _TORCHMETRICS_GREATER_EQUAL_0_9_1 and dictionary._enable_compute_groups:
kwargs["copy_state"] = False
for k, v in dictionary.items(**kwargs):
self.log(
name=k,
value=v,
prog_bar=prog_bar,
logger=logger,
on_step=on_step,
on_epoch=on_epoch,
reduce_fx=reduce_fx,
enable_graph=enable_graph,
sync_dist=sync_dist,
sync_dist_group=sync_dist_group,
add_dataloader_idx=add_dataloader_idx,
batch_size=batch_size,
rank_zero_only=rank_zero_only,
)
return None
def _log_dict_through_fabric(
self, dictionary: Union[Mapping[str, _METRIC], MetricCollection], logger: Optional[bool] = None
) -> None:
if logger is False:
# Passing `logger=False` with Fabric does not make much sense because there is no other destination to
# log to, but we support it in case the original code was written for Trainer use
return
if any(isinstance(v, dict) for v in dictionary.values()):
raise ValueError(f"`self.log_dict({dictionary})` was called, but nested dictionaries cannot be logged")
for name, value in dictionary.items():
apply_to_collection(value, object, self.__check_allowed, name, value, wrong_dtype=(numbers.Number, Tensor))
assert self._fabric is not None
self._fabric.log_dict(metrics=dictionary) # type: ignore[arg-type]
@staticmethod
def __check_not_nested(value: dict, name: str) -> None:
# self-imposed restriction. for simplicity
if any(isinstance(v, dict) for v in value.values()):
raise ValueError(f"`self.log({name}, {value})` was called, but nested dictionaries cannot be logged")
@staticmethod
def __check_allowed(v: Any, name: str, value: Any) -> None:
raise ValueError(f"`self.log({name}, {value})` was called, but `{type(v).__name__}` values cannot be logged")
def __to_tensor(self, value: Union[Tensor, numbers.Number], name: str) -> Tensor:
value = (
value.clone().detach()
if isinstance(value, Tensor)
else torch.tensor(value, device=self.device, dtype=_get_default_dtype())
)
if not torch.numel(value) == 1:
raise ValueError(
f"`self.log({name}, {value})` was called, but the tensor must have a single element."
f" You can try doing `self.log({name}, {value}.mean())`"
)
value = value.squeeze()
return value
def all_gather(
self, data: Union[Tensor, dict, list, tuple], group: Optional[Any] = None, sync_grads: bool = False
) -> Union[Tensor, dict, list, tuple]:
r"""Gather tensors or collections of tensors from multiple processes.
This method needs to be called on all processes and the tensors need to have the same shape across all
processes, otherwise your program will stall forever.
Args:
data: int, float, tensor of shape (batch, ...), or a (possibly nested) collection thereof.
group: the process group to gather results from. Defaults to all processes (world)
sync_grads: flag that allows users to synchronize gradients for the all_gather operation
Return:
A tensor of shape (world_size, batch, ...), or if the input was a collection
the output will also be a collection with tensors of this shape. For the special case where
world_size is 1, no additional dimension is added to the tensor(s).
"""
group = group if group is not None else torch.distributed.group.WORLD
all_gather = self.trainer.strategy.all_gather
data = convert_to_tensors(data, device=self.device)
return apply_to_collection(data, Tensor, all_gather, group=group, sync_grads=sync_grads)
@override
def forward(self, *args: Any, **kwargs: Any) -> Any:
r"""Same as :meth:`torch.nn.Module.forward`.
Args:
*args: Whatever you decide to pass into the forward method.
**kwargs: Keyword arguments are also possible.
Return:
Your model's output
"""
return super().forward(*args, **kwargs)
def training_step(self, *args: Any, **kwargs: Any) -> STEP_OUTPUT:
r"""Here you compute and return the training loss and some additional metrics for e.g. the progress bar or
logger.
Args:
batch: The output of your data iterable, normally a :class:`~torch.utils.data.DataLoader`.
batch_idx: The index of this batch.
dataloader_idx: The index of the dataloader that produced this batch.
(only if multiple dataloaders used)
Return:
- :class:`~torch.Tensor` - The loss tensor
- ``dict`` - A dictionary which can include any keys, but must include the key ``'loss'`` in the case of
automatic optimization.
- ``None`` - In automatic optimization, this will skip to the next batch (but is not supported for
multi-GPU, TPU, or DeepSpeed). For manual optimization, this has no special meaning, as returning
the loss is not required.
In this step you'd normally do the forward pass and calculate the loss for a batch.
You can also do fancier things like multiple forward passes or something model specific.
Example::
def training_step(self, batch, batch_idx):
x, y, z = batch
out = self.encoder(x)
loss = self.loss(out, x)
return loss
To use multiple optimizers, you can switch to 'manual optimization' and control their stepping:
.. code-block:: python
def __init__(self):
super().__init__()
self.automatic_optimization = False
# Multiple optimizers (e.g.: GANs)
def training_step(self, batch, batch_idx):
opt1, opt2 = self.optimizers()
# do training_step with encoder
...
opt1.step()
# do training_step with decoder
...
opt2.step()
Note:
When ``accumulate_grad_batches`` > 1, the loss returned here will be automatically
normalized by ``accumulate_grad_batches`` internally.
"""
rank_zero_warn("`training_step` must be implemented to be used with the Lightning Trainer")
def validation_step(self, *args: Any, **kwargs: Any) -> STEP_OUTPUT:
r"""Operates on a single batch of data from the validation set. In this step you'd might generate examples or
calculate anything of interest like accuracy.
Args:
batch: The output of your data iterable, normally a :class:`~torch.utils.data.DataLoader`.
batch_idx: The index of this batch.
dataloader_idx: The index of the dataloader that produced this batch.
(only if multiple dataloaders used)
Return:
- :class:`~torch.Tensor` - The loss tensor
- ``dict`` - A dictionary. Can include any keys, but must include the key ``'loss'``.
- ``None`` - Skip to the next batch.
.. code-block:: python
# if you have one val dataloader:
def validation_step(self, batch, batch_idx): ...
# if you have multiple val dataloaders:
def validation_step(self, batch, batch_idx, dataloader_idx=0): ...
Examples::
# CASE 1: A single validation dataset
def validation_step(self, batch, batch_idx):
x, y = batch
# implement your own
out = self(x)
loss = self.loss(out, y)
# log 6 example images
# or generated text... or whatever
sample_imgs = x[:6]
grid = torchvision.utils.make_grid(sample_imgs)
self.logger.experiment.add_image('example_images', grid, 0)
# calculate acc
labels_hat = torch.argmax(out, dim=1)
val_acc = torch.sum(y == labels_hat).item() / (len(y) * 1.0)
# log the outputs!
self.log_dict({'val_loss': loss, 'val_acc': val_acc})
If you pass in multiple val dataloaders, :meth:`validation_step` will have an additional argument. We recommend
setting the default value of 0 so that you can quickly switch between single and multiple dataloaders.
.. code-block:: python
# CASE 2: multiple validation dataloaders
def validation_step(self, batch, batch_idx, dataloader_idx=0):
# dataloader_idx tells you which dataset this is.
...
Note:
If you don't need to validate you don't need to implement this method.
Note:
When the :meth:`validation_step` is called, the model has been put in eval mode
and PyTorch gradients have been disabled. At the end of validation,
the model goes back to training mode and gradients are enabled.
"""
def test_step(self, *args: Any, **kwargs: Any) -> STEP_OUTPUT:
r"""Operates on a single batch of data from the test set. In this step you'd normally generate examples or
calculate anything of interest such as accuracy.
Args:
batch: The output of your data iterable, normally a :class:`~torch.utils.data.DataLoader`.
batch_idx: The index of this batch.
dataloader_idx: The index of the dataloader that produced this batch.
(only if multiple dataloaders used)
Return:
- :class:`~torch.Tensor` - The loss tensor
- ``dict`` - A dictionary. Can include any keys, but must include the key ``'loss'``.
- ``None`` - Skip to the next batch.
.. code-block:: python
# if you have one test dataloader:
def test_step(self, batch, batch_idx): ...
# if you have multiple test dataloaders:
def test_step(self, batch, batch_idx, dataloader_idx=0): ...
Examples::
# CASE 1: A single test dataset
def test_step(self, batch, batch_idx):
x, y = batch
# implement your own
out = self(x)
loss = self.loss(out, y)
# log 6 example images
# or generated text... or whatever
sample_imgs = x[:6]
grid = torchvision.utils.make_grid(sample_imgs)
self.logger.experiment.add_image('example_images', grid, 0)
# calculate acc
labels_hat = torch.argmax(out, dim=1)
test_acc = torch.sum(y == labels_hat).item() / (len(y) * 1.0)
# log the outputs!
self.log_dict({'test_loss': loss, 'test_acc': test_acc})
If you pass in multiple test dataloaders, :meth:`test_step` will have an additional argument. We recommend
setting the default value of 0 so that you can quickly switch between single and multiple dataloaders.
.. code-block:: python
# CASE 2: multiple test dataloaders
def test_step(self, batch, batch_idx, dataloader_idx=0):
# dataloader_idx tells you which dataset this is.
...
Note:
If you don't need to test you don't need to implement this method.
Note:
When the :meth:`test_step` is called, the model has been put in eval mode and
PyTorch gradients have been disabled. At the end of the test epoch, the model goes back
to training mode and gradients are enabled.
"""
def predict_step(self, *args: Any, **kwargs: Any) -> Any:
"""Step function called during :meth:`~lightning.pytorch.trainer.trainer.Trainer.predict`. By default, it calls
:meth:`~lightning.pytorch.core.LightningModule.forward`. Override to add any processing logic.
The :meth:`~lightning.pytorch.core.LightningModule.predict_step` is used
to scale inference on multi-devices.
To prevent an OOM error, it is possible to use :class:`~lightning.pytorch.callbacks.BasePredictionWriter`
callback to write the predictions to disk or database after each batch or on epoch end.
The :class:`~lightning.pytorch.callbacks.BasePredictionWriter` should be used while using a spawn
based accelerator. This happens for ``Trainer(strategy="ddp_spawn")``
or training on 8 TPU cores with ``Trainer(accelerator="tpu", devices=8)`` as predictions won't be returned.
Args:
batch: The output of your data iterable, normally a :class:`~torch.utils.data.DataLoader`.
batch_idx: The index of this batch.
dataloader_idx: The index of the dataloader that produced this batch.
(only if multiple dataloaders used)
Return:
Predicted output (optional).
Example ::
class MyModel(LightningModule):
def predict_step(self, batch, batch_idx, dataloader_idx=0):
return self(batch)
dm = ...
model = MyModel()
trainer = Trainer(accelerator="gpu", devices=2)
predictions = trainer.predict(model, dm)
"""
# For backwards compatibility
batch = kwargs.get("batch", args[0])
return self(batch)
def configure_callbacks(self) -> Union[Sequence[Callback], Callback]:
"""Configure model-specific callbacks. When the model gets attached, e.g., when ``.fit()`` or ``.test()`` gets
called, the list or a callback returned here will be merged with the list of callbacks passed to the Trainer's
``callbacks`` argument. If a callback returned here has the same type as one or several callbacks already
present in the Trainer's callbacks list, it will take priority and replace them. In addition, Lightning will
make sure :class:`~lightning.pytorch.callbacks.model_checkpoint.ModelCheckpoint` callbacks run last.
Return:
A callback or a list of callbacks which will extend the list of callbacks in the Trainer.
Example::
def configure_callbacks(self):
early_stop = EarlyStopping(monitor="val_acc", mode="max")
checkpoint = ModelCheckpoint(monitor="val_loss")
return [early_stop, checkpoint]
"""
return []
def configure_optimizers(self) -> OptimizerLRScheduler:
r"""Choose what optimizers and learning-rate schedulers to use in your optimization. Normally you'd need one.
But in the case of GANs or similar you might have multiple. Optimization with multiple optimizers only works in
the manual optimization mode.
Return:
Any of these 6 options.
- **Single optimizer**.
- **List or Tuple** of optimizers.
- **Two lists** - The first list has multiple optimizers, and the second has multiple LR schedulers
(or multiple ``lr_scheduler_config``).
- **Dictionary**, with an ``"optimizer"`` key, and (optionally) a ``"lr_scheduler"``
key whose value is a single LR scheduler or ``lr_scheduler_config``.
- **None** - Fit will run without any optimizer.
The ``lr_scheduler_config`` is a dictionary which contains the scheduler and its associated configuration.
The default configuration is shown below.
.. code-block:: python
lr_scheduler_config = {
# REQUIRED: The scheduler instance
"scheduler": lr_scheduler,
# The unit of the scheduler's step size, could also be 'step'.
# 'epoch' updates the scheduler on epoch end whereas 'step'
# updates it after a optimizer update.
"interval": "epoch",
# How many epochs/steps should pass between calls to
# `scheduler.step()`. 1 corresponds to updating the learning
# rate after every epoch/step.
"frequency": 1,
# Metric to monitor for schedulers like `ReduceLROnPlateau`
"monitor": "val_loss",
# If set to `True`, will enforce that the value specified 'monitor'
# is available when the scheduler is updated, thus stopping
# training if not found. If set to `False`, it will only produce a warning
"strict": True,
# If using the `LearningRateMonitor` callback to monitor the
# learning rate progress, this keyword can be used to specify
# a custom logged name
"name": None,
}
When there are schedulers in which the ``.step()`` method is conditioned on a value, such as the
:class:`torch.optim.lr_scheduler.ReduceLROnPlateau` scheduler, Lightning requires that the
``lr_scheduler_config`` contains the keyword ``"monitor"`` set to the metric name that the scheduler
should be conditioned on.
.. testcode::