Skip to content

Commit 279a0c8

Browse files
RunningLeonrasmith
authored andcommitted
[Model]: Support internlm3 (vllm-project#12037)
1 parent 602b5d4 commit 279a0c8

File tree

4 files changed

+28
-15
lines changed

4 files changed

+28
-15
lines changed

docs/source/models/supported_models.md

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -216,6 +216,11 @@ See [this page](#generative-models) for more information on how to use generativ
216216
- `internlm/internlm2-7b`, `internlm/internlm2-chat-7b`, etc.
217217
- ✅︎
218218
- ✅︎
219+
* - `InternLM3ForCausalLM`
220+
- InternLM3
221+
- `internlm/internlm3-8b-instruct`, etc.
222+
- ✅︎
223+
- ✅︎
219224
* - `JAISLMHeadModel`
220225
- Jais
221226
- `inceptionai/jais-13b`, `inceptionai/jais-13b-chat`, `inceptionai/jais-30b-v3`, `inceptionai/jais-30b-chat-v3`, etc.

tests/models/registry.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -85,6 +85,8 @@ class _HfExamplesInfo:
8585
trust_remote_code=True),
8686
"InternLM2VEForCausalLM": _HfExamplesInfo("OpenGVLab/Mono-InternVL-2B",
8787
trust_remote_code=True),
88+
"InternLM3ForCausalLM": _HfExamplesInfo("internlm/internlm3-8b-instruct",
89+
trust_remote_code=True),
8890
"JAISLMHeadModel": _HfExamplesInfo("inceptionai/jais-13b-chat"),
8991
"JambaForCausalLM": _HfExamplesInfo("ai21labs/AI21-Jamba-1.5-Mini"),
9092
"LlamaForCausalLM": _HfExamplesInfo("meta-llama/Meta-Llama-3-8B"),

vllm/model_executor/models/llama.py

Lines changed: 20 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -97,20 +97,19 @@ def forward(self, x):
9797

9898
class LlamaAttention(nn.Module):
9999

100-
def __init__(
101-
self,
102-
config: LlamaConfig,
103-
hidden_size: int,
104-
num_heads: int,
105-
num_kv_heads: int,
106-
rope_theta: float = 10000,
107-
rope_scaling: Optional[Dict[str, Any]] = None,
108-
max_position_embeddings: int = 8192,
109-
quant_config: Optional[QuantizationConfig] = None,
110-
bias: bool = False,
111-
cache_config: Optional[CacheConfig] = None,
112-
prefix: str = "",
113-
) -> None:
100+
def __init__(self,
101+
config: LlamaConfig,
102+
hidden_size: int,
103+
num_heads: int,
104+
num_kv_heads: int,
105+
rope_theta: float = 10000,
106+
rope_scaling: Optional[Dict[str, Any]] = None,
107+
max_position_embeddings: int = 8192,
108+
quant_config: Optional[QuantizationConfig] = None,
109+
bias: bool = False,
110+
cache_config: Optional[CacheConfig] = None,
111+
prefix: str = "",
112+
bias_o_proj: bool = False) -> None:
114113
super().__init__()
115114
layer_idx = extract_layer_index(prefix)
116115
self.hidden_size = hidden_size
@@ -150,7 +149,7 @@ def __init__(
150149
self.o_proj = RowParallelLinear(
151150
input_size=self.total_num_heads * self.head_dim,
152151
output_size=hidden_size,
153-
bias=bias,
152+
bias=bias_o_proj,
154153
quant_config=quant_config,
155154
prefix=f"{prefix}.o_proj",
156155
)
@@ -232,6 +231,11 @@ def __init__(
232231
# Support internlm/internlm-7b with bias
233232
attention_bias = getattr(config, "attention_bias", False) or getattr(
234233
config, "bias", False)
234+
bias_o_proj = attention_bias
235+
# support internlm/internlm3-8b with qkv_bias
236+
if hasattr(config, 'qkv_bias'):
237+
attention_bias = config.qkv_bias
238+
235239
self.self_attn = LlamaAttention(
236240
config=config,
237241
hidden_size=self.hidden_size,
@@ -243,6 +247,7 @@ def __init__(
243247
max_position_embeddings=max_position_embeddings,
244248
quant_config=quant_config,
245249
bias=attention_bias,
250+
bias_o_proj=bias_o_proj,
246251
cache_config=cache_config,
247252
prefix=f"{prefix}.self_attn",
248253
)

vllm/model_executor/models/registry.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -60,6 +60,7 @@
6060
"InternLMForCausalLM": ("llama", "LlamaForCausalLM"),
6161
"InternLM2ForCausalLM": ("internlm2", "InternLM2ForCausalLM"),
6262
"InternLM2VEForCausalLM": ("internlm2_ve", "InternLM2VEForCausalLM"),
63+
"InternLM3ForCausalLM": ("llama", "LlamaForCausalLM"),
6364
"JAISLMHeadModel": ("jais", "JAISLMHeadModel"),
6465
"JambaForCausalLM": ("jamba", "JambaForCausalLM"),
6566
"LlamaForCausalLM": ("llama", "LlamaForCausalLM"),

0 commit comments

Comments
 (0)