@@ -305,9 +305,12 @@ def __init__(
305
305
self .enforce_eager = False
306
306
307
307
sliding_window = getattr (self .hf_text_config , "sliding_window" , None )
308
+ sliding_window_layers = getattr (self .hf_text_config ,
309
+ "sliding_window_layers" , None )
308
310
has_interleaved_attention = (sliding_window is not None ) and (
309
311
isinstance (sliding_window , list ) or
310
- (self .hf_text_config .model_type in ["gemma2" , "cohere2" ]))
312
+ (self .hf_text_config .model_type in ["gemma2" , "cohere2" ])
313
+ or sliding_window_layers is not None )
311
314
312
315
if (not self .disable_sliding_window and has_interleaved_attention ):
313
316
if envs .VLLM_ATTENTION_BACKEND == "XFORMERS" :
@@ -713,6 +716,9 @@ def get_hf_config_sliding_window(
713
716
if (hasattr (self .hf_text_config , "use_sliding_window" )
714
717
and not self .hf_text_config .use_sliding_window ):
715
718
return None
719
+ if hasattr (self .hf_text_config , 'sliding_window_layers' ):
720
+ return None
721
+
716
722
return getattr (self .hf_text_config , "sliding_window" , None )
717
723
718
724
def get_sliding_window (self ) -> Optional [Union [int , List [Optional [int ]]]]:
@@ -724,6 +730,10 @@ def get_sliding_window(self) -> Optional[Union[int, List[Optional[int]]]]:
724
730
# Otherwise get the value from the hf config.
725
731
return self .get_hf_config_sliding_window ()
726
732
733
+ def get_sliding_window_layers (self ,
734
+ parallel_config ) -> Optional [List [int ]]:
735
+ return getattr (self .hf_text_config , "sliding_window_layers" , [])
736
+
727
737
def get_vocab_size (self ) -> int :
728
738
return self .hf_text_config .vocab_size
729
739
@@ -751,6 +761,12 @@ def get_head_size(self) -> int:
751
761
return (self .hf_text_config .hidden_size //
752
762
self .hf_text_config .num_attention_heads )
753
763
764
+ def get_head_size_swa (self ) -> int :
765
+ if hasattr (self .hf_text_config , "num_swa_attention_heads" ):
766
+ return (self .hf_text_config .hidden_size //
767
+ self .hf_text_config .num_swa_attention_heads )
768
+ return self .get_head_size ()
769
+
754
770
def get_total_num_kv_heads (self ) -> int :
755
771
"""Returns the total number of KV heads."""
756
772
# For GPTBigCode & Falcon:
@@ -797,6 +813,22 @@ def get_total_num_kv_heads(self) -> int:
797
813
# equal to the number of attention heads.
798
814
return self .hf_text_config .num_attention_heads
799
815
816
+ def get_total_num_kv_heads_swa (self ) -> int :
817
+ if hasattr (self .hf_text_config , "num_swa_key_value_heads" ):
818
+ return self .hf_text_config .num_swa_key_value_heads
819
+ return self .get_total_num_kv_heads ()
820
+
821
+ def get_num_swa_key_value_heads (self ,
822
+ parallel_config : "ParallelConfig" ) -> int :
823
+ """Returns the number of KV heads per GPU."""
824
+ total_num_kv_heads_swa = self .get_total_num_kv_heads_swa ()
825
+ # If tensor parallelism is used, we divide the number of KV heads by
826
+ # the tensor parallel size. We will replicate the KV heads in the
827
+ # case where the number of KV heads is smaller than the tensor
828
+ # parallel size so each GPU has at least one KV head.
829
+ return max (
830
+ 1 , total_num_kv_heads_swa // parallel_config .tensor_parallel_size )
831
+
800
832
def get_num_kv_heads (self , parallel_config : "ParallelConfig" ) -> int :
801
833
"""Returns the number of KV heads per GPU."""
802
834
total_num_kv_heads = self .get_total_num_kv_heads ()
@@ -839,7 +871,18 @@ def get_num_layers_by_block_type(
839
871
840
872
if is_transformer :
841
873
# Handle the basic case first
842
- return end - start if attn_block_type else 0
874
+ swa_layers = self .get_sliding_window_layers (parallel_config )
875
+ num_layers = 0
876
+ if not swa_layers :
877
+ num_layers = end - start if attn_block_type else 0
878
+ else :
879
+ for layer_id in range (start , end ):
880
+ if (block_type == LayerBlockType .attention
881
+ and layer_id not in swa_layers ) or (
882
+ block_type == LayerBlockType .swa
883
+ and layer_id in swa_layers ):
884
+ num_layers += 1
885
+ return num_layers
843
886
elif self .is_attention_free :
844
887
# Attention free
845
888
# Note that this code assumes there
@@ -2360,7 +2403,6 @@ def _get_and_verify_max_len(
2360
2403
max_len_key = key if max_len < derived_max_model_len \
2361
2404
else max_len_key
2362
2405
derived_max_model_len = min (derived_max_model_len , max_len )
2363
-
2364
2406
# If sliding window is manually disabled, max_length should be less
2365
2407
# than the sliding window length in the model config.
2366
2408
if disable_sliding_window and sliding_window_len is not None :
0 commit comments