Skip to content

[Sampler] Vectorized sampling (simplified) #1048

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 27 commits into from
Sep 23, 2023
Merged

Conversation

zhuohan123
Copy link
Member

@zhuohan123 zhuohan123 commented Sep 15, 2023

Simplified version of #820. This version does not have any complicated torch operations.

cc @Yard1

TODOs:

@WoosukKwon
Copy link
Collaborator

@zhuohan123 Could you compare the performance against the original PR? Thanks.

@zhuohan123
Copy link
Member Author

7B LLaMA on 80GB-A100 throughput before this PR:

~/vllm/vllm/benchmarks$ python benchmark_throughput.py --backend vllm --model huggyllama/llama-7b --dataset ../../data/ShareGPT_V3_unfiltered_cleaned_split.json --num-prompts 2000
Namespace(backend='vllm', dataset='../../data/ShareGPT_V3_unfiltered_cleaned_split.json', model='huggyllama/llama-7b', tokenizer='huggyllama/llama-7b', tensor_parallel_size=1, n=1, use_beam_search=False, num_prompts=2000, seed=0, hf_max_batch_size=None, trust_remote_code=False)
INFO 09-15 06:02:38 tokenizer.py:30] For some LLaMA V1 models, initializing the fast tokenizer may take a long time. To reduce the initialization time, consider using 'hf-internal-testing/llama-tokenizer' instead of the original tokenizer.
INFO 09-15 06:03:23 llm_engine.py:72] Initializing an LLM engine with config: model='huggyllama/llama-7b', tokenizer='huggyllama/llama-7b', tokenizer_mode=auto, revision=None, trust_remote_code=False, dtype=torch.float16, download_dir=None, load_format=auto, tensor_parallel_size=1, seed=0)
INFO 09-15 06:03:23 tokenizer.py:30] For some LLaMA V1 models, initializing the fast tokenizer may take a long time. To reduce the initialization time, consider using 'hf-internal-testing/llama-tokenizer' instead of the original tokenizer.
INFO 09-15 06:03:29 llm_engine.py:201] # GPU blocks: 7448, # CPU blocks: 512
Processed prompts: 100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 2000/2000 [05:24<00:00,  6.16it/s]
Throughput: 6.16 requests/s, 2978.86 tokens/s

After this PR:

~/vllm/vllm/benchmarks$ python benchmark_throughput.py --backend vllm --model huggyllama/llama-7b --dataset ../../data/ShareGPT_V3_unfiltered_cleaned_split.json --num-prompts 2000
Namespace(backend='vllm', dataset='../../data/ShareGPT_V3_unfiltered_cleaned_split.json', model='huggyllama/llama-7b', tokenizer='huggyllama/llama-7b', tensor_parallel_size=1, n=1, use_beam_search=False, num_prompts=2000, seed=0, hf_max_batch_size=None, trust_remote_code=False)
INFO 09-15 05:54:21 tokenizer.py:30] For some LLaMA V1 models, initializing the fast tokenizer may take a long time. To reduce the initialization time, consider using 'hf-internal-testing/llama-tokenizer' instead of the original tokenizer.
INFO 09-15 05:55:04 llm_engine.py:72] Initializing an LLM engine with config: model='huggyllama/llama-7b', tokenizer='huggyllama/llama-7b', tokenizer_mode=auto, revision=None, trust_remote_code=False, dtype=torch.float16, download_dir=None, load_format=auto, tensor_parallel_size=1, seed=0)
INFO 09-15 05:55:04 tokenizer.py:30] For some LLaMA V1 models, initializing the fast tokenizer may take a long time. To reduce the initialization time, consider using 'hf-internal-testing/llama-tokenizer' instead of the original tokenizer.
INFO 09-15 05:55:10 llm_engine.py:201] # GPU blocks: 7448, # CPU blocks: 512
Processed prompts: 100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 2000/2000 [04:05<00:00,  8.15it/s]
Throughput: 8.15 requests/s, 3942.01 tokens/s

@zhuohan123
Copy link
Member Author

zhuohan123 commented Sep 15, 2023

For the case in #820, this PR:

(vllm) zhuohan@zhuohan-1:~/vllm/vllm/benchmarks$ python benchmark_throughput.py --backend vllm --model meta-llama/Llama-2-13b-chat-hf --dataset ../../data/ShareGPT_V3_unfiltered_cleaned_split.json --num-prompts 1000
Namespace(backend='vllm', dataset='../../data/ShareGPT_V3_unfiltered_cleaned_split.json', model='meta-llama/Llama-2-13b-chat-hf', tokenizer='meta-llama/Llama-2-13b-chat-hf', tensor_parallel_size=1, n=1, use_beam_search=False, num_prompts=1000, seed=0, hf_max_batch_size=None, trust_remote_code=False)
INFO 09-15 06:28:41 tokenizer.py:30] For some LLaMA V1 models, initializing the fast tokenizer may take a long time. To reduce the initialization time, consider using 'hf-internal-testing/llama-tokenizer' instead of the original tokenizer.
INFO 09-15 06:29:27 llm_engine.py:72] Initializing an LLM engine with config: model='meta-llama/Llama-2-13b-chat-hf', tokenizer='meta-llama/Llama-2-13b-chat-hf', tokenizer_mode=auto, revision=None, trust_remote_code=False, dtype=torch.float16, download_dir=None, load_format=auto, tensor_parallel_size=1, seed=0)
INFO 09-15 06:29:27 tokenizer.py:30] For some LLaMA V1 models, initializing the fast tokenizer may take a long time. To reduce the initialization time, consider using 'hf-internal-testing/llama-tokenizer' instead of the original tokenizer.
INFO 09-15 06:29:38 llm_engine.py:201] # GPU blocks: 3804, # CPU blocks: 327
Processed prompts: 100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 1000/1000 [03:13<00:00,  5.18it/s]
Throughput: 5.17 requests/s, 2474.69 tokens/s

Current main on my machine:

~/vllm/vllm/benchmarks$ python benchmark_throughput.py --backend vllm --model meta-llama/Llama-2-13b-chat-hf --dataset ../../data/ShareGPT_V3_unfiltered_cleaned_split.json --num-prompts 1000
Namespace(backend='vllm', dataset='../../data/ShareGPT_V3_unfiltered_cleaned_split.json', model='meta-llama/Llama-2-13b-chat-hf', tokenizer='meta-llama/Llama-2-13b-chat-hf', tensor_parallel_size=1, n=1, use_beam_search=False, num_prompts=1000, seed=0, hf_max_batch_size=None, trust_remote_code=False)
INFO 09-15 06:22:39 tokenizer.py:30] For some LLaMA V1 models, initializing the fast tokenizer may take a long time. To reduce the initialization time, consider using 'hf-internal-testing/llama-tokenizer' instead of the original tokenizer.
Downloading (…)okenizer_config.json: 100%|██████████████████████████████████████████████████████████████████████████████████████████████████| 776/776 [00:00<00:00, 5.22MB/s]
Downloading tokenizer.model: 100%|████████████████████████████████████████████████████████████████████████████████████████████████████████| 500k/500k [00:00<00:00, 82.4MB/s]
Downloading (…)/main/tokenizer.json: 100%|██████████████████████████████████████████████████████████████████████████████████████████████| 1.84M/1.84M [00:00<00:00, 55.9MB/s]
Downloading (…)cial_tokens_map.json: 100%|██████████████████████████████████████████████████████████████████████████████████████████████████| 414/414 [00:00<00:00, 3.06MB/s]
Downloading (…)lve/main/config.json: 100%|██████████████████████████████████████████████████████████████████████████████████████████████████| 587/587 [00:00<00:00, 3.98MB/s]
INFO 09-15 06:23:22 llm_engine.py:72] Initializing an LLM engine with config: model='meta-llama/Llama-2-13b-chat-hf', tokenizer='meta-llama/Llama-2-13b-chat-hf', tokenizer_mode=auto, revision=None, trust_remote_code=False, dtype=torch.float16, download_dir=None, load_format=auto, tensor_parallel_size=1, seed=0)
INFO 09-15 06:23:22 tokenizer.py:30] For some LLaMA V1 models, initializing the fast tokenizer may take a long time. To reduce the initialization time, consider using 'hf-internal-testing/llama-tokenizer' instead of the original tokenizer.
Downloading (…)of-00003.safetensors: 100%|███████████████████████████████████████████████████████████████████████████████████████████████| 6.18G/6.18G [00:13<00:00, 456MB/s]
Downloading (…)of-00003.safetensors: 100%|███████████████████████████████████████████████████████████████████████████████████████████████| 9.90G/9.90G [00:22<00:00, 447MB/s]
Downloading (…)of-00003.safetensors: 100%|███████████████████████████████████████████████████████████████████████████████████████████████| 9.95G/9.95G [00:22<00:00, 435MB/s]
INFO 09-15 06:23:55 llm_engine.py:201] # GPU blocks: 3804, # CPU blocks: 327████████████████████████████████████████████████████████▋    | 9.50G/9.95G [00:22<00:00, 521MB/s]
Processed prompts: 100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 1000/1000 [03:54<00:00,  4.27it/s]
Throughput: 4.27 requests/s, 2040.46 tokens/s

#820 (merged with main, with main's throughput benchmark script) on my machine:

(vllm) zhuohan@zhuohan-1:~/vllm/vllm/benchmarks$ python benchmark_throughput.py --backend vllm --model meta-llama/Llama-2-13b-chat-hf --dataset ../../data/ShareGPT_V3_unfiltered_cleaned_split.json --num-prompts 1000
Namespace(backend='vllm', dataset='../../data/ShareGPT_V3_unfiltered_cleaned_split.json', model='meta-llama/Llama-2-13b-chat-hf', tokenizer='meta-llama/Llama-2-13b-chat-hf', tensor_parallel_size=1, n=1, use_beam_search=False, num_prompts=1000, seed=0, hf_max_batch_size=None, trust_remote_code=False)
INFO 09-15 06:49:42 tokenizer.py:30] For some LLaMA V1 models, initializing the fast tokenizer may take a long time. To reduce the initialization time, consider using 'hf-internal-testing/llama-tokenizer' instead of the original tokenizer.
INFO 09-15 06:50:27 llm_engine.py:72] Initializing an LLM engine with config: model='meta-llama/Llama-2-13b-chat-hf', tokenizer='meta-llama/Llama-2-13b-chat-hf', tokenizer_mode=auto, revision=None, trust_remote_code=False, dtype=torch.float16, download_dir=None, load_format=auto, tensor_parallel_size=1, seed=0)
INFO 09-15 06:50:27 tokenizer.py:30] For some LLaMA V1 models, initializing the fast tokenizer may take a long time. To reduce the initialization time, consider using 'hf-internal-testing/llama-tokenizer' instead of the original tokenizer.
INFO 09-15 06:50:36 llm_engine.py:201] # GPU blocks: 3804, # CPU blocks: 327
Processed prompts: 100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 1000/1000 [03:05<00:00,  5.40it/s]
Throughput: 5.39 requests/s, 2579.86 tokens/s

#820 with benchmark script in #820

(vllm) zhuohan@zhuohan-1:~/vllm/vllm/benchmarks$ python benchmark_throughput.py --backend vllm --model meta-llama/Llama-2-13b-chat-hf --dataset ../../data/ShareGPT_V3_unfiltered_cleaned_split.json --num-prompts 1000
Namespace(backend='vllm', dataset='../../data/ShareGPT_V3_unfiltered_cleaned_split.json', model='meta-llama/Llama-2-13b-chat-hf', tokenizer='meta-llama/Llama-2-13b-chat-hf', tensor_parallel_size=1, n=1, use_beam_search=False, num_prompts=1000, seed=0, hf_max_batch_size=None, trust_remote_code=False)
INFO 09-15 06:33:56 tokenizer.py:30] For some LLaMA V1 models, initializing the fast tokenizer may take a long time. To reduce the initialization time, consider using 'hf-internal-testing/llama-tokenizer' instead of the original tokenizer.
INFO 09-15 06:34:38 llm_engine.py:72] Initializing an LLM engine with config: model='meta-llama/Llama-2-13b-chat-hf', tokenizer='meta-llama/Llama-2-13b-chat-hf', tokenizer_mode=auto, revision=None, trust_remote_code=False, dtype=torch.float16, download_dir=None, load_format=auto, tensor_parallel_size=1, seed=0)
INFO 09-15 06:34:38 tokenizer.py:30] For some LLaMA V1 models, initializing the fast tokenizer may take a long time. To reduce the initialization time, consider using 'hf-internal-testing/llama-tokenizer' instead of the original tokenizer.
INFO 09-15 06:34:47 llm_engine.py:201] # GPU blocks: 3804, # CPU blocks: 327
Processed prompts: 100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 1000/1000 [03:31<00:00,  4.73it/s]
Throughput: 4.73 requests/s, 2259.64 tokens/s

@zhuohan123
Copy link
Member Author

@WoosukKwon This PR is ready for review.

@WoosukKwon
Copy link
Collaborator

@zhuohan123 I've run python examples/llm_engine_example.py and found that the outputs do not look good.

Current main:

RequestOutput(request_id=2, prompt='What is the meaning of life?', prompt_token_ids=[2, 2264, 16, 5, 3099, 9, 301, 116], outputs=[CompletionOutput(index=1, text='\nThe story of the origin of life has been a mystery for centuries. It', token_ids=[50118, 133, 527, 9, 5, 9813, 9, 301, 34, 57, 10, 9001, 13, 11505, 4, 85], cumulative_logprob=-28.09819699637592, logprobs={}, finish_reason=length), CompletionOutput(index=0, text='\n\nLife can be a gift. It can provide you with a gift and', token_ids=[50118, 50118, 12116, 64, 28, 10, 4085, 4, 85, 64, 694, 47, 19, 10, 4085, 8], cumulative_logprob=-29.639200404286385, logprobs={}, finish_reason=length)], finished=True)
RequestOutput(request_id=3, prompt='It is only with the heart that one can see rightly', prompt_token_ids=[2, 243, 16, 129, 19, 5, 1144, 14, 65, 64, 192, 19765], outputs=[CompletionOutput(index=1, text=' what is happening in the world.\n\nIt is only with the heart that', token_ids=[99, 16, 2909, 11, 5, 232, 4, 50118, 50118, 243, 16, 129, 19, 5, 1144, 14], cumulative_logprob=-17.742405578494072, logprobs={}, finish_reason=length), CompletionOutput(index=2, text=' what is happening in the world.\n\nIt is only with the hearts that', token_ids=[99, 16, 2909, 11, 5, 232, 4, 50118, 50118, 243, 16, 129, 19, 5, 7754, 14], cumulative_logprob=-20.42438167333603, logprobs={}, finish_reason=length), CompletionOutput(index=0, text=' what is happening in the world.\n\nIt is only with the soul that', token_ids=[99, 16, 2909, 11, 5, 232, 4, 50118, 50118, 243, 16, 129, 19, 5, 7047, 14], cumulative_logprob=-20.607836961746216, logprobs={}, finish_reason=length)], finished=True)

This PR:

RequestOutput(request_id=2, prompt='What is the meaning of life?', prompt_token_ids=[2, 2264, 16, 5, 3099, 9, 301, 116], outputs=[CompletionOutput(index=6, text="I don't know. I don't know. I don't know. I", token_ids=[100, 218, 75, 216, 4, 38, 218, 75, 216, 4, 38, 218, 75, 216, 4, 38], cumulative_logprob=-75585.00610077381, logprobs={}, finish_reason=length), CompletionOutput(index=7, text="I don't know. I don't know. I don't know.\n", token_ids=[100, 218, 75, 216, 4, 38, 218, 75, 216, 4, 38, 218, 75, 216, 4, 50118], cumulative_logprob=-75586.06078827381, logprobs={}, finish_reason=length)], finished=True)
RequestOutput(request_id=3, prompt='It is only with the heart that one can see rightly', prompt_token_ids=[2, 243, 16, 129, 19, 5, 1144, 14, 65, 64, 192, 19765], outputs=[CompletionOutput(index=0, text=' rightly', token_ids=[2], cumulative_logprob=-3.3043150901794434, logprobs={}, finish_reason=stop), CompletionOutput(index=1, text='I am the one who is the one who is the one who is the one', token_ids=[100, 524, 5, 65, 54, 16, 5, 65, 54, 16, 5, 65, 54, 16, 5, 65], cumulative_logprob=-111459.22226893902, logprobs={}, finish_reason=length), CompletionOutput(index=2, text='I am the one who is the one who is the one who is the One', token_ids=[100, 524, 5, 65, 54, 16, 5, 65, 54, 16, 5, 65, 54, 16, 5, 509], cumulative_logprob=-111461.84726893902, logprobs={}, finish_reason=length)], finished=True)

Could you take a look at this problem?

@WoosukKwon
Copy link
Collaborator

@zhuohan123 Please let me know if the PR is ready for review.

@zhuohan123
Copy link
Member Author

zhuohan123 commented Sep 17, 2023

@WoosukKwon Just fixed the correctness issue. Also slightly boost the performance:

~/vllm/vllm/benchmarks$ python benchmark_throughput.py --backend vllm --model meta-llama/Llama-2-13b-chat-hf --dataset ../../data/ShareGPT_V3_unfiltered_cleaned_split.json --num-prompts 1000
Namespace(backend='vllm', dataset='../../data/ShareGPT_V3_unfiltered_cleaned_split.json', model='meta-llama/Llama-2-13b-chat-hf', tokenizer='meta-llama/Llama-2-13b-chat-hf', tensor_parallel_size=1, n=1, use_beam_search=False, num_prompts=1000, seed=0, hf_max_batch_size=None, trust_remote_code=False)
INFO 09-17 03:11:30 tokenizer.py:30] For some LLaMA V1 models, initializing the fast tokenizer may take a long time. To reduce the initialization time, consider using 'hf-internal-testing/llama-tokenizer' instead of the original tokenizer.
INFO 09-17 03:12:13 llm_engine.py:72] Initializing an LLM engine with config: model='meta-llama/Llama-2-13b-chat-hf', tokenizer='meta-llama/Llama-2-13b-chat-hf', tokenizer_mode=auto, revision=None, trust_remote_code=False, dtype=torch.float16, download_dir=None, load_format=auto, tensor_parallel_size=1, seed=0)
INFO 09-17 03:12:13 tokenizer.py:30] For some LLaMA V1 models, initializing the fast tokenizer may take a long time. To reduce the initialization time, consider using 'hf-internal-testing/llama-tokenizer' instead of the original tokenizer.
INFO 09-17 03:12:22 llm_engine.py:201] # GPU blocks: 3804, # CPU blocks: 327
Processed prompts: 100%|████████████████████████████████████████████████████████████████████████████████████████████| 1000/1000 [03:09<00:00,  5.27it/s]
Throughput: 5.27 requests/s, 2521.92 tokens/s

The PR is ready for review now. @WoosukKwon

@WoosukKwon
Copy link
Collaborator

I've tried llm_engine_example.py again and got:

RequestOutput(request_id=2, prompt='What is the meaning of life?', ...
CompletionOutput(index=4, text='\n\nLife is a journey. It is a journey of life. It is', ...)
CompletionOutput(index=5, text='\n\nLife is a journey. It is a journey of life. It is', ...)

Is this just coincidence? The two outputs were different from each other when I used current main.

@zhuohan123
Copy link
Member Author

I've tried llm_engine_example.py again and got:

RequestOutput(request_id=2, prompt='What is the meaning of life?', ...
CompletionOutput(index=4, text='\n\nLife is a journey. It is a journey of life. It is', ...)
CompletionOutput(index=5, text='\n\nLife is a journey. It is a journey of life. It is', ...)

Is this just coincidence? The two outputs were different from each other when I used current main.

Thanks for catching this! Just fixed. The result should match main now.

@zhuohan123
Copy link
Member Author

Further optimized performance. Now this PR is faster than #820:

~/vllm/vllm/benchmarks$ python benchmark_throughput.py --backend vllm --model meta-llama/Llama-2-13b-chat-hf --dataset ../../data/ShareGPT_V3_unfiltered_cleaned_split.json --num-prompts 1000
Namespace(backend='vllm', dataset='../../data/ShareGPT_V3_unfiltered_cleaned_split.json', model='meta-llama/Llama-2-13b-chat-hf', tokenizer='meta-llama/Llama-2-13b-chat-hf', tensor_parallel_size=1, n=1, use_beam_search=False, num_prompts=1000, seed=0, hf_max_batch_size=None, trust_remote_code=False)
INFO 09-17 06:05:34 tokenizer.py:30] For some LLaMA V1 models, initializing the fast tokenizer may take a long time. To reduce the initialization time, consider using 'hf-internal-testing/llama-tokenizer' instead of the original tokenizer.
INFO 09-17 06:06:14 llm_engine.py:72] Initializing an LLM engine with config: model='meta-llama/Llama-2-13b-chat-hf', tokenizer='meta-llama/Llama-2-13b-chat-hf', tokenizer_mode=auto, revision=None, trust_remote_code=False, dtype=torch.float16, download_dir=None, load_format=auto, tensor_parallel_size=1, seed=0)
INFO 09-17 06:06:14 tokenizer.py:30] For some LLaMA V1 models, initializing the fast tokenizer may take a long time. To reduce the initialization time, consider using 'hf-internal-testing/llama-tokenizer' instead of the original tokenizer.
INFO 09-17 06:06:23 llm_engine.py:201] # GPU blocks: 3804, # CPU blocks: 327
Processed prompts: 100%|████████████████████████████████████████████████████████████████████████████████████████████| 1000/1000 [03:03<00:00,  5.45it/s]
Throughput: 5.45 requests/s, 2604.18 tokens/s

@zhuohan123
Copy link
Member Author

7B LLaMA on 80GB-A100 throughput before this PR:

~/vllm/vllm/benchmarks$ python benchmark_throughput.py --backend vllm --model huggyllama/llama-7b --dataset ../../data/ShareGPT_V3_unfiltered_cleaned_split.json --num-prompts 2000
Namespace(backend='vllm', dataset='../../data/ShareGPT_V3_unfiltered_cleaned_split.json', model='huggyllama/llama-7b', tokenizer='huggyllama/llama-7b', tensor_parallel_size=1, n=1, use_beam_search=False, num_prompts=2000, seed=0, hf_max_batch_size=None, trust_remote_code=False)
INFO 09-15 06:02:38 tokenizer.py:30] For some LLaMA V1 models, initializing the fast tokenizer may take a long time. To reduce the initialization time, consider using 'hf-internal-testing/llama-tokenizer' instead of the original tokenizer.
INFO 09-15 06:03:23 llm_engine.py:72] Initializing an LLM engine with config: model='huggyllama/llama-7b', tokenizer='huggyllama/llama-7b', tokenizer_mode=auto, revision=None, trust_remote_code=False, dtype=torch.float16, download_dir=None, load_format=auto, tensor_parallel_size=1, seed=0)
INFO 09-15 06:03:23 tokenizer.py:30] For some LLaMA V1 models, initializing the fast tokenizer may take a long time. To reduce the initialization time, consider using 'hf-internal-testing/llama-tokenizer' instead of the original tokenizer.
INFO 09-15 06:03:29 llm_engine.py:201] # GPU blocks: 7448, # CPU blocks: 512
Processed prompts: 100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 2000/2000 [05:24<00:00,  6.16it/s]
Throughput: 6.16 requests/s, 2978.86 tokens/s

After this PR:

~/vllm/vllm/benchmarks$ python benchmark_throughput.py --backend vllm --model huggyllama/llama-7b --dataset ../../data/ShareGPT_V3_unfiltered_cleaned_split.json --num-prompts 2000
Namespace(backend='vllm', dataset='../../data/ShareGPT_V3_unfiltered_cleaned_split.json', model='huggyllama/llama-7b', tokenizer='huggyllama/llama-7b', tensor_parallel_size=1, n=1, use_beam_search=False, num_prompts=2000, seed=0, hf_max_batch_size=None, trust_remote_code=False)
INFO 09-15 05:54:21 tokenizer.py:30] For some LLaMA V1 models, initializing the fast tokenizer may take a long time. To reduce the initialization time, consider using 'hf-internal-testing/llama-tokenizer' instead of the original tokenizer.
INFO 09-15 05:55:04 llm_engine.py:72] Initializing an LLM engine with config: model='huggyllama/llama-7b', tokenizer='huggyllama/llama-7b', tokenizer_mode=auto, revision=None, trust_remote_code=False, dtype=torch.float16, download_dir=None, load_format=auto, tensor_parallel_size=1, seed=0)
INFO 09-15 05:55:04 tokenizer.py:30] For some LLaMA V1 models, initializing the fast tokenizer may take a long time. To reduce the initialization time, consider using 'hf-internal-testing/llama-tokenizer' instead of the original tokenizer.
INFO 09-15 05:55:10 llm_engine.py:201] # GPU blocks: 7448, # CPU blocks: 512
Processed prompts: 100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 2000/2000 [04:05<00:00,  8.15it/s]
Throughput: 8.15 requests/s, 3942.01 tokens/s

Same benchmark on the latest commit:

~/vllm/vllm/benchmarks$ python benchmark_throughput.py --backend vllm --model huggyllama/llama-7b --dataset ../../data/ShareGPT_V3_unfiltered_cleaned_split.json --num-prompts 2000
Namespace(backend='vllm', dataset='../../data/ShareGPT_V3_unfiltered_cleaned_split.json', model='huggyllama/llama-7b', tokenizer='huggyllama/llama-7b', tensor_parallel_size=1, n=1, use_beam_search=False, num_prompts=2000, seed=0, hf_max_batch_size=None, trust_remote_code=False)
INFO 09-17 06:19:00 tokenizer.py:30] For some LLaMA V1 models, initializing the fast tokenizer may take a long time. To reduce the initialization time, consider using 'hf-internal-testing/llama-tokenizer' instead of the original tokenizer.
INFO 09-17 06:19:42 llm_engine.py:72] Initializing an LLM engine with config: model='huggyllama/llama-7b', tokenizer='huggyllama/llama-7b', tokenizer_mode=auto, revision=None, trust_remote_code=False, dtype=torch.float16, download_dir=None, load_format=auto, tensor_parallel_size=1, seed=0)
INFO 09-17 06:19:42 tokenizer.py:30] For some LLaMA V1 models, initializing the fast tokenizer may take a long time. To reduce the initialization time, consider using 'hf-internal-testing/llama-tokenizer' instead of the original tokenizer.
INFO 09-17 06:19:47 llm_engine.py:201] # GPU blocks: 7448, # CPU blocks: 512
Processed prompts: 100%|████████████████████████████████████████████████████████████████████████████████████████████| 2000/2000 [03:48<00:00,  8.76it/s]
Throughput: 8.76 requests/s, 4236.86 tokens/s

@esmeetu
Copy link
Member

esmeetu commented Sep 17, 2023

@zhuohan123 Sent. Yes, but i current mostly use WizardCoder, and it just work without change kernel precision. Codellama indeed should apply that to improve performance.

Copy link
Collaborator

@Yard1 Yard1 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looks great! I would also advise carrying over the vectorized apply_penalties code from my PR which is faster than the current numpy-based solution.

Copy link
Collaborator

@WoosukKwon WoosukKwon left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@zhuohan123 Thanks for the awesome work! The code looks very clean to me. Left minor comments.

@zhuohan123
Copy link
Member Author

@WoosukKwon @Yard1 Thanks for your detailed review! All review comments are fixed. Please take a look.

Copy link
Collaborator

@WoosukKwon WoosukKwon left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM. Left minor comments.

Copy link
Collaborator

@Yard1 Yard1 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM!

@zhuohan123 zhuohan123 merged commit 947b794 into main Sep 23, 2023
@zhuohan123 zhuohan123 deleted the simpler-batched-sampling branch September 23, 2023 00:48
@masahi
Copy link

masahi commented Oct 2, 2023

Hi, does this PR change the output of a model, given the same input and random seed? Comparing the output of example/offline_inference.py I'm seeing

Old

Prompt: 'Hello, my name is', Generated text: " Joel. I'm from Massachusetts and live in Melbourne, Australia.\nI'm"
Prompt: 'The president of the United States is', Generated text: ' about to be arrested in Europe for allegedly meddling in the 2016 election.\n\n'
Prompt: 'The capital of France is', Generated text: ' becoming a state of chaos with a significant urban and industrial boom. France’'
Prompt: 'The future of AI is', Generated text: ' not as simple as you think, and you have to understand it in order to'

New

Prompt: 'Hello, my name is', Generated text: ' Joel, my dad is my friend and we are in a relationship. I am'
Prompt: 'The president of the United States is', Generated text: ' speaking out against the release of some State Department documents which show the Russians were involved'
Prompt: 'The capital of France is', Generated text: ' known as the “Pale Oasis”, a German-run'
Prompt: 'The future of AI is', Generated text: " going to be a product of the aging society.\nThat's why I think"

@WoosukKwon
Copy link
Collaborator

@masahi Yes, the PR has changed the outputs of random sampling, because now we use batched operations for sampling.

hongxiayang pushed a commit to hongxiayang/vllm that referenced this pull request Feb 13, 2024
yiliu30 pushed a commit to yiliu30/vllm-fork that referenced this pull request Apr 16, 2025
…lm-project#1048)

The make_attn_bias in hpu_model_runner doesn't cover the non-causal
embedding model mask set and also vertical mask off is not set when
merged prefill is enabled.
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

5 participants