53
53
from vllm .sequence import IntermediateTensors
54
54
55
55
from .interfaces import SupportsPP
56
- from .utils import (PPMissingLayer , is_pp_missing_parameter ,
56
+ from .utils import (AutoWeightsLoader , PPMissingLayer , is_pp_missing_parameter ,
57
57
make_empty_intermediate_tensors_factory , make_layers ,
58
58
maybe_prefix )
59
59
@@ -668,73 +668,6 @@ def forward(
668
668
hidden_states , _ = self .norm (hidden_states , residual )
669
669
return hidden_states
670
670
671
-
672
- class DeepseekV2ForCausalLM (nn .Module , SupportsPP ):
673
-
674
- def __init__ (self , * , vllm_config : VllmConfig , prefix : str = "" ):
675
- super ().__init__ ()
676
- config = vllm_config .model_config .hf_config
677
- quant_config = vllm_config .quant_config
678
- self .config = config
679
- self .quant_config = quant_config
680
- self .model = DeepseekV2Model (vllm_config = vllm_config ,
681
- prefix = maybe_prefix (prefix , "model" ))
682
- if get_pp_group ().is_last_rank :
683
- self .lm_head = ParallelLMHead (config .vocab_size ,
684
- config .hidden_size ,
685
- quant_config = quant_config )
686
- else :
687
- self .lm_head = PPMissingLayer ()
688
- self .logits_processor = LogitsProcessor (config .vocab_size )
689
- self .sampler = get_sampler ()
690
- self .make_empty_intermediate_tensors = (
691
- self .model .make_empty_intermediate_tensors )
692
-
693
- def get_input_embeddings (self , input_ids : torch .Tensor ) -> torch .Tensor :
694
- return self .model .get_input_embeddings (input_ids )
695
-
696
- def forward (
697
- self ,
698
- input_ids : torch .Tensor ,
699
- positions : torch .Tensor ,
700
- intermediate_tensors : Optional [IntermediateTensors ] = None ,
701
- inputs_embeds : Optional [torch .Tensor ] = None ,
702
- ) -> Union [torch .Tensor , IntermediateTensors ]:
703
- hidden_states = self .model (input_ids , positions , intermediate_tensors ,
704
- inputs_embeds )
705
- return hidden_states
706
-
707
- def compute_logits (
708
- self ,
709
- hidden_states : torch .Tensor ,
710
- sampling_metadata : SamplingMetadata ,
711
- ) -> Optional [torch .Tensor ]:
712
- logits = self .logits_processor (self .lm_head , hidden_states ,
713
- sampling_metadata )
714
- return logits
715
-
716
- def sample (
717
- self ,
718
- logits : Optional [torch .Tensor ],
719
- sampling_metadata : SamplingMetadata ,
720
- ) -> Optional [SamplerOutput ]:
721
- next_tokens = self .sampler (logits , sampling_metadata )
722
- return next_tokens
723
-
724
- def make_empty_intermediate_tensors (
725
- self , batch_size : int , dtype : torch .dtype ,
726
- device : torch .device ) -> IntermediateTensors :
727
- return IntermediateTensors ({
728
- "hidden_states" :
729
- torch .zeros ((batch_size , self .config .hidden_size ),
730
- dtype = dtype ,
731
- device = device ),
732
- "residual" :
733
- torch .zeros ((batch_size , self .config .hidden_size ),
734
- dtype = dtype ,
735
- device = device ),
736
- })
737
-
738
671
def load_weights (self , weights : Iterable [Tuple [str ,
739
672
torch .Tensor ]]) -> Set [str ]:
740
673
stacked_params_mapping = [
@@ -754,9 +687,6 @@ def load_weights(self, weights: Iterable[Tuple[str,
754
687
params_dict = dict (self .named_parameters ())
755
688
loaded_params : Set [str ] = set ()
756
689
for name , loaded_weight in weights :
757
- if "rotary_emb.inv_freq" in name :
758
- continue
759
-
760
690
spec_layer = get_spec_layer_idx_from_weight_name (self .config , name )
761
691
if spec_layer is not None :
762
692
continue # skip spec decode layers for main model
@@ -824,6 +754,78 @@ def load_weights(self, weights: Iterable[Tuple[str,
824
754
return loaded_params
825
755
826
756
757
+ class DeepseekV2ForCausalLM (nn .Module , SupportsPP ):
758
+
759
+ def __init__ (self , * , vllm_config : VllmConfig , prefix : str = "" ):
760
+ super ().__init__ ()
761
+ config = vllm_config .model_config .hf_config
762
+ quant_config = vllm_config .quant_config
763
+ self .config = config
764
+ self .quant_config = quant_config
765
+ self .model = DeepseekV2Model (vllm_config = vllm_config ,
766
+ prefix = maybe_prefix (prefix , "model" ))
767
+ if get_pp_group ().is_last_rank :
768
+ self .lm_head = ParallelLMHead (config .vocab_size ,
769
+ config .hidden_size ,
770
+ quant_config = quant_config )
771
+ else :
772
+ self .lm_head = PPMissingLayer ()
773
+ self .logits_processor = LogitsProcessor (config .vocab_size )
774
+ self .sampler = get_sampler ()
775
+ self .make_empty_intermediate_tensors = (
776
+ self .model .make_empty_intermediate_tensors )
777
+
778
+ def get_input_embeddings (self , input_ids : torch .Tensor ) -> torch .Tensor :
779
+ return self .model .get_input_embeddings (input_ids )
780
+
781
+ def forward (
782
+ self ,
783
+ input_ids : torch .Tensor ,
784
+ positions : torch .Tensor ,
785
+ intermediate_tensors : Optional [IntermediateTensors ] = None ,
786
+ inputs_embeds : Optional [torch .Tensor ] = None ,
787
+ ) -> Union [torch .Tensor , IntermediateTensors ]:
788
+ hidden_states = self .model (input_ids , positions , intermediate_tensors ,
789
+ inputs_embeds )
790
+ return hidden_states
791
+
792
+ def compute_logits (
793
+ self ,
794
+ hidden_states : torch .Tensor ,
795
+ sampling_metadata : SamplingMetadata ,
796
+ ) -> Optional [torch .Tensor ]:
797
+ logits = self .logits_processor (self .lm_head , hidden_states ,
798
+ sampling_metadata )
799
+ return logits
800
+
801
+ def sample (
802
+ self ,
803
+ logits : Optional [torch .Tensor ],
804
+ sampling_metadata : SamplingMetadata ,
805
+ ) -> Optional [SamplerOutput ]:
806
+ next_tokens = self .sampler (logits , sampling_metadata )
807
+ return next_tokens
808
+
809
+ def make_empty_intermediate_tensors (
810
+ self , batch_size : int , dtype : torch .dtype ,
811
+ device : torch .device ) -> IntermediateTensors :
812
+ return IntermediateTensors ({
813
+ "hidden_states" :
814
+ torch .zeros ((batch_size , self .config .hidden_size ),
815
+ dtype = dtype ,
816
+ device = device ),
817
+ "residual" :
818
+ torch .zeros ((batch_size , self .config .hidden_size ),
819
+ dtype = dtype ,
820
+ device = device ),
821
+ })
822
+
823
+ def load_weights (self , weights : Iterable [Tuple [str ,
824
+ torch .Tensor ]]) -> Set [str ]:
825
+ loader = AutoWeightsLoader (self , skip_prefixes = ["rotary_emb.inv_freq" ])
826
+ return loader .load_weights (weights )
827
+
828
+
827
829
class DeepseekV3ForCausalLM (DeepseekV2ForCausalLM ):
828
830
pass
829
831
0 commit comments