File tree 1 file changed +4
-2
lines changed
1 file changed +4
-2
lines changed Original file line number Diff line number Diff line change @@ -786,7 +786,8 @@ def execute_model(
786
786
logits = self .structured_decode (require_struct_decoding ,
787
787
grammar_bitmask_padded , logits ,
788
788
arange )
789
- selected_token_ids , logprobs = self .sample_from_logits (logits , tpu_sampling_metadata )
789
+ selected_token_ids , logprobs = self .sample_from_logits (
790
+ logits , tpu_sampling_metadata )
790
791
# Remove padding on cpu and keep dynamic op outside of xla graph.
791
792
selected_token_ids = selected_token_ids .cpu ()[:num_reqs ]
792
793
logprobs_lists = logprobs .tolists () \
@@ -1253,7 +1254,8 @@ def compute_logits(self,
1253
1254
@torch .compile (backend = "openxla" , fullgraph = True , dynamic = False )
1254
1255
def sample_from_logits (
1255
1256
self , logits : torch .Tensor ,
1256
- sampling_metadata : TPUSupportedSamplingMetadata ) -> torch .Tensor :
1257
+ sampling_metadata : TPUSupportedSamplingMetadata ) -> \
1258
+ tuple [torch .Tensor , Optional [LogprobsTensors ]]:
1257
1259
"""
1258
1260
Sample with xla-friendly function. This function is to be traced
1259
1261
separately from `forward` for lighter compilation overhead.
You can’t perform that action at this time.
0 commit comments