Skip to content

Commit 31ef1dd

Browse files
Shaoting-Fenglulmer
authored andcommitted
[Misc] Add offline test for disaggregated prefill (vllm-project#12418)
Signed-off-by: Louis Ulmer <[email protected]>
1 parent 0666250 commit 31ef1dd

File tree

1 file changed

+111
-0
lines changed

1 file changed

+111
-0
lines changed
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,111 @@
1+
# SPDX-License-Identifier: Apache-2.0
2+
"""
3+
This file demonstrates the example usage of disaggregated prefilling
4+
We will launch 2 vllm instances (GPU 0 for prefill and GPU 1 for decode),
5+
and then transfer the KV cache between them.
6+
"""
7+
import os
8+
import time
9+
from multiprocessing import Event, Process
10+
11+
from vllm import LLM, SamplingParams
12+
from vllm.config import KVTransferConfig
13+
14+
15+
def run_prefill(prefill_done):
16+
# We use GPU 0 for prefill node.
17+
os.environ["CUDA_VISIBLE_DEVICES"] = "0"
18+
19+
# The prefill node receives two requests, while the decode node receives
20+
# three requests. So the decode node will only receive the KV Cache for
21+
# requests 1 and 3. The decode node will use the KV Cache of requests 1
22+
# and 3 and do prefilling on request 2.
23+
prompts = [
24+
"Hello, my name is",
25+
# "Hi, your name is",
26+
# The decode node will actually "prefill" this request.
27+
"Tell me a very long story",
28+
]
29+
sampling_params = SamplingParams(temperature=0, top_p=0.95, max_tokens=1)
30+
31+
# Using PyNcclConnector to transmit KV caches between vLLM instances.
32+
# This instance is the prefill node (kv_producer, rank 0).
33+
# The number of parallel instances for KV cache transfer is set to 2,
34+
# as required for PyNcclConnector.
35+
ktc = KVTransferConfig.from_cli(
36+
'{"kv_connector":"PyNcclConnector","kv_role":"kv_producer","kv_rank":0,"kv_parallel_size":2}'
37+
)
38+
39+
# Set GPU memory utilization to 0.8 for an A6000 GPU with 40GB
40+
# memory. You may need to adjust the value to fit your GPU.
41+
llm = LLM(model="meta-llama/Meta-Llama-3.1-8B-Instruct",
42+
kv_transfer_config=ktc,
43+
max_model_len=2000,
44+
gpu_memory_utilization=0.8)
45+
46+
llm.generate(prompts, sampling_params)
47+
print("Prefill node is finished.")
48+
prefill_done.set()
49+
50+
# To keep the prefill node running in case the decode node is not done;
51+
# otherwise, the script might exit prematurely, causing incomplete decoding.
52+
try:
53+
while True:
54+
time.sleep(1)
55+
except KeyboardInterrupt:
56+
print("Script stopped by user.")
57+
58+
59+
def run_decode(prefill_done):
60+
# We use GPU 1 for decode node.
61+
os.environ["CUDA_VISIBLE_DEVICES"] = "1"
62+
63+
prompts = [
64+
"Hello, my name is",
65+
"Hi, your name is",
66+
"Tell me a very long story",
67+
]
68+
sampling_params = SamplingParams(temperature=0, top_p=0.95)
69+
70+
# Using PyNcclConnector to transmit KV caches between vLLM instances.
71+
# This instance is the decode node (kv_consumer, rank 1).
72+
# The number of parallel instances for KV cache transfer is set to 2,
73+
# as required for PyNcclConnector.
74+
ktc = KVTransferConfig.from_cli(
75+
'{"kv_connector":"PyNcclConnector","kv_role":"kv_consumer","kv_rank":1,"kv_parallel_size":2}'
76+
)
77+
78+
# Set GPU memory utilization to 0.8 for an A6000 GPU with 40GB
79+
# memory. You may need to adjust the value to fit your GPU.
80+
llm = LLM(model="meta-llama/Meta-Llama-3.1-8B-Instruct",
81+
kv_transfer_config=ktc,
82+
max_model_len=2000,
83+
gpu_memory_utilization=0.8)
84+
85+
# Wait for the producer to start the pipe
86+
print("Waiting for prefill node to finish...")
87+
prefill_done.wait()
88+
89+
# At this point when the prefill_done is set, the kv-cache should have been
90+
# transferred to this decode node, so we can start decoding.
91+
outputs = llm.generate(prompts, sampling_params)
92+
for output in outputs:
93+
prompt = output.prompt
94+
generated_text = output.outputs[0].text
95+
print(f"Prompt: {prompt!r}, Generated text: {generated_text!r}")
96+
97+
98+
if __name__ == "__main__":
99+
prefill_done = Event()
100+
prefill_process = Process(target=run_prefill, args=(prefill_done, ))
101+
decode_process = Process(target=run_decode, args=(prefill_done, ))
102+
103+
# Start prefill node
104+
prefill_process.start()
105+
106+
# Start decode node
107+
decode_process.start()
108+
109+
# Terminate the prefill node when decode is finished
110+
decode_process.join()
111+
prefill_process.terminate()

0 commit comments

Comments
 (0)