|
| 1 | +# SPDX-License-Identifier: Apache-2.0 |
| 2 | +""" |
| 3 | +This file demonstrates the example usage of disaggregated prefilling |
| 4 | +We will launch 2 vllm instances (GPU 0 for prefill and GPU 1 for decode), |
| 5 | +and then transfer the KV cache between them. |
| 6 | +""" |
| 7 | +import os |
| 8 | +import time |
| 9 | +from multiprocessing import Event, Process |
| 10 | + |
| 11 | +from vllm import LLM, SamplingParams |
| 12 | +from vllm.config import KVTransferConfig |
| 13 | + |
| 14 | + |
| 15 | +def run_prefill(prefill_done): |
| 16 | + # We use GPU 0 for prefill node. |
| 17 | + os.environ["CUDA_VISIBLE_DEVICES"] = "0" |
| 18 | + |
| 19 | + # The prefill node receives two requests, while the decode node receives |
| 20 | + # three requests. So the decode node will only receive the KV Cache for |
| 21 | + # requests 1 and 3. The decode node will use the KV Cache of requests 1 |
| 22 | + # and 3 and do prefilling on request 2. |
| 23 | + prompts = [ |
| 24 | + "Hello, my name is", |
| 25 | + # "Hi, your name is", |
| 26 | + # The decode node will actually "prefill" this request. |
| 27 | + "Tell me a very long story", |
| 28 | + ] |
| 29 | + sampling_params = SamplingParams(temperature=0, top_p=0.95, max_tokens=1) |
| 30 | + |
| 31 | + # Using PyNcclConnector to transmit KV caches between vLLM instances. |
| 32 | + # This instance is the prefill node (kv_producer, rank 0). |
| 33 | + # The number of parallel instances for KV cache transfer is set to 2, |
| 34 | + # as required for PyNcclConnector. |
| 35 | + ktc = KVTransferConfig.from_cli( |
| 36 | + '{"kv_connector":"PyNcclConnector","kv_role":"kv_producer","kv_rank":0,"kv_parallel_size":2}' |
| 37 | + ) |
| 38 | + |
| 39 | + # Set GPU memory utilization to 0.8 for an A6000 GPU with 40GB |
| 40 | + # memory. You may need to adjust the value to fit your GPU. |
| 41 | + llm = LLM(model="meta-llama/Meta-Llama-3.1-8B-Instruct", |
| 42 | + kv_transfer_config=ktc, |
| 43 | + max_model_len=2000, |
| 44 | + gpu_memory_utilization=0.8) |
| 45 | + |
| 46 | + llm.generate(prompts, sampling_params) |
| 47 | + print("Prefill node is finished.") |
| 48 | + prefill_done.set() |
| 49 | + |
| 50 | + # To keep the prefill node running in case the decode node is not done; |
| 51 | + # otherwise, the script might exit prematurely, causing incomplete decoding. |
| 52 | + try: |
| 53 | + while True: |
| 54 | + time.sleep(1) |
| 55 | + except KeyboardInterrupt: |
| 56 | + print("Script stopped by user.") |
| 57 | + |
| 58 | + |
| 59 | +def run_decode(prefill_done): |
| 60 | + # We use GPU 1 for decode node. |
| 61 | + os.environ["CUDA_VISIBLE_DEVICES"] = "1" |
| 62 | + |
| 63 | + prompts = [ |
| 64 | + "Hello, my name is", |
| 65 | + "Hi, your name is", |
| 66 | + "Tell me a very long story", |
| 67 | + ] |
| 68 | + sampling_params = SamplingParams(temperature=0, top_p=0.95) |
| 69 | + |
| 70 | + # Using PyNcclConnector to transmit KV caches between vLLM instances. |
| 71 | + # This instance is the decode node (kv_consumer, rank 1). |
| 72 | + # The number of parallel instances for KV cache transfer is set to 2, |
| 73 | + # as required for PyNcclConnector. |
| 74 | + ktc = KVTransferConfig.from_cli( |
| 75 | + '{"kv_connector":"PyNcclConnector","kv_role":"kv_consumer","kv_rank":1,"kv_parallel_size":2}' |
| 76 | + ) |
| 77 | + |
| 78 | + # Set GPU memory utilization to 0.8 for an A6000 GPU with 40GB |
| 79 | + # memory. You may need to adjust the value to fit your GPU. |
| 80 | + llm = LLM(model="meta-llama/Meta-Llama-3.1-8B-Instruct", |
| 81 | + kv_transfer_config=ktc, |
| 82 | + max_model_len=2000, |
| 83 | + gpu_memory_utilization=0.8) |
| 84 | + |
| 85 | + # Wait for the producer to start the pipe |
| 86 | + print("Waiting for prefill node to finish...") |
| 87 | + prefill_done.wait() |
| 88 | + |
| 89 | + # At this point when the prefill_done is set, the kv-cache should have been |
| 90 | + # transferred to this decode node, so we can start decoding. |
| 91 | + outputs = llm.generate(prompts, sampling_params) |
| 92 | + for output in outputs: |
| 93 | + prompt = output.prompt |
| 94 | + generated_text = output.outputs[0].text |
| 95 | + print(f"Prompt: {prompt!r}, Generated text: {generated_text!r}") |
| 96 | + |
| 97 | + |
| 98 | +if __name__ == "__main__": |
| 99 | + prefill_done = Event() |
| 100 | + prefill_process = Process(target=run_prefill, args=(prefill_done, )) |
| 101 | + decode_process = Process(target=run_decode, args=(prefill_done, )) |
| 102 | + |
| 103 | + # Start prefill node |
| 104 | + prefill_process.start() |
| 105 | + |
| 106 | + # Start decode node |
| 107 | + decode_process.start() |
| 108 | + |
| 109 | + # Terminate the prefill node when decode is finished |
| 110 | + decode_process.join() |
| 111 | + prefill_process.terminate() |
0 commit comments