File tree 1 file changed +4
-2
lines changed
1 file changed +4
-2
lines changed Original file line number Diff line number Diff line change @@ -787,7 +787,8 @@ def execute_model(
787
787
logits = self .structured_decode (require_struct_decoding ,
788
788
grammar_bitmask_padded , logits ,
789
789
arange )
790
- selected_token_ids , logprobs = self .sample_from_logits (logits , tpu_sampling_metadata )
790
+ selected_token_ids , logprobs = self .sample_from_logits (
791
+ logits , tpu_sampling_metadata )
791
792
# Remove padding on cpu and keep dynamic op outside of xla graph.
792
793
selected_token_ids = selected_token_ids .cpu ()[:num_reqs ]
793
794
logprobs_lists = logprobs .tolists () \
@@ -1254,7 +1255,8 @@ def compute_logits(self,
1254
1255
@torch .compile (backend = "openxla" , fullgraph = True , dynamic = False )
1255
1256
def sample_from_logits (
1256
1257
self , logits : torch .Tensor ,
1257
- sampling_metadata : TPUSupportedSamplingMetadata ) -> torch .Tensor :
1258
+ sampling_metadata : TPUSupportedSamplingMetadata ) -> \
1259
+ tuple [torch .Tensor , Optional [LogprobsTensors ]]:
1258
1260
"""
1259
1261
Sample with xla-friendly function. This function is to be traced
1260
1262
separately from `forward` for lighter compilation overhead.
You can’t perform that action at this time.
0 commit comments