|
25 | 25 | from torch.utils.hooks import RemovableHandle
|
26 | 26 |
|
27 | 27 | import lightning.pytorch as pl
|
| 28 | +from lightning.fabric.utilities.distributed import _is_dtensor |
28 | 29 | from lightning.pytorch.utilities.model_helpers import _ModuleMode
|
29 | 30 | from lightning.pytorch.utilities.rank_zero import WarningCache
|
30 | 31 |
|
@@ -135,7 +136,7 @@ def layer_type(self) -> str:
|
135 | 136 | @property
|
136 | 137 | def num_parameters(self) -> int:
|
137 | 138 | """Returns the number of parameters in this module."""
|
138 |
| - return sum(math.prod(p.shape) if not _is_lazy_weight_tensor(p) else 0 for p in self._module.parameters()) |
| 139 | + return sum(p.numel() if not _tensor_has_shape(p) else 0 for p in self._module.parameters()) |
139 | 140 |
|
140 | 141 | @property
|
141 | 142 | def training(self) -> bool:
|
@@ -264,13 +265,11 @@ def total_training_modes(self) -> Dict[str, int]:
|
264 | 265 |
|
265 | 266 | @property
|
266 | 267 | def total_parameters(self) -> int:
|
267 |
| - return sum(p.numel() if not _is_lazy_weight_tensor(p) else 0 for p in self._model.parameters()) |
| 268 | + return sum(p.numel() if not _tensor_has_shape(p) else 0 for p in self._model.parameters()) |
268 | 269 |
|
269 | 270 | @property
|
270 | 271 | def trainable_parameters(self) -> int:
|
271 |
| - return sum( |
272 |
| - p.numel() if not _is_lazy_weight_tensor(p) else 0 for p in self._model.parameters() if p.requires_grad |
273 |
| - ) |
| 272 | + return sum(p.numel() if not _tensor_has_shape(p) else 0 for p in self._model.parameters() if p.requires_grad) |
274 | 273 |
|
275 | 274 | @property
|
276 | 275 | def total_layer_params(self) -> int:
|
@@ -470,10 +469,11 @@ def get_human_readable_count(number: int) -> str:
|
470 | 469 | return f"{number:,.1f} {labels[index]}"
|
471 | 470 |
|
472 | 471 |
|
473 |
| -def _is_lazy_weight_tensor(p: Tensor) -> bool: |
| 472 | +def _tensor_has_shape(p: Tensor) -> bool: |
474 | 473 | from torch.nn.parameter import UninitializedParameter
|
475 | 474 |
|
476 |
| - if isinstance(p, UninitializedParameter): |
| 475 | + # DTensor is a subtype of `UninitializedParameter`, but the shape is known |
| 476 | + if isinstance(p, UninitializedParameter) and not _is_dtensor(p): |
477 | 477 | warning_cache.warn(
|
478 | 478 | "The total number of parameters detected may be inaccurate because the model contains"
|
479 | 479 | " an instance of `UninitializedParameter`. To get an accurate number, set `self.example_input_array`"
|
|
0 commit comments