@@ -95,6 +95,7 @@ def __init__(
95
95
# InputBatch needs to work with sampling tensors greater than padding
96
96
# to avoid dynamic shapes. Also, avoid suboptimal alignment.
97
97
self .max_num_reqs = max (scheduler_config .max_num_seqs , MIN_NUM_SEQS )
98
+ self ._disable_sampler = envs .VLLM_TPU_DISABLE_SAMPLER_DEBUG
98
99
99
100
# Model-related.
100
101
self .num_attn_layers = model_config .get_num_layers_by_block_type (
@@ -639,7 +640,7 @@ def execute_model(
639
640
num_reqs = self .input_batch .num_reqs
640
641
641
642
# Temporary debug pathway.
642
- if envs . VLLM_TPU_DISABLE_SAMPLER_DEBUG :
643
+ if self . _disable_sampler :
643
644
with set_forward_context (attn_metadata , self .vllm_config ):
644
645
hidden_states = self .model (
645
646
input_ids = input_ids ,
@@ -648,7 +649,7 @@ def execute_model(
648
649
inputs_embeds = inputs_embeds ,
649
650
)
650
651
selected_token_ids = self .model .compute_logits_no_sampler (
651
- hidden_states , logits_indices , None )
652
+ hidden_states , logits_indices )
652
653
selected_token_ids = selected_token_ids .cpu ()[:num_reqs ]
653
654
else :
654
655
# NOTE (NickLucche) here we sync with TPU: sampling params tensors
@@ -857,18 +858,23 @@ def capture_model(self) -> None:
857
858
dtype = self ._hidden_states_dtype )
858
859
# Compile for [8, 16, .., 128,.., `self.max_num_reqs`]
859
860
while True :
861
+ logger .info (" -- num_tokens: %d, num_seqs: %d" , num_tokens ,
862
+ num_reqs_to_sample )
860
863
indices = torch .zeros (
861
864
num_reqs_to_sample ,
862
865
dtype = torch .int32 ,
863
866
device = device ,
864
867
)
865
868
xm .mark_step ()
866
- sampling_meta = TPUSupportedSamplingMetadata .\
867
- from_input_batch (self .input_batch , indices )
868
- logger .info (" -- num_tokens: %d, num_seqs: %d" , num_tokens ,
869
- num_reqs_to_sample )
870
- out = self .model .sample_from_hidden (dummy_hidden ,
871
- sampling_meta )
869
+ if self ._disable_sampler :
870
+ # Compile no sampler path for debugging performance
871
+ out = self .model .compute_logits_no_sampler (
872
+ dummy_hidden , indices )
873
+ else :
874
+ sampling_meta = TPUSupportedSamplingMetadata .\
875
+ from_input_batch (self .input_batch , indices )
876
+ out = self .model .sample_from_hidden (
877
+ dummy_hidden , sampling_meta )
872
878
out = out .cpu ()
873
879
# Requests can't be more than tokens. But do compile for the
874
880
# next bigger value in case num_tokens uses bucketed padding.
@@ -991,13 +997,10 @@ def compute_logits(self,
991
997
992
998
@torch .compile (backend = "openxla" , fullgraph = True , dynamic = False )
993
999
def compute_logits_no_sampler (
994
- self ,
995
- hidden_states : torch .Tensor ,
996
- logits_indices : torch .Tensor ,
997
- sampling_metadata ,
998
- ) -> Optional [torch .Tensor ]:
1000
+ self , hidden_states : torch .Tensor ,
1001
+ logits_indices : torch .Tensor ) -> Optional [torch .Tensor ]:
999
1002
hidden_states = hidden_states [logits_indices ]
1000
- logits = self .model .compute_logits (hidden_states , sampling_metadata )
1003
+ logits = self .model .compute_logits (hidden_states , None )
1001
1004
selected_token_ids = torch .argmax (logits , dim = - 1 , keepdim = True )
1002
1005
return selected_token_ids
1003
1006
0 commit comments