@@ -416,8 +416,8 @@ def _prepare_inputs(self, scheduler_output: "SchedulerOutput"):
416
416
num_scheduled_tokens_per_req )
417
417
418
418
# Do the padding and copy the tensors to the TPU.
419
- padded_total_num_scheduled_tokens = _get_padded_number (
420
- total_num_scheduled_tokens , NUM_QUERIES_PER_BLOCK )
419
+ padded_total_num_scheduled_tokens = _get_padded_token_len (
420
+ total_num_scheduled_tokens )
421
421
self .input_ids = self .input_ids_cpu [:
422
422
padded_total_num_scheduled_tokens ].to (
423
423
self .device )
@@ -428,23 +428,22 @@ def _prepare_inputs(self, scheduler_output: "SchedulerOutput"):
428
428
slot_mapping = self .slot_mapping_cpu [:
429
429
padded_total_num_scheduled_tokens ].to (
430
430
self .device )
431
- padded_block_table = self .block_table_cpu [:
432
- padded_total_num_scheduled_tokens ]
431
+ padded_block_table = self .block_table_cpu [:self .max_num_reqs ]
433
432
padded_block_table [:num_reqs , :self .max_num_blocks_per_req ] = (
434
433
self .input_batch .block_table .get_cpu_tensor ()[:num_reqs ])
435
434
padded_block_table = padded_block_table .to (self .device )
436
- query_start_loc = self .query_start_loc_cpu [:
437
- padded_total_num_scheduled_tokens
438
- + 1 ].to (self .device )
439
- seq_lens = self .seq_lens_cpu [:padded_total_num_scheduled_tokens ].to (
435
+ query_start_loc = self .query_start_loc_cpu [:self .max_num_reqs + 1 ].to (
440
436
self .device )
437
+ seq_lens = self .seq_lens_cpu [:self .max_num_reqs ].to (self .device )
441
438
442
439
attn_metadata = PallasMetadata (
443
440
slot_mapping = slot_mapping ,
444
441
block_tables = padded_block_table ,
445
442
context_lens = seq_lens ,
446
443
query_start_loc = query_start_loc ,
447
- num_seqs = num_reqs ,
444
+ num_seqs = torch .tensor ([num_reqs ],
445
+ dtype = torch .int32 ,
446
+ device = self .device ),
448
447
)
449
448
# NOTE(woosuk): Due to chunked prefills, there can be at most 1 partial
450
449
# request in the batch. While we should not sample any token from this
@@ -693,29 +692,34 @@ def _dummy_run(
693
692
dtype = torch .int32 ,
694
693
device = self .device )
695
694
inputs_embeds = None
695
+ actual_num_reqs = min (num_tokens , self .max_num_reqs )
696
696
position_ids = torch .zeros (num_tokens ,
697
697
dtype = torch .int32 ,
698
698
device = self .device )
699
699
slot_mapping = torch .zeros (num_tokens ,
700
700
dtype = torch .int64 ,
701
701
device = self .device )
702
- block_tables = torch .zeros ((num_tokens , self .block_table_cpu .shape [1 ]),
703
- dtype = torch .int32 ,
704
- device = self .device )
705
- query_lens = [1 ] * num_tokens
702
+ block_tables = torch .zeros (
703
+ (self .max_num_reqs , self .block_table_cpu .shape [1 ]),
704
+ dtype = torch .int32 ,
705
+ device = self .device )
706
+ query_lens = [1 ] * self .max_num_reqs
706
707
query_start_loc = torch .cumsum (torch .tensor ([0 ] + query_lens ,
707
708
dtype = torch .int32 ),
708
709
dim = 0 ,
709
710
dtype = torch .int32 ).to (self .device )
710
- context_lens = torch .ones ((num_tokens , ),
711
+ context_lens = torch .ones ((self . max_num_reqs , ),
711
712
dtype = torch .int32 ,
712
713
device = self .device )
714
+ num_seqs = torch .tensor ([actual_num_reqs ],
715
+ dtype = torch .int32 ,
716
+ device = self .device )
713
717
attn_metadata = PallasMetadata (
714
718
slot_mapping = slot_mapping ,
715
719
block_tables = block_tables ,
716
720
context_lens = context_lens ,
717
721
query_start_loc = query_start_loc ,
718
- num_seqs = num_tokens ,
722
+ num_seqs = num_seqs ,
719
723
)
720
724
721
725
if self .is_multimodal_model :
@@ -724,9 +728,6 @@ def _dummy_run(
724
728
torch ._dynamo .mark_dynamic (input_ids , 0 )
725
729
torch ._dynamo .mark_dynamic (position_ids , 0 )
726
730
torch ._dynamo .mark_dynamic (attn_metadata .slot_mapping , 0 )
727
- torch ._dynamo .mark_dynamic (attn_metadata .block_tables , 0 )
728
- torch ._dynamo .mark_dynamic (attn_metadata .query_start_loc , 0 )
729
- torch ._dynamo .mark_dynamic (attn_metadata .context_lens , 0 )
730
731
731
732
with set_forward_context (attn_metadata , self .vllm_config , 0 ):
732
733
assert self .model is not None
@@ -817,28 +818,6 @@ def forward(
817
818
inputs_embeds: The input embeddings of shape [num_tokens,
818
819
hidden_size]. It is used for multimodal models.
819
820
"""
820
- # Skip this in memory profiling at initialization.
821
- if kv_caches [0 ][0 ].numel () > 0 :
822
- attn_metadata = get_forward_context ().attn_metadata
823
- # index_copy_(slot_mapping) only works when the inserted dimension
824
- # is 0. However, the KV cache in the Pallas backend has the shape
825
- # [num_kv_heads, num_blocks, block_size, head_size]. To make it
826
- # work, we need to flatten the first three dimensions and modify
827
- # the slot_mapping accordingly.
828
- # kv_caches: list[tuple[torch.Tensor, torch.Tensor]]
829
- num_kv_heads , num_blocks , block_size , _ = kv_caches [0 ][0 ].shape
830
- slot_mapping = attn_metadata .slot_mapping
831
- slot_mapping = slot_mapping .flatten ()
832
- head_indicies = torch .arange (0 ,
833
- num_kv_heads ,
834
- device = slot_mapping .device ,
835
- dtype = slot_mapping .dtype )
836
- head_indicies *= block_size * num_blocks
837
- slot_mapping = slot_mapping .repeat_interleave (num_kv_heads ).view (
838
- - 1 , num_kv_heads )
839
- slot_mapping = slot_mapping + head_indicies .view (1 , - 1 )
840
- slot_mapping = slot_mapping .flatten ()
841
- attn_metadata .slot_mapping = slot_mapping
842
821
843
822
assert self .model is not None
844
823
hidden_states = self .model (
@@ -866,3 +845,9 @@ def get_input_embeddings(self, *args, **kwargs):
866
845
867
846
def _get_padded_number (n : int , multiple : int ) -> int :
868
847
return ((n + multiple - 1 ) // multiple ) * multiple
848
+
849
+
850
+ def _get_padded_token_len (x : int ) -> int :
851
+ if x <= 16 :
852
+ return 16
853
+ return 1 << (x - 1 ).bit_length ()
0 commit comments