-
-
Notifications
You must be signed in to change notification settings - Fork 7.8k
[Sampler] Vectorized sampler #820
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
Signed-off-by: Antoni Baum <[email protected]>
Signed-off-by: Antoni Baum <[email protected]>
Signed-off-by: Antoni Baum <[email protected]>
Signed-off-by: Antoni Baum <[email protected]>
Signed-off-by: Antoni Baum <[email protected]>
Signed-off-by: Antoni Baum <[email protected]>
Signed-off-by: Antoni Baum <[email protected]>
Signed-off-by: Antoni Baum <[email protected]>
Signed-off-by: Antoni Baum <[email protected]>
Updated benchmark to also apply top_p.
1.31x improvement |
I have also compared argmax outputs of the benchmark on the ShareGPT dataset, and they are identical between master and this PR (meaning that there is no impact on correctness). |
Signed-off-by: Antoni Baum <[email protected]>
Signed-off-by: Antoni Baum <[email protected]>
Using modified benchmark and llama-2-7b on A100-80:
x1.47 improvement As expected, the improvement is bigger the smaller the model (and thus more time is spent in sampling). |
benchmark llama70-b with A100-40 * 4, 512 context length and 128 generation, 500 requests
|
Signed-off-by: Antoni Baum <[email protected]>
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@Yard1, thanks for submitting the PR. Before I go deeper into the PR, could you please update the branch with the latest commit?
@@ -78,15 +73,19 @@ def run_vllm( | |||
) | |||
|
|||
# Add the requests to the engine. | |||
do_sample = False |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
A dumb question: What is this variable used for?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Allows us to alternate between greedy and stochastic sampling for requests
outputs = llm._run_engine(use_tqdm=True) | ||
end = time.time() | ||
with open("output.txt", "w") as f: | ||
for output in outputs: | ||
f.write(output.__repr__() + "\n") |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I guess this is to check if vLLM generated valid outputs, right? If so, I believe this ought to be done by our test code, rather than in the benchmarking code.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yeah, I will revert the changes to the benchmarking script later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
also consider add a test?
Signed-off-by: Antoni Baum <[email protected]>
Signed-off-by: Antoni Baum <[email protected]>
Signed-off-by: Antoni Baum <[email protected]>
Signed-off-by: Antoni Baum <[email protected]>
lgtm |
Signed-off-by: Antoni Baum <[email protected]>
@Yard1 I got an error when running
|
Signed-off-by: Antoni Baum <[email protected]>
@WoosukKwon Good catch! Just pushed a fix. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@Yard1 Thanks for the PR. It seems the performance improvement is quite significant! However, I'm a bit worried that this PR 1) includes unnecessary changes, and 2) complicates the sampler logic.
For 1), could you revert back the unnecessary changes? I pointed out a few places in the code.
For 2), do you have any idea to simplify the code? I think the complexity mainly comes from a lot of tensor/list indexing operations, and sort-gathers. Can you somehow reduce the use of them?
p = torch.tensor(top_ps, | ||
dtype=logits.dtype, | ||
device=logits.device) | ||
k = torch.tensor(top_ks, dtype=torch.int, device=logits.device) | ||
_apply_top_p_top_k_in_place(non_greedy_logits, p, k) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
What is the benefit of creating p
and k
outside the function? It seems the change has no effect.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
right now it's necessary for the p's and k's to be indexed correctly
Co-authored-by: Woosuk Kwon <[email protected]>
Co-authored-by: Woosuk Kwon <[email protected]>
@WoosukKwon Thanks for the feedback, let me see if I can reduce unnecessary changes. As for the compelxity, the current code is optimized for memory and speed which sacrifices simplicty. I am not sure if it's possible to reduce the complexity without impacting performance. I am happy to add more comments and change the code structure for more readability. |
Signed-off-by: Antoni Baum <[email protected]>
Closed in favor of #1048 |
Signed-off-by: Chendi Xue <[email protected]>
Signed-off-by: Chendi Xue <[email protected]>
Ready for initial review.
Using the modified throughput benchmark with
python benchmarks/benchmark_throughput.py --backend vllm --dataset "./ShareGPT_V3_unfiltered_cleaned_split.json" --model meta-llama/Llama-2-13b-chat-hf --tokenizer meta-llama/Llama-2-13b-chat-hf --num-prompts=1000
and one A100-80, I get the following results:which would be a 1.28x improvement.
The argmax outputs between this PR and master match. Random sampling is different, but that is expected (the text still makes sense and doesn't differ much from master).