-
-
Notifications
You must be signed in to change notification settings - Fork 7.6k
[misc] improve memory profiling #11809
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from 6 commits
ae6e5ae
9a3c964
fccb65f
c83b3b9
9fc3db7
f07c38c
5f61195
c86efad
f6a6222
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -5,6 +5,7 @@ | |
|
||
import pytest | ||
import torch | ||
from vllm_test_utils import monitor | ||
|
||
from vllm.utils import (FlexibleArgumentParser, StoreBoolean, deprecate_kwargs, | ||
get_open_port, memory_profiling, merge_async_iterators, | ||
|
@@ -289,16 +290,32 @@ def test_memory_profiling(): | |
|
||
weights_memory_in_bytes = 128 * 1024 * 1024 * 4 # 512 MiB | ||
|
||
def measure_current_non_torch(): | ||
free, total = torch.cuda.mem_get_info() | ||
current_used = total - free | ||
current_torch = torch.cuda.memory_reserved() | ||
current_non_torch = current_used - current_torch | ||
return current_non_torch | ||
|
||
with memory_profiling(baseline_memory_in_bytes=baseline_memory_in_bytes, | ||
weights_memory_in_bytes=weights_memory_in_bytes) as result: | ||
weights_memory_in_bytes=weights_memory_in_bytes) as result, \ | ||
monitor(measure_current_non_torch) as monitored_values: | ||
# make a memory spike, 1 GiB | ||
spike = torch.randn(256, 1024, 1024, device='cuda', dtype=torch.float32) | ||
del spike | ||
|
||
# Add some extra non-torch memory 256 MiB (simulate NCCL) | ||
handle2 = lib.cudaMalloc(256 * 1024 * 1024) | ||
|
||
# this is an analytic value, it is exact, | ||
# we only have 256 MiB non-torch memory increase | ||
measured_diff =monitored_values.values[-1] - monitored_values.values[0] | ||
assert measured_diff == 256 * 1024 * 1024 | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This is really cool! I'm a bit confused on how it's useful though, since we're testing a test utility that isn't used in the actual memory profiling? Did we want to enable There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
for the current memory profiling, we mainly use the utility function is more about future-proof, we can get the ground-truth non-torch memory, which will help us profile which part of memory can be offloaded in RLHF workload. |
||
|
||
# Check that the memory usage is within 5% of the expected values | ||
# 5% tolerance is caused by PyTorch caching allocator, | ||
# we cannot control PyTorch's behavior of its internal buffers, | ||
# which causes a small error (<10 MiB in practice) | ||
non_torch_ratio = result.non_torch_increase_in_bytes / (256 * 1024 * 1024) # noqa | ||
torch_peak_ratio = result.torch_peak_increase_in_bytes / (1024 * 1024 * 1024) # noqa | ||
assert abs(non_torch_ratio - 1) <= 0.05 | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,65 @@ | ||
import contextlib | ||
import dataclasses | ||
import sys | ||
import traceback | ||
from typing import Any, Callable, Generator, List | ||
|
||
|
||
@dataclasses.dataclass | ||
class MonitoredValues: | ||
values: List[Any] = dataclasses.field(default_factory=list) | ||
trace_stacks: List[str] = dataclasses.field(default_factory=list) | ||
|
||
|
||
@contextlib.contextmanager | ||
def monitor( | ||
measure_func: Callable[[], | ||
Any]) -> Generator[MonitoredValues, None, None]: | ||
""" | ||
Trace the function calls to continuously monitor the change of | ||
a value. | ||
|
||
Usage: | ||
|
||
```python | ||
|
||
def measure_func(): | ||
... # measure the current value | ||
return current_value | ||
|
||
with monitor(measure_func) as monitored_values: | ||
# do something | ||
|
||
monitored_values.values # all changes of the values | ||
monitored_values.trace_stacks # trace stacks of every change | ||
""" | ||
monitored_values = MonitoredValues() | ||
youkaichao marked this conversation as resolved.
Show resolved
Hide resolved
|
||
|
||
def _trace_calls(frame, event, arg=None): | ||
nonlocal monitored_values | ||
if event in ['line']: | ||
# triggered by every line of Python code. | ||
# only Python functions will trigger it, | ||
# c/cpp functions will not trigger it. | ||
try: | ||
# Temporarily disable the trace function | ||
sys.settrace(None) | ||
# do a measurement | ||
current_value = measure_func() | ||
if len(monitored_values.values | ||
) == 0 or current_value != monitored_values.values[-1]: | ||
monitored_values.values.append(current_value) | ||
monitored_values.trace_stacks.append("".join( | ||
traceback.format_stack())) | ||
# Re-enable the trace function | ||
sys.settrace(_trace_calls) | ||
except NameError: | ||
# modules are deleted during shutdown | ||
pass | ||
return _trace_calls | ||
|
||
try: | ||
sys.settrace(_trace_calls) | ||
yield monitored_values | ||
finally: | ||
sys.settrace(None) |
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -1742,10 +1742,10 @@ class MemorySnapshot: | |
timestamp: float = 0.0 | ||
|
||
def measure(self): | ||
self.torch_peak_in_bytes = torch.cuda.memory_stats( | ||
)["allocated_bytes.all.peak"] | ||
self.torch_memory_in_bytes = torch.cuda.memory_stats( | ||
)["allocated_bytes.all.current"] | ||
self.torch_peak_in_bytes = torch.cuda.max_memory_reserved() | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. ohhh nice, I didn't catch that there was a peak measurement for the total reserved memory as well |
||
# torch.cuda.memory_reserved() is how many bytes | ||
# PyTorch gets from cuda (by calling cudaMalloc, etc.) | ||
self.torch_memory_in_bytes = torch.cuda.memory_reserved() | ||
self.timestamp = time.time() | ||
|
||
def __sub__(self, other: "MemorySnapshot") -> "MemorySnapshot": | ||
|
@@ -1822,10 +1822,10 @@ def memory_profiling( | |
|
||
The memory used for loading weights (a.) is directly given from the argument `weights_memory_in_bytes`. | ||
|
||
The increase of ``torch.cuda.memory_stats()["allocated_bytes.all.peak"]` after profiling gives (b.). | ||
The increase of `torch.cuda.memory_stats()["allocated_bytes.all.peak"]` after profiling gives (b.). | ||
|
||
(c.) is tricky. We measure the total memory used in this GPU (`torch.cuda.mem_get_info()[1] - torch.cuda.mem_get_info()[0]`), | ||
subtract the baseline memory, the memory used by the model weights, and diff of `torch.cuda.memory_stats()["allocated_bytes.all.current"]`. | ||
subtract the baseline memory, the memory used by the model weights, and diff of `torch.cuda.memory_reserved()`. | ||
""" # noqa | ||
torch.cuda.reset_peak_memory_stats() | ||
|
||
|
Uh oh!
There was an error while loading. Please reload this page.