@@ -688,20 +688,19 @@ def execute_model(
688
688
hidden_states = self .model (
689
689
input_ids = input_ids ,
690
690
positions = self .position_ids ,
691
- kv_caches = self .kv_caches ,
692
691
inputs_embeds = inputs_embeds ,
693
692
)
694
693
# Temporary debug pathway for sampling.
695
694
if self ._disable_sampler :
696
- selected_token_ids = self .model . compute_logits_no_sampler (
695
+ selected_token_ids = self .compute_logits_no_sampler (
697
696
hidden_states , logits_indices )
698
697
else :
699
698
# NOTE (NickLucche) here we sync with TPU: sampling params tensors
700
699
# are copied to device in chunks of pre-compiled padded shape to
701
700
# avoid recompilations.
702
701
tpu_sampling_metadata = TPUSupportedSamplingMetadata .\
703
702
from_input_batch (self .input_batch , logits_indices )
704
- selected_token_ids = self .model . sample_from_hidden (
703
+ selected_token_ids = self .sample_from_hidden (
705
704
hidden_states , tpu_sampling_metadata )
706
705
selected_token_ids = selected_token_ids .cpu ()[:num_reqs ]
707
706
@@ -902,13 +901,11 @@ def capture_model(self) -> None:
902
901
xm .mark_step ()
903
902
if self ._disable_sampler :
904
903
# Compile no sampler path for debugging performance
905
- out = self .model .compute_logits_no_sampler (
906
- dummy_hidden , indices )
904
+ out = self .compute_logits_no_sampler (dummy_hidden , indices )
907
905
else :
908
906
sampling_meta = TPUSupportedSamplingMetadata .\
909
907
from_input_batch (self .input_batch , indices )
910
- out = self .model .sample_from_hidden (
911
- dummy_hidden , sampling_meta )
908
+ out = self .sample_from_hidden (dummy_hidden , sampling_meta )
912
909
out = out .cpu ()
913
910
# Requests can't be more than tokens. But do compile for the
914
911
# next bigger value in case num_tokens uses bucketed padding.
@@ -980,9 +977,9 @@ def sample_from_hidden(
980
977
sampling_metadata : TPUSupportedSamplingMetadata ,
981
978
) -> torch .Tensor :
982
979
"""
983
- Sample with xla-friendly function. This function is to be traced
984
- separately for lighter compilation overhead.
985
- """
980
+ Sample with xla-friendly function. This function is to be traced
981
+ separately for lighter compilation overhead.
982
+ """
986
983
# Tensor `sample_hidden_states` is of fixed pre-compiled size.
987
984
sample_hidden_states = \
988
985
hidden_states [sampling_metadata .indices_do_sample ]
@@ -1004,7 +1001,6 @@ def sample(
1004
1001
sample (logits , sampling_metadata ).sampled_token_ids )
1005
1002
return out_tokens
1006
1003
1007
-
1008
1004
@torch .compile (backend = "openxla" , fullgraph = True , dynamic = False )
1009
1005
def compute_logits_no_sampler (
1010
1006
self , hidden_states : torch .Tensor ,
0 commit comments