Skip to content

Commit 107f5fc

Browse files
reidliu41reidliu41
and
reidliu41
authored
[Misc] refactor disaggregated-prefill-v1 example (#18474)
Signed-off-by: reidliu41 <[email protected]> Co-authored-by: reidliu41 <[email protected]>
1 parent 907f935 commit 107f5fc

File tree

3 files changed

+96
-71
lines changed

3 files changed

+96
-71
lines changed

examples/offline_inference/disaggregated-prefill-v1/README.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,5 +5,6 @@ This example contains scripts that demonstrate disaggregated prefill in the offl
55
## Files
66

77
- `run.sh` - A helper script that will run `prefill_example.py` and `decode_example.py` sequentially.
8+
- Make sure you are in the `examples/offline_inference/disaggregated-prefill-v1` directory before running `run.sh`.
89
- `prefill_example.py` - A script which performs prefill only, saving the KV state to the `local_storage` directory and the prompts to `output.txt`.
910
- `decode_example.py` - A script which performs decode only, loading the KV state from the `local_storage` directory and the prompts from `output.txt`.

examples/offline_inference/disaggregated-prefill-v1/decode_example.py

Lines changed: 44 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -3,35 +3,47 @@
33
from vllm import LLM, SamplingParams
44
from vllm.config import KVTransferConfig
55

6-
# Read prompts from output.txt
7-
prompts = []
8-
try:
9-
with open("output.txt") as f:
10-
for line in f:
11-
prompts.append(line.strip())
12-
print(f"Loaded {len(prompts)} prompts from output.txt")
13-
except FileNotFoundError:
14-
print("Error: output.txt file not found")
15-
exit(-1)
16-
17-
sampling_params = SamplingParams(temperature=0, top_p=0.95, max_tokens=10)
18-
19-
llm = LLM(model="meta-llama/Llama-3.2-1B-Instruct",
20-
enforce_eager=True,
21-
gpu_memory_utilization=0.8,
22-
max_num_batched_tokens=64,
23-
max_num_seqs=16,
24-
kv_transfer_config=KVTransferConfig(
25-
kv_connector="SharedStorageConnector",
26-
kv_role="kv_both",
27-
kv_connector_extra_config={
28-
"shared_storage_path": "local_storage"
29-
})) #, max_model_len=2048, max_num_batched_tokens=2048)
30-
31-
# 1ST generation (prefill instance)
32-
outputs = llm.generate(prompts, sampling_params)
33-
34-
for output in outputs:
35-
prompt = output.prompt
36-
generated_text = output.outputs[0].text
37-
print(f"Prompt: {prompt!r}, Generated text: {generated_text!r}")
6+
7+
def read_prompts():
8+
"""Read prompts from output.txt"""
9+
prompts = []
10+
try:
11+
with open("output.txt") as f:
12+
for line in f:
13+
prompts.append(line.strip())
14+
print(f"Loaded {len(prompts)} prompts from output.txt")
15+
return prompts
16+
except FileNotFoundError:
17+
print("Error: output.txt file not found")
18+
exit(-1)
19+
20+
21+
def main():
22+
prompts = read_prompts()
23+
sampling_params = SamplingParams(temperature=0, top_p=0.95, max_tokens=10)
24+
25+
llm = LLM(model="meta-llama/Llama-3.2-1B-Instruct",
26+
enforce_eager=True,
27+
gpu_memory_utilization=0.8,
28+
max_num_batched_tokens=64,
29+
max_num_seqs=16,
30+
kv_transfer_config=KVTransferConfig(
31+
kv_connector="SharedStorageConnector",
32+
kv_role="kv_both",
33+
kv_connector_extra_config={
34+
"shared_storage_path": "local_storage"
35+
})) #, max_model_len=2048, max_num_batched_tokens=2048)
36+
37+
# 1ST generation (prefill instance)
38+
outputs = llm.generate(prompts, sampling_params)
39+
40+
print("-" * 30)
41+
for output in outputs:
42+
prompt = output.prompt
43+
generated_text = output.outputs[0].text
44+
print(f"Prompt: {prompt!r}\nGenerated text: {generated_text!r}")
45+
print("-" * 30)
46+
47+
48+
if __name__ == "__main__":
49+
main()

examples/offline_inference/disaggregated-prefill-v1/prefill_example.py

Lines changed: 51 additions & 39 deletions
Original file line numberDiff line numberDiff line change
@@ -3,42 +3,54 @@
33
from vllm import LLM, SamplingParams
44
from vllm.config import KVTransferConfig
55

6-
context = "Hi " * 1000
7-
context2 = "Hey " * 500
8-
prompts = [
9-
context + "Hello, my name is",
10-
context + "The capital of France is",
11-
context2 + "Your name is",
12-
context2 + "The capital of China is",
13-
]
14-
15-
sampling_params = SamplingParams(temperature=0, top_p=0.95, max_tokens=1)
16-
17-
llm = LLM(model="meta-llama/Llama-3.2-1B-Instruct",
18-
enforce_eager=True,
19-
gpu_memory_utilization=0.8,
20-
kv_transfer_config=KVTransferConfig(
21-
kv_connector="SharedStorageConnector",
22-
kv_role="kv_both",
23-
kv_connector_extra_config={
24-
"shared_storage_path": "local_storage"
25-
})) #, max_model_len=2048, max_num_batched_tokens=2048)
26-
27-
# 1ST generation (prefill instance)
28-
outputs = llm.generate(
29-
prompts,
30-
sampling_params,
31-
)
32-
33-
new_prompts = []
34-
for output in outputs:
35-
prompt = output.prompt
36-
generated_text = output.outputs[0].text
37-
new_prompts.append(prompt + generated_text)
38-
print(f"Prompt: {prompt!r}, Generated text: {generated_text!r}")
39-
40-
# Write new_prompts to output.txt
41-
with open("output.txt", "w") as f:
42-
for prompt in new_prompts:
43-
f.write(prompt + "\n")
44-
print(f"Saved {len(new_prompts)} prompts to output.txt")
6+
7+
def read_prompts():
8+
context = "Hi " * 1000
9+
context2 = "Hey " * 500
10+
return [
11+
context + "Hello, my name is",
12+
context + "The capital of France is",
13+
context2 + "Your name is",
14+
context2 + "The capital of China is",
15+
]
16+
17+
18+
def main():
19+
prompts = read_prompts()
20+
21+
sampling_params = SamplingParams(temperature=0, top_p=0.95, max_tokens=1)
22+
23+
llm = LLM(model="meta-llama/Llama-3.2-1B-Instruct",
24+
enforce_eager=True,
25+
gpu_memory_utilization=0.8,
26+
kv_transfer_config=KVTransferConfig(
27+
kv_connector="SharedStorageConnector",
28+
kv_role="kv_both",
29+
kv_connector_extra_config={
30+
"shared_storage_path": "local_storage"
31+
})) #, max_model_len=2048, max_num_batched_tokens=2048)
32+
33+
# 1ST generation (prefill instance)
34+
outputs = llm.generate(
35+
prompts,
36+
sampling_params,
37+
)
38+
39+
new_prompts = []
40+
print("-" * 30)
41+
for output in outputs:
42+
prompt = output.prompt
43+
generated_text = output.outputs[0].text
44+
new_prompts.append(prompt + generated_text)
45+
print(f"Prompt: {prompt!r}\nGenerated text: {generated_text!r}")
46+
print("-" * 30)
47+
48+
# Write new_prompts to output.txt
49+
with open("output.txt", "w") as f:
50+
for prompt in new_prompts:
51+
f.write(prompt + "\n")
52+
print(f"Saved {len(new_prompts)} prompts to output.txt")
53+
54+
55+
if __name__ == "__main__":
56+
main()

0 commit comments

Comments
 (0)