@@ -684,34 +684,26 @@ def execute_model(
684
684
inputs_embeds = None
685
685
num_reqs = self .input_batch .num_reqs
686
686
687
- # Temporary debug pathway.
687
+ with set_forward_context (attn_metadata , self .vllm_config ):
688
+ hidden_states = self .model (
689
+ input_ids = input_ids ,
690
+ positions = self .position_ids ,
691
+ kv_caches = self .kv_caches ,
692
+ inputs_embeds = inputs_embeds ,
693
+ )
694
+ # Temporary debug pathway for sampling.
688
695
if self ._disable_sampler :
689
- with set_forward_context (attn_metadata , self .vllm_config ):
690
- hidden_states = self .model (
691
- input_ids = input_ids ,
692
- positions = self .position_ids ,
693
- kv_caches = self .kv_caches ,
694
- inputs_embeds = inputs_embeds ,
695
- )
696
696
selected_token_ids = self .model .compute_logits_no_sampler (
697
697
hidden_states , logits_indices )
698
- selected_token_ids = selected_token_ids .cpu ()[:num_reqs ]
699
698
else :
700
699
# NOTE (NickLucche) here we sync with TPU: sampling params tensors
701
700
# are copied to device in chunks of pre-compiled padded shape to
702
701
# avoid recompilations.
703
702
tpu_sampling_metadata = TPUSupportedSamplingMetadata .\
704
703
from_input_batch (self .input_batch , logits_indices )
705
- with set_forward_context (attn_metadata , self .vllm_config ):
706
- hidden_states = self .model (
707
- input_ids = input_ids ,
708
- positions = self .position_ids ,
709
- kv_caches = self .kv_caches ,
710
- inputs_embeds = inputs_embeds ,
711
- )
712
- selected_token_ids = self .model .sample_from_hidden (
713
- hidden_states , tpu_sampling_metadata )
714
- selected_token_ids = selected_token_ids .cpu ()[:num_reqs ]
704
+ selected_token_ids = self .model .sample_from_hidden (
705
+ hidden_states , tpu_sampling_metadata )
706
+ selected_token_ids = selected_token_ids .cpu ()[:num_reqs ]
715
707
716
708
# Update the cache state concurrently. Code above will not block until
717
709
# we use `selected_token_ids`. Add mark_step if post-processing changes
0 commit comments