Skip to content

[Misc] refactor disaggregated-prefill-v1 example #18474

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -5,5 +5,6 @@ This example contains scripts that demonstrate disaggregated prefill in the offl
## Files

- `run.sh` - A helper script that will run `prefill_example.py` and `decode_example.py` sequentially.
- Make sure you are in the `examples/offline_inference/disaggregated-prefill-v1` directory before running `run.sh`.
- `prefill_example.py` - A script which performs prefill only, saving the KV state to the `local_storage` directory and the prompts to `output.txt`.
- `decode_example.py` - A script which performs decode only, loading the KV state from the `local_storage` directory and the prompts from `output.txt`.
Original file line number Diff line number Diff line change
Expand Up @@ -3,35 +3,47 @@
from vllm import LLM, SamplingParams
from vllm.config import KVTransferConfig

# Read prompts from output.txt
prompts = []
try:
with open("output.txt") as f:
for line in f:
prompts.append(line.strip())
print(f"Loaded {len(prompts)} prompts from output.txt")
except FileNotFoundError:
print("Error: output.txt file not found")
exit(-1)

sampling_params = SamplingParams(temperature=0, top_p=0.95, max_tokens=10)

llm = LLM(model="meta-llama/Llama-3.2-1B-Instruct",
enforce_eager=True,
gpu_memory_utilization=0.8,
max_num_batched_tokens=64,
max_num_seqs=16,
kv_transfer_config=KVTransferConfig(
kv_connector="SharedStorageConnector",
kv_role="kv_both",
kv_connector_extra_config={
"shared_storage_path": "local_storage"
})) #, max_model_len=2048, max_num_batched_tokens=2048)

# 1ST generation (prefill instance)
outputs = llm.generate(prompts, sampling_params)

for output in outputs:
prompt = output.prompt
generated_text = output.outputs[0].text
print(f"Prompt: {prompt!r}, Generated text: {generated_text!r}")

def read_prompts():
"""Read prompts from output.txt"""
prompts = []
try:
with open("output.txt") as f:
for line in f:
prompts.append(line.strip())
print(f"Loaded {len(prompts)} prompts from output.txt")
return prompts
except FileNotFoundError:
print("Error: output.txt file not found")
exit(-1)


def main():
prompts = read_prompts()
sampling_params = SamplingParams(temperature=0, top_p=0.95, max_tokens=10)

llm = LLM(model="meta-llama/Llama-3.2-1B-Instruct",
enforce_eager=True,
gpu_memory_utilization=0.8,
max_num_batched_tokens=64,
max_num_seqs=16,
kv_transfer_config=KVTransferConfig(
kv_connector="SharedStorageConnector",
kv_role="kv_both",
kv_connector_extra_config={
"shared_storage_path": "local_storage"
})) #, max_model_len=2048, max_num_batched_tokens=2048)

# 1ST generation (prefill instance)
outputs = llm.generate(prompts, sampling_params)

print("-" * 30)
for output in outputs:
prompt = output.prompt
generated_text = output.outputs[0].text
print(f"Prompt: {prompt!r}\nGenerated text: {generated_text!r}")
print("-" * 30)


if __name__ == "__main__":
main()
Original file line number Diff line number Diff line change
Expand Up @@ -3,42 +3,54 @@
from vllm import LLM, SamplingParams
from vllm.config import KVTransferConfig

context = "Hi " * 1000
context2 = "Hey " * 500
prompts = [
context + "Hello, my name is",
context + "The capital of France is",
context2 + "Your name is",
context2 + "The capital of China is",
]

sampling_params = SamplingParams(temperature=0, top_p=0.95, max_tokens=1)

llm = LLM(model="meta-llama/Llama-3.2-1B-Instruct",
enforce_eager=True,
gpu_memory_utilization=0.8,
kv_transfer_config=KVTransferConfig(
kv_connector="SharedStorageConnector",
kv_role="kv_both",
kv_connector_extra_config={
"shared_storage_path": "local_storage"
})) #, max_model_len=2048, max_num_batched_tokens=2048)

# 1ST generation (prefill instance)
outputs = llm.generate(
prompts,
sampling_params,
)

new_prompts = []
for output in outputs:
prompt = output.prompt
generated_text = output.outputs[0].text
new_prompts.append(prompt + generated_text)
print(f"Prompt: {prompt!r}, Generated text: {generated_text!r}")

# Write new_prompts to output.txt
with open("output.txt", "w") as f:
for prompt in new_prompts:
f.write(prompt + "\n")
print(f"Saved {len(new_prompts)} prompts to output.txt")

def read_prompts():
context = "Hi " * 1000
context2 = "Hey " * 500
return [
context + "Hello, my name is",
context + "The capital of France is",
context2 + "Your name is",
context2 + "The capital of China is",
]


def main():
prompts = read_prompts()

sampling_params = SamplingParams(temperature=0, top_p=0.95, max_tokens=1)

llm = LLM(model="meta-llama/Llama-3.2-1B-Instruct",
enforce_eager=True,
gpu_memory_utilization=0.8,
kv_transfer_config=KVTransferConfig(
kv_connector="SharedStorageConnector",
kv_role="kv_both",
kv_connector_extra_config={
"shared_storage_path": "local_storage"
})) #, max_model_len=2048, max_num_batched_tokens=2048)

# 1ST generation (prefill instance)
outputs = llm.generate(
prompts,
sampling_params,
)

new_prompts = []
print("-" * 30)
for output in outputs:
prompt = output.prompt
generated_text = output.outputs[0].text
new_prompts.append(prompt + generated_text)
print(f"Prompt: {prompt!r}\nGenerated text: {generated_text!r}")
print("-" * 30)

# Write new_prompts to output.txt
with open("output.txt", "w") as f:
for prompt in new_prompts:
f.write(prompt + "\n")
print(f"Saved {len(new_prompts)} prompts to output.txt")


if __name__ == "__main__":
main()