Skip to content

[Sampler] Vectorized sampler #820

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 33 commits into from
Closed

Conversation

Yard1
Copy link
Collaborator

@Yard1 Yard1 commented Aug 22, 2023

Ready for initial review.

Using the modified throughput benchmark with python benchmarks/benchmark_throughput.py --backend vllm --dataset "./ShareGPT_V3_unfiltered_cleaned_split.json" --model meta-llama/Llama-2-13b-chat-hf --tokenizer meta-llama/Llama-2-13b-chat-hf --num-prompts=1000 and one A100-80, I get the following results:

baseline: Throughput: 3.70 requests/s, 1771.58 tokens/s
this PR: Throughput: 4.77 requests/s, 2280.04 tokens/s

which would be a 1.28x improvement.

The argmax outputs between this PR and master match. Random sampling is different, but that is expected (the text still makes sense and doesn't differ much from master).

Yard1 added 9 commits August 19, 2023 16:11
Signed-off-by: Antoni Baum <[email protected]>
Signed-off-by: Antoni Baum <[email protected]>
Signed-off-by: Antoni Baum <[email protected]>
Signed-off-by: Antoni Baum <[email protected]>
Signed-off-by: Antoni Baum <[email protected]>
Signed-off-by: Antoni Baum <[email protected]>
Signed-off-by: Antoni Baum <[email protected]>
Signed-off-by: Antoni Baum <[email protected]>
Signed-off-by: Antoni Baum <[email protected]>
@Yard1
Copy link
Collaborator Author

Yard1 commented Aug 23, 2023

Updated benchmark to also apply top_p.

baseline: Throughput: 3.58 requests/s, 1710.60 tokens/s
PR: Throughput: 4.69 requests/s, 2243.01 tokens/s

1.31x improvement

@Yard1
Copy link
Collaborator Author

Yard1 commented Aug 23, 2023

I have also compared argmax outputs of the benchmark on the ShareGPT dataset, and they are identical between master and this PR (meaning that there is no impact on correctness).

@Yard1 Yard1 marked this pull request as ready for review August 23, 2023 17:39
@Yard1 Yard1 changed the title [WIP][Sampler] Vectorized sampler [Sampler] Vectorized sampler Aug 23, 2023
@Yard1
Copy link
Collaborator Author

Yard1 commented Aug 23, 2023

Using modified benchmark and llama-2-7b on A100-80:

baseline: Throughput: 4.67 requests/s, 2233.30 tokens/s
PR: Throughput: 6.84 requests/s, 3272.51 tokens/s

x1.47 improvement

As expected, the improvement is bigger the smaller the model (and thus more time is spent in sampling).

@scv119
Copy link
Contributor

scv119 commented Aug 23, 2023

benchmark llama70-b with A100-40 * 4, 512 context length and 128 generation, 500 requests

baseline: dur_s 133.22 tokens_per_s 1807.24 qps 3.75 successful_responses 500
PR: dur_s 114.26 tokens_per_s 2108.07 qps 4.38 successful_responses 500 

@WoosukKwon WoosukKwon self-requested a review September 4, 2023 16:00
Copy link
Collaborator

@WoosukKwon WoosukKwon left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@Yard1, thanks for submitting the PR. Before I go deeper into the PR, could you please update the branch with the latest commit?

@@ -78,15 +73,19 @@ def run_vllm(
)

# Add the requests to the engine.
do_sample = False
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

A dumb question: What is this variable used for?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Allows us to alternate between greedy and stochastic sampling for requests

Comment on lines +98 to +102
outputs = llm._run_engine(use_tqdm=True)
end = time.time()
with open("output.txt", "w") as f:
for output in outputs:
f.write(output.__repr__() + "\n")
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I guess this is to check if vLLM generated valid outputs, right? If so, I believe this ought to be done by our test code, rather than in the benchmarking code.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah, I will revert the changes to the benchmarking script later.

Copy link
Contributor

@scv119 scv119 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

also consider add a test?

@Yard1 Yard1 requested a review from scv119 September 8, 2023 00:52
Signed-off-by: Antoni Baum <[email protected]>
@scv119
Copy link
Contributor

scv119 commented Sep 8, 2023

lgtm

@WoosukKwon
Copy link
Collaborator

@Yard1 I got an error when running examples/llm_engine_example.py:

File "/home/workspace/vllm/vllm/model_executor/layers/sampler.py", line 98, in forward
    assert len(temperatures) == non_greedy_logits.shape[0]
AssertionError

Signed-off-by: Antoni Baum <[email protected]>
@Yard1
Copy link
Collaborator Author

Yard1 commented Sep 12, 2023

@WoosukKwon Good catch! Just pushed a fix.

Copy link
Collaborator

@WoosukKwon WoosukKwon left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@Yard1 Thanks for the PR. It seems the performance improvement is quite significant! However, I'm a bit worried that this PR 1) includes unnecessary changes, and 2) complicates the sampler logic.

For 1), could you revert back the unnecessary changes? I pointed out a few places in the code.
For 2), do you have any idea to simplify the code? I think the complexity mainly comes from a lot of tensor/list indexing operations, and sort-gathers. Can you somehow reduce the use of them?

Comment on lines 114 to 118
p = torch.tensor(top_ps,
dtype=logits.dtype,
device=logits.device)
k = torch.tensor(top_ks, dtype=torch.int, device=logits.device)
_apply_top_p_top_k_in_place(non_greedy_logits, p, k)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What is the benefit of creating p and k outside the function? It seems the change has no effect.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

right now it's necessary for the p's and k's to be indexed correctly

@Yard1
Copy link
Collaborator Author

Yard1 commented Sep 12, 2023

@WoosukKwon Thanks for the feedback, let me see if I can reduce unnecessary changes.

As for the compelxity, the current code is optimized for memory and speed which sacrifices simplicty. I am not sure if it's possible to reduce the complexity without impacting performance. I am happy to add more comments and change the code structure for more readability.

@Yard1
Copy link
Collaborator Author

Yard1 commented Sep 22, 2023

Closed in favor of #1048

@Yard1 Yard1 closed this Sep 22, 2023
jikunshang pushed a commit to jikunshang/vllm that referenced this pull request Feb 20, 2025
jikunshang pushed a commit to jikunshang/vllm that referenced this pull request Mar 14, 2025
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants