@@ -98,6 +98,7 @@ def __init__(
98
98
# InputBatch needs to work with sampling tensors greater than padding
99
99
# to avoid dynamic shapes. Also, avoid suboptimal alignment.
100
100
self .max_num_reqs = max (scheduler_config .max_num_seqs , MIN_NUM_SEQS )
101
+ self ._disable_sampler = envs .VLLM_TPU_DISABLE_SAMPLER_DEBUG
101
102
102
103
# Model-related.
103
104
self .num_attn_layers = model_config .get_num_layers_by_block_type (
@@ -684,7 +685,7 @@ def execute_model(
684
685
num_reqs = self .input_batch .num_reqs
685
686
686
687
# Temporary debug pathway.
687
- if envs . VLLM_TPU_DISABLE_SAMPLER_DEBUG :
688
+ if self . _disable_sampler :
688
689
with set_forward_context (attn_metadata , self .vllm_config ):
689
690
hidden_states = self .model (
690
691
input_ids = input_ids ,
@@ -693,7 +694,7 @@ def execute_model(
693
694
inputs_embeds = inputs_embeds ,
694
695
)
695
696
selected_token_ids = self .model .compute_logits_no_sampler (
696
- hidden_states , logits_indices , None )
697
+ hidden_states , logits_indices )
697
698
selected_token_ids = selected_token_ids .cpu ()[:num_reqs ]
698
699
else :
699
700
# NOTE (NickLucche) here we sync with TPU: sampling params tensors
@@ -899,17 +900,23 @@ def capture_model(self) -> None:
899
900
dtype = self ._hidden_states_dtype )
900
901
# Compile for [8, 16, .., 128,.., `self.max_num_reqs`]
901
902
while True :
903
+ logger .info (" -- num_tokens: %d, num_seqs: %d" , num_tokens ,
904
+ num_reqs_to_sample )
902
905
indices = torch .zeros (
903
906
num_reqs_to_sample ,
904
907
dtype = torch .int32 ,
905
908
device = device ,
906
909
)
907
910
xm .mark_step ()
908
- sampling_meta = TPUSupportedSamplingMetadata .\
909
- from_input_batch (self .input_batch , indices )
910
- logger .info (" -- num_tokens: %d, num_seqs: %d" , num_tokens ,
911
- num_reqs_to_sample )
912
- out = self .sample_from_hidden (dummy_hidden , sampling_meta )
911
+ if self ._disable_sampler :
912
+ # Compile no sampler path for debugging performance
913
+ out = self .model .compute_logits_no_sampler (
914
+ dummy_hidden , indices )
915
+ else :
916
+ sampling_meta = TPUSupportedSamplingMetadata .\
917
+ from_input_batch (self .input_batch , indices )
918
+ out = self .model .sample_from_hidden (
919
+ dummy_hidden , sampling_meta )
913
920
out = out .cpu ()
914
921
# Requests can't be more than tokens. But do compile for the
915
922
# next bigger value in case num_tokens uses bucketed padding.
@@ -1006,6 +1013,16 @@ def sample(
1006
1013
return out_tokens
1007
1014
1008
1015
1016
+ @torch .compile (backend = "openxla" , fullgraph = True , dynamic = False )
1017
+ def compute_logits_no_sampler (
1018
+ self , hidden_states : torch .Tensor ,
1019
+ logits_indices : torch .Tensor ) -> Optional [torch .Tensor ]:
1020
+ hidden_states = hidden_states [logits_indices ]
1021
+ logits = self .model .compute_logits (hidden_states , None )
1022
+ selected_token_ids = torch .argmax (logits , dim = - 1 , keepdim = True )
1023
+ return selected_token_ids
1024
+
1025
+
1009
1026
def _get_padded_number (n : int , multiple : int ) -> int :
1010
1027
return ((n + multiple - 1 ) // multiple ) * multiple
1011
1028
0 commit comments