Skip to content

Commit 8939174

Browse files
author
dangshunya
committed
[New Model] support Baichuan-M1
Signed-off-by: dangshunya <[email protected]>
1 parent 6dd94db commit 8939174

File tree

8 files changed

+882
-9
lines changed

8 files changed

+882
-9
lines changed

docs/source/models/supported_models.md

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -96,6 +96,11 @@ See [this page](#generative-models) for more information on how to use generativ
9696
- `baichuan-inc/Baichuan2-13B-Chat`, `baichuan-inc/Baichuan-7B`, etc.
9797
- ✅︎
9898
- ✅︎
99+
* - `BaichuanM1ForCausalLM`
100+
- Baichuan-M1
101+
- `baichuan-inc/Baichuan-M1-14B-Instruct`, `baichuan-inc/Baichuan-M1-14B-Base`, etc.
102+
- ✅︎
103+
- ✅︎
99104
* - `BloomForCausalLM`
100105
- BLOOM, BLOOMZ, BLOOMChat
101106
- `bigscience/bloom`, `bigscience/bloomz`, etc.

tests/models/registry.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -100,6 +100,8 @@ def check_available_online(
100100
trust_remote_code=True),
101101
"BaichuanForCausalLM": _HfExamplesInfo("baichuan-inc/Baichuan2-7B-chat",
102102
trust_remote_code=True),
103+
"BaichuanM1ForCausalLM": _HfExamplesInfo("baichuan-inc/Baichuan-M1-14B-Instruct", # noqa: E501
104+
trust_remote_code=True),
103105
"BloomForCausalLM": _HfExamplesInfo("bigscience/bloomz-1b1"),
104106
# ChatGLMModel supports multimodal
105107
"CohereForCausalLM": _HfExamplesInfo("CohereForAI/c4ai-command-r-v01",

vllm/config.py

Lines changed: 45 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -305,9 +305,12 @@ def __init__(
305305
self.enforce_eager = False
306306

307307
sliding_window = getattr(self.hf_text_config, "sliding_window", None)
308+
sliding_window_layers = getattr(self.hf_text_config,
309+
"sliding_window_layers", None)
308310
has_interleaved_attention = (sliding_window is not None) and (
309311
isinstance(sliding_window, list) or
310-
(self.hf_text_config.model_type in ["gemma2", "cohere2"]))
312+
(self.hf_text_config.model_type in ["gemma2", "cohere2"])
313+
or sliding_window_layers is not None)
311314

312315
if (not self.disable_sliding_window and has_interleaved_attention):
313316
if envs.VLLM_ATTENTION_BACKEND == "XFORMERS":
@@ -713,6 +716,9 @@ def get_hf_config_sliding_window(
713716
if (hasattr(self.hf_text_config, "use_sliding_window")
714717
and not self.hf_text_config.use_sliding_window):
715718
return None
719+
if hasattr(self.hf_text_config, 'sliding_window_layers'):
720+
return None
721+
716722
return getattr(self.hf_text_config, "sliding_window", None)
717723

718724
def get_sliding_window(self) -> Optional[Union[int, List[Optional[int]]]]:
@@ -724,6 +730,10 @@ def get_sliding_window(self) -> Optional[Union[int, List[Optional[int]]]]:
724730
# Otherwise get the value from the hf config.
725731
return self.get_hf_config_sliding_window()
726732

733+
def get_sliding_window_layers(self,
734+
parallel_config) -> Optional[List[int]]:
735+
return getattr(self.hf_text_config, "sliding_window_layers", [])
736+
727737
def get_vocab_size(self) -> int:
728738
return self.hf_text_config.vocab_size
729739

@@ -751,6 +761,12 @@ def get_head_size(self) -> int:
751761
return (self.hf_text_config.hidden_size //
752762
self.hf_text_config.num_attention_heads)
753763

764+
def get_head_size_swa(self) -> int:
765+
if hasattr(self.hf_text_config, "num_swa_attention_heads"):
766+
return (self.hf_text_config.hidden_size //
767+
self.hf_text_config.num_swa_attention_heads)
768+
return self.get_head_size()
769+
754770
def get_total_num_kv_heads(self) -> int:
755771
"""Returns the total number of KV heads."""
756772
# For GPTBigCode & Falcon:
@@ -797,6 +813,22 @@ def get_total_num_kv_heads(self) -> int:
797813
# equal to the number of attention heads.
798814
return self.hf_text_config.num_attention_heads
799815

816+
def get_total_num_kv_heads_swa(self) -> int:
817+
if hasattr(self.hf_text_config, "num_swa_key_value_heads"):
818+
return self.hf_text_config.num_swa_key_value_heads
819+
return self.get_total_num_kv_heads()
820+
821+
def get_num_swa_key_value_heads(self,
822+
parallel_config: "ParallelConfig") -> int:
823+
"""Returns the number of KV heads per GPU."""
824+
total_num_kv_heads_swa = self.get_total_num_kv_heads_swa()
825+
# If tensor parallelism is used, we divide the number of KV heads by
826+
# the tensor parallel size. We will replicate the KV heads in the
827+
# case where the number of KV heads is smaller than the tensor
828+
# parallel size so each GPU has at least one KV head.
829+
return max(
830+
1, total_num_kv_heads_swa // parallel_config.tensor_parallel_size)
831+
800832
def get_num_kv_heads(self, parallel_config: "ParallelConfig") -> int:
801833
"""Returns the number of KV heads per GPU."""
802834
total_num_kv_heads = self.get_total_num_kv_heads()
@@ -839,7 +871,18 @@ def get_num_layers_by_block_type(
839871

840872
if is_transformer:
841873
# Handle the basic case first
842-
return end - start if attn_block_type else 0
874+
swa_layers = self.get_sliding_window_layers(parallel_config)
875+
num_layers = 0
876+
if not swa_layers:
877+
num_layers = end - start if attn_block_type else 0
878+
else:
879+
for layer_id in range(start, end):
880+
if (block_type == LayerBlockType.attention
881+
and layer_id not in swa_layers) or (
882+
block_type == LayerBlockType.swa
883+
and layer_id in swa_layers):
884+
num_layers += 1
885+
return num_layers
843886
elif self.is_attention_free:
844887
# Attention free
845888
# Note that this code assumes there
@@ -2360,7 +2403,6 @@ def _get_and_verify_max_len(
23602403
max_len_key = key if max_len < derived_max_model_len \
23612404
else max_len_key
23622405
derived_max_model_len = min(derived_max_model_len, max_len)
2363-
23642406
# If sliding window is manually disabled, max_length should be less
23652407
# than the sliding window length in the model config.
23662408
if disable_sliding_window and sliding_window_len is not None:

vllm/engine/llm_engine.py

Lines changed: 10 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -341,6 +341,11 @@ def get_tokenizer_for_seq(sequence: Sequence) -> AnyTokenizer:
341341
# of request outputs to asyncio queues
342342
self.process_request_outputs_callback: Optional[Callable] = None
343343

344+
# This is used to evict the finished requests from the Mamba cache and
345+
# Baichuan-M1, We should use it to keep finished_req_ids when scheduler
346+
# is empty.
347+
self.finished_requests_ids: List[str] = list()
348+
344349
# Create the scheduler.
345350
# NOTE: the cache_config here have been updated with the numbers of
346351
# GPU and CPU blocks, which are profiled in the distributed executor.
@@ -1323,6 +1328,7 @@ def step(self) -> List[Union[RequestOutput, PoolingRequestOutput]]:
13231328

13241329
finished_requests_ids = self.scheduler[
13251330
virtual_engine].get_and_reset_finished_requests_ids()
1331+
self.finished_requests_ids.extend(finished_requests_ids)
13261332

13271333
# Maybe switch from async mode to sync mode
13281334
if not allow_async_output_proc and len(ctx.output_queue) > 0:
@@ -1335,8 +1341,6 @@ def step(self) -> List[Union[RequestOutput, PoolingRequestOutput]]:
13351341
self._cache_scheduler_outputs_for_multi_step(
13361342
virtual_engine, seq_group_metadata_list, scheduler_outputs,
13371343
allow_async_output_proc)
1338-
else:
1339-
finished_requests_ids = list()
13401344

13411345
assert seq_group_metadata_list is not None
13421346
assert scheduler_outputs is not None
@@ -1357,11 +1361,14 @@ def step(self) -> List[Union[RequestOutput, PoolingRequestOutput]]:
13571361
blocks_to_copy=scheduler_outputs.blocks_to_copy,
13581362
num_lookahead_slots=scheduler_outputs.num_lookahead_slots,
13591363
running_queue_size=scheduler_outputs.running_queue_size,
1360-
finished_requests_ids=finished_requests_ids,
1364+
finished_requests_ids=self.finished_requests_ids,
13611365
# We use ExecuteModelRequest to pass the last sampled_token_ids
13621366
# to each of the non-last PP stages for in-place prepare_input.
13631367
last_sampled_token_ids=last_sampled_token_ids)
13641368

1369+
# Clear finished_requests_ids list.
1370+
self.finished_requests_ids = list()
1371+
13651372
if allow_async_output_proc:
13661373
execute_model_req.async_callback = self.async_callbacks[
13671374
virtual_engine]

0 commit comments

Comments
 (0)