Skip to content

[BugFix] Fix Llama4 - Index Error When Single Request Near Max Context #16209

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion vllm/v1/attention/backends/flash_attn.py
Original file line number Diff line number Diff line change
Expand Up @@ -264,7 +264,7 @@ def make_local_attention_virtual_batches(
np.arange(pages_per_local_batch, dtype=np.int32),
(virtual_batches, pages_per_local_batch)) \
+ np.expand_dims(block_starts, axis=1)
block_indices = block_indices.flatten()
block_indices = block_indices.flatten().clip(max=block_table.shape[1] - 1)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

why does block_indices contain OOB items without the clip?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

it happens when np.arange(pages_per_local_batch, dtype=np.int32) runs off the end of the block-table, i.e. max_model_len is not a multiple of the attention_chunk_size, in this case we need to clip to simulate that there is a partial attention chunk at the end of the context

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm not sure what the correct number of elements is here, but np.minimum seems to be a little faster:

Size (elements) np.clip (µs) np.minimum (µs) Speedup
100 2.629 1.775 1.48x
500 3.065 2.270 1.35x
1,000 3.647 2.900 1.26x
2,000 4.469 3.637 1.23x
4,000 5.873 5.121 1.15x
6,000 7.392 6.604 1.12x

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Cool ya I can open up a separate PR to use that, a bit hesitant to update this PR since this one has be vetted by @LagPixelLOL (since im not setup to repo it locally) and it would be nice to get something in

batch_indices = np.repeat(np.arange(actual_batch_size, dtype=np.int32),
local_blocks * pages_per_local_batch)
block_table_local = block_table[batch_indices, block_indices]\
Expand Down