diff --git a/examples/offline_inference/disaggregated-prefill-v1/README.md b/examples/offline_inference/disaggregated-prefill-v1/README.md index f708eb25383..9cbdb19820f 100644 --- a/examples/offline_inference/disaggregated-prefill-v1/README.md +++ b/examples/offline_inference/disaggregated-prefill-v1/README.md @@ -5,5 +5,6 @@ This example contains scripts that demonstrate disaggregated prefill in the offl ## Files - `run.sh` - A helper script that will run `prefill_example.py` and `decode_example.py` sequentially. + - Make sure you are in the `examples/offline_inference/disaggregated-prefill-v1` directory before running `run.sh`. - `prefill_example.py` - A script which performs prefill only, saving the KV state to the `local_storage` directory and the prompts to `output.txt`. - `decode_example.py` - A script which performs decode only, loading the KV state from the `local_storage` directory and the prompts from `output.txt`. diff --git a/examples/offline_inference/disaggregated-prefill-v1/decode_example.py b/examples/offline_inference/disaggregated-prefill-v1/decode_example.py index 11918f72fee..531c96f176a 100644 --- a/examples/offline_inference/disaggregated-prefill-v1/decode_example.py +++ b/examples/offline_inference/disaggregated-prefill-v1/decode_example.py @@ -3,35 +3,47 @@ from vllm import LLM, SamplingParams from vllm.config import KVTransferConfig -# Read prompts from output.txt -prompts = [] -try: - with open("output.txt") as f: - for line in f: - prompts.append(line.strip()) - print(f"Loaded {len(prompts)} prompts from output.txt") -except FileNotFoundError: - print("Error: output.txt file not found") - exit(-1) - -sampling_params = SamplingParams(temperature=0, top_p=0.95, max_tokens=10) - -llm = LLM(model="meta-llama/Llama-3.2-1B-Instruct", - enforce_eager=True, - gpu_memory_utilization=0.8, - max_num_batched_tokens=64, - max_num_seqs=16, - kv_transfer_config=KVTransferConfig( - kv_connector="SharedStorageConnector", - kv_role="kv_both", - kv_connector_extra_config={ - "shared_storage_path": "local_storage" - })) #, max_model_len=2048, max_num_batched_tokens=2048) - -# 1ST generation (prefill instance) -outputs = llm.generate(prompts, sampling_params) - -for output in outputs: - prompt = output.prompt - generated_text = output.outputs[0].text - print(f"Prompt: {prompt!r}, Generated text: {generated_text!r}") + +def read_prompts(): + """Read prompts from output.txt""" + prompts = [] + try: + with open("output.txt") as f: + for line in f: + prompts.append(line.strip()) + print(f"Loaded {len(prompts)} prompts from output.txt") + return prompts + except FileNotFoundError: + print("Error: output.txt file not found") + exit(-1) + + +def main(): + prompts = read_prompts() + sampling_params = SamplingParams(temperature=0, top_p=0.95, max_tokens=10) + + llm = LLM(model="meta-llama/Llama-3.2-1B-Instruct", + enforce_eager=True, + gpu_memory_utilization=0.8, + max_num_batched_tokens=64, + max_num_seqs=16, + kv_transfer_config=KVTransferConfig( + kv_connector="SharedStorageConnector", + kv_role="kv_both", + kv_connector_extra_config={ + "shared_storage_path": "local_storage" + })) #, max_model_len=2048, max_num_batched_tokens=2048) + + # 1ST generation (prefill instance) + outputs = llm.generate(prompts, sampling_params) + + print("-" * 30) + for output in outputs: + prompt = output.prompt + generated_text = output.outputs[0].text + print(f"Prompt: {prompt!r}\nGenerated text: {generated_text!r}") + print("-" * 30) + + +if __name__ == "__main__": + main() diff --git a/examples/offline_inference/disaggregated-prefill-v1/prefill_example.py b/examples/offline_inference/disaggregated-prefill-v1/prefill_example.py index 798128301e0..24b7b1d8fdb 100644 --- a/examples/offline_inference/disaggregated-prefill-v1/prefill_example.py +++ b/examples/offline_inference/disaggregated-prefill-v1/prefill_example.py @@ -3,42 +3,54 @@ from vllm import LLM, SamplingParams from vllm.config import KVTransferConfig -context = "Hi " * 1000 -context2 = "Hey " * 500 -prompts = [ - context + "Hello, my name is", - context + "The capital of France is", - context2 + "Your name is", - context2 + "The capital of China is", -] - -sampling_params = SamplingParams(temperature=0, top_p=0.95, max_tokens=1) - -llm = LLM(model="meta-llama/Llama-3.2-1B-Instruct", - enforce_eager=True, - gpu_memory_utilization=0.8, - kv_transfer_config=KVTransferConfig( - kv_connector="SharedStorageConnector", - kv_role="kv_both", - kv_connector_extra_config={ - "shared_storage_path": "local_storage" - })) #, max_model_len=2048, max_num_batched_tokens=2048) - -# 1ST generation (prefill instance) -outputs = llm.generate( - prompts, - sampling_params, -) - -new_prompts = [] -for output in outputs: - prompt = output.prompt - generated_text = output.outputs[0].text - new_prompts.append(prompt + generated_text) - print(f"Prompt: {prompt!r}, Generated text: {generated_text!r}") - -# Write new_prompts to output.txt -with open("output.txt", "w") as f: - for prompt in new_prompts: - f.write(prompt + "\n") -print(f"Saved {len(new_prompts)} prompts to output.txt") + +def read_prompts(): + context = "Hi " * 1000 + context2 = "Hey " * 500 + return [ + context + "Hello, my name is", + context + "The capital of France is", + context2 + "Your name is", + context2 + "The capital of China is", + ] + + +def main(): + prompts = read_prompts() + + sampling_params = SamplingParams(temperature=0, top_p=0.95, max_tokens=1) + + llm = LLM(model="meta-llama/Llama-3.2-1B-Instruct", + enforce_eager=True, + gpu_memory_utilization=0.8, + kv_transfer_config=KVTransferConfig( + kv_connector="SharedStorageConnector", + kv_role="kv_both", + kv_connector_extra_config={ + "shared_storage_path": "local_storage" + })) #, max_model_len=2048, max_num_batched_tokens=2048) + + # 1ST generation (prefill instance) + outputs = llm.generate( + prompts, + sampling_params, + ) + + new_prompts = [] + print("-" * 30) + for output in outputs: + prompt = output.prompt + generated_text = output.outputs[0].text + new_prompts.append(prompt + generated_text) + print(f"Prompt: {prompt!r}\nGenerated text: {generated_text!r}") + print("-" * 30) + + # Write new_prompts to output.txt + with open("output.txt", "w") as f: + for prompt in new_prompts: + f.write(prompt + "\n") + print(f"Saved {len(new_prompts)} prompts to output.txt") + + +if __name__ == "__main__": + main()