Skip to content

Commit 0814f64

Browse files
MartinGleizemgleize user
authored andcommitted
[Model] Support for fairseq2 Llama (vllm-project#11442)
Signed-off-by: Martin Gleize <[email protected]> Co-authored-by: mgleize user <[email protected]>
1 parent d4d77ca commit 0814f64

File tree

7 files changed

+197
-21
lines changed

7 files changed

+197
-21
lines changed

tests/models/registry.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -69,6 +69,7 @@ class _HfExamplesInfo:
6969
"DeepseekV3ForCausalLM": _HfExamplesInfo("deepseek-ai/DeepSeek-V3", # noqa: E501
7070
trust_remote_code=True),
7171
"ExaoneForCausalLM": _HfExamplesInfo("LGAI-EXAONE/EXAONE-3.0-7.8B-Instruct"), # noqa: E501
72+
"Fairseq2LlamaForCausalLM": _HfExamplesInfo("mgleize/fairseq2-dummy-Llama-3.2-1B"), # noqa: E501
7273
"FalconForCausalLM": _HfExamplesInfo("tiiuae/falcon-7b"),
7374
"GemmaForCausalLM": _HfExamplesInfo("google/gemma-2b"),
7475
"Gemma2ForCausalLM": _HfExamplesInfo("google/gemma-2-9b"),

tests/weight_loading/models.txt

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -30,4 +30,5 @@ marlin, nm-testing/zephyr-beta-7b-marlin-g128, main
3030
marlin, robertgshaw2/zephyr-7b-beta-channelwise-marlin, main
3131
qqq, HandH1998/QQQ-Llama-3-8b-g128, main
3232
qqq, HandH1998/QQQ-Llama-3-8b, main
33-
hqq, nm-testing/Llama-3.2-1B-Instruct-HQQ, main
33+
hqq, nm-testing/Llama-3.2-1B-Instruct-HQQ, main
34+
None, mgleize/fairseq2-dummy-Llama-3.2-1B, main

tests/weight_loading/test_weight_loading.py

Lines changed: 7 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -20,12 +20,13 @@ def test_weight_loading(vllm_runner):
2020
"""
2121
Test parameter weight loading with tp>1.
2222
"""
23-
with vllm_runner(model_name=MODEL_NAME,
24-
revision=REVISION,
25-
dtype=torch.half if QUANTIZATION == "gptq" else "auto",
26-
quantization=QUANTIZATION,
27-
max_model_len=MAX_MODEL_LEN,
28-
tensor_parallel_size=2) as model:
23+
with vllm_runner(
24+
model_name=MODEL_NAME,
25+
revision=REVISION,
26+
dtype=torch.half if QUANTIZATION == "gptq" else "auto",
27+
quantization=None if QUANTIZATION == "None" else QUANTIZATION,
28+
max_model_len=MAX_MODEL_LEN,
29+
tensor_parallel_size=2) as model:
2930

3031
output = model.generate_greedy("Hello world!", max_tokens=20)
3132
print(output)

vllm/model_executor/layers/linear.py

Lines changed: 22 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -344,11 +344,13 @@ def weight_loader(self, param: Parameter, loaded_weight: torch.Tensor):
344344
param.materialize(loaded_weight.shape, dtype=loaded_weight.dtype)
345345

346346
use_bitsandbytes_4bit = getattr(param, "use_bitsandbytes_4bit", False)
347+
is_sharded_weight = getattr(param, "is_sharded_weight", False)
348+
# bitsandbytes loads the weights of the specific portion
349+
# no need to narrow
350+
is_sharded_weight = is_sharded_weight or use_bitsandbytes_4bit
347351

348352
param_data = param.data
349-
# bitsandbytes loads the weights of the specific portion
350-
# no need to narrow here
351-
if output_dim is not None and not use_bitsandbytes_4bit:
353+
if output_dim is not None and not is_sharded_weight:
352354
shard_size = param_data.shape[output_dim]
353355
start_idx = tp_rank * shard_size
354356
loaded_weight = loaded_weight.narrow(output_dim, start_idx,
@@ -546,6 +548,11 @@ def weight_loader(self,
546548

547549
use_bitsandbytes_4bit = getattr(param, "use_bitsandbytes_4bit",
548550
False)
551+
is_sharded_weight = getattr(param, "is_sharded_weight", False)
552+
# bitsandbytes loads the weights of the specific portion
553+
# no need to narrow
554+
is_sharded_weight = is_sharded_weight or use_bitsandbytes_4bit
555+
549556
if use_bitsandbytes_4bit:
550557
shard_size = loaded_weight.shape[output_dim]
551558
shard_offset = loaded_weight.shape[output_dim] * \
@@ -554,9 +561,7 @@ def weight_loader(self,
554561
param_data = param_data.narrow(output_dim, shard_offset,
555562
shard_size)
556563
start_idx = tp_rank * shard_size
557-
# bitsandbytes loads the weights of the specific portion
558-
# no need to narrow here
559-
if not use_bitsandbytes_4bit:
564+
if not is_sharded_weight:
560565
loaded_weight = loaded_weight.narrow(output_dim, start_idx,
561566
shard_size)
562567
# Special case for AQLM codebooks.
@@ -941,6 +946,11 @@ def weight_loader(self,
941946

942947
use_bitsandbytes_4bit = getattr(param, "use_bitsandbytes_4bit",
943948
False)
949+
is_sharded_weight = getattr(param, "is_sharded_weight", False)
950+
# bitsandbytes loads the weights of the specific portion
951+
# no need to narrow
952+
is_sharded_weight = is_sharded_weight or use_bitsandbytes_4bit
953+
944954
if use_bitsandbytes_4bit:
945955
orig_qkv_offsets = {
946956
"q": (0, self.num_heads * self.head_size),
@@ -964,9 +974,7 @@ def weight_loader(self,
964974
shard_id = tp_rank // self.num_kv_head_replicas
965975
start_idx = shard_id * shard_size
966976

967-
# bitsandbytes loads the weights of the specific portion
968-
# no need to narrow here
969-
if not use_bitsandbytes_4bit:
977+
if not is_sharded_weight:
970978
loaded_weight = loaded_weight.narrow(output_dim, start_idx,
971979
shard_size)
972980

@@ -1070,6 +1078,10 @@ def weight_loader(self, param: Parameter, loaded_weight: torch.Tensor):
10701078
tp_size = get_tensor_model_parallel_world_size()
10711079
input_dim = getattr(param, "input_dim", None)
10721080
use_bitsandbytes_4bit = getattr(param, "use_bitsandbytes_4bit", False)
1081+
is_sharded_weight = getattr(param, "is_sharded_weight", False)
1082+
# bitsandbytes loads the weights of the specific portion
1083+
# no need to narrow
1084+
is_sharded_weight = is_sharded_weight or use_bitsandbytes_4bit
10731085

10741086
# Special case for GGUF
10751087
is_gguf_weight = getattr(param, "is_gguf_weight", False)
@@ -1085,9 +1097,7 @@ def weight_loader(self, param: Parameter, loaded_weight: torch.Tensor):
10851097
param.materialize(tuple(weight_shape), dtype=loaded_weight.dtype)
10861098

10871099
param_data = param.data
1088-
# bitsandbytes loads the weights of the specific portion
1089-
# no need to narrow here
1090-
if input_dim is not None and not use_bitsandbytes_4bit:
1100+
if input_dim is not None and not is_sharded_weight:
10911101
shard_size = param_data.shape[input_dim]
10921102
start_idx = tp_rank * shard_size
10931103
loaded_weight = loaded_weight.narrow(input_dim, start_idx,

vllm/model_executor/model_loader/loader.py

Lines changed: 13 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -182,6 +182,9 @@ class Source:
182182
fall_back_to_pt: bool = True
183183
"""Whether .pt weights can be used."""
184184

185+
allow_patterns_overrides: Optional[list[str]] = None
186+
"""If defined, weights will load exclusively using these patterns."""
187+
185188
def __init__(self, load_config: LoadConfig):
186189
super().__init__(load_config)
187190
if load_config.model_loader_extra_config:
@@ -218,6 +221,7 @@ def _prepare_weights(
218221
model_name_or_path: str,
219222
revision: Optional[str],
220223
fall_back_to_pt: bool,
224+
allow_patterns_overrides: Optional[list[str]],
221225
) -> Tuple[str, List[str], bool]:
222226
"""Prepare weights for the model.
223227
@@ -249,6 +253,9 @@ def _prepare_weights(
249253
if fall_back_to_pt:
250254
allow_patterns += ["*.pt"]
251255

256+
if allow_patterns_overrides is not None:
257+
allow_patterns = allow_patterns_overrides
258+
252259
if not is_local:
253260
hf_folder = download_weights_from_hf(
254261
model_name_or_path,
@@ -298,7 +305,8 @@ def _get_weights_iterator(
298305
) -> Generator[Tuple[str, torch.Tensor], None, None]:
299306
"""Get an iterator for the model weights based on the load format."""
300307
hf_folder, hf_weights_files, use_safetensors = self._prepare_weights(
301-
source.model_or_path, source.revision, source.fall_back_to_pt)
308+
source.model_or_path, source.revision, source.fall_back_to_pt,
309+
source.allow_patterns_overrides)
302310
if self.load_config.load_format == LoadFormat.NPCACHE:
303311
# Currently np_cache only support *.bin checkpoints
304312
assert use_safetensors is False
@@ -340,6 +348,8 @@ def _get_all_weights(
340348
prefix="",
341349
fall_back_to_pt=getattr(model, "fall_back_to_pt_during_load",
342350
True),
351+
allow_patterns_overrides=getattr(model, "allow_patterns_overrides",
352+
None),
343353
)
344354
yield from self._get_weights_iterator(primary_weights)
345355

@@ -353,7 +363,8 @@ def _get_all_weights(
353363
def download_model(self, model_config: ModelConfig) -> None:
354364
self._prepare_weights(model_config.model,
355365
model_config.revision,
356-
fall_back_to_pt=True)
366+
fall_back_to_pt=True,
367+
allow_patterns_overrides=None)
357368

358369
def load_model(self, vllm_config: VllmConfig) -> nn.Module:
359370
device_config = vllm_config.device_config
Lines changed: 151 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,151 @@
1+
# Copyright 2024 The vLLM team.
2+
# Copyright 2024 Meta Platforms, Inc. and affiliates. All rights reserved.
3+
#
4+
# Licensed under the Apache License, Version 2.0 (the "License");
5+
# you may not use this file except in compliance with the License.
6+
# You may obtain a copy of the License at
7+
#
8+
# http://www.apache.org/licenses/LICENSE-2.0
9+
#
10+
# Unless required by applicable law or agreed to in writing, software
11+
# distributed under the License is distributed on an "AS IS" BASIS,
12+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
# See the License for the specific language governing permissions and
14+
# limitations under the License.
15+
"""Llama model for fairseq2 weights."""
16+
17+
from typing import Iterable, Set, Tuple
18+
19+
import torch
20+
from torch.nn import Parameter
21+
22+
from vllm.config import VllmConfig
23+
from vllm.distributed import (get_tensor_model_parallel_rank,
24+
get_tensor_model_parallel_world_size)
25+
from vllm.model_executor.layers.linear import set_weight_attrs
26+
from vllm.model_executor.models.llama import LlamaForCausalLM
27+
28+
from .utils import AutoWeightsLoader, WeightsMapper
29+
30+
31+
class Fairseq2LlamaForCausalLM(LlamaForCausalLM):
32+
33+
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
34+
super().__init__(vllm_config=vllm_config, prefix=prefix)
35+
self.tp_rank = get_tensor_model_parallel_rank()
36+
self.tp_size = get_tensor_model_parallel_world_size()
37+
# For the model loader to read only the relevant checkpoint files
38+
self.allow_patterns_overrides = [
39+
# either the full checkpoint
40+
"model.pt",
41+
# or the tp-sharded checkpoint of the current rank
42+
f"model.{self.tp_rank}.pt",
43+
]
44+
45+
def load_weights(self, weights: Iterable[Tuple[str,
46+
torch.Tensor]]) -> Set[str]:
47+
# fairseq2's serialization adds a wrapper to usual .pt state_dict's:
48+
# { "model_key": my_model_name, "my_model_name": state_dict }
49+
# which we first need to unpack
50+
weights_wrapped = dict(weights)
51+
weights = weights_wrapped[
52+
weights_wrapped["model_key"]].items() # type: ignore
53+
54+
# remap keys
55+
fs2_to_vllm_mapper = WeightsMapper(
56+
orig_to_new_prefix={
57+
"decoder_frontend.embed.": "model.embed_tokens.",
58+
"decoder.": "model.",
59+
"final_proj.": "lm_head.",
60+
},
61+
orig_to_new_substr={
62+
".self_attn_layer_norm.": ".input_layernorm.",
63+
".ffn_layer_norm.": ".post_attention_layernorm.",
64+
".self_attn.output_proj.": ".self_attn.o_proj.",
65+
".ffn.gate_proj.": ".mlp.gate_proj.",
66+
".ffn.inner_proj.": ".mlp.up_proj.",
67+
".ffn.output_proj.": ".mlp.down_proj.",
68+
".layer_norm.": ".norm.",
69+
},
70+
)
71+
weights = fs2_to_vllm_mapper.apply(weights)
72+
73+
params = dict(self.named_parameters())
74+
75+
loader = AutoWeightsLoader(
76+
self,
77+
skip_prefixes=(["lm_head."]
78+
if self.config.tie_word_embeddings else None),
79+
)
80+
return loader.load_weights(
81+
(self.reshape_fairseq2_weights(name, loaded_weight, params)
82+
for name, loaded_weight in weights))
83+
84+
def flag_sharded_weights(self, params: dict[str, Parameter]):
85+
"""Sets the `is_sharded_weight` flag to True for all sharded weights"""
86+
for name, param in params.items():
87+
modules = name.split(".")
88+
if "norm" in name and len(param.size()) < 2:
89+
# layer norms are not sharded
90+
continue
91+
elif any(emb in modules for emb in ["embed_tokens", "lm_head"]):
92+
# for now we repeat embedding layers for compatibility
93+
continue
94+
else:
95+
# all other layers are sharded
96+
set_weight_attrs(param, {"is_sharded_weight": True})
97+
98+
def reshape_fairseq2_weights(
99+
self,
100+
name: str,
101+
loaded_weight: torch.Tensor,
102+
params: dict[str, Parameter],
103+
) -> Tuple[str, torch.Tensor]:
104+
"""Reshape fairseq2's weights."""
105+
106+
def permute(w: torch.Tensor, n_heads: int) -> torch.Tensor:
107+
attn_in = self.config.head_dim * n_heads
108+
# check for a sharded weight on dim 0
109+
if attn_in // self.tp_size == w.size()[0]:
110+
attn_in //= self.tp_size
111+
n_heads //= self.tp_size
112+
attn_out = self.config.hidden_size
113+
return (w.view(n_heads, attn_in // n_heads // 2, 2,
114+
attn_out).transpose(1,
115+
2).reshape(attn_in, attn_out))
116+
117+
modules = name.split(".")
118+
119+
# rotary embeds should be sliced
120+
if "k_proj" in modules:
121+
loaded_weight = permute(loaded_weight,
122+
self.config.num_key_value_heads)
123+
124+
elif "q_proj" in modules:
125+
loaded_weight = permute(loaded_weight,
126+
self.config.num_attention_heads)
127+
128+
# We make the loaded weights compatible with both
129+
# full checkpoints and tp sharded checkpoints.
130+
# Embeddings are repeated to fit the vocab size.
131+
# Other weights are flagged for the weight_loader calls.
132+
if any(emb in modules for emb in ["embed_tokens", "lm_head"]):
133+
# Embeddings are sharded on dim 0
134+
dim = 0
135+
# In fairseq2, vocab size has to be divisible by tp_size
136+
# so we don't worry about padding
137+
if self.tp_size > 1 and loaded_weight.shape[
138+
dim] < self.config.vocab_size:
139+
assert loaded_weight.shape[
140+
dim] * self.tp_size == self.config.vocab_size, \
141+
"vocab_size should be divisible by tp_size."
142+
repeats = [1] * len(loaded_weight.size())
143+
repeats[dim] = self.tp_size
144+
# repeat to match vocab size and to be easily 'narrow'able
145+
loaded_weight = loaded_weight.repeat(repeats)
146+
set_weight_attrs(params[name], {"is_sharded_weight": False})
147+
# if embeddings are sharded, the rest is too
148+
if "embed_tokens" in modules:
149+
self.flag_sharded_weights(params)
150+
151+
return name, loaded_weight

vllm/model_executor/models/registry.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -47,6 +47,7 @@
4747
"DeepseekV3ForCausalLM": ("deepseek_v3", "DeepseekV3ForCausalLM"),
4848
"ExaoneForCausalLM": ("exaone", "ExaoneForCausalLM"),
4949
"FalconForCausalLM": ("falcon", "FalconForCausalLM"),
50+
"Fairseq2LlamaForCausalLM": ("fairseq2_llama", "Fairseq2LlamaForCausalLM"),
5051
"GemmaForCausalLM": ("gemma", "GemmaForCausalLM"),
5152
"Gemma2ForCausalLM": ("gemma2", "Gemma2ForCausalLM"),
5253
"GlmForCausalLM": ("glm", "GlmForCausalLM"),

0 commit comments

Comments
 (0)