File tree 1 file changed +5
-18
lines changed
pytorch_lightning/utilities 1 file changed +5
-18
lines changed Original file line number Diff line number Diff line change 17
17
import os
18
18
import shutil
19
19
import subprocess
20
- from typing import Any , cast , Dict , IO
20
+ from io import BytesIO
21
+ from typing import Any , Dict
21
22
22
23
import torch
23
24
from torch .nn import Module
24
25
25
26
from pytorch_lightning .utilities .apply_func import apply_to_collection
26
27
27
28
28
- class _ByteCounter :
29
- """Accumulate and stores the total bytes of an object."""
30
-
31
- def __init__ (self ) -> None :
32
- self .nbytes : int = 0
33
-
34
- def write (self , data : bytes ) -> None :
35
- """Stores the total bytes of the data."""
36
- self .nbytes += len (data )
37
-
38
- def flush (self ) -> None :
39
- pass
40
-
41
-
42
29
def recursive_detach (in_dict : Any , to_cpu : bool = False ) -> Any :
43
30
"""Detach all tensors in `in_dict`.
44
31
@@ -183,7 +170,7 @@ def get_model_size_mb(model: Module) -> float:
183
170
Returns:
184
171
Number of megabytes in the parameters of the input module.
185
172
"""
186
- model_size = _ByteCounter ()
187
- torch .save (model .state_dict (), cast ( IO [ bytes ], model_size ) )
188
- size_mb = model_size .nbytes / 1e6
173
+ model_size = BytesIO ()
174
+ torch .save (model .state_dict (), model_size )
175
+ size_mb = model_size .getbuffer (). nbytes / 1e6
189
176
return size_mb
You can’t perform that action at this time.
0 commit comments