@@ -401,6 +401,7 @@ def _prepare_inputs(self, scheduler_output: "SchedulerOutput"):
401
401
self .query_start_loc_np [0 ] = 0
402
402
np .cumsum (num_scheduled_tokens_per_req ,
403
403
out = self .query_start_loc_np [1 :num_reqs + 1 ])
404
+ self .query_start_loc_np [num_reqs + 1 :] = 1
404
405
405
406
self .seq_lens_np [:num_reqs ] = (
406
407
self .input_batch .num_computed_tokens_cpu [:num_reqs ] +
@@ -441,7 +442,10 @@ def _prepare_inputs(self, scheduler_output: "SchedulerOutput"):
441
442
# partial request, we do so for simplicity. We will ignore the sampled
442
443
# token from the partial request.
443
444
# TODO: Support prompt logprobs.
444
- logits_indices = query_start_loc [1 :] - 1
445
+ padded_num_reqs = _get_padded_num_reqs_with_upper_limit (
446
+ num_reqs , self .max_num_reqs )
447
+ logits_indices = self .query_start_loc_cpu [1 :padded_num_reqs + 1 ] - 1
448
+ logits_indices = logits_indices .to (self .device )
445
449
return attn_metadata , logits_indices
446
450
447
451
def _execute_encoder (self , scheduler_output : "SchedulerOutput" ):
@@ -551,7 +555,6 @@ def execute_model(
551
555
552
556
# Prepare inputs
553
557
attn_metadata , logits_indices = self ._prepare_inputs (scheduler_output )
554
- total_num_scheduled_tokens = scheduler_output .total_num_scheduled_tokens
555
558
556
559
if self .is_multimodal_model :
557
560
# NOTE(woosuk): To unify token ids and soft tokens (vision
@@ -579,12 +582,10 @@ def execute_model(
579
582
kv_caches = self .kv_caches ,
580
583
inputs_embeds = inputs_embeds ,
581
584
)
582
- hidden_states = hidden_states [:total_num_scheduled_tokens ]
583
585
num_reqs = self .input_batch .num_reqs
584
- logits_indices = logits_indices [:num_reqs ]
585
- hidden_states = hidden_states [logits_indices ]
586
- logits = self .model .compute_logits (hidden_states , None )
587
- selected_token_ids = torch .argmax (logits , dim = - 1 , keepdim = True )
586
+ selected_token_ids = self .model .compute_logits (hidden_states ,
587
+ logits_indices , None )
588
+ selected_token_ids = selected_token_ids .cpu ()[:num_reqs ]
588
589
589
590
# Then, let's update the cache state.
590
591
request_seq_lens : list [tuple [int , CachedRequestState , int ]] = []
@@ -726,12 +727,31 @@ def _dummy_run(
726
727
727
728
with set_forward_context (attn_metadata , self .vllm_config , 0 ):
728
729
assert self .model is not None
729
- self .model (
730
+ hidden_states = self .model (
730
731
input_ids = input_ids ,
731
732
positions = position_ids ,
732
733
kv_caches = kv_caches ,
733
734
inputs_embeds = inputs_embeds ,
734
735
)
736
+ num_reqs = _get_padded_num_reqs_with_upper_limit (
737
+ 64 , self .max_num_reqs )
738
+ # NOTE(chengjiyao): In total, the compute_logits function utilizes a
739
+ # compilation cache size of token_bucket_num multiplied by
740
+ # req_bucket_num. This is acceptable, given the graph's relatively
741
+ # small size.
742
+ while True :
743
+ logits_indices = torch .zeros (
744
+ num_reqs ,
745
+ dtype = torch .int32 ,
746
+ device = self .device ,
747
+ )
748
+ torch ._dynamo .mark_dynamic (hidden_states , 0 )
749
+ torch ._dynamo .mark_dynamic (logits_indices , 0 )
750
+ self .model .compute_logits (hidden_states , logits_indices , None )
751
+ if num_reqs >= self .max_num_reqs :
752
+ break
753
+ num_reqs = _get_padded_num_reqs_with_upper_limit (
754
+ num_reqs + 1 , self .max_num_reqs )
735
755
736
756
def capture_model (self ) -> None :
737
757
"""Compile the model."""
@@ -823,13 +843,17 @@ def forward(
823
843
824
844
return hidden_states
825
845
846
+ @torch .compile (backend = "openxla" , fullgraph = True , dynamic = False )
826
847
def compute_logits (
827
848
self ,
828
849
hidden_states : torch .Tensor ,
850
+ logits_indices : torch .Tensor ,
829
851
sampling_metadata ,
830
852
) -> Optional [torch .Tensor ]:
853
+ hidden_states = hidden_states [logits_indices ]
831
854
logits = self .model .compute_logits (hidden_states , sampling_metadata )
832
- return logits
855
+ selected_token_ids = torch .argmax (logits , dim = - 1 , keepdim = True )
856
+ return selected_token_ids
833
857
834
858
def get_multimodal_embeddings (self , * args , ** kwargs ):
835
859
return self .model .get_multimodal_embeddings (* args , ** kwargs )
@@ -846,3 +870,8 @@ def _get_padded_token_len(x: int) -> int:
846
870
if x <= 16 :
847
871
return 16
848
872
return 1 << (x - 1 ).bit_length ()
873
+
874
+
875
+ def _get_padded_num_reqs_with_upper_limit (x , upper_limit ) -> int :
876
+ res = 64 if x <= 64 else 1 << (x - 1 ).bit_length ()
877
+ return min (res , upper_limit )
0 commit comments