49
49
from vllm .sequence import IntermediateTensors
50
50
51
51
from .interfaces import SupportsLoRA , SupportsPP
52
- from .utils import (is_pp_missing_parameter ,
52
+ from .utils import (AutoWeightsLoader , is_pp_missing_parameter ,
53
53
make_empty_intermediate_tensors_factory , make_layers ,
54
54
maybe_prefix )
55
55
@@ -448,6 +448,8 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
448
448
(lora_config .max_loras or 1 )) if lora_config else 0 )
449
449
self .vocab_size = config .vocab_size + lora_vocab
450
450
self .org_vocab_size = config .vocab_size
451
+ self .config = config
452
+ self .quant_config = quant_config
451
453
452
454
self .embed_tokens = VocabParallelEmbedding (
453
455
self .vocab_size ,
@@ -504,85 +506,6 @@ def forward(
504
506
hidden_states = self .norm (hidden_states )
505
507
return hidden_states
506
508
507
-
508
- class PhiMoEForCausalLM (nn .Module , SupportsLoRA , SupportsPP ):
509
- fall_back_to_pt_during_load = False
510
-
511
- packed_modules_mapping = {
512
- "qkv_proj" : [
513
- "q_proj" ,
514
- "k_proj" ,
515
- "v_proj" ,
516
- ],
517
- }
518
-
519
- # LoRA specific attributes
520
- embedding_modules = {
521
- "embed_tokens" : "input_embeddings" ,
522
- "lm_head" : "output_embeddings" ,
523
- }
524
- embedding_padding_modules = ["lm_head" ]
525
-
526
- def __init__ (self , * , vllm_config : VllmConfig , prefix : str = "" ):
527
- super ().__init__ ()
528
- config = vllm_config .model_config .hf_config
529
- lora_config = vllm_config .lora_config
530
- self .config = config
531
- self .lora_config = lora_config
532
- self .quant_config = vllm_config .quant_config
533
-
534
- self .model = PhiMoEModel (vllm_config = vllm_config ,
535
- prefix = maybe_prefix (prefix , "model" ))
536
- self .unpadded_vocab_size = config .vocab_size
537
- if lora_config :
538
- self .unpadded_vocab_size += lora_config .lora_extra_vocab_size
539
- self .lm_head = ParallelLMHead (
540
- self .unpadded_vocab_size ,
541
- config .hidden_size ,
542
- org_num_embeddings = config .vocab_size ,
543
- padding_size = (
544
- DEFAULT_VOCAB_PADDING_SIZE
545
- # We need bigger padding if using lora for kernel
546
- # compatibility
547
- if not lora_config else lora_config .lora_vocab_padding_size ),
548
- quant_config = None ,
549
- bias = True ,
550
- )
551
- self .logits_processor = LogitsProcessor (self .unpadded_vocab_size ,
552
- config .vocab_size )
553
- self .sampler = get_sampler ()
554
-
555
- self .make_empty_intermediate_tensors = (
556
- self .model .make_empty_intermediate_tensors )
557
-
558
- def get_input_embeddings (self , input_ids : torch .Tensor ) -> torch .Tensor :
559
- return self .model .get_input_embeddings (input_ids )
560
-
561
- def forward (
562
- self ,
563
- input_ids : torch .Tensor ,
564
- positions : torch .Tensor ,
565
- intermediate_tensors : Optional [IntermediateTensors ] = None ,
566
- inputs_embeds : Optional [torch .Tensor ] = None ,
567
- ) -> Union [torch .Tensor , IntermediateTensors ]:
568
- hidden_states = self .model (input_ids , positions , intermediate_tensors ,
569
- inputs_embeds )
570
- return hidden_states
571
-
572
- def compute_logits (self , hidden_states : torch .Tensor ,
573
- sampling_metadata : SamplingMetadata ) -> torch .Tensor :
574
- logits = self .logits_processor (self .lm_head , hidden_states ,
575
- sampling_metadata )
576
- return logits
577
-
578
- def sample (
579
- self ,
580
- logits : Optional [torch .Tensor ],
581
- sampling_metadata : SamplingMetadata ,
582
- ) -> Optional [SamplerOutput ]:
583
- next_tokens = self .sampler (logits , sampling_metadata )
584
- return next_tokens
585
-
586
509
def load_weights (self , weights : Iterable [Tuple [str ,
587
510
torch .Tensor ]]) -> Set [str ]:
588
511
stacked_params_mapping = [
@@ -601,9 +524,6 @@ def load_weights(self, weights: Iterable[Tuple[str,
601
524
params_dict = dict (self .named_parameters ())
602
525
loaded_params : Set [str ] = set ()
603
526
for name , loaded_weight in weights :
604
- if "rotary_emb.inv_freq" in name :
605
- continue
606
-
607
527
if (self .quant_config is not None and
608
528
(scale_name := self .quant_config .get_cache_scale (name ))):
609
529
# Loading kv cache quantization scales
@@ -667,3 +587,90 @@ def load_weights(self, weights: Iterable[Tuple[str,
667
587
weight_loader (param , loaded_weight )
668
588
loaded_params .add (name )
669
589
return loaded_params
590
+
591
+
592
+ class PhiMoEForCausalLM (nn .Module , SupportsLoRA , SupportsPP ):
593
+ fall_back_to_pt_during_load = False
594
+
595
+ packed_modules_mapping = {
596
+ "qkv_proj" : [
597
+ "q_proj" ,
598
+ "k_proj" ,
599
+ "v_proj" ,
600
+ ],
601
+ }
602
+
603
+ # LoRA specific attributes
604
+ embedding_modules = {
605
+ "embed_tokens" : "input_embeddings" ,
606
+ "lm_head" : "output_embeddings" ,
607
+ }
608
+ embedding_padding_modules = ["lm_head" ]
609
+
610
+ def __init__ (self , * , vllm_config : VllmConfig , prefix : str = "" ):
611
+ super ().__init__ ()
612
+ config = vllm_config .model_config .hf_config
613
+ lora_config = vllm_config .lora_config
614
+ self .config = config
615
+ self .lora_config = lora_config
616
+ self .quant_config = vllm_config .quant_config
617
+
618
+ self .model = PhiMoEModel (vllm_config = vllm_config ,
619
+ prefix = maybe_prefix (prefix , "model" ))
620
+ self .unpadded_vocab_size = config .vocab_size
621
+ if lora_config :
622
+ self .unpadded_vocab_size += lora_config .lora_extra_vocab_size
623
+ self .lm_head = ParallelLMHead (
624
+ self .unpadded_vocab_size ,
625
+ config .hidden_size ,
626
+ org_num_embeddings = config .vocab_size ,
627
+ padding_size = (
628
+ DEFAULT_VOCAB_PADDING_SIZE
629
+ # We need bigger padding if using lora for kernel
630
+ # compatibility
631
+ if not lora_config else lora_config .lora_vocab_padding_size ),
632
+ quant_config = None ,
633
+ bias = True ,
634
+ )
635
+ self .logits_processor = LogitsProcessor (self .unpadded_vocab_size ,
636
+ config .vocab_size )
637
+ self .sampler = get_sampler ()
638
+
639
+ self .make_empty_intermediate_tensors = (
640
+ self .model .make_empty_intermediate_tensors )
641
+
642
+ def get_input_embeddings (self , input_ids : torch .Tensor ) -> torch .Tensor :
643
+ return self .model .get_input_embeddings (input_ids )
644
+
645
+ def forward (
646
+ self ,
647
+ input_ids : torch .Tensor ,
648
+ positions : torch .Tensor ,
649
+ intermediate_tensors : Optional [IntermediateTensors ] = None ,
650
+ inputs_embeds : Optional [torch .Tensor ] = None ,
651
+ ) -> Union [torch .Tensor , IntermediateTensors ]:
652
+ hidden_states = self .model (input_ids , positions , intermediate_tensors ,
653
+ inputs_embeds )
654
+ return hidden_states
655
+
656
+ def compute_logits (self , hidden_states : torch .Tensor ,
657
+ sampling_metadata : SamplingMetadata ) -> torch .Tensor :
658
+ logits = self .logits_processor (self .lm_head , hidden_states ,
659
+ sampling_metadata )
660
+ return logits
661
+
662
+ def sample (
663
+ self ,
664
+ logits : Optional [torch .Tensor ],
665
+ sampling_metadata : SamplingMetadata ,
666
+ ) -> Optional [SamplerOutput ]:
667
+ next_tokens = self .sampler (logits , sampling_metadata )
668
+ return next_tokens
669
+
670
+ def load_weights (self , weights : Iterable [Tuple [str ,
671
+ torch .Tensor ]]) -> Set [str ]:
672
+ loader = AutoWeightsLoader (
673
+ self ,
674
+ skip_prefixes = (["rotary_emb.inv_freq" ]),
675
+ )
676
+ return loader .load_weights (weights )
0 commit comments