Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Bug]: When using the VLLM framework to load visual models, CPU memory overflow occurs while continuously processing data with images. #12973

Closed
woshiwanlei1 opened this issue Feb 9, 2025 · 21 comments
Labels
bug Something isn't working

Comments

@woshiwanlei1
Copy link

woshiwanlei1 commented Feb 9, 2025

The problem I encountered

After deploying Qwen2-VL-7B-Instruct-GPTQ-Int4 using VLLM, continuous requests from clients will cause CPU memory to continue to rise. Is it because some memory has not been reclaimed?

My specific usage scenario is:
I have two GPUs. When I use the ray framework for distributed deployment, as the number of VL models processed increases, my CPU memory becomes larger, leading to actor crashes in ray.

I have tested the native loading method of Qwen2-VL-7B-Instruct-GPTQ-Int4 and it does not cause CPU memory overflow. Once the VLLM framework is used for loading, there will be continuous CPU overflow

[Special note]: When you test, be sure to change the image each time, so that you can clearly see the CPU memory overflow. If only the same image is used, it will only leak once, causing the memory overflow to appear inconspicuous.

My code and environment

Here is my code

def getMessage(pic_file):
    messages = [{'role': 'system', 'content': 'You are a very useful assistant, please strictly follow the requirements to complete the task!'}, {'role': 'user', 'content': [{'type': 'image_url', 'image_url': pic_file, 'min_pixels': 50176, 'max_pixels': 1411200}, {'type': 'text', 'text': 'Don't worry about the prompt words here, they are just examples'}]}]
    return messages

def vllm_extract_text(result_list,model_path,temperature,top_p,max_token,min_pixels,max_pixels):
    os.environ["CUDA_VISIBLE_DEVICES"] ="0"
    model_path = "/mnt/data/programdata/vl_model/Qwen2-VL-7B-Instruct-GPTQ-Int4"
    llm = LLM(model=model_path, limit_mm_per_prompt={"image": 5, "video": 0})
    sampling_params = SamplingParams(temperature=temperature, top_p=top_p, max_tokens=max_token, stop_token_ids=[])
    processor = AutoProcessor.from_pretrained(model_path, min_pixels=min_pixels, max_pixels=max_pixels)
    
    #Ignore result_list, they are the return data of MongoDB
    for doc in result_list:
          messages = getMessage(doc['pic'])
          text = processor.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
          image_inputs, _ = process_vision_info(messages)
          mm_data = {}
          if image_inputs is not None:
              mm_data["image"] = image_inputs
          llm_inputs = {
              "prompt": text,
              "multi_modal_data": mm_data,
          }
          outputs = llm.generate([llm_inputs], sampling_params=sampling_params, use_tqdm=False)
          for output in outputs:
              generated_text = output.outputs[0].text
          
          del llm_inputs,outputs

This is vllm version information

Name: vllm
Version: 0.7.2

This is my gpu info

Image

This is memory leak information

Image
Image

@woshiwanlei1 woshiwanlei1 added the bug Something isn't working label Feb 9, 2025
@ywang96
Copy link
Member

ywang96 commented Mar 4, 2025

@woshiwanlei1 Can you try specifying disable_mm_preprocessor_cache=True and see if the host memory overflow issue still persists?

@woshiwanlei1
Copy link
Author

@woshiwanlei1 Can you try specifying disable_mm_preprocessor_cache=True and see if the host memory overflow issue still persists?

I'll come back and give it a try, thank you.
I found out later that the leakage of this memory doesn't seem to be unlimited. It seems that if it leaks to a certain limit, he won't leak anymore. I installed multiple memory modules at the back, which solved this problem.

@DarkLight1337
Copy link
Member

Can you also try out #14336 and see if it alleviates the issue?

@ktobah
Copy link

ktobah commented Mar 17, 2025

If I disable_mm_preprocessor_cache=True, the memory is stable but inference quite slow.

So, I decided to try your fix. I pulled the latest docker image with:

docker pull public.ecr.aws/q9t5s3a7/vllm-ci-postmerge-repo:${VLLM_COMMIT}

I think this issue is not resolved yet.

Loading was fine:

INFO 03-17 07:03:00 [__init__.py:256] Automatically detected platform cuda.
INFO 03-17 07:03:04 [api_server.py:972] vLLM API server version 0.7.4.dev497+ga73e183e
INFO 03-17 07:03:04 [api_server.py:973] args: Namespace(host='0.0.0.0', port=5025, uvicorn_log_level='info', allow_credentials=False, allowed_origins=['*'], allowed_methods=['*'], allowed_headers=['*'], api_key='-', lora_modules=None, prompt_adapters=None, chat_template=None, chat_template_content_format='auto', response_role='assistant', ssl_keyfile=None, ssl_certfile=None, ssl_ca_certs=None, enable_ssl_refresh=False, ssl_cert_reqs=0, root_path=None, middleware=[], return_tokens_as_token_ids=False, disable_frontend_multiprocessing=False, enable_request_id_headers=False, enable_auto_tool_choice=False, tool_call_parser=None, tool_parser_plugin='', model='/data/Qwen/Qwen2.5-VL-7B-Instruct-V1', task='auto', tokenizer=None, hf_config_path=None, skip_tokenizer_init=False, revision=None, code_revision=None, tokenizer_revision=None, tokenizer_mode='auto', trust_remote_code=True, allowed_local_media_path=None, download_dir=None, load_format='auto', config_format=<ConfigFormat.AUTO: 'auto'>, dtype='auto', kv_cache_dtype='auto', max_model_len=7000, guided_decoding_backend='xgrammar', logits_processor_pattern=None, model_impl='auto', distributed_executor_backend=None, pipeline_parallel_size=1, tensor_parallel_size=1, enable_expert_parallel=False, max_parallel_loading_workers=None, ray_workers_use_nsight=False, block_size=None, enable_prefix_caching=None, disable_sliding_window=False, use_v2_block_manager=True, num_lookahead_slots=0, seed=7, swap_space=4, cpu_offload_gb=0, gpu_memory_utilization=0.9, num_gpu_blocks_override=None, max_num_batched_tokens=None, max_num_partial_prefills=1, max_long_partial_prefills=1, long_prefill_token_threshold=0, max_num_seqs=None, max_logprobs=20, disable_log_stats=True, quantization=None, rope_scaling=None, rope_theta=None, hf_overrides=None, enforce_eager=False, max_seq_len_to_capture=8192, disable_custom_all_reduce=False, tokenizer_pool_size=0, tokenizer_pool_type='ray', tokenizer_pool_extra_config=None, limit_mm_per_prompt=None, mm_processor_kwargs={'max_pixels': 3920000}, disable_mm_preprocessor_cache=False, enable_lora=False, enable_lora_bias=False, max_loras=1, max_lora_rank=16, lora_extra_vocab_size=256, lora_dtype='auto', long_lora_scaling_factors=None, max_cpu_loras=None, fully_sharded_loras=False, enable_prompt_adapter=False, max_prompt_adapters=1, max_prompt_adapter_token=0, device='cuda', num_scheduler_steps=1, use_tqdm_on_load=True, multi_step_stream_outputs=True, scheduler_delay_factor=0.0, enable_chunked_prefill=None, speculative_model=None, speculative_model_quantization=None, num_speculative_tokens=None, speculative_disable_mqa_scorer=False, speculative_draft_tensor_parallel_size=None, speculative_max_model_len=None, speculative_disable_by_batch_size=None, ngram_prompt_lookup_max=None, ngram_prompt_lookup_min=None, spec_decoding_acceptance_method='rejection_sampler', typical_acceptance_sampler_posterior_threshold=None, typical_acceptance_sampler_posterior_alpha=None, disable_logprobs_during_spec_decoding=None, model_loader_extra_config=None, ignore_patterns=[], preemption_mode=None, served_model_name=['Qwen/Qwen2.5-VL-7B-Instruct-V1'], qlora_adapter_name_or_path=None, show_hidden_metrics_for_version=None, otlp_traces_endpoint=None, collect_detailed_traces=None, disable_async_output_proc=False, scheduling_policy='fcfs', scheduler_cls='vllm.core.scheduler.Scheduler', override_neuron_config=None, override_pooler_config=None, compilation_config=None, kv_transfer_config=None, worker_cls='auto', worker_extension_cls='', generation_config='auto', override_generation_config=None, enable_sleep_mode=False, calculate_kv_scales=False, additional_config=None, enable_reasoning=False, reasoning_parser=None, disable_log_requests=True, max_log_len=None, disable_fastapi_docs=False, enable_prompt_tokens_details=False, enable_server_load_tracking=False)
INFO 03-17 07:03:14 [config.py:583] This model supports multiple tasks: {'score', 'classify', 'generate', 'reward', 'embed'}. Defaulting to 'generate'.
INFO 03-17 07:03:14 [config.py:1677] Chunked prefill is enabled with max_num_batched_tokens=2048.
INFO 03-17 07:03:15 [core.py:53] Initializing a V1 LLM engine (v0.7.4.dev497+ga73e183e) with config: model='/data/Qwen/Qwen2.5-VL-7B-Instruct-V1', speculative_config=None, tokenizer='/data/Qwen/Qwen2.5-VL-7B-Instruct-V1', skip_tokenizer_init=False, tokenizer_mode=auto, revision=None, override_neuron_config=None, tokenizer_revision=None, trust_remote_code=True, dtype=torch.bfloat16, max_seq_len=7000, download_dir=None, load_format=LoadFormat.AUTO, tensor_parallel_size=1, pipeline_parallel_size=1, disable_custom_all_reduce=False, quantization=None, enforce_eager=False, kv_cache_dtype=auto,  device_config=cuda, decoding_config=DecodingConfig(guided_decoding_backend='xgrammar', reasoning_backend=None), observability_config=ObservabilityConfig(show_hidden_metrics=False, otlp_traces_endpoint=None, collect_model_forward_time=False, collect_model_execute_time=False), seed=7, served_model_name=Qwen/Qwen2.5-VL-7B-Instruct-V1, num_scheduler_steps=1, multi_step_stream_outputs=True, enable_prefix_caching=True, chunked_prefill_enabled=True, use_async_output_proc=True, disable_mm_preprocessor_cache=False, mm_processor_kwargs={'max_pixels': 3920000}, pooler_config=None, compilation_config={"level":3,"custom_ops":["none"],"splitting_ops":["vllm.unified_attention","vllm.unified_attention_with_output"],"use_inductor":true,"compile_sizes":[],"use_cudagraph":true,"cudagraph_num_of_warmups":1,"cudagraph_capture_sizes":[512,504,496,488,480,472,464,456,448,440,432,424,416,408,400,392,384,376,368,360,352,344,336,328,320,312,304,296,288,280,272,264,256,248,240,232,224,216,208,200,192,184,176,168,160,152,144,136,128,120,112,104,96,88,80,72,64,56,48,40,32,24,16,8,4,2,1],"max_capture_size":512}
WARNING 03-17 07:03:18 [utils.py:2282] Methods determine_num_available_blocks,device_config,get_cache_block_size_bytes,initialize_cache not implemented in <vllm.v1.worker.gpu_worker.Worker object at 0x7e088b849280>
INFO 03-17 07:03:18 [parallel_state.py:948] rank 0 in world size 1 is assigned as DP rank 0, PP rank 0, TP rank 0
INFO 03-17 07:03:18 [cuda.py:215] Using Flash Attention backend on V1 engine.
WARNING 03-17 07:03:18 [registry.py:337] `mm_limits` has already been set for model=/data/Qwen/Qwen2.5-VL-7B-Instruct-V1, and will be overwritten by the new values.
Using a slow image processor as `use_fast` is unset and a slow processor was saved with this model. `use_fast=True` will be the default behavior in v4.48, even if the model was saved with a slow processor. This will result in minor differences in outputs. You'll still be able to use a slow processor with `use_fast=False`.
INFO 03-17 07:03:19 [gpu_model_runner.py:1118] Starting to load model /data/Qwen/Qwen2.5-VL-7B-Instruct-V1...
WARNING 03-17 07:03:19 [vision.py:94] Current `vllm-flash-attn` has a bug inside vision module, so we use xformers backend instead. You can run `pip install flash-attn` to use flash-attention backend.
INFO 03-17 07:03:19 [config.py:3206] cudagraph sizes specified by model runner [1, 2, 4, 8, 16, 24, 32, 40, 48, 56, 64, 72, 80, 88, 96, 104, 112, 120, 128, 136, 144, 152, 160, 168, 176, 184, 192, 200, 208, 216, 224, 232, 240, 248, 256, 264, 272, 280, 288, 296, 304, 312, 320, 328, 336, 344, 352, 360, 368, 376, 384, 392, 400, 408, 416, 424, 432, 440, 448, 456, 464, 472, 480, 488, 496, 504, 512] is overridden by config [512, 384, 256, 128, 4, 2, 1, 392, 264, 136, 8, 400, 272, 144, 16, 408, 280, 152, 24, 416, 288, 160, 32, 424, 296, 168, 40, 432, 304, 176, 48, 440, 312, 184, 56, 448, 320, 192, 64, 456, 328, 200, 72, 464, 336, 208, 80, 472, 344, 216, 88, 120, 480, 352, 248, 224, 96, 488, 504, 360, 232, 104, 496, 368, 240, 112, 376]
INFO 03-17 07:03:19 [topk_topp_sampler.py:53] Using FlashInfer for top-p & top-k sampling.
Loading safetensors checkpoint shards:   0% Completed | 0/5 [00:00<?, ?it/s]
Loading safetensors checkpoint shards:  20% Completed | 1/5 [00:29<01:56, 29.01s/it]
Loading safetensors checkpoint shards:  40% Completed | 2/5 [00:58<01:27, 29.32s/it]
Loading safetensors checkpoint shards:  60% Completed | 3/5 [01:06<00:39, 19.72s/it]
Loading safetensors checkpoint shards:  80% Completed | 4/5 [01:36<00:23, 23.75s/it]
Loading safetensors checkpoint shards: 100% Completed | 5/5 [02:06<00:00, 25.87s/it]
Loading safetensors checkpoint shards: 100% Completed | 5/5 [02:06<00:00, 25.28s/it]

INFO 03-17 07:05:25 [loader.py:429] Loading weights took 126.59 seconds
INFO 03-17 07:05:26 [gpu_model_runner.py:1130] Model loading took 15.6270 GB and 126.968260 seconds
INFO 03-17 07:05:26 [gpu_model_runner.py:1348] Encoder cache will be initialized with a budget of 4900 tokens, and profiled with 1 image items of the maximum feature size.
Keyword argument `max_pixels` is not a valid argument for this processor and will be ignored.
It looks like you are trying to rescale already rescaled images. If the input images have pixel values between 0 and 1, set `do_rescale=False` to avoid rescaling them again.
INFO 03-17 07:05:46 [backends.py:409] Using cache directory: /root/.cache/vllm/torch_compile_cache/dc8904a923/rank_0_0 for vLLM's torch.compile
INFO 03-17 07:05:46 [backends.py:419] Dynamo bytecode transform time: 11.98 s
INFO 03-17 07:05:50 [backends.py:132] Cache the graph of shape None for later use
INFO 03-17 07:06:19 [backends.py:144] Compiling a graph for general shape takes 32.61 s
INFO 03-17 07:06:32 [monitor.py:33] torch.compile takes 44.59 s in total
INFO 03-17 07:06:33 [kv_cache_utils.py:537] GPU KV cache size: 31,760 tokens
INFO 03-17 07:06:33 [kv_cache_utils.py:540] Maximum concurrency for 7,000 tokens per request: 4.54x
INFO 03-17 07:07:05 [gpu_model_runner.py:1440] Graph capturing finished in 32 secs, took 0.50 GiB
INFO 03-17 07:07:05 [core.py:138] init engine (profile, create kv cache, warmup model) took 99.25 seconds
INFO 03-17 07:07:05 [serving_chat.py:115] Using default chat sampling params from model: {'repetition_penalty': 1.05, 'temperature': 0.1, 'top_k': 1, 'top_p': 0.001}
INFO 03-17 07:07:05 [serving_completion.py:61] Using default completion sampling params from model: {'repetition_penalty': 1.05, 'temperature': 0.1, 'top_k': 1, 'top_p': 0.001}
INFO 03-17 07:07:05 [api_server.py:1019] Starting vLLM API server on http://0.0.0.0:5025
INFO 03-17 07:07:05 [launcher.py:26] Available routes are:
INFO 03-17 07:07:05 [launcher.py:34] Route: /openapi.json, Methods: GET, HEAD
INFO 03-17 07:07:05 [launcher.py:34] Route: /docs, Methods: GET, HEAD
INFO 03-17 07:07:05 [launcher.py:34] Route: /docs/oauth2-redirect, Methods: GET, HEAD
INFO 03-17 07:07:05 [launcher.py:34] Route: /redoc, Methods: GET, HEAD
INFO 03-17 07:07:05 [launcher.py:34] Route: /health, Methods: GET
INFO 03-17 07:07:05 [launcher.py:34] Route: /load, Methods: GET
INFO 03-17 07:07:05 [launcher.py:34] Route: /ping, Methods: GET, POST
INFO 03-17 07:07:05 [launcher.py:34] Route: /tokenize, Methods: POST
INFO 03-17 07:07:05 [launcher.py:34] Route: /detokenize, Methods: POST
INFO 03-17 07:07:05 [launcher.py:34] Route: /v1/models, Methods: GET
INFO 03-17 07:07:05 [launcher.py:34] Route: /version, Methods: GET
INFO 03-17 07:07:05 [launcher.py:34] Route: /v1/chat/completions, Methods: POST
INFO 03-17 07:07:05 [launcher.py:34] Route: /v1/completions, Methods: POST
INFO 03-17 07:07:05 [launcher.py:34] Route: /v1/embeddings, Methods: POST
INFO 03-17 07:07:05 [launcher.py:34] Route: /pooling, Methods: POST
INFO 03-17 07:07:05 [launcher.py:34] Route: /score, Methods: POST
INFO 03-17 07:07:05 [launcher.py:34] Route: /v1/score, Methods: POST
INFO 03-17 07:07:05 [launcher.py:34] Route: /v1/audio/transcriptions, Methods: POST
INFO 03-17 07:07:05 [launcher.py:34] Route: /rerank, Methods: POST
INFO 03-17 07:07:05 [launcher.py:34] Route: /v1/rerank, Methods: POST
INFO 03-17 07:07:05 [launcher.py:34] Route: /v2/rerank, Methods: POST
INFO 03-17 07:07:05 [launcher.py:34] Route: /invocations, Methods: POST
INFO:     Started server process [1]
INFO:     Waiting for application startup.
INFO:     Application startup complete.
INFO 03-17 07:09:01 [chat_utils.py:346] Detected the chat template content format to be 'string'. You can set `--chat-template-content-format` to override this.
Using a slow image processor as `use_fast` is unset and a slow processor was saved with this model. `use_fast=True` will be the default behavior in v4.48, even if the model was saved with a slow processor. This will result in minor differences in outputs. You'll still be able to use a slow processor with `use_fast=False`.
Keyword argument `max_pixels` is not a valid argument for this processor and will be ignored.
/opt/venv/lib/python3.12/site-packages/torch/utils/cpp_extension.py:2059: UserWarning: TORCH_CUDA_ARCH_LIST is not set, all archs for visible cards are included for compilation. 
If this is not desired, please set os.environ['TORCH_CUDA_ARCH_LIST'].

Then, I run a 3 tests with 100 images each time. The following are the RAM usage results from the instance (EC2 g5.xlarge):

process 1: 15.9% -- (1st test [100 images] 20.1%) -- (2nd test [same 100 images] 20.1%) -- (3rd test [different 100 images] 40%) [didn't finish and all RAM was used, instance crashed]
process 2: 04.5% -- (1st test [100 images] 09.2%) -- (2nd test [same 100 images] 09.2%) -- (3rd test [different 100 images] 35%) [didn't finish and all RAM was used, instance crashed]

It doesn't seem to get stable after an Nth number of calls, but rather grow with every unseen image or am I missing something?

Image

@woshiwanlei1 Could you please elaborate on this:

I installed multiple memory modules at the back, which solved this problem.

@DarkLight1337
Copy link
Member

DarkLight1337 commented Mar 17, 2025

So, I decided to try your fix. I pulled the latest docker image with:
docker pull public.ecr.aws/q9t5s3a7/vllm-ci-postmerge-repo:${VLLM_COMMIT}
I think this issue is not resolved yet.

I assume you mean that vLLM may still crash if you don't set disable_mm_preprocessor_cache=True?

If you are using latest commit, #14805 has been merged already so you can try setting VLLM_MM_INPUT_CACHE_GIB to explicitly limit the amount of memory the cache can take up without disabling it entirely.

@ktobah
Copy link

ktobah commented Mar 19, 2025

Thank you. With VLLM_MM_INPUT_CACHE_GiB it is more controllable.

@cchadowitz
Copy link

A note for anyone else encountering something like this, the env var is actually VLLM_MM_INPUT_CACHE_GIB (note the capital I).

@kuladeephx
Copy link

@cchadowitz / @ktobah, can you pls let me know on how can I set this environment variable VLLM_MM_INPUT_CACHE_GIB

@DarkLight1337
Copy link
Member

DarkLight1337 commented Mar 23, 2025

You just set it like any other environment variable. For example, to limit the cache size to 4 GiB when running vllm serve:

VLLM_MM_INPUT_CACHE_GIB=4 vllm serve <args>

@kuladeephx
Copy link

os.environ["VLLM_MM_INPUT_CACHE_GIB"] = "4"

I have tried setting this way as I am using vllm wrapper, but still consuming high CPU RAM

@DarkLight1337
Copy link
Member

DarkLight1337 commented Mar 23, 2025

You should set it before importing vLLM. It's preferable to set it in command line. If CPU usage is still high, you can try --disable-mm-preprocessor-cache to disable such caching entirely.

@kuladeephx
Copy link

have tried disable-mm-preprocessor-cache, still its consuming high RAM with each increasing request

@kuladeephx
Copy link

have tried setting variable on command line before execution: export VLLM_MM_INPUT_CACHE_GIB=4
also tried to set the variable before importing vllm
still the same

@DarkLight1337
Copy link
Member

cc @ywang96 I think you're working on this issue with the frontend?

@ywang96
Copy link
Member

ywang96 commented Mar 23, 2025

Hello @kuladeephx ! This issue has been fixed on the main branch - please see a post-mortem/analysis at the comment here

@kuladeephx
Copy link

kuladeephx commented Mar 24, 2025

@ywang96, I have tried with latest main branch

  1. disable_mm_preprocessor_cache=True
  2. export VLLM_MM_INPUT_CACHE_GIB=4

still facing the same issue

@kuladeephx
Copy link

Have tried with vllm serve(openai server based). Did not face memory issue using disable-mm-preprocessor-cache
not sure why its consuming high cpu RAM using vllm wrapper class

@Mitix-EPI
Copy link

using v0.8.2 seems to fix the issue. I use only VLLM_MM_INPUT_CACHE_GIB and it limits the cpu RAM consumption.

@kuladeephx
Copy link

@Mitix-EPI, is it using open ai server based or using the vllm wrapper

@zhangyuygss
Copy link

@ywang96, I have tried with latest main branch

  1. disable_mm_preprocessor_cache=True
  2. export VLLM_MM_INPUT_CACHE_GIB=4

still facing the same issue

Have you solved the problem? I'm using 0.7.x and facing similar issue.
My situation is:
When I set n=1 in SamplingParms, the cpu memory stops growing at a certain point;
When set n>1, the cpu memory keeps growing(>300GB) and OOM

@Mitix-EPI
Copy link

@kuladeephx vllm wrapper with the v1 api
Hope you'll fix your issue !

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
bug Something isn't working
Projects
None yet
Development

No branches or pull requests

8 participants