@@ -791,8 +791,18 @@ def execute_model(
791
791
arange )
792
792
selected_token_ids = self .sample_from_logits (logits ,
793
793
tpu_sampling_metadata )
794
+
795
+ # NOTE (NickLucche) Use the original logits (before any penalties or
796
+ # temperature scaling) for the top-k logprobs. We can't enforce it due
797
+ # to recompilations outside torch.compiled code, so just make sure
798
+ # `sample_from_logits` does not modify the logits in-place.
799
+ logprobs = self .gather_logprobs (logits , selected_token_ids ) \
800
+ if tpu_sampling_metadata .logprobs else None
801
+
794
802
# Remove padding on cpu and keep dynamic op outside of xla graph.
795
803
selected_token_ids = selected_token_ids .cpu ()[:num_reqs ]
804
+ logprobs_lists = logprobs .tolists () \
805
+ if tpu_sampling_metadata .logprobs else None
796
806
797
807
# Update the cache state concurrently. Code above will not block until
798
808
# we use `selected_token_ids`. Add mark_step if post-processing changes
@@ -862,7 +872,7 @@ def execute_model(
862
872
req_id_to_index = self .input_batch .req_id_to_index ,
863
873
sampled_token_ids = valid_sampled_token_ids ,
864
874
spec_token_ids = None ,
865
- logprobs = None ,
875
+ logprobs = logprobs_lists ,
866
876
prompt_logprobs_dict = prompt_logprobs_dict ,
867
877
)
868
878
@@ -1121,6 +1131,22 @@ def _precompile_sample_from_logits(self) -> None:
1121
1131
logger .info ("Compilation finished in %.2f [secs]." , end - start )
1122
1132
self ._update_num_xla_graphs ("sample_from_logits" )
1123
1133
1134
+ def _precompile_gather_logprobs (self ) -> None :
1135
+ logger .info ("Compiling gather_logprobs with different input shapes." )
1136
+ start = time .perf_counter ()
1137
+ for num_reqs in self .num_reqs_paddings :
1138
+ dummy_logits = torch .zeros ((num_reqs , self .vocab_size ),
1139
+ device = self .device ,
1140
+ dtype = self ._hidden_states_dtype )
1141
+ dummy_tokens = torch .zeros ((num_reqs , 1 ),
1142
+ dtype = torch .int64 ).to (self .device )
1143
+ self .gather_logprobs (dummy_logits , dummy_tokens )
1144
+ logger .info (" -- num_seqs: %d" , num_reqs )
1145
+ xm .wait_device_ops ()
1146
+ end = time .perf_counter ()
1147
+ logger .info ("Compilation finished in %.2f [secs]." , end - start )
1148
+ self ._update_num_xla_graphs ("gather_logprobs" )
1149
+
1124
1150
def capture_model (self ) -> None :
1125
1151
"""
1126
1152
Precompile all the subgraphs with possible input shapes.
@@ -1131,6 +1157,7 @@ def capture_model(self) -> None:
1131
1157
self ._precompile_compute_logits ()
1132
1158
self ._precompile_structured_decoding ()
1133
1159
self ._precompile_sample_from_logits ()
1160
+ self ._precompile_gather_logprobs ()
1134
1161
1135
1162
def profile_run (
1136
1163
self ,
@@ -1254,13 +1281,31 @@ def compute_logits(self,
1254
1281
def sample_from_logits (
1255
1282
self , logits : torch .Tensor ,
1256
1283
sampling_metadata : TPUSupportedSamplingMetadata ) -> torch .Tensor :
1284
+ """
1285
+ Sample with xla-friendly function. This function is to be traced
1286
+ separately from `forward` for lighter compilation overhead.
1287
+ """
1257
1288
if sampling_metadata .all_greedy :
1258
1289
out_tokens = torch .argmax (logits , dim = - 1 , keepdim = True )
1259
1290
else :
1260
1291
out_tokens = self .sampler (logits ,
1261
1292
sampling_metadata ).sampled_token_ids
1262
1293
return out_tokens
1263
1294
1295
+ @torch .compile (backend = "openxla" , fullgraph = True , dynamic = False )
1296
+ def gather_logprobs (self , logits : torch .Tensor ,
1297
+ sampled_tokens : torch .Tensor ) -> LogprobsTensors :
1298
+ """
1299
+ Gather the top_logprobs with corresponding tokens. Use a fixed number
1300
+ of logprobs as an alternative to having multiple pre-compiled graphs.
1301
+ Select the number of logprobs actually demanded by each request on CPU.
1302
+ """
1303
+ logprobs = self .sampler .compute_logprobs (logits )
1304
+ return self .sampler .gather_logprobs (
1305
+ logprobs ,
1306
+ self .model_config .max_logprobs ,
1307
+ token_ids = sampled_tokens .squeeze (- 1 ))
1308
+
1264
1309
@torch .compile (backend = "openxla" , fullgraph = True , dynamic = False )
1265
1310
def structured_decode (self , require_struct_decoding : torch .Tensor ,
1266
1311
grammar_bitmask : torch .Tensor , logits : torch .Tensor ,
0 commit comments