-
-
Notifications
You must be signed in to change notification settings - Fork 7.7k
Add GLM-4-0414 support #16338
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Merged
Merged
Add GLM-4-0414 support #16338
Changes from all commits
Commits
Show all changes
15 commits
Select commit
Hold shift + click to select a range
f9988ba
[Misc] Fix test_sharded_state_loader.py(#16004) (#16005)
Accelerator1996 b86e411
[Bugfix] Avoid transferring cached multi-modal items from P0 to P1 (#…
DarkLight1337 53ab38a
Update label-tpu mergify and remove removal bot (#16298)
mgoin 32e49dd
[BugFix] logger is not callable (#16312)
yihong0618 638e304
[BugFix] llama4 qknorm should be not shared across head (#16311)
luccafong 77569cb
GLM-4-0414 Support
zRzRzRzRzRzRzR 3bbac32
update neuron config (#16289)
ajayvohra2005 26dd582
[BugFix] fix some typos found by typos. (#16314)
yihong0618 987c9a4
update with pre-commit
zRzRzRzRzRzRzR c8e2b50
remove useless code
zRzRzRzRzRzRzR 2a91bc1
[Model] Add `SupportsMultiModal.get_language_model` interface (#16007)
NickLucche 9c7468f
[Bugfix][Frontend] respect provided default guided decoding backend (…
gcalmettes 9c78bcf
Merge branch 'vllm-project:main' into main
zRzRzRzRzRzRzR 4da7a28
GLM-4-0414 Support
zRzRzRzRzRzRzR db95cbc
update with gate_up_proj
zRzRzRzRzRzRzR File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,313 @@ | ||
# SPDX-License-Identifier: Apache-2.0 | ||
|
||
# Copyright 2025 The Zhipu AI team. | ||
# Copyright 2023 The vLLM team. | ||
# Copyright 2022 EleutherAI and the HuggingFace Inc. team. All rights reserved. | ||
# | ||
# This code is based on EleutherAI's GPT-NeoX library and the GPT-NeoX | ||
# and OPT implementations in this library. It has been modified from its | ||
# original forms to accommodate minor architectural differences compared | ||
# to GPT-NeoX and OPT used by the Meta AI team that trained the model. | ||
# | ||
# Licensed under the Apache License, Version 2.0 (the "License"); | ||
# you may not use this file except in compliance with the License. | ||
# You may obtain a copy of the License at | ||
# | ||
# http://www.apache.org/licenses/LICENSE-2.0 | ||
# | ||
# Unless required by applicable law or agreed to in writing, software | ||
# distributed under the License is distributed on an "AS IS" BASIS, | ||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
# See the License for the specific language governing permissions and | ||
# limitations under the License. | ||
"""Inference-only GLM-4-0414 model compatible with HuggingFace weights.""" | ||
from typing import Iterable, Optional, Set, Tuple, Union | ||
|
||
import torch | ||
from torch import nn | ||
from transformers import Glm4Config | ||
|
||
from vllm.attention import Attention, AttentionType | ||
from vllm.compilation.decorators import support_torch_compile | ||
from vllm.config import CacheConfig, VllmConfig | ||
from vllm.distributed import get_pp_group, get_tensor_model_parallel_world_size | ||
from vllm.model_executor.layers.layernorm import RMSNorm | ||
from vllm.model_executor.layers.linear import (QKVParallelLinear, | ||
RowParallelLinear) | ||
from vllm.model_executor.layers.logits_processor import LogitsProcessor | ||
from vllm.model_executor.layers.quantization import QuantizationConfig | ||
from vllm.model_executor.layers.rotary_embedding import get_rope | ||
from vllm.model_executor.layers.sampler import SamplerOutput, get_sampler | ||
from vllm.model_executor.layers.vocab_parallel_embedding import ParallelLMHead | ||
from vllm.model_executor.sampling_metadata import SamplingMetadata | ||
from vllm.sequence import IntermediateTensors | ||
|
||
from .interfaces import SupportsLoRA, SupportsPP | ||
from .llama import LlamaMLP as Glm4MLP | ||
from .llama import LlamaModel | ||
from .utils import AutoWeightsLoader, PPMissingLayer, maybe_prefix | ||
|
||
|
||
class Glm4Attention(nn.Module): | ||
|
||
def __init__(self, | ||
config: Glm4Config, | ||
hidden_size: int, | ||
num_heads: int, | ||
num_kv_heads: int, | ||
max_position: int = 4096 * 32, | ||
head_dim: Optional[int] = None, | ||
qkv_bias: bool = False, | ||
rope_theta: float = 10000, | ||
cache_config: Optional[CacheConfig] = None, | ||
quant_config: Optional[QuantizationConfig] = None, | ||
rope_scaling: Optional[Tuple] = None, | ||
prefix: str = "", | ||
attn_type: str = AttentionType.DECODER) -> None: | ||
super().__init__() | ||
self.hidden_size = hidden_size | ||
tp_size = get_tensor_model_parallel_world_size() | ||
self.total_num_heads = num_heads | ||
assert self.total_num_heads % tp_size == 0 | ||
self.num_heads = self.total_num_heads // tp_size | ||
self.total_num_kv_heads = num_kv_heads | ||
if self.total_num_kv_heads >= tp_size: | ||
# Number of KV heads is greater than TP size, so we partition | ||
# the KV heads across multiple tensor parallel GPUs. | ||
assert self.total_num_kv_heads % tp_size == 0 | ||
else: | ||
# Number of KV heads is less than TP size, so we replicate | ||
# the KV heads across multiple tensor parallel GPUs. | ||
assert tp_size % self.total_num_kv_heads == 0 | ||
partial_rotary_factor = getattr(config, "partial_rotary_factor", 0.5) | ||
self.num_kv_heads = max(1, self.total_num_kv_heads // tp_size) | ||
self.head_dim = head_dim or hidden_size // self.total_num_heads | ||
self.rotary_dim = int(partial_rotary_factor * self.head_dim) | ||
self.q_size = self.num_heads * self.head_dim | ||
self.kv_size = self.num_kv_heads * self.head_dim | ||
self.scaling = self.head_dim**-0.5 | ||
self.rope_theta = rope_theta | ||
self.qkv_proj = QKVParallelLinear( | ||
hidden_size, | ||
self.head_dim, | ||
self.total_num_heads, | ||
self.total_num_kv_heads, | ||
bias=qkv_bias, | ||
quant_config=quant_config, | ||
prefix=f"{prefix}.qkv_proj", | ||
) | ||
self.o_proj = RowParallelLinear( | ||
self.total_num_heads * self.head_dim, | ||
hidden_size, | ||
bias=False, | ||
quant_config=quant_config, | ||
prefix=f"{prefix}.o_proj", | ||
) | ||
self.rotary_emb = get_rope( | ||
self.head_dim, | ||
rotary_dim=self.rotary_dim, | ||
max_position=max_position, | ||
base=self.rope_theta, | ||
rope_scaling=rope_scaling, | ||
partial_rotary_factor=partial_rotary_factor, | ||
) | ||
self.attn = Attention(self.num_heads, | ||
self.head_dim, | ||
self.scaling, | ||
num_kv_heads=self.num_kv_heads, | ||
cache_config=cache_config, | ||
quant_config=quant_config, | ||
prefix=f"{prefix}.attn", | ||
attn_type=attn_type) | ||
|
||
def forward( | ||
self, | ||
positions: torch.Tensor, | ||
hidden_states: torch.Tensor, | ||
) -> torch.Tensor: | ||
qkv, _ = self.qkv_proj(hidden_states) | ||
q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1) | ||
q, k = self.rotary_emb(positions, q, k) | ||
attn_output = self.attn(q, k, v) | ||
output, _ = self.o_proj(attn_output) | ||
return output | ||
|
||
|
||
class Glm4DecoderLayer(nn.Module): | ||
|
||
def __init__( | ||
self, | ||
config: Glm4Config, | ||
cache_config: Optional[CacheConfig] = None, | ||
quant_config: Optional[QuantizationConfig] = None, | ||
prefix: str = "", | ||
) -> None: | ||
super().__init__() | ||
self.hidden_size = config.hidden_size | ||
rope_theta = getattr(config, "rope_theta", 1000000) | ||
rope_scaling = getattr(config, "rope_scaling", None) | ||
|
||
self.self_attn = Glm4Attention( | ||
config=config, | ||
hidden_size=self.hidden_size, | ||
num_heads=config.num_attention_heads, | ||
max_position=config.max_position_embeddings, | ||
num_kv_heads=config.num_key_value_heads, | ||
rope_theta=rope_theta, | ||
qkv_bias=getattr(config, 'attention_bias', False), | ||
head_dim=getattr(config, 'head_dim', None), | ||
cache_config=cache_config, | ||
quant_config=quant_config, | ||
rope_scaling=rope_scaling, | ||
prefix=f"{prefix}.self_attn", | ||
attn_type=AttentionType.DECODER, | ||
) | ||
self.mlp = Glm4MLP( | ||
hidden_size=self.hidden_size, | ||
intermediate_size=config.intermediate_size, | ||
hidden_act=config.hidden_act, | ||
quant_config=quant_config, | ||
prefix=f"{prefix}.mlp", | ||
) | ||
self.input_layernorm = RMSNorm(config.hidden_size, | ||
eps=config.rms_norm_eps) | ||
self.post_attention_layernorm = RMSNorm(config.hidden_size, | ||
eps=config.rms_norm_eps) | ||
self.post_self_attn_layernorm = RMSNorm(config.hidden_size, | ||
eps=config.rms_norm_eps) | ||
self.post_mlp_layernorm = RMSNorm(config.hidden_size, | ||
eps=config.rms_norm_eps) | ||
|
||
def forward( | ||
self, | ||
positions: torch.Tensor, | ||
hidden_states: torch.Tensor, | ||
residual: Optional[torch.Tensor], | ||
) -> Tuple[torch.Tensor, torch.Tensor]: | ||
# Self Attention | ||
if residual is None: | ||
residual = hidden_states | ||
hidden_states = self.input_layernorm(hidden_states) | ||
else: | ||
hidden_states, residual = self.input_layernorm( | ||
hidden_states, residual) | ||
hidden_states = self.self_attn( | ||
positions=positions, | ||
hidden_states=hidden_states, | ||
) | ||
|
||
hidden_states = self.post_self_attn_layernorm(hidden_states) | ||
hidden_states = residual + hidden_states | ||
|
||
# Fully Connected | ||
hidden_states = self.post_attention_layernorm(hidden_states, residual) | ||
hidden_states = self.mlp(hidden_states) | ||
hidden_states = self.post_mlp_layernorm(hidden_states) | ||
hidden_states = residual + hidden_states | ||
|
||
return hidden_states, residual | ||
|
||
|
||
ALL_DECODER_LAYER_TYPES = { | ||
"attention": Glm4DecoderLayer, | ||
} | ||
|
||
|
||
@support_torch_compile( | ||
dynamic_arg_dims={ | ||
"input_ids": 0, | ||
"positions": -1, | ||
"intermediate_tensors": 0, | ||
"inputs_embeds": 0, | ||
}) | ||
class Glm4Model(LlamaModel): | ||
|
||
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): | ||
super().__init__(vllm_config=vllm_config, | ||
prefix=prefix, | ||
layer_type=Glm4DecoderLayer) | ||
|
||
|
||
class Glm4ForCausalLM(nn.Module, SupportsLoRA, SupportsPP): | ||
packed_modules_mapping = { | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. we also need to add |
||
"qkv_proj": [ | ||
"q_proj", | ||
"k_proj", | ||
"v_proj", | ||
], | ||
"gate_up_proj": [ | ||
"gate_proj", | ||
"up_proj", | ||
], | ||
} | ||
|
||
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): | ||
super().__init__() | ||
config = vllm_config.model_config.hf_config | ||
quant_config = vllm_config.quant_config | ||
lora_config = vllm_config.lora_config | ||
|
||
self.config = config | ||
self.lora_config = lora_config | ||
|
||
self.quant_config = quant_config | ||
self.model = Glm4Model(vllm_config=vllm_config, | ||
prefix=maybe_prefix(prefix, "model")) | ||
|
||
if get_pp_group().is_last_rank: | ||
if config.tie_word_embeddings: | ||
self.lm_head = self.model.embed_tokens | ||
else: | ||
self.lm_head = ParallelLMHead(config.vocab_size, | ||
config.hidden_size, | ||
quant_config=quant_config, | ||
prefix=maybe_prefix( | ||
prefix, "lm_head")) | ||
else: | ||
self.lm_head = PPMissingLayer() | ||
|
||
self.logits_processor = LogitsProcessor(config.vocab_size) | ||
self.sampler = get_sampler() | ||
|
||
self.make_empty_intermediate_tensors = ( | ||
self.model.make_empty_intermediate_tensors) | ||
|
||
def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor: | ||
return self.model.get_input_embeddings(input_ids) | ||
|
||
def forward( | ||
self, | ||
input_ids: torch.Tensor, | ||
positions: torch.Tensor, | ||
intermediate_tensors: Optional[IntermediateTensors] = None, | ||
inputs_embeds: Optional[torch.Tensor] = None, | ||
) -> Union[torch.Tensor, IntermediateTensors]: | ||
hidden_states = self.model(input_ids, positions, intermediate_tensors, | ||
inputs_embeds) | ||
return hidden_states | ||
|
||
def compute_logits( | ||
self, | ||
hidden_states: torch.Tensor, | ||
sampling_metadata: SamplingMetadata, | ||
) -> Optional[torch.Tensor]: | ||
logits = self.logits_processor(self.lm_head, hidden_states, | ||
sampling_metadata) | ||
return logits | ||
|
||
def sample( | ||
self, | ||
logits: torch.Tensor, | ||
sampling_metadata: SamplingMetadata, | ||
) -> Optional[SamplerOutput]: | ||
next_tokens = self.sampler(logits, sampling_metadata) | ||
return next_tokens | ||
|
||
def load_weights(self, weights: Iterable[Tuple[str, | ||
torch.Tensor]]) -> Set[str]: | ||
loader = AutoWeightsLoader( | ||
self, | ||
skip_prefixes=(["lm_head."] | ||
if self.config.tie_word_embeddings else None), | ||
) | ||
return loader.load_weights(weights) |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
we also need to add
"gate_up_proj": ["gate_proj", "up_proj"]
, right?There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
What is the function of this, I found that the model works normally whether it is added or not.