Skip to content

[Performance] Enable chunked prefill and prefix caching together #7753

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 8 commits into from
Aug 28, 2024

Conversation

comaniac
Copy link
Collaborator

@comaniac comaniac commented Aug 21, 2024

Reference PRs: #6144, #6819
Make @sighingnow and @Juelianqvq as co-authors of this PR.

This PR supports prefix caching and chunked prefill to be enabled together. Different from the reference PRs, this PR simplifies the logic of dealing with partial blocks (thanks to @rkooo567 for the suggestion). Here is the execution flow:

  1. In scheduler, when determining the new tokens to be scheduled and both chunked prefill and prefix caching are enabled.
    1. If all uncomputed tokens can be scheduled (i.e., the last chunk of the prompt), then schedule them all.
    2. Otherwise, we always schedule the number of tokens that is divisible by the block size. For example, if the remaining budget is 133 tokens and the block size is 16, we will only schedule (133//16)*16=112 tokens. Although this approach wastes some token budget, it makes the following process straightforward.
  2. In prepare input, if all scheduled tokens are cached, we only compute the last block. Note that:
    1. We cannot skip all blocks at this moment because model runner doesn't support this case. Currently when block manager determines prefix cache blocks, it will also skip the last block due to the same reason (e.g., https://github.com/vllm-project/vllm/blob/main/vllm/core/block/prefix_caching_block.py#L556). This can be improved in the future if we move prefix caching to scheduler so that this case won't happen anymore.
    2. Since we guarantee the scheduled tokens are divisible by block size, we don't need to consider partial blocks in prepare input.

A test case for functional correctness is also added.

Throughput benchmarking results:

  • Model: neuralmagic/Meta-Llama-3-8B-Instruct-FP8
  • GPU: 1xL4
  • Number of requests: 600
  • Average prompt length: 637 (shared prefix ~180, cache hit rate ~20%)
  • Max output length: 200
  • Block manager v1
  • Chunked prefill size 2048
Branch ChunkedPrefill PrefixCaching Elapsed Time (s) Throughput (tok/s)
main x v 154.37 3631.2
main v x 173.84 3215.1
PR x v 155.88 3596.2
PR v x 174.18 3298.8
PR v v 142.81 3929.7

cc @rkooo567

Copy link

👋 Hi! Thank you for contributing to the vLLM project.
Just a reminder: PRs would not trigger full CI run by default. Instead, it would only run fastcheck CI which consists a small and essential subset of CI tests to quickly catch errors. You can run other CI tests on top of default ones by unblocking the steps in your fast-check build on Buildkite UI.

Once the PR is approved and ready to go, please make sure to run full CI as it is required to merge (or just use auto-merge).

To run full CI, you can do one of these:

  • Comment /ready on the PR
  • Add ready label to the PR
  • Enable auto-merge.

🚀

@comaniac comaniac changed the title Prefix cache chunked prefill [Performance] Enable chunked prefill and prefix caching together Aug 21, 2024
@rkooo567
Copy link
Collaborator

result seems very good!!

@sighingnow
Copy link
Contributor

sighingnow commented Aug 21, 2024

Hi @comaniac @rkooo567 I would like you folks to notice my last commit on #6144 (a043643).

Without it, this PR is still incorrect, and the error can be reproduced with even a single request:

  • request 1: length 120
  • chunked prefill enabled
  • prefix caching enabled
  • max_num_batched_tokens = 64, max_num_seqs = 64

You will find that with this PR, at the first round, tokens[0:64] is prefilled, at the second round, tokens[96:119] is prefilled, and the tokens between 64 and 96 are skipped.

This is because the num_computed_blocks is incorrectly updated as the whole block table for prompt tokens, rather than tokens that are prefilled at the first round.

@comaniac
Copy link
Collaborator Author

IIUC, this PR already guarantees every sequence will have at least one block to compute even it fully hits the cache, so it shouldn't trigger the issue you mentioned? If I missed anything, can you modify the unit test added in this PR so that the problem can be exposed and tested?

@sighingnow
Copy link
Contributor

IIUC, this PR already guarantees every sequence will have at least one block to compute even it fully hits the cache, so it shouldn't trigger the issue you mentioned?

It is not about fully matched. In the case commented above, there are only 1 request, and the prefill are spited to [0:64] and [64:120], and the second part is treated as prefix matched as the computed_block_nums are updated to [0,1,2,3,4,5,6,7] after the first chunk prefill.

@sighingnow
Copy link
Contributor

IIUC, this PR already guarantees every sequence will have at least one block to compute even it fully hits the cache, so it shouldn't trigger the issue you mentioned? If I missed anything, can you modify the unit test added in this PR so that the problem can be exposed and tested?

The test case in this PR didn't fail just because the max_num_batched_tokens (14) is smaller than the block size (16). Try larger value like 64.

@comaniac
Copy link
Collaborator Author

IIUC, this PR already guarantees every sequence will have at least one block to compute even it fully hits the cache, so it shouldn't trigger the issue you mentioned? If I missed anything, can you modify the unit test added in this PR so that the problem can be exposed and tested?

The test case in this PR didn't fail just because the max_num_batched_tokens (14) is smaller than the block size (16). Try larger value like 64.

The size 14 is used to test invalid size. The actual size being tested in this case is 16. Meanwhile, I tried all 16, 32 and 64 but none of them failed.

@sighingnow
Copy link
Contributor

sighingnow commented Aug 21, 2024

IIUC, this PR already guarantees every sequence will have at least one block to compute even it fully hits the cache, so it shouldn't trigger the issue you mentioned? If I missed anything, can you modify the unit test added in this PR so that the problem can be exposed and tested?

The test case in this PR didn't fail just because the max_num_batched_tokens (14) is smaller than the block size (16). Try larger value like 64.

The size 14 is used to test invalid size. The actual size being tested in this case is 16. Meanwhile, I tried all 16, 32 and 64 but none of them failed.

With max_num_batched_tokens=64, you need sequence length at least to 64 + 2 * block_size to reproduce the problem, 41 is not enough.

max_num_batched_tokens=16/32 cannot reproduce the issue, too, as the second block are guaranteed to be recomputed in this PR.

@comaniac
Copy link
Collaborator Author

Ok I could reproduce the issue you pointed out. It actually only happens in block manager v1 as block manager v2 doesn't use this mechanism to mark computed blocks. This may also explain the too good speedup I got. I'll apply your fix in this PR and try to make the test cover this case.

@comaniac
Copy link
Collaborator Author

@sighingnow I applied your commit with some modifications. The test is also changed so that it will fail without fixing the issue in block manager v1. PTAL.

@sighingnow
Copy link
Contributor

sighingnow commented Aug 22, 2024

@sighingnow I applied your commit with some modifications. The test is also changed so that it will fail without fixing the issue in block manager v1. PTAL.

Thanks! LGTM.

Copy link
Collaborator

@rkooo567 rkooo567 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looks good. One question is should we just make scheduler handle prefix caching + chunked prefill correctly and make logics in model_runner simplified?

raise ValueError("When enabling chunked prefill and "
"prefix caching, max_num_batched_tokens "
"(chunk size) must be dividable by "
"block size, but got "
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

can you also print chunk size and block size along with budget.token_budget % block_size?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It now looks like

ValueError: When enabling chunked prefill and prefix caching, max_num_batched_tokens (chunk size) must be dividable by block size, but got chunk_size (30) % block_size (16) = 14

@sighingnow
Copy link
Contributor

@sighingnow I applied your commit with some modifications. The test is also changed so that it will fail without fixing the issue in block manager v1. PTAL.

Will the fix for v2 block manager be addressed by this PR as well? The behavior of v2-block-manager looks quite strange and I'm wondering if #7619 is related.

@comaniac
Copy link
Collaborator Author

@sighingnow I applied your commit with some modifications. The test is also changed so that it will fail without fixing the issue in block manager v1. PTAL.

Will the fix for v2 block manager be addressed by this PR as well? The behavior of v2-block-manager looks quite strange and I'm wondering if #7619 is related.

I have a fix in my local but it would be a separate PR

@JaheimLee
Copy link

Is it for flash-attn backend only or for all backends?

@comaniac
Copy link
Collaborator Author

Is it for flash-attn backend only or for all backends?

I've tested flash-attn and FlashInfer so at least these 2 backends work. Need to test xformers later.

@Juelianqvq
Copy link
Contributor

I've tested flash-attn and FlashInfer so at least these 2 backends work. Need to test xformers later.

@comaniac https://github.com/vllm-project/vllm/blob/main/vllm/attention/backends/flashinfer.py#L360 Really supported here?

@comaniac
Copy link
Collaborator Author

I've tested flash-attn and FlashInfer so at least these 2 backends work. Need to test xformers later.

@comaniac https://github.com/vllm-project/vllm/blob/main/vllm/attention/backends/flashinfer.py#L360 Really supported here?

Yeah I noticed that too so not fully sure what's going on. Will find some time tomorrow for it.

@comaniac
Copy link
Collaborator Author

Updates:

  1. More tests are added.
  2. Chunk prefill does only support flash attention backend for now. My local test passed because it didn't schedule prefill and decode in the same batch. However, there shouldn't be a blocker for FlashInfer to support chunked prefill, so we should add this support in a follow-up PR.

@sighingnow
Copy link
Contributor

Updates:

  1. More tests are added.
  2. Chunk prefill does only support flash attention backend for now. My local test passed because it didn't schedule prefill and decode in the same batch. However, there shouldn't be a blocker for FlashInfer to support chunked prefill, so we should add this support in a follow-up PR.

May I know more why you choose to recompute the whole block if it is fully matched? Only recompute the last token is enough and requires no changes in scheduler, and it would be a bit more efficient.

@comaniac
Copy link
Collaborator Author

You're right it would be a bit more efficient to compute only the last token. Meanwhile I found that it might not be that hard to deal with prefix matching in scheduler so that this case would never happen in model runner. I'll give it a try

@comaniac comaniac force-pushed the prefix-cache-chunked-prefill branch from b305e0d to 324fcec Compare August 26, 2024 19:59
@comaniac
Copy link
Collaborator Author

comaniac commented Aug 26, 2024

@sighingnow changed to re-compute only the last token. PTAL.

@rkooo567 I've tried to move prefix caching to scheduler and it's actually easy for default scheduler. For chunked prefill, we have to refactor the scheduler (e.g., .schedule(), ._schedule_prefill(), .get_new_tokens()) and block manager (e.g., .can_allocate()). Since we have to be careful with this refactor and it can be decoupled from this PR, I'll put it in a follow-up PR tracked by #7883

@comaniac comaniac added the ready ONLY add when PR is ready to merge/full CI is needed label Aug 26, 2024
Copy link
Collaborator

@rkooo567 rkooo567 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Generally looks good. I'd like to actually also add a warning if the block size is big and prefix caching + CP is enabled (because it can waste a lot of tokens). Maybe if block_size >32, we can print a warning?

@@ -595,3 +595,43 @@ def test_sliding_window_multi_seq():

# assert all blocks are free now
assert block_manager.get_num_free_gpu_blocks() == num_gpu_blocks


def test_mark_blocks_as_computed_with_prefix_cache_and_chunked_prefill():
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

do we have corresponding test in v2?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We don't need to test v2 because v2 automatically mark touched blocks as computed.

# to avoid partial block matching.
block_size = self.cache_config.block_size
reminder = budget.token_budget % block_size
if reminder != 0:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Btw, should we raise this exception at the engine start time instead and just add assert here?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I feel we could just raise here for now because this constraint should be able to be removed once we refactor the schedule to consider prefix caching.

@comaniac
Copy link
Collaborator Author

Generally looks good. I'd like to actually also add a warning if the block size is big and prefix caching + CP is enabled (because it can waste a lot of tokens). Maybe if block_size >32, we can print a warning?

Sure I'll add the warning in a follow-up PR.

@comaniac comaniac merged commit e358053 into vllm-project:main Aug 28, 2024
54 checks passed
@comaniac comaniac deleted the prefix-cache-chunked-prefill branch August 28, 2024 07:36
@Juelianqvq
Copy link
Contributor

Since this PR has been merged, both #6144 and #6819 can be closed, and are you willing to add me and @sighingnow as the co-authors? @comaniac

@comaniac
Copy link
Collaborator Author

Ah I intended to do that. Actually I put you two as co-authors in one commit of this PR and I thought it should work when the PR is merged but somehow it didn't...let me try to figure out how to fix that. Also cc @simon-mo

@sighingnow
Copy link
Contributor

sighingnow commented Aug 29, 2024

To whom it may concern: after this PR there are still occasional crashes when prefix caching and chunked prefill are enabled at the same time on Nvidia GPUs (inside the flash_attn_varlen_func function in the prefix-enabled attention branch). I investigated the kernel input and find nothing wrong and cannot reproduce it when run the kernel standalone with the pickle saved inputs. I think there are still overflow bugs inside vllm-flash-attention, set the block_size to 256 could fix the issue and the crash disappeared under high pressure.

comaniac added a commit to comaniac/vllm that referenced this pull request Sep 3, 2024
Co-authored-by: Tao He <[email protected]>
Co-authored-by: Juelianqvq <[email protected]>
@ashgold
Copy link

ashgold commented Sep 3, 2024

To whom it may concern: after this PR there are still occasional crashes when prefix caching and chunked prefill are enabled at the same time on Nvidia GPUs (inside the flash_attn_varlen_func function in the prefix-enabled attention branch). I investigated the kernel input and find nothing wrong and cannot reproduce it when run the kernel standalone with the pickle saved outputs. I think there are still overflow bugs inside vllm-flash-attention, set the block_size to 256 could fix the issue and the crash disappeared under high pressure.

This looks like a serious bug that needs to be fixed before it can go to production. Thanks for sharing the workaround solution as well.

@hmellor
Copy link
Member

hmellor commented Sep 10, 2024

If you are using a model with max_model_len > 32K (i.e. Llama 3.1) then chunked prefill is enabled by default. However, this PR leaves the and not self.enable_prefix_caching condition in this automatic enabling of chunked prefill.

This means that a user relying on the automatic enabling of chunked prefill might not notice it becoming disabled when they enable prefix caching.

if self.enable_chunked_prefill is None:
# If not explicitly set, enable chunked prefill by default for
# long context (> 32K) models. This is to avoid OOM errors in the
# initial memory profiling phase.
if use_long_context:
is_gpu = device_config.device_type == "cuda"
use_sliding_window = (model_config.get_sliding_window()
is not None)
use_spec_decode = self.speculative_model is not None
has_seqlen_agnostic_layers = (
model_config.contains_seqlen_agnostic_layers(
parallel_config))
if (is_gpu and not use_sliding_window and not use_spec_decode
and not self.enable_lora
and not self.enable_prompt_adapter
and not self.enable_prefix_caching
and not has_seqlen_agnostic_layers):
self.enable_chunked_prefill = True
logger.warning(
"Chunked prefill is enabled by default for models with "
"max_model_len > 32K. Currently, chunked prefill might "
"not work with some features or models. If you "
"encounter any issues, please disable chunked prefill "
"by setting --enable-chunked-prefill=False.")
if self.enable_chunked_prefill is None:
self.enable_chunked_prefill = False

cc @comaniac

@comaniac
Copy link
Collaborator Author

Good point. I'll file another PR to fix it.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
ready ONLY add when PR is ready to merge/full CI is needed
Projects
None yet
Development

Successfully merging this pull request may close these issues.

7 participants