Skip to content

Commit 889e662

Browse files
[misc] improve memory profiling (vllm-project#11809)
Signed-off-by: youkaichao <[email protected]> Co-authored-by: Cyrus Leung <[email protected]>
1 parent ef68eb2 commit 889e662

File tree

4 files changed

+94
-8
lines changed

4 files changed

+94
-8
lines changed

tests/test_utils.py

Lines changed: 18 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55

66
import pytest
77
import torch
8+
from vllm_test_utils import monitor
89

910
from vllm.utils import (FlexibleArgumentParser, StoreBoolean, deprecate_kwargs,
1011
get_open_port, memory_profiling, merge_async_iterators,
@@ -289,16 +290,32 @@ def test_memory_profiling():
289290

290291
weights_memory_in_bytes = 128 * 1024 * 1024 * 4 # 512 MiB
291292

293+
def measure_current_non_torch():
294+
free, total = torch.cuda.mem_get_info()
295+
current_used = total - free
296+
current_torch = torch.cuda.memory_reserved()
297+
current_non_torch = current_used - current_torch
298+
return current_non_torch
299+
292300
with memory_profiling(baseline_memory_in_bytes=baseline_memory_in_bytes,
293-
weights_memory_in_bytes=weights_memory_in_bytes) as result:
301+
weights_memory_in_bytes=weights_memory_in_bytes) as result, \
302+
monitor(measure_current_non_torch) as monitored_values:
294303
# make a memory spike, 1 GiB
295304
spike = torch.randn(256, 1024, 1024, device='cuda', dtype=torch.float32)
296305
del spike
297306

298307
# Add some extra non-torch memory 256 MiB (simulate NCCL)
299308
handle2 = lib.cudaMalloc(256 * 1024 * 1024)
300309

310+
# this is an analytic value, it is exact,
311+
# we only have 256 MiB non-torch memory increase
312+
measured_diff = monitored_values.values[-1] - monitored_values.values[0]
313+
assert measured_diff == 256 * 1024 * 1024
314+
301315
# Check that the memory usage is within 5% of the expected values
316+
# 5% tolerance is caused by PyTorch caching allocator,
317+
# we cannot control PyTorch's behavior of its internal buffers,
318+
# which causes a small error (<10 MiB in practice)
302319
non_torch_ratio = result.non_torch_increase_in_bytes / (256 * 1024 * 1024) # noqa
303320
torch_peak_ratio = result.torch_peak_increase_in_bytes / (1024 * 1024 * 1024) # noqa
304321
assert abs(non_torch_ratio - 1) <= 0.05

tests/vllm_test_utils/vllm_test_utils/__init__.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,5 +4,6 @@
44
"""
55

66
from .blame import BlameResult, blame
7+
from .monitor import MonitoredValues, monitor
78

8-
__all__ = ["blame", "BlameResult"]
9+
__all__ = ["blame", "BlameResult", "monitor", "MonitoredValues"]
Lines changed: 68 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,68 @@
1+
import contextlib
2+
import dataclasses
3+
import sys
4+
import traceback
5+
from typing import Callable, Generator, Generic, TypeVar
6+
7+
_T = TypeVar("_T")
8+
9+
10+
@dataclasses.dataclass
11+
class MonitoredValues(Generic[_T]):
12+
values: list[_T] = dataclasses.field(default_factory=list)
13+
trace_stacks: list[str] = dataclasses.field(default_factory=list)
14+
15+
16+
@contextlib.contextmanager
17+
def monitor(
18+
measure_func: Callable[[],
19+
_T]) -> Generator[MonitoredValues[_T], None, None]:
20+
"""
21+
Trace the function calls to continuously monitor the change of
22+
a value.
23+
24+
Usage:
25+
26+
```python
27+
28+
def measure_func():
29+
... # measure the current value
30+
return current_value
31+
32+
with monitor(measure_func) as monitored_values:
33+
# do something
34+
35+
monitored_values.values # all changes of the values
36+
monitored_values.trace_stacks # trace stacks of every change
37+
```
38+
"""
39+
monitored_values = MonitoredValues[_T]()
40+
41+
def _trace_calls(frame, event, arg=None):
42+
nonlocal monitored_values
43+
if event in ['line']:
44+
# triggered by every line of Python code.
45+
# only Python functions will trigger it,
46+
# c/cpp functions will not trigger it.
47+
try:
48+
# Temporarily disable the trace function
49+
sys.settrace(None)
50+
# do a measurement
51+
current_value = measure_func()
52+
if len(monitored_values.values
53+
) == 0 or current_value != monitored_values.values[-1]:
54+
monitored_values.values.append(current_value)
55+
monitored_values.trace_stacks.append("".join(
56+
traceback.format_stack()))
57+
# Re-enable the trace function
58+
sys.settrace(_trace_calls)
59+
except NameError:
60+
# modules are deleted during shutdown
61+
pass
62+
return _trace_calls
63+
64+
try:
65+
sys.settrace(_trace_calls)
66+
yield monitored_values
67+
finally:
68+
sys.settrace(None)

vllm/utils.py

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1742,10 +1742,10 @@ class MemorySnapshot:
17421742
timestamp: float = 0.0
17431743

17441744
def measure(self):
1745-
self.torch_peak_in_bytes = torch.cuda.memory_stats(
1746-
)["allocated_bytes.all.peak"]
1747-
self.torch_memory_in_bytes = torch.cuda.memory_stats(
1748-
)["allocated_bytes.all.current"]
1745+
self.torch_peak_in_bytes = torch.cuda.max_memory_reserved()
1746+
# torch.cuda.memory_reserved() is how many bytes
1747+
# PyTorch gets from cuda (by calling cudaMalloc, etc.)
1748+
self.torch_memory_in_bytes = torch.cuda.memory_reserved()
17491749
self.timestamp = time.time()
17501750

17511751
def __sub__(self, other: "MemorySnapshot") -> "MemorySnapshot":
@@ -1822,10 +1822,10 @@ def memory_profiling(
18221822
18231823
The memory used for loading weights (a.) is directly given from the argument `weights_memory_in_bytes`.
18241824
1825-
The increase of ``torch.cuda.memory_stats()["allocated_bytes.all.peak"]` after profiling gives (b.).
1825+
The increase of `torch.cuda.memory_stats()["allocated_bytes.all.peak"]` after profiling gives (b.).
18261826
18271827
(c.) is tricky. We measure the total memory used in this GPU (`torch.cuda.mem_get_info()[1] - torch.cuda.mem_get_info()[0]`),
1828-
subtract the baseline memory, the memory used by the model weights, and diff of `torch.cuda.memory_stats()["allocated_bytes.all.current"]`.
1828+
subtract the baseline memory, the memory used by the model weights, and diff of `torch.cuda.memory_reserved()`.
18291829
""" # noqa
18301830
torch.cuda.reset_peak_memory_stats()
18311831

0 commit comments

Comments
 (0)