Skip to content

Commit a9bd832

Browse files
authored
[Model] use AutoWeightsLoader for deepseek_v2, internlm2 (vllm-project#16383)
Signed-off-by: Aaron Ang <[email protected]>
1 parent 417bcef commit a9bd832

File tree

2 files changed

+112
-107
lines changed

2 files changed

+112
-107
lines changed

vllm/model_executor/models/deepseek_v2.py

Lines changed: 73 additions & 71 deletions
Original file line numberDiff line numberDiff line change
@@ -53,7 +53,7 @@
5353
from vllm.sequence import IntermediateTensors
5454

5555
from .interfaces import SupportsPP
56-
from .utils import (PPMissingLayer, is_pp_missing_parameter,
56+
from .utils import (AutoWeightsLoader, PPMissingLayer, is_pp_missing_parameter,
5757
make_empty_intermediate_tensors_factory, make_layers,
5858
maybe_prefix)
5959

@@ -668,73 +668,6 @@ def forward(
668668
hidden_states, _ = self.norm(hidden_states, residual)
669669
return hidden_states
670670

671-
672-
class DeepseekV2ForCausalLM(nn.Module, SupportsPP):
673-
674-
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
675-
super().__init__()
676-
config = vllm_config.model_config.hf_config
677-
quant_config = vllm_config.quant_config
678-
self.config = config
679-
self.quant_config = quant_config
680-
self.model = DeepseekV2Model(vllm_config=vllm_config,
681-
prefix=maybe_prefix(prefix, "model"))
682-
if get_pp_group().is_last_rank:
683-
self.lm_head = ParallelLMHead(config.vocab_size,
684-
config.hidden_size,
685-
quant_config=quant_config)
686-
else:
687-
self.lm_head = PPMissingLayer()
688-
self.logits_processor = LogitsProcessor(config.vocab_size)
689-
self.sampler = get_sampler()
690-
self.make_empty_intermediate_tensors = (
691-
self.model.make_empty_intermediate_tensors)
692-
693-
def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor:
694-
return self.model.get_input_embeddings(input_ids)
695-
696-
def forward(
697-
self,
698-
input_ids: torch.Tensor,
699-
positions: torch.Tensor,
700-
intermediate_tensors: Optional[IntermediateTensors] = None,
701-
inputs_embeds: Optional[torch.Tensor] = None,
702-
) -> Union[torch.Tensor, IntermediateTensors]:
703-
hidden_states = self.model(input_ids, positions, intermediate_tensors,
704-
inputs_embeds)
705-
return hidden_states
706-
707-
def compute_logits(
708-
self,
709-
hidden_states: torch.Tensor,
710-
sampling_metadata: SamplingMetadata,
711-
) -> Optional[torch.Tensor]:
712-
logits = self.logits_processor(self.lm_head, hidden_states,
713-
sampling_metadata)
714-
return logits
715-
716-
def sample(
717-
self,
718-
logits: Optional[torch.Tensor],
719-
sampling_metadata: SamplingMetadata,
720-
) -> Optional[SamplerOutput]:
721-
next_tokens = self.sampler(logits, sampling_metadata)
722-
return next_tokens
723-
724-
def make_empty_intermediate_tensors(
725-
self, batch_size: int, dtype: torch.dtype,
726-
device: torch.device) -> IntermediateTensors:
727-
return IntermediateTensors({
728-
"hidden_states":
729-
torch.zeros((batch_size, self.config.hidden_size),
730-
dtype=dtype,
731-
device=device),
732-
"residual":
733-
torch.zeros((batch_size, self.config.hidden_size),
734-
dtype=dtype,
735-
device=device),
736-
})
737-
738671
def load_weights(self, weights: Iterable[Tuple[str,
739672
torch.Tensor]]) -> Set[str]:
740673
stacked_params_mapping = [
@@ -754,9 +687,6 @@ def load_weights(self, weights: Iterable[Tuple[str,
754687
params_dict = dict(self.named_parameters())
755688
loaded_params: Set[str] = set()
756689
for name, loaded_weight in weights:
757-
if "rotary_emb.inv_freq" in name:
758-
continue
759-
760690
spec_layer = get_spec_layer_idx_from_weight_name(self.config, name)
761691
if spec_layer is not None:
762692
continue # skip spec decode layers for main model
@@ -824,6 +754,78 @@ def load_weights(self, weights: Iterable[Tuple[str,
824754
return loaded_params
825755

826756

757+
class DeepseekV2ForCausalLM(nn.Module, SupportsPP):
758+
759+
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
760+
super().__init__()
761+
config = vllm_config.model_config.hf_config
762+
quant_config = vllm_config.quant_config
763+
self.config = config
764+
self.quant_config = quant_config
765+
self.model = DeepseekV2Model(vllm_config=vllm_config,
766+
prefix=maybe_prefix(prefix, "model"))
767+
if get_pp_group().is_last_rank:
768+
self.lm_head = ParallelLMHead(config.vocab_size,
769+
config.hidden_size,
770+
quant_config=quant_config)
771+
else:
772+
self.lm_head = PPMissingLayer()
773+
self.logits_processor = LogitsProcessor(config.vocab_size)
774+
self.sampler = get_sampler()
775+
self.make_empty_intermediate_tensors = (
776+
self.model.make_empty_intermediate_tensors)
777+
778+
def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor:
779+
return self.model.get_input_embeddings(input_ids)
780+
781+
def forward(
782+
self,
783+
input_ids: torch.Tensor,
784+
positions: torch.Tensor,
785+
intermediate_tensors: Optional[IntermediateTensors] = None,
786+
inputs_embeds: Optional[torch.Tensor] = None,
787+
) -> Union[torch.Tensor, IntermediateTensors]:
788+
hidden_states = self.model(input_ids, positions, intermediate_tensors,
789+
inputs_embeds)
790+
return hidden_states
791+
792+
def compute_logits(
793+
self,
794+
hidden_states: torch.Tensor,
795+
sampling_metadata: SamplingMetadata,
796+
) -> Optional[torch.Tensor]:
797+
logits = self.logits_processor(self.lm_head, hidden_states,
798+
sampling_metadata)
799+
return logits
800+
801+
def sample(
802+
self,
803+
logits: Optional[torch.Tensor],
804+
sampling_metadata: SamplingMetadata,
805+
) -> Optional[SamplerOutput]:
806+
next_tokens = self.sampler(logits, sampling_metadata)
807+
return next_tokens
808+
809+
def make_empty_intermediate_tensors(
810+
self, batch_size: int, dtype: torch.dtype,
811+
device: torch.device) -> IntermediateTensors:
812+
return IntermediateTensors({
813+
"hidden_states":
814+
torch.zeros((batch_size, self.config.hidden_size),
815+
dtype=dtype,
816+
device=device),
817+
"residual":
818+
torch.zeros((batch_size, self.config.hidden_size),
819+
dtype=dtype,
820+
device=device),
821+
})
822+
823+
def load_weights(self, weights: Iterable[Tuple[str,
824+
torch.Tensor]]) -> Set[str]:
825+
loader = AutoWeightsLoader(self, skip_prefixes=["rotary_emb.inv_freq"])
826+
return loader.load_weights(weights)
827+
828+
827829
class DeepseekV3ForCausalLM(DeepseekV2ForCausalLM):
828830
pass
829831

vllm/model_executor/models/internlm2.py

Lines changed: 39 additions & 36 deletions
Original file line numberDiff line numberDiff line change
@@ -32,7 +32,7 @@
3232
from vllm.sequence import IntermediateTensors, PoolerOutput
3333

3434
from .interfaces import SupportsLoRA, SupportsPP
35-
from .utils import (is_pp_missing_parameter,
35+
from .utils import (AutoWeightsLoader, is_pp_missing_parameter,
3636
make_empty_intermediate_tensors_factory, make_layers,
3737
maybe_prefix)
3838

@@ -306,6 +306,42 @@ def forward(
306306
hidden_states, _ = self.norm(hidden_states, residual)
307307
return hidden_states
308308

309+
def load_weights(self, weights: Iterable[Tuple[str,
310+
torch.Tensor]]) -> Set[str]:
311+
stacked_params_mapping = [
312+
# (param_name, shard_name, shard_id)
313+
("gate_up_proj", "w1", 0),
314+
("gate_up_proj", "w3", 1),
315+
]
316+
params_dict = dict(self.named_parameters())
317+
loaded_params: Set[str] = set()
318+
for name, loaded_weight in weights:
319+
for (param_name, weight_name, shard_id) in stacked_params_mapping:
320+
if weight_name not in name:
321+
continue
322+
name = name.replace(weight_name, param_name)
323+
# Skip loading extra bias for GPTQ models.
324+
if name.endswith(".bias") and name not in params_dict:
325+
continue
326+
if is_pp_missing_parameter(name, self):
327+
continue
328+
param = params_dict[name]
329+
weight_loader = param.weight_loader
330+
weight_loader(param, loaded_weight, shard_id)
331+
break
332+
else:
333+
# Skip loading extra bias for GPTQ models.
334+
if name.endswith(".bias") and name not in params_dict:
335+
continue
336+
if is_pp_missing_parameter(name, self):
337+
continue
338+
param = params_dict[name]
339+
weight_loader = getattr(param, "weight_loader",
340+
default_weight_loader)
341+
weight_loader(param, loaded_weight)
342+
loaded_params.add(name)
343+
return loaded_params
344+
309345

310346
class InternLM2ForCausalLM(nn.Module, SupportsPP, SupportsLoRA):
311347
packed_modules_mapping = {
@@ -373,41 +409,8 @@ def sample(
373409

374410
def load_weights(self, weights: Iterable[Tuple[str,
375411
torch.Tensor]]) -> Set[str]:
376-
stacked_params_mapping = [
377-
# (param_name, shard_name, shard_id)
378-
("gate_up_proj", "w1", 0),
379-
("gate_up_proj", "w3", 1),
380-
]
381-
params_dict = dict(self.named_parameters())
382-
loaded_params: Set[str] = set()
383-
for name, loaded_weight in weights:
384-
if "rotary_emb.inv_freq" in name:
385-
continue
386-
for (param_name, weight_name, shard_id) in stacked_params_mapping:
387-
if weight_name not in name:
388-
continue
389-
name = name.replace(weight_name, param_name)
390-
# Skip loading extra bias for GPTQ models.
391-
if name.endswith(".bias") and name not in params_dict:
392-
continue
393-
if is_pp_missing_parameter(name, self):
394-
continue
395-
param = params_dict[name]
396-
weight_loader = param.weight_loader
397-
weight_loader(param, loaded_weight, shard_id)
398-
break
399-
else:
400-
# Skip loading extra bias for GPTQ models.
401-
if name.endswith(".bias") and name not in params_dict:
402-
continue
403-
if is_pp_missing_parameter(name, self):
404-
continue
405-
param = params_dict[name]
406-
weight_loader = getattr(param, "weight_loader",
407-
default_weight_loader)
408-
weight_loader(param, loaded_weight)
409-
loaded_params.add(name)
410-
return loaded_params
412+
loader = AutoWeightsLoader(self, skip_prefixes=["rotary_emb.inv_freq"])
413+
return loader.load_weights(weights)
411414

412415

413416
class InternLM2ForRewardModel(InternLM2ForCausalLM):

0 commit comments

Comments
 (0)