Skip to content

Enable interleaved sliding window attention models for Transformers backend #18494

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion docs/source/contributing/model/basic.md
Original file line number Diff line number Diff line change
Expand Up @@ -117,7 +117,7 @@ For models with interleaving sliding windows (e.g. `google/gemma-2-2b-it` and `m

To support a model with interleaving sliding windows, we need to take care of the following details:

- Make sure [this line](https://github.com/vllm-project/vllm/blob/996357e4808ca5eab97d4c97c7d25b3073f46aab/vllm/config.py#L308) evaluates `has_interleaved_attention` to `True` for this model, and set `self.hf_text_config.interleaved_sliding_window` to the format of interleaving sliding windows the model can understand. Then, `self.hf_text_config.sliding_window` will be deleted, and the model will be treated as a full-attention model.
- Make sure the model's `config.json` contains `sliding_window_pattern`. vLLM then sets `self.hf_text_config.interleaved_sliding_window` to the value of `self.hf_text_config.sliding_window` and deletes `sliding_window` from `self.hf_text_config`. The model will then be treated as a full-attention model.
- In the modeling code, parse the correct sliding window value for every layer, and pass it to the attention layer's `per_layer_sliding_window` argument. For reference, check [this line](https://github.com/vllm-project/vllm/blob/996357e4808ca5eab97d4c97c7d25b3073f46aab/vllm/model_executor/models/llama.py#L171).

With these two steps, interleave sliding windows should work with the model.
56 changes: 42 additions & 14 deletions tests/models/test_transformers.py
Original file line number Diff line number Diff line change
@@ -1,37 +1,50 @@
# SPDX-License-Identifier: Apache-2.0
"""Test the functionality of the Transformers backend."""
from typing import Any, Optional, Union

import pytest

from vllm.platforms import current_platform

from ..conftest import HfRunner, VllmRunner
from ..core.block.e2e.test_correctness_sliding_window import prep_prompts
from ..utils import multi_gpu_test
from .utils import check_logprobs_close


def check_implementation(
hf_runner: type[HfRunner],
vllm_runner: type[VllmRunner],
runner_ref: type[Union[HfRunner, VllmRunner]],
runner_test: type[VllmRunner],
example_prompts: list[str],
model: str,
kwargs_ref: Optional[dict[str, Any]] = None,
kwargs_test: Optional[dict[str, Any]] = None,
**kwargs,
):
if kwargs_ref is None:
kwargs_ref = {}
if kwargs_test is None:
kwargs_test = {}

max_tokens = 32
num_logprobs = 5

with vllm_runner(model, **kwargs) as vllm_model:
vllm_outputs = vllm_model.generate_greedy_logprobs(
example_prompts, max_tokens, num_logprobs)
args = (example_prompts, max_tokens, num_logprobs)

with runner_test(model, **kwargs_test, **kwargs) as model_test:
outputs_test = model_test.generate_greedy_logprobs(*args)

with hf_runner(model) as hf_model:
hf_outputs = hf_model.generate_greedy_logprobs_limit(
example_prompts, max_tokens, num_logprobs)
with runner_ref(model, **kwargs_ref) as model_ref:
if isinstance(model_ref, VllmRunner):
outputs_ref = model_ref.generate_greedy_logprobs(*args)
else:
outputs_ref = model_ref.generate_greedy_logprobs_limit(*args)

check_logprobs_close(
outputs_0_lst=hf_outputs,
outputs_1_lst=vllm_outputs,
name_0="hf",
name_1="vllm",
outputs_0_lst=outputs_ref,
outputs_1_lst=outputs_test,
name_0="ref",
name_1="test",
)


Expand All @@ -58,15 +71,30 @@ def test_models(
model_impl=model_impl)


def test_hybrid_attention(vllm_runner: type[VllmRunner]) -> None:
prompts, _, _ = prep_prompts(4, (800, 801))
kwargs_ref = {"max_model_len": 8192, "enforce_eager": True}
kwargs_test = {"model_impl": "transformers", **kwargs_ref}
check_implementation(vllm_runner,
vllm_runner,
prompts,
model="hmellor/tiny-random-Gemma2ForCausalLM",
kwargs_ref=kwargs_ref,
kwargs_test=kwargs_test)


@multi_gpu_test(num_gpus=2)
def test_distributed(
hf_runner: type[HfRunner],
vllm_runner: type[VllmRunner],
example_prompts,
):
kwargs = {"model_impl": "transformers", "tensor_parallel_size": 2}
check_implementation(hf_runner, vllm_runner, example_prompts,
"meta-llama/Llama-3.2-1B-Instruct", **kwargs)
check_implementation(hf_runner,
vllm_runner,
example_prompts,
"meta-llama/Llama-3.2-1B-Instruct",
kwargs_test=kwargs)


@pytest.mark.skipif(
Expand Down
19 changes: 11 additions & 8 deletions vllm/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -536,13 +536,17 @@ def __post_init__(self) -> None:
self.model, hf_token=self.hf_token, revision=self.revision)
self.dtype = _get_and_verify_dtype(self.hf_config, self.dtype)

interleaved_attn_models = ["gemma2", "gemma3_text", "cohere2"]
# Workaround for Gemma 2 which uses interleaved sliding window
# attention, but it's not specified in its config. TODO: remove this
# when Gemma 2 is fixed in Transformers.
if self.hf_text_config.model_type == "gemma2":
self.hf_text_config.sliding_window_pattern = 2

sliding_window = getattr(self.hf_text_config, "sliding_window", None)
has_interleaved_attention = (sliding_window is not None) and (
isinstance(sliding_window, list) or
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Oh, just remembered that this line is for Mistral models. cc @patrickvonplaten do any of your models still use sliding_window as a list?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Using a script to download the config of every model in mistralai would suggest no:

Repository Sliding Window Sliding Window Pattern
mistralai/Mistral-7B-v0.1 4096 None
mistralai/Mistral-7B-Instruct-v0.1 4096 None
mistralai/Ministral-8B-Instruct-2410 32768 None

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Plus, according to the docstrings provided in and Mistral and Mixtral, setting sliding_window as list[int] is not supported anyway.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Alright, then it should be fine to merge. Thanks for looking into this!

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Oh, just realized that the sliding_window is actually set inside params.json... opening a PR to handle the case where it's a list

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Wouldn't that have been identified from my script? It checked the instantiated config classes rather than examining config.json directly

(self.hf_text_config.model_type in interleaved_attn_models))
sliding_window_pattern = getattr(self.hf_text_config,
"sliding_window_pattern", None)

if (not self.disable_sliding_window and has_interleaved_attention):
if not (self.disable_sliding_window or sliding_window_pattern is None):
if (backend :=
envs.VLLM_ATTENTION_BACKEND) in ("XFORMERS", "FLASHINFER"):
sliding_window_len_min = get_min_sliding_window(
Expand Down Expand Up @@ -1040,8 +1044,7 @@ def verify_with_parallel_config(
if self.use_async_output_proc:
self.use_async_output_proc = False

def get_hf_config_sliding_window(
self) -> Union[Optional[int], list[Optional[int]]]:
def get_hf_config_sliding_window(self) -> Optional[int]:
"""Get the sliding window size, or None if disabled."""

# Some models, like Qwen2 and Qwen1.5, use `use_sliding_window` in
Expand All @@ -1052,7 +1055,7 @@ def get_hf_config_sliding_window(
return None
return getattr(self.hf_text_config, "sliding_window", None)

def get_sliding_window(self) -> Optional[Union[int, list[Optional[int]]]]:
def get_sliding_window(self) -> Optional[int]:
"""Get the sliding window size, or None if disabled.
"""
# If user disables sliding window, return None.
Expand Down
57 changes: 51 additions & 6 deletions vllm/model_executor/models/transformers.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
"""Wrapper around `transformers` models"""
import re
from collections.abc import Iterable
from contextlib import nullcontext
from typing import Literal, Optional, Union

import torch
Expand Down Expand Up @@ -110,6 +111,33 @@ def replace_linear_class(
)


class ConfigOverride:
"""Context manager to temporarily override config attributes."""

def __init__(self, config: PretrainedConfig, **kwargs):
self.config = config
self.kwargs = kwargs
self.kwargs_original = {}
self.kwargs_delete = set()

def __enter__(self):
"""Override config attributes."""
for key, value in self.kwargs.items():
if not hasattr(self.config, key):
self.kwargs_delete.add(key)
self.kwargs_original[key] = getattr(self.config, key, None)
setattr(self.config, key, value)
return self.config

def __exit__(self, exc_type, exc_value, traceback):
"""Restore original config attributes."""
for key, value in self.kwargs_original.items():
if key in self.kwargs_delete:
delattr(self.config, key)
else:
setattr(self.config, key, value)


class TransformersModel(nn.Module):

def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
Expand All @@ -135,8 +163,17 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
self.pp_rank = self.pp_group.rank_in_group
self.tp_size = get_tensor_model_parallel_world_size()

# vLLM handles interleaved sliding window attention by creating a new
# interleaved_sliding_window attribute and deleting the sliding_window
# attribute. This breaks the constructors in Transformers so we
# temporarily add the attribute back to construct the model.
config_override = nullcontext()
if hasattr(config, "interleaved_sliding_window"):
config_override = ConfigOverride(
config, sliding_window=config.interleaved_sliding_window)

# Use meta device to delay allocating GPU tensors
with torch.device("meta"):
with torch.device("meta"), config_override:
# FIXME(Isotr0py): We need to refactor this part in the future to
# avoid registering an extra model layer, otherwise we will need a
# weights mapper to rename weights.
Expand Down Expand Up @@ -262,9 +299,17 @@ def create_attention_instances(self) -> dict[int, Attention]:
num_kv_heads = self.model_config.get_num_kv_heads(self.parallel_config)
start, end = get_pp_indices(self.config.num_hidden_layers,
self.pp_rank, self.pp_size)
return {
i:
Attention(

attention_instances = {}
for i in range(start, end):
# Handle interleaved sliding window attention
sliding_window = None
if (hasattr(self.config, "interleaved_sliding_window")
and hasattr(self.config, "sliding_window_pattern")
and ((i + 1) % self.config.sliding_window_pattern > 0)):
sliding_window = self.config.interleaved_sliding_window

attention_instances[i] = Attention(
num_heads=num_heads,
head_size=head_size,
# NOTE: We use Llama scale as default, if it's set by
Expand All @@ -273,9 +318,9 @@ def create_attention_instances(self) -> dict[int, Attention]:
num_kv_heads=num_kv_heads,
cache_config=self.cache_config,
quant_config=self.quant_config,
per_layer_sliding_window=sliding_window,
prefix=f"{i}.attn")
for i in range(start, end)
}
return attention_instances

def init_buffers(self, module: nn.Module):
"""
Expand Down