Skip to content

Commit 345450b

Browse files
authored
Fix parameter count in ModelSummary when parameters are DTensors (#20163)
1 parent 3de60f4 commit 345450b

File tree

6 files changed

+52
-13
lines changed

6 files changed

+52
-13
lines changed

src/lightning/fabric/utilities/distributed.py

+12-1
Original file line numberDiff line numberDiff line change
@@ -14,10 +14,11 @@
1414
from lightning_utilities.core.imports import package_available
1515
from torch import Tensor
1616
from torch.utils.data import Dataset, DistributedSampler, Sampler
17-
from typing_extensions import Self, override
17+
from typing_extensions import Self, TypeGuard, override
1818

1919
from lightning.fabric.utilities.cloud_io import _is_local_file_protocol
2020
from lightning.fabric.utilities.data import _num_cpus_available
21+
from lightning.fabric.utilities.imports import _TORCH_GREATER_EQUAL_2_4
2122
from lightning.fabric.utilities.rank_zero import rank_zero_info
2223
from lightning.fabric.utilities.types import _PATH, ReduceOp
2324

@@ -30,6 +31,8 @@ class group: # type: ignore
3031

3132

3233
if TYPE_CHECKING:
34+
from torch.distributed._tensor import DTensor
35+
3336
from lightning.fabric.plugins import ClusterEnvironment
3437
from lightning.fabric.strategies import Strategy
3538

@@ -427,3 +430,11 @@ def __exit__(self, exc_type: Any, exc_value: Any, traceback: Any) -> None:
427430
self.barrier()
428431
if self.group is not None:
429432
torch.distributed.destroy_process_group(self.group)
433+
434+
435+
def _is_dtensor(tensor: Tensor) -> TypeGuard["DTensor"]:
436+
if _TORCH_GREATER_EQUAL_2_4:
437+
from torch.distributed._tensor import DTensor
438+
439+
return isinstance(tensor, DTensor)
440+
return False

src/lightning/pytorch/CHANGELOG.md

+1
Original file line numberDiff line numberDiff line change
@@ -49,6 +49,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
4949

5050
- Fixed `_LoggerConnector`'s `_ResultMetric` to move all registered keys to the device of the logged value if needed ([#19814](https://github.com/Lightning-AI/pytorch-lightning/issues/19814))
5151

52+
- Fixed parameter counts in `ModelSummary` when model has distributed parameters (DTensor) ([#20163](https://github.com/Lightning-AI/pytorch-lightning/pull/20163))
5253

5354

5455
## [2.3.0] - 2024-06-13

src/lightning/pytorch/utilities/model_summary/model_summary.py

+7-7
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,7 @@
2525
from torch.utils.hooks import RemovableHandle
2626

2727
import lightning.pytorch as pl
28+
from lightning.fabric.utilities.distributed import _is_dtensor
2829
from lightning.pytorch.utilities.model_helpers import _ModuleMode
2930
from lightning.pytorch.utilities.rank_zero import WarningCache
3031

@@ -135,7 +136,7 @@ def layer_type(self) -> str:
135136
@property
136137
def num_parameters(self) -> int:
137138
"""Returns the number of parameters in this module."""
138-
return sum(math.prod(p.shape) if not _is_lazy_weight_tensor(p) else 0 for p in self._module.parameters())
139+
return sum(p.numel() if not _tensor_has_shape(p) else 0 for p in self._module.parameters())
139140

140141
@property
141142
def training(self) -> bool:
@@ -264,13 +265,11 @@ def total_training_modes(self) -> Dict[str, int]:
264265

265266
@property
266267
def total_parameters(self) -> int:
267-
return sum(p.numel() if not _is_lazy_weight_tensor(p) else 0 for p in self._model.parameters())
268+
return sum(p.numel() if not _tensor_has_shape(p) else 0 for p in self._model.parameters())
268269

269270
@property
270271
def trainable_parameters(self) -> int:
271-
return sum(
272-
p.numel() if not _is_lazy_weight_tensor(p) else 0 for p in self._model.parameters() if p.requires_grad
273-
)
272+
return sum(p.numel() if not _tensor_has_shape(p) else 0 for p in self._model.parameters() if p.requires_grad)
274273

275274
@property
276275
def total_layer_params(self) -> int:
@@ -470,10 +469,11 @@ def get_human_readable_count(number: int) -> str:
470469
return f"{number:,.1f} {labels[index]}"
471470

472471

473-
def _is_lazy_weight_tensor(p: Tensor) -> bool:
472+
def _tensor_has_shape(p: Tensor) -> bool:
474473
from torch.nn.parameter import UninitializedParameter
475474

476-
if isinstance(p, UninitializedParameter):
475+
# DTensor is a subtype of `UninitializedParameter`, but the shape is known
476+
if isinstance(p, UninitializedParameter) and not _is_dtensor(p):
477477
warning_cache.warn(
478478
"The total number of parameters detected may be inaccurate because the model contains"
479479
" an instance of `UninitializedParameter`. To get an accurate number, set `self.example_input_array`"

src/lightning/pytorch/utilities/model_summary/model_summary_deepspeed.py

+5-5
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,7 @@
2525
NOT_APPLICABLE,
2626
LayerSummary,
2727
ModelSummary,
28-
_is_lazy_weight_tensor,
28+
_tensor_has_shape,
2929
get_human_readable_count,
3030
)
3131

@@ -40,7 +40,7 @@ class DeepSpeedLayerSummary(LayerSummary):
4040
@override
4141
def num_parameters(self) -> int:
4242
"""Returns the number of parameters in this module."""
43-
return sum(deepspeed_param_size(p) if not _is_lazy_weight_tensor(p) else 0 for p in self._module.parameters())
43+
return sum(deepspeed_param_size(p) if not _tensor_has_shape(p) else 0 for p in self._module.parameters())
4444

4545
@property
4646
def average_shard_parameters(self) -> int:
@@ -49,7 +49,7 @@ def average_shard_parameters(self) -> int:
4949
def partitioned_size(p: Parameter) -> int:
5050
return p.partitioned_size() if RequirementCache("deepspeed<0.6.6") else p.partition_numel()
5151

52-
return sum(partitioned_size(p) if not _is_lazy_weight_tensor(p) else 0 for p in self._module.parameters())
52+
return sum(partitioned_size(p) if not _tensor_has_shape(p) else 0 for p in self._module.parameters())
5353

5454

5555
class DeepSpeedSummary(ModelSummary):
@@ -71,13 +71,13 @@ def summarize(self) -> Dict[str, DeepSpeedLayerSummary]: # type: ignore[overrid
7171
@property
7272
@override
7373
def total_parameters(self) -> int:
74-
return sum(deepspeed_param_size(p) if not _is_lazy_weight_tensor(p) else 0 for p in self._model.parameters())
74+
return sum(deepspeed_param_size(p) if not _tensor_has_shape(p) else 0 for p in self._model.parameters())
7575

7676
@property
7777
@override
7878
def trainable_parameters(self) -> int:
7979
return sum(
80-
deepspeed_param_size(p) if not _is_lazy_weight_tensor(p) else 0
80+
deepspeed_param_size(p) if not _tensor_has_shape(p) else 0
8181
for p in self._model.parameters()
8282
if p.requires_grad
8383
)

tests/tests_fabric/utilities/test_distributed.py

+14
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,9 @@
33
from functools import partial
44
from pathlib import Path
55
from unittest import mock
6+
from unittest.mock import Mock
67

8+
import lightning.fabric
79
import pytest
810
import torch
911
from lightning.fabric.accelerators import CPUAccelerator, CUDAAccelerator, MPSAccelerator
@@ -15,6 +17,7 @@
1517
_gather_all_tensors,
1618
_InfiniteBarrier,
1719
_init_dist_connection,
20+
_is_dtensor,
1821
_set_num_threads_if_needed,
1922
_suggested_max_num_threads,
2023
_sync_ddp,
@@ -234,3 +237,14 @@ def test_init_dist_connection_registers_destruction_handler(_, atexit_mock):
234237
atexit_mock.reset_mock()
235238
_init_dist_connection(LightningEnvironment(), "gloo")
236239
atexit_mock.register.assert_not_called()
240+
241+
242+
@RunIf(min_torch="2.4")
243+
def test_is_dtensor(monkeypatch):
244+
from torch.distributed._tensor import DTensor
245+
246+
assert _is_dtensor(Mock(spec=DTensor))
247+
assert not _is_dtensor(torch.zeros(2, 2))
248+
249+
monkeypatch.setattr(lightning.fabric.utilities.distributed, "_TORCH_GREATER_EQUAL_2_4", False)
250+
assert not _is_dtensor(Mock(spec=DTensor))

tests/tests_pytorch/utilities/test_model_summary.py

+13
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@
1313
# limitations under the License.
1414
from collections import OrderedDict
1515
from typing import Any
16+
from unittest import mock
1617

1718
import pytest
1819
import torch
@@ -345,6 +346,18 @@ def test_lazy_model_summary():
345346
assert summary.trainable_parameters == 0
346347

347348

349+
@mock.patch("lightning.pytorch.utilities.model_summary.model_summary._is_dtensor", return_value=True)
350+
def test_dtensor_model_summary(_):
351+
"""Test that the model summary can work with layers that have DTensor parameters."""
352+
# We mock the `_is_dtensor` to pretend parameters are DTensors, because testing with real DTensors
353+
# would require setting up distributed
354+
dtensor_model = UnorderedModel()
355+
summary = ModelSummary(dtensor_model)
356+
assert summary.total_layer_params > 0
357+
assert summary.total_parameters > 0
358+
assert summary.trainable_parameters > 0
359+
360+
348361
@pytest.mark.parametrize("max_depth", [-1, 0, 1, 3, 999])
349362
def test_max_depth_param(max_depth):
350363
"""Test that only the modules up to the desired depth are shown."""

0 commit comments

Comments
 (0)