Skip to content

[Model] Enable Inference Support for the New Baichuan-M1 Model #12251

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions docs/source/models/supported_models.md
Original file line number Diff line number Diff line change
Expand Up @@ -96,6 +96,11 @@ See [this page](#generative-models) for more information on how to use generativ
- `baichuan-inc/Baichuan2-13B-Chat`, `baichuan-inc/Baichuan-7B`, etc.
- ✅︎
- ✅︎
* - `BaichuanM1ForCausalLM`
- Baichuan-M1
- `baichuan-inc/Baichuan-M1-14B-Instruct`, `baichuan-inc/Baichuan-M1-14B-Base`, etc.
- ✅︎
- ✅︎
* - `BloomForCausalLM`
- BLOOM, BLOOMZ, BLOOMChat
- `bigscience/bloom`, `bigscience/bloomz`, etc.
Expand Down
2 changes: 2 additions & 0 deletions tests/models/registry.py
Original file line number Diff line number Diff line change
Expand Up @@ -100,6 +100,8 @@ def check_available_online(
trust_remote_code=True),
"BaichuanForCausalLM": _HfExamplesInfo("baichuan-inc/Baichuan2-7B-chat",
trust_remote_code=True),
"BaichuanM1ForCausalLM": _HfExamplesInfo("baichuan-inc/Baichuan-M1-14B-Instruct", # noqa: E501
trust_remote_code=True),
"BloomForCausalLM": _HfExamplesInfo("bigscience/bloomz-1b1"),
# ChatGLMModel supports multimodal
"CohereForCausalLM": _HfExamplesInfo("CohereForAI/c4ai-command-r-v01",
Expand Down
48 changes: 45 additions & 3 deletions vllm/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -305,9 +305,12 @@ def __init__(
self.enforce_eager = False

sliding_window = getattr(self.hf_text_config, "sliding_window", None)
sliding_window_layers = getattr(self.hf_text_config,
"sliding_window_layers", None)
has_interleaved_attention = (sliding_window is not None) and (
isinstance(sliding_window, list) or
(self.hf_text_config.model_type in ["gemma2", "cohere2"]))
(self.hf_text_config.model_type in ["gemma2", "cohere2"])
or sliding_window_layers is not None)

if (not self.disable_sliding_window and has_interleaved_attention):
if envs.VLLM_ATTENTION_BACKEND == "XFORMERS":
Expand Down Expand Up @@ -713,6 +716,9 @@ def get_hf_config_sliding_window(
if (hasattr(self.hf_text_config, "use_sliding_window")
and not self.hf_text_config.use_sliding_window):
return None
if hasattr(self.hf_text_config, 'sliding_window_layers'):
return None

return getattr(self.hf_text_config, "sliding_window", None)

def get_sliding_window(self) -> Optional[Union[int, List[Optional[int]]]]:
Expand All @@ -724,6 +730,10 @@ def get_sliding_window(self) -> Optional[Union[int, List[Optional[int]]]]:
# Otherwise get the value from the hf config.
return self.get_hf_config_sliding_window()

def get_sliding_window_layers(self,
parallel_config) -> Optional[List[int]]:
return getattr(self.hf_text_config, "sliding_window_layers", [])

def get_vocab_size(self) -> int:
return self.hf_text_config.vocab_size

Expand Down Expand Up @@ -751,6 +761,12 @@ def get_head_size(self) -> int:
return (self.hf_text_config.hidden_size //
self.hf_text_config.num_attention_heads)

def get_head_size_swa(self) -> int:
if hasattr(self.hf_text_config, "num_swa_attention_heads"):
return (self.hf_text_config.hidden_size //
self.hf_text_config.num_swa_attention_heads)
return self.get_head_size()

def get_total_num_kv_heads(self) -> int:
"""Returns the total number of KV heads."""
# For GPTBigCode & Falcon:
Expand Down Expand Up @@ -797,6 +813,22 @@ def get_total_num_kv_heads(self) -> int:
# equal to the number of attention heads.
return self.hf_text_config.num_attention_heads

def get_total_num_kv_heads_swa(self) -> int:
if hasattr(self.hf_text_config, "num_swa_key_value_heads"):
return self.hf_text_config.num_swa_key_value_heads
return self.get_total_num_kv_heads()

def get_num_swa_key_value_heads(self,
parallel_config: "ParallelConfig") -> int:
"""Returns the number of KV heads per GPU."""
total_num_kv_heads_swa = self.get_total_num_kv_heads_swa()
# If tensor parallelism is used, we divide the number of KV heads by
# the tensor parallel size. We will replicate the KV heads in the
# case where the number of KV heads is smaller than the tensor
# parallel size so each GPU has at least one KV head.
return max(
1, total_num_kv_heads_swa // parallel_config.tensor_parallel_size)

def get_num_kv_heads(self, parallel_config: "ParallelConfig") -> int:
"""Returns the number of KV heads per GPU."""
total_num_kv_heads = self.get_total_num_kv_heads()
Expand Down Expand Up @@ -839,7 +871,18 @@ def get_num_layers_by_block_type(

if is_transformer:
# Handle the basic case first
return end - start if attn_block_type else 0
swa_layers = self.get_sliding_window_layers(parallel_config)
num_layers = 0
if not swa_layers:
num_layers = end - start if attn_block_type else 0
else:
for layer_id in range(start, end):
if (block_type == LayerBlockType.attention
and layer_id not in swa_layers) or (
block_type == LayerBlockType.swa
and layer_id in swa_layers):
num_layers += 1
return num_layers
elif self.is_attention_free:
# Attention free
# Note that this code assumes there
Expand Down Expand Up @@ -2360,7 +2403,6 @@ def _get_and_verify_max_len(
max_len_key = key if max_len < derived_max_model_len \
else max_len_key
derived_max_model_len = min(derived_max_model_len, max_len)

# If sliding window is manually disabled, max_length should be less
# than the sliding window length in the model config.
if disable_sliding_window and sliding_window_len is not None:
Expand Down
13 changes: 10 additions & 3 deletions vllm/engine/llm_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -341,6 +341,11 @@ def get_tokenizer_for_seq(sequence: Sequence) -> AnyTokenizer:
# of request outputs to asyncio queues
self.process_request_outputs_callback: Optional[Callable] = None

# This is used to evict the finished requests from the Mamba cache and
# Baichuan-M1, We should use it to keep finished_req_ids when scheduler
# is empty.
self.finished_requests_ids: List[str] = list()

# Create the scheduler.
# NOTE: the cache_config here have been updated with the numbers of
# GPU and CPU blocks, which are profiled in the distributed executor.
Expand Down Expand Up @@ -1323,6 +1328,7 @@ def step(self) -> List[Union[RequestOutput, PoolingRequestOutput]]:

finished_requests_ids = self.scheduler[
virtual_engine].get_and_reset_finished_requests_ids()
self.finished_requests_ids.extend(finished_requests_ids)

# Maybe switch from async mode to sync mode
if not allow_async_output_proc and len(ctx.output_queue) > 0:
Expand All @@ -1335,8 +1341,6 @@ def step(self) -> List[Union[RequestOutput, PoolingRequestOutput]]:
self._cache_scheduler_outputs_for_multi_step(
virtual_engine, seq_group_metadata_list, scheduler_outputs,
allow_async_output_proc)
else:
finished_requests_ids = list()

assert seq_group_metadata_list is not None
assert scheduler_outputs is not None
Expand All @@ -1357,11 +1361,14 @@ def step(self) -> List[Union[RequestOutput, PoolingRequestOutput]]:
blocks_to_copy=scheduler_outputs.blocks_to_copy,
num_lookahead_slots=scheduler_outputs.num_lookahead_slots,
running_queue_size=scheduler_outputs.running_queue_size,
finished_requests_ids=finished_requests_ids,
finished_requests_ids=self.finished_requests_ids,
# We use ExecuteModelRequest to pass the last sampled_token_ids
# to each of the non-last PP stages for in-place prepare_input.
last_sampled_token_ids=last_sampled_token_ids)

# Clear finished_requests_ids list.
self.finished_requests_ids = list()

if allow_async_output_proc:
execute_model_req.async_callback = self.async_callbacks[
virtual_engine]
Expand Down
Loading
Loading