Skip to content

Commit 8d93b25

Browse files
jonghyunchoeMu Huai
authored and
Mu Huai
committed
[Model] use AutoWeightsLoader for phi, gemma, deepseek (vllm-project#16088)
Signed-off-by: Jonghyun Choe <[email protected]> Signed-off-by: Mu Huai <[email protected]>
1 parent cd8369f commit 8d93b25

File tree

3 files changed

+147
-131
lines changed

3 files changed

+147
-131
lines changed

vllm/model_executor/models/deepseek.py

Lines changed: 54 additions & 48 deletions
Original file line numberDiff line numberDiff line change
@@ -51,7 +51,8 @@
5151
from vllm.sequence import IntermediateTensors
5252

5353
from .interfaces import SupportsPP
54-
from .utils import (extract_layer_index, is_pp_missing_parameter,
54+
from .utils import (AutoWeightsLoader, extract_layer_index,
55+
is_pp_missing_parameter,
5556
make_empty_intermediate_tensors_factory, make_layers,
5657
maybe_prefix)
5758

@@ -385,6 +386,56 @@ def forward(
385386
hidden_states, _ = self.norm(hidden_states, residual)
386387
return hidden_states
387388

389+
def load_weights(self, weights: Iterable[Tuple[str,
390+
torch.Tensor]]) -> Set[str]:
391+
stacked_params_mapping = [
392+
# (param_name, shard_name, shard_id)
393+
("qkv_proj", "q_proj", "q"),
394+
("qkv_proj", "k_proj", "k"),
395+
("qkv_proj", "v_proj", "v"),
396+
("gate_up_proj", "gate_proj", 0),
397+
("gate_up_proj", "up_proj", 1),
398+
]
399+
400+
params_dict = dict(self.named_parameters())
401+
loaded_params: Set[str] = set()
402+
for name, loaded_weight in weights:
403+
if "rotary_emb.inv_freq" in name:
404+
continue
405+
for (param_name, weight_name, shard_id) in stacked_params_mapping:
406+
if weight_name not in name:
407+
continue
408+
name = name.replace(weight_name, param_name)
409+
# Skip loading extra bias for GPTQ models.
410+
if name.endswith(".bias") and name not in params_dict:
411+
continue
412+
# Skip experts that are not assigned to this worker.
413+
if (("mlp.experts." in name or "mlp.shared_experts." in name)
414+
and name not in params_dict):
415+
continue
416+
if is_pp_missing_parameter(name, self):
417+
continue
418+
param = params_dict[name]
419+
weight_loader = param.weight_loader
420+
weight_loader(param, loaded_weight, shard_id)
421+
break
422+
else:
423+
# Skip loading extra bias for GPTQ models.
424+
if name.endswith(".bias") and name not in params_dict:
425+
continue
426+
# Skip experts that are not assigned to this worker.
427+
if (("mlp.experts." in name or "mlp.shared_experts." in name)
428+
and name not in params_dict):
429+
continue
430+
if is_pp_missing_parameter(name, self):
431+
continue
432+
param = params_dict[name]
433+
weight_loader = getattr(param, "weight_loader",
434+
default_weight_loader)
435+
weight_loader(param, loaded_weight)
436+
loaded_params.add(name)
437+
return loaded_params
438+
388439

389440
class DeepseekForCausalLM(nn.Module, SupportsPP):
390441

@@ -439,50 +490,5 @@ def sample(
439490

440491
def load_weights(self, weights: Iterable[Tuple[str,
441492
torch.Tensor]]) -> Set[str]:
442-
stacked_params_mapping = [
443-
# (param_name, shard_name, shard_id)
444-
("qkv_proj", "q_proj", "q"),
445-
("qkv_proj", "k_proj", "k"),
446-
("qkv_proj", "v_proj", "v"),
447-
("gate_up_proj", "gate_proj", 0),
448-
("gate_up_proj", "up_proj", 1),
449-
]
450-
451-
params_dict = dict(self.named_parameters())
452-
loaded_params: Set[str] = set()
453-
for name, loaded_weight in weights:
454-
if "rotary_emb.inv_freq" in name:
455-
continue
456-
for (param_name, weight_name, shard_id) in stacked_params_mapping:
457-
if weight_name not in name:
458-
continue
459-
name = name.replace(weight_name, param_name)
460-
# Skip loading extra bias for GPTQ models.
461-
if name.endswith(".bias") and name not in params_dict:
462-
continue
463-
# Skip experts that are not assigned to this worker.
464-
if (("mlp.experts." in name or "mlp.shared_experts." in name)
465-
and name not in params_dict):
466-
continue
467-
if is_pp_missing_parameter(name, self):
468-
continue
469-
param = params_dict[name]
470-
weight_loader = param.weight_loader
471-
weight_loader(param, loaded_weight, shard_id)
472-
break
473-
else:
474-
# Skip loading extra bias for GPTQ models.
475-
if name.endswith(".bias") and name not in params_dict:
476-
continue
477-
# Skip experts that are not assigned to this worker.
478-
if (("mlp.experts." in name or "mlp.shared_experts." in name)
479-
and name not in params_dict):
480-
continue
481-
if is_pp_missing_parameter(name, self):
482-
continue
483-
param = params_dict[name]
484-
weight_loader = getattr(param, "weight_loader",
485-
default_weight_loader)
486-
weight_loader(param, loaded_weight)
487-
loaded_params.add(name)
488-
return loaded_params
493+
loader = AutoWeightsLoader(self)
494+
return loader.load_weights(weights)

vllm/model_executor/models/gemma.py

Lines changed: 47 additions & 42 deletions
Original file line numberDiff line numberDiff line change
@@ -43,7 +43,7 @@
4343
from vllm.sequence import IntermediateTensors
4444

4545
from .interfaces import SupportsLoRA, SupportsPP
46-
from .utils import (is_pp_missing_parameter,
46+
from .utils import (AutoWeightsLoader, is_pp_missing_parameter,
4747
make_empty_intermediate_tensors_factory, make_layers,
4848
maybe_prefix)
4949

@@ -319,6 +319,46 @@ def forward(
319319
hidden_states, _ = self.norm(hidden_states, residual)
320320
return hidden_states
321321

322+
def load_weights(self, weights: Iterable[Tuple[str,
323+
torch.Tensor]]) -> Set[str]:
324+
stacked_params_mapping = [
325+
# (param_name, shard_name, shard_id)
326+
("qkv_proj", "q_proj", "q"),
327+
("qkv_proj", "k_proj", "k"),
328+
("qkv_proj", "v_proj", "v"),
329+
("gate_up_proj", "gate_proj", 0),
330+
("gate_up_proj", "up_proj", 1),
331+
]
332+
params_dict = dict(self.named_parameters())
333+
loaded_params: Set[str] = set()
334+
for name, loaded_weight in weights:
335+
for (param_name, shard_name, shard_id) in stacked_params_mapping:
336+
if shard_name not in name:
337+
continue
338+
name = name.replace(shard_name, param_name)
339+
# Skip loading extra bias for GPTQ models.
340+
if name.endswith(".bias") and name not in params_dict:
341+
continue
342+
if is_pp_missing_parameter(name, self):
343+
continue
344+
param = params_dict[name]
345+
weight_loader = param.weight_loader
346+
weight_loader(param, loaded_weight, shard_id)
347+
break
348+
else:
349+
# Skip loading extra bias for GPTQ models.
350+
if name.endswith(".bias") and name not in params_dict:
351+
continue
352+
if is_pp_missing_parameter(name, self):
353+
continue
354+
param = params_dict[name]
355+
weight_loader = getattr(param, "weight_loader",
356+
default_weight_loader)
357+
weight_loader(param, loaded_weight)
358+
loaded_params.add(name)
359+
360+
return loaded_params
361+
322362

323363
class GemmaForCausalLM(nn.Module, SupportsLoRA, SupportsPP):
324364
packed_modules_mapping = {
@@ -385,44 +425,9 @@ def sample(
385425

386426
def load_weights(self, weights: Iterable[Tuple[str,
387427
torch.Tensor]]) -> Set[str]:
388-
stacked_params_mapping = [
389-
# (param_name, shard_name, shard_id)
390-
("qkv_proj", "q_proj", "q"),
391-
("qkv_proj", "k_proj", "k"),
392-
("qkv_proj", "v_proj", "v"),
393-
("gate_up_proj", "gate_proj", 0),
394-
("gate_up_proj", "up_proj", 1),
395-
]
396-
params_dict = dict(self.named_parameters())
397-
loaded_params: Set[str] = set()
398-
for name, loaded_weight in weights:
399-
for (param_name, shard_name, shard_id) in stacked_params_mapping:
400-
if shard_name not in name:
401-
continue
402-
name = name.replace(shard_name, param_name)
403-
# Skip loading extra bias for GPTQ models.
404-
if name.endswith(".bias") and name not in params_dict:
405-
continue
406-
if is_pp_missing_parameter(name, self):
407-
continue
408-
param = params_dict[name]
409-
weight_loader = param.weight_loader
410-
weight_loader(param, loaded_weight, shard_id)
411-
break
412-
else:
413-
# lm_head is not used in vllm as it is tied with embed_token.
414-
# To prevent errors, skip loading lm_head.weight.
415-
if "lm_head.weight" in name:
416-
continue
417-
# Skip loading extra bias for GPTQ models.
418-
if name.endswith(".bias") and name not in params_dict:
419-
continue
420-
if is_pp_missing_parameter(name, self):
421-
continue
422-
param = params_dict[name]
423-
weight_loader = getattr(param, "weight_loader",
424-
default_weight_loader)
425-
weight_loader(param, loaded_weight)
426-
loaded_params.add(name)
427-
428-
return loaded_params
428+
loader = AutoWeightsLoader(
429+
self,
430+
skip_prefixes=(["lm_head."]
431+
if self.config.tie_word_embeddings else None),
432+
)
433+
return loader.load_weights(weights)

vllm/model_executor/models/phi.py

Lines changed: 46 additions & 41 deletions
Original file line numberDiff line numberDiff line change
@@ -61,7 +61,7 @@
6161
from vllm.sequence import IntermediateTensors
6262

6363
from .interfaces import SupportsLoRA, SupportsPP
64-
from .utils import (is_pp_missing_parameter,
64+
from .utils import (AutoWeightsLoader, is_pp_missing_parameter,
6565
make_empty_intermediate_tensors_factory, make_layers,
6666
maybe_prefix)
6767

@@ -249,6 +249,49 @@ def forward(
249249

250250
return hidden_states
251251

252+
def load_weights(self, weights: Iterable[Tuple[str,
253+
torch.Tensor]]) -> Set[str]:
254+
stacked_params_mapping = [
255+
# (param_name, shard_name, shard_id)
256+
("qkv_proj", "q_proj", "q"),
257+
("qkv_proj", "k_proj", "k"),
258+
("qkv_proj", "v_proj", "v")
259+
]
260+
params_dict = dict(self.named_parameters())
261+
loaded_params: Set[str] = set()
262+
263+
for name, loaded_weight in weights:
264+
if "rotary_emb.inv_freq" in name:
265+
continue
266+
267+
for (param_name, weight_name, shard_id) in stacked_params_mapping:
268+
if weight_name not in name:
269+
continue
270+
name = name.replace(weight_name, param_name)
271+
# Skip loading extra bias for GPTQ models.
272+
if name.endswith(".bias") and name not in params_dict:
273+
continue
274+
if is_pp_missing_parameter(name, self):
275+
continue
276+
param = params_dict[name]
277+
weight_loader = param.weight_loader
278+
weight_loader(param, loaded_weight, shard_id)
279+
break
280+
else:
281+
# Skip loading extra bias for GPTQ models.
282+
if name.endswith(".bias") and name not in params_dict:
283+
continue
284+
# pylint: disable=E1136
285+
286+
if is_pp_missing_parameter(name, self):
287+
continue
288+
param = params_dict[name]
289+
weight_loader = getattr(param, "weight_loader",
290+
default_weight_loader)
291+
weight_loader(param, loaded_weight)
292+
loaded_params.add(name)
293+
return loaded_params
294+
252295

253296
class PhiForCausalLM(nn.Module, SupportsLoRA, SupportsPP):
254297
packed_modules_mapping = {
@@ -317,43 +360,5 @@ def sample(
317360

318361
def load_weights(self, weights: Iterable[Tuple[str,
319362
torch.Tensor]]) -> Set[str]:
320-
stacked_params_mapping = [
321-
# (param_name, shard_name, shard_id)
322-
("qkv_proj", "q_proj", "q"),
323-
("qkv_proj", "k_proj", "k"),
324-
("qkv_proj", "v_proj", "v")
325-
]
326-
params_dict = dict(self.named_parameters())
327-
loaded_params: Set[str] = set()
328-
329-
for name, loaded_weight in weights:
330-
if "rotary_emb.inv_freq" in name:
331-
continue
332-
333-
for (param_name, weight_name, shard_id) in stacked_params_mapping:
334-
if weight_name not in name:
335-
continue
336-
name = name.replace(weight_name, param_name)
337-
# Skip loading extra bias for GPTQ models.
338-
if name.endswith(".bias") and name not in params_dict:
339-
continue
340-
if is_pp_missing_parameter(name, self):
341-
continue
342-
param = params_dict[name]
343-
weight_loader = param.weight_loader
344-
weight_loader(param, loaded_weight, shard_id)
345-
break
346-
else:
347-
# Skip loading extra bias for GPTQ models.
348-
if name.endswith(".bias") and name not in params_dict:
349-
continue
350-
# pylint: disable=E1136
351-
352-
if is_pp_missing_parameter(name, self):
353-
continue
354-
param = params_dict[name]
355-
weight_loader = getattr(param, "weight_loader",
356-
default_weight_loader)
357-
weight_loader(param, loaded_weight)
358-
loaded_params.add(name)
359-
return loaded_params
363+
loader = AutoWeightsLoader(self)
364+
return loader.load_weights(weights)

0 commit comments

Comments
 (0)