@@ -639,34 +639,26 @@ def execute_model(
639
639
inputs_embeds = None
640
640
num_reqs = self .input_batch .num_reqs
641
641
642
- # Temporary debug pathway.
642
+ with set_forward_context (attn_metadata , self .vllm_config ):
643
+ hidden_states = self .model (
644
+ input_ids = input_ids ,
645
+ positions = self .position_ids ,
646
+ kv_caches = self .kv_caches ,
647
+ inputs_embeds = inputs_embeds ,
648
+ )
649
+ # Temporary debug pathway for sampling.
643
650
if self ._disable_sampler :
644
- with set_forward_context (attn_metadata , self .vllm_config ):
645
- hidden_states = self .model (
646
- input_ids = input_ids ,
647
- positions = self .position_ids ,
648
- kv_caches = self .kv_caches ,
649
- inputs_embeds = inputs_embeds ,
650
- )
651
651
selected_token_ids = self .model .compute_logits_no_sampler (
652
652
hidden_states , logits_indices )
653
- selected_token_ids = selected_token_ids .cpu ()[:num_reqs ]
654
653
else :
655
654
# NOTE (NickLucche) here we sync with TPU: sampling params tensors
656
655
# are copied to device in chunks of pre-compiled padded shape to
657
656
# avoid recompilations.
658
657
tpu_sampling_metadata = TPUSupportedSamplingMetadata .\
659
658
from_input_batch (self .input_batch , logits_indices )
660
- with set_forward_context (attn_metadata , self .vllm_config ):
661
- hidden_states = self .model (
662
- input_ids = input_ids ,
663
- positions = self .position_ids ,
664
- kv_caches = self .kv_caches ,
665
- inputs_embeds = inputs_embeds ,
666
- )
667
- selected_token_ids = self .model .sample_from_hidden (
668
- hidden_states , tpu_sampling_metadata )
669
- selected_token_ids = selected_token_ids .cpu ()[:num_reqs ]
659
+ selected_token_ids = self .model .sample_from_hidden (
660
+ hidden_states , tpu_sampling_metadata )
661
+ selected_token_ids = selected_token_ids .cpu ()[:num_reqs ]
670
662
671
663
# Update the cache state concurrently. Code above will not block until
672
664
# we use `selected_token_ids`. Add mark_step if post-processing changes
0 commit comments