Skip to content

Commit 38f5fd6

Browse files
committed
[Model] use AutoWeightsLoader for baichuan, gpt-neox, mpt
Signed-off-by: Jonghyun Choe <[email protected]>
1 parent 2529378 commit 38f5fd6

File tree

3 files changed

+115
-100
lines changed

3 files changed

+115
-100
lines changed

vllm/model_executor/models/baichuan.py

Lines changed: 53 additions & 48 deletions
Original file line numberDiff line numberDiff line change
@@ -47,7 +47,7 @@
4747
from vllm.sequence import IntermediateTensors
4848

4949
from .interfaces import SupportsLoRA, SupportsPP, SupportsQuant
50-
from .utils import (is_pp_missing_parameter,
50+
from .utils import (AutoWeightsLoader, is_pp_missing_parameter,
5151
make_empty_intermediate_tensors_factory, make_layers)
5252

5353

@@ -321,6 +321,56 @@ def forward(
321321
hidden_states, _ = self.norm(hidden_states, residual)
322322
return hidden_states
323323

324+
def load_weights(self, weights: Iterable[Tuple[str,
325+
torch.Tensor]]) -> Set[str]:
326+
stacked_params_mapping = [
327+
# (param_name, shard_name, shard_id)
328+
("gate_up_proj", "gate_proj", 0),
329+
("gate_up_proj", "up_proj", 1),
330+
]
331+
params_dict = dict(self.named_parameters())
332+
loaded_params: Set[str] = set()
333+
for name, loaded_weight in weights:
334+
if "rotary_emb.inv_freq" in name:
335+
continue
336+
if name == "lm_head.weight":
337+
# Unlike Baichuan, Baichuan2 normalizes the head weights.
338+
# Refer to:
339+
# https://huggingface.co/baichuan-inc/Baichuan2-7B-Chat/blob/84603cde5ebffb6084e476cfaeceaf0b8b91fe54/modeling_baichuan.py#L508
340+
# Distinguish between Baichuan and Baichuan2 by checking the
341+
# vocab size. This is suggested by
342+
# https://github.com/vllm-project/vllm/pull/1022#discussion_r1325652704
343+
is_baichuan2 = self.config.vocab_size == 125696
344+
if is_baichuan2:
345+
loaded_weight = torch.nn.functional.normalize(
346+
loaded_weight)
347+
348+
for (param_name, weight_name, shard_id) in stacked_params_mapping:
349+
if weight_name not in name:
350+
continue
351+
name = name.replace(weight_name, param_name)
352+
# Skip loading extra bias for GPTQ models.
353+
if name.endswith(".bias") and name not in params_dict:
354+
continue
355+
if is_pp_missing_parameter(name, self):
356+
continue
357+
param = params_dict[name]
358+
weight_loader = param.weight_loader
359+
weight_loader(param, loaded_weight, shard_id)
360+
break
361+
else:
362+
# Skip loading extra bias for GPTQ models.
363+
if name.endswith(".bias") and name not in params_dict:
364+
continue
365+
if is_pp_missing_parameter(name, self):
366+
continue
367+
param = params_dict[name]
368+
weight_loader = getattr(param, "weight_loader",
369+
default_weight_loader)
370+
weight_loader(param, loaded_weight)
371+
loaded_params.add(name)
372+
return loaded_params
373+
324374

325375
class BaiChuanBaseForCausalLM(nn.Module, SupportsLoRA, SupportsPP,
326376
SupportsQuant):
@@ -393,53 +443,8 @@ def sample(
393443

394444
def load_weights(self, weights: Iterable[Tuple[str,
395445
torch.Tensor]]) -> Set[str]:
396-
stacked_params_mapping = [
397-
# (param_name, shard_name, shard_id)
398-
("gate_up_proj", "gate_proj", 0),
399-
("gate_up_proj", "up_proj", 1),
400-
]
401-
params_dict = dict(self.named_parameters())
402-
loaded_params: Set[str] = set()
403-
for name, loaded_weight in weights:
404-
if "rotary_emb.inv_freq" in name:
405-
continue
406-
if name == "lm_head.weight":
407-
# Unlike Baichuan, Baichuan2 normalizes the head weights.
408-
# Refer to:
409-
# https://huggingface.co/baichuan-inc/Baichuan2-7B-Chat/blob/84603cde5ebffb6084e476cfaeceaf0b8b91fe54/modeling_baichuan.py#L508
410-
# Distinguish between Baichuan and Baichuan2 by checking the
411-
# vocab size. This is suggested by
412-
# https://github.com/vllm-project/vllm/pull/1022#discussion_r1325652704
413-
is_baichuan2 = self.config.vocab_size == 125696
414-
if is_baichuan2:
415-
loaded_weight = torch.nn.functional.normalize(
416-
loaded_weight)
417-
418-
for (param_name, weight_name, shard_id) in stacked_params_mapping:
419-
if weight_name not in name:
420-
continue
421-
name = name.replace(weight_name, param_name)
422-
# Skip loading extra bias for GPTQ models.
423-
if name.endswith(".bias") and name not in params_dict:
424-
continue
425-
if is_pp_missing_parameter(name, self):
426-
continue
427-
param = params_dict[name]
428-
weight_loader = param.weight_loader
429-
weight_loader(param, loaded_weight, shard_id)
430-
break
431-
else:
432-
# Skip loading extra bias for GPTQ models.
433-
if name.endswith(".bias") and name not in params_dict:
434-
continue
435-
if is_pp_missing_parameter(name, self):
436-
continue
437-
param = params_dict[name]
438-
weight_loader = getattr(param, "weight_loader",
439-
default_weight_loader)
440-
weight_loader(param, loaded_weight)
441-
loaded_params.add(name)
442-
return loaded_params
446+
loader = AutoWeightsLoader(self)
447+
return loader.load_weights(weights)
443448

444449

445450
class BaichuanForCausalLM(BaiChuanBaseForCausalLM):

vllm/model_executor/models/gpt_neox.py

Lines changed: 42 additions & 37 deletions
Original file line numberDiff line numberDiff line change
@@ -42,7 +42,7 @@
4242
from vllm.sequence import IntermediateTensors
4343

4444
from .interfaces import SupportsPP
45-
from .utils import (is_pp_missing_parameter,
45+
from .utils import (AutoWeightsLoader, is_pp_missing_parameter,
4646
make_empty_intermediate_tensors_factory, make_layers,
4747
maybe_prefix)
4848

@@ -241,6 +241,45 @@ def forward(
241241
hidden_states = self.final_layer_norm(hidden_states)
242242
return hidden_states
243243

244+
def load_weights(self, weights: Iterable[Tuple[str,
245+
torch.Tensor]]) -> Set[str]:
246+
params_dict = dict(self.named_parameters())
247+
loaded_params: Set[str] = set()
248+
for name, loaded_weight in weights:
249+
if ("attention.bias" in name or "attention.masked_bias" in name
250+
or "rotary_emb.inv_freq" in name):
251+
continue
252+
if ("rotary_emb.cos_cached" in name
253+
or "rotary_emb.sin_cached" in name):
254+
# Models trained using OpenRLHF may include
255+
# these tensors in the checkpoint. Skip them.
256+
continue
257+
if is_pp_missing_parameter(name, self):
258+
continue
259+
param = params_dict[name]
260+
261+
if "query_key_value" in name:
262+
# NOTE: GPT-NeoX's fused QKV's output_dim has the shape of
263+
# (num_heads * 3 * head_size), while the
264+
# required shape is (3 * num_heads * head_size).
265+
# Thus, we need weight conversion.
266+
output_dim = getattr(param, "output_dim", None)
267+
num_heads = self.config.num_attention_heads
268+
if output_dim is not None:
269+
loaded_weight_shape = loaded_weight.shape
270+
loaded_weight = loaded_weight.view(
271+
loaded_weight_shape[:output_dim] + (num_heads, 3, -1) +
272+
loaded_weight_shape[output_dim + 1:])
273+
loaded_weight = loaded_weight.transpose(
274+
output_dim, output_dim + 1)
275+
loaded_weight = loaded_weight.reshape(loaded_weight_shape)
276+
277+
weight_loader = getattr(param, "weight_loader",
278+
default_weight_loader)
279+
weight_loader(param, loaded_weight)
280+
loaded_params.add(name)
281+
return loaded_params
282+
244283

245284
class GPTNeoXForCausalLM(nn.Module, SupportsPP):
246285

@@ -297,39 +336,5 @@ def sample(
297336

298337
def load_weights(self, weights: Iterable[Tuple[str,
299338
torch.Tensor]]) -> Set[str]:
300-
params_dict = dict(self.named_parameters())
301-
loaded_params: Set[str] = set()
302-
for name, loaded_weight in weights:
303-
if ("attention.bias" in name or "attention.masked_bias" in name
304-
or "rotary_emb.inv_freq" in name):
305-
continue
306-
if ("rotary_emb.cos_cached" in name
307-
or "rotary_emb.sin_cached" in name):
308-
# Models trained using OpenRLHF may include
309-
# these tensors in the checkpoint. Skip them.
310-
continue
311-
if is_pp_missing_parameter(name, self):
312-
continue
313-
param = params_dict[name]
314-
315-
if "query_key_value" in name:
316-
# NOTE: GPT-NeoX's fused QKV's output_dim has the shape of
317-
# (num_heads * 3 * head_size), while the
318-
# required shape is (3 * num_heads * head_size).
319-
# Thus, we need weight conversion.
320-
output_dim = getattr(param, "output_dim", None)
321-
num_heads = self.config.num_attention_heads
322-
if output_dim is not None:
323-
loaded_weight_shape = loaded_weight.shape
324-
loaded_weight = loaded_weight.view(
325-
loaded_weight_shape[:output_dim] + (num_heads, 3, -1) +
326-
loaded_weight_shape[output_dim + 1:])
327-
loaded_weight = loaded_weight.transpose(
328-
output_dim, output_dim + 1)
329-
loaded_weight = loaded_weight.reshape(loaded_weight_shape)
330-
331-
weight_loader = getattr(param, "weight_loader",
332-
default_weight_loader)
333-
weight_loader(param, loaded_weight)
334-
loaded_params.add(name)
335-
return loaded_params
339+
loader = AutoWeightsLoader(self)
340+
return loader.load_weights(weights)

vllm/model_executor/models/mpt.py

Lines changed: 20 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,7 @@
2727
from vllm.transformers_utils.configs.mpt import MPTConfig
2828

2929
from .interfaces import SupportsPP
30-
from .utils import (is_pp_missing_parameter,
30+
from .utils import (AutoWeightsLoader, is_pp_missing_parameter,
3131
make_empty_intermediate_tensors_factory, make_layers,
3232
maybe_prefix)
3333

@@ -266,6 +266,23 @@ def forward(
266266
hidden_states = self.norm_f(hidden_states)
267267
return hidden_states
268268

269+
def load_weights(self, weights: Iterable[Tuple[str,
270+
torch.Tensor]]) -> Set[str]:
271+
params_dict = dict(self.named_parameters(remove_duplicate=False))
272+
loaded_params: Set[str] = set()
273+
for name, loaded_weight in weights:
274+
# Skip loading extra bias for GPTQ models.
275+
if name.endswith(".bias") and name not in params_dict:
276+
continue
277+
if is_pp_missing_parameter(name, self):
278+
continue
279+
param = params_dict[name]
280+
weight_loader = getattr(param, "weight_loader",
281+
default_weight_loader)
282+
weight_loader(param, loaded_weight)
283+
loaded_params.add(name)
284+
return loaded_params
285+
269286

270287
class MPTForCausalLM(nn.Module, SupportsPP):
271288

@@ -318,17 +335,5 @@ def sample(
318335

319336
def load_weights(self, weights: Iterable[Tuple[str,
320337
torch.Tensor]]) -> Set[str]:
321-
params_dict = dict(self.named_parameters(remove_duplicate=False))
322-
loaded_params: Set[str] = set()
323-
for name, loaded_weight in weights:
324-
# Skip loading extra bias for GPTQ models.
325-
if name.endswith(".bias") and name not in params_dict:
326-
continue
327-
if is_pp_missing_parameter(name, self):
328-
continue
329-
param = params_dict[name]
330-
weight_loader = getattr(param, "weight_loader",
331-
default_weight_loader)
332-
weight_loader(param, loaded_weight)
333-
loaded_params.add(name)
334-
return loaded_params
338+
loader = AutoWeightsLoader(self)
339+
return loader.load_weights(weights)

0 commit comments

Comments
 (0)