Skip to content

Commit c57f619

Browse files
author
yicongd
committed
Merge branch 'master' of github.com:DuYicong515/pytorch-lightning
2 parents 33f6357 + bc463c6 commit c57f619

File tree

1 file changed

+5
-18
lines changed

1 file changed

+5
-18
lines changed

pytorch_lightning/utilities/memory.py

Lines changed: 5 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -17,28 +17,15 @@
1717
import os
1818
import shutil
1919
import subprocess
20-
from typing import Any, cast, Dict, IO
20+
from io import BytesIO
21+
from typing import Any, Dict
2122

2223
import torch
2324
from torch.nn import Module
2425

2526
from pytorch_lightning.utilities.apply_func import apply_to_collection
2627

2728

28-
class _ByteCounter:
29-
"""Accumulate and stores the total bytes of an object."""
30-
31-
def __init__(self) -> None:
32-
self.nbytes: int = 0
33-
34-
def write(self, data: bytes) -> None:
35-
"""Stores the total bytes of the data."""
36-
self.nbytes += len(data)
37-
38-
def flush(self) -> None:
39-
pass
40-
41-
4229
def recursive_detach(in_dict: Any, to_cpu: bool = False) -> Any:
4330
"""Detach all tensors in `in_dict`.
4431
@@ -183,7 +170,7 @@ def get_model_size_mb(model: Module) -> float:
183170
Returns:
184171
Number of megabytes in the parameters of the input module.
185172
"""
186-
model_size = _ByteCounter()
187-
torch.save(model.state_dict(), cast(IO[bytes], model_size))
188-
size_mb = model_size.nbytes / 1e6
173+
model_size = BytesIO()
174+
torch.save(model.state_dict(), model_size)
175+
size_mb = model_size.getbuffer().nbytes / 1e6
189176
return size_mb

0 commit comments

Comments
 (0)