@@ -786,8 +786,16 @@ def execute_model(
786
786
logits = self .structured_decode (require_struct_decoding ,
787
787
grammar_bitmask_padded , logits ,
788
788
arange )
789
- selected_token_ids , logprobs = self .sample_from_logits (
790
- logits , tpu_sampling_metadata )
789
+ selected_token_ids = self .sample_from_logits (logits ,
790
+ tpu_sampling_metadata )
791
+
792
+ # NOTE (NickLucche) Use the original logits (before any penalties or
793
+ # temperature scaling) for the top-k logprobs. We can't enforce it due
794
+ # to recompilations outside torch.compiled code, so just make sure
795
+ # `sample_from_logits` does not modify the logits in-place.
796
+ logprobs = self .gather_logprobs (logits , selected_token_ids ) \
797
+ if tpu_sampling_metadata .logprobs else None
798
+
791
799
# Remove padding on cpu and keep dynamic op outside of xla graph.
792
800
selected_token_ids = selected_token_ids .cpu ()[:num_reqs ]
793
801
logprobs_lists = logprobs .tolists () \
@@ -894,7 +902,7 @@ def load_model(self) -> None:
894
902
xm .mark_step ()
895
903
xm .wait_device_ops ()
896
904
self .model = model
897
- self .sampler = TPUSampler (self . model_config . max_logprobs )
905
+ self .sampler = TPUSampler ()
898
906
899
907
@torch .no_grad ()
900
908
def _dummy_run (self , num_tokens : int ) -> None :
@@ -1105,23 +1113,37 @@ def _precompile_sample_from_logits(self) -> None:
1105
1113
# because some operations in the sampler require it to be static.
1106
1114
for all_greedy in [False , True ]:
1107
1115
generate_params_if_all_greedy = not all_greedy
1108
- for top_logprobs in [False , True ]:
1109
- sampling_metadata = (
1110
- TPUSupportedSamplingMetadata .from_input_batch (
1111
- self .input_batch ,
1112
- num_reqs ,
1113
- self .device ,
1114
- generate_params_if_all_greedy ,
1115
- ))
1116
- sampling_metadata .logprobs = top_logprobs
1117
- sampling_metadata .all_greedy = all_greedy
1118
- self .sample_from_logits (dummy_logits , sampling_metadata )
1116
+ sampling_metadata = (
1117
+ TPUSupportedSamplingMetadata .from_input_batch (
1118
+ self .input_batch ,
1119
+ num_reqs ,
1120
+ self .device ,
1121
+ generate_params_if_all_greedy ,
1122
+ ))
1123
+ sampling_metadata .all_greedy = all_greedy
1124
+ self .sample_from_logits (dummy_logits , sampling_metadata )
1119
1125
logger .info (" -- num_seqs: %d" , num_reqs )
1120
1126
xm .wait_device_ops ()
1121
1127
end = time .perf_counter ()
1122
1128
logger .info ("Compilation finished in %.2f [secs]." , end - start )
1123
1129
self ._update_num_xla_graphs ("sample_from_logits" )
1124
1130
1131
+ def _precompile_gather_logprobs (self ) -> None :
1132
+ logger .info ("Compiling gather_logprobs with different input shapes." )
1133
+ start = time .perf_counter ()
1134
+ for num_reqs in self .num_reqs_paddings :
1135
+ dummy_logits = torch .zeros ((num_reqs , self .vocab_size ),
1136
+ device = self .device ,
1137
+ dtype = self ._hidden_states_dtype )
1138
+ dummy_tokens = torch .zeros ((num_reqs , 1 ),
1139
+ dtype = torch .int64 ).to (self .device )
1140
+ self .gather_logprobs (dummy_logits , dummy_tokens )
1141
+ logger .info (" -- num_seqs: %d" , num_reqs )
1142
+ xm .wait_device_ops ()
1143
+ end = time .perf_counter ()
1144
+ logger .info ("Compilation finished in %.2f [secs]." , end - start )
1145
+ self ._update_num_xla_graphs ("gather_logprobs" )
1146
+
1125
1147
def capture_model (self ) -> None :
1126
1148
"""
1127
1149
Precompile all the subgraphs with possible input shapes.
@@ -1132,6 +1154,7 @@ def capture_model(self) -> None:
1132
1154
self ._precompile_compute_logits ()
1133
1155
self ._precompile_structured_decoding ()
1134
1156
self ._precompile_sample_from_logits ()
1157
+ self ._precompile_gather_logprobs ()
1135
1158
1136
1159
def profile_run (
1137
1160
self ,
@@ -1254,29 +1277,31 @@ def compute_logits(self,
1254
1277
@torch .compile (backend = "openxla" , fullgraph = True , dynamic = False )
1255
1278
def sample_from_logits (
1256
1279
self , logits : torch .Tensor ,
1257
- sampling_metadata : TPUSupportedSamplingMetadata ) -> \
1258
- tuple [torch .Tensor , Optional [LogprobsTensors ]]:
1280
+ sampling_metadata : TPUSupportedSamplingMetadata ) -> torch .Tensor :
1259
1281
"""
1260
1282
Sample with xla-friendly function. This function is to be traced
1261
1283
separately from `forward` for lighter compilation overhead.
1262
- Optionally (in a separate graph) returns top-logprobs too, by gathering
1263
- a fixed maximum number of logprobs for the whole batch, 20 by default.
1264
1284
"""
1265
1285
if sampling_metadata .all_greedy :
1266
1286
out_tokens = torch .argmax (logits , dim = - 1 , keepdim = True )
1267
- if sampling_metadata .logprobs :
1268
- logprobs = self .sampler .compute_logprobs (logits )
1269
- logprobs_tensors = self .sampler .gather_logprobs (
1270
- logprobs ,
1271
- self .model_config .max_logprobs ,
1272
- token_ids = out_tokens .squeeze (- 1 ))
1273
- else :
1274
- logprobs_tensors = None
1275
1287
else :
1276
1288
sampler_out = self .sampler (logits , sampling_metadata )
1277
1289
out_tokens = sampler_out .sampled_token_ids
1278
- logprobs_tensors = sampler_out .logprobs_tensors
1279
- return out_tokens , logprobs_tensors
1290
+ return out_tokens
1291
+
1292
+ @torch .compile (backend = "openxla" , fullgraph = True , dynamic = False )
1293
+ def gather_logprobs (self , logits : torch .Tensor ,
1294
+ sampled_tokens : torch .Tensor ) -> LogprobsTensors :
1295
+ """
1296
+ Gather the top_logprobs with corresponding tokens. Use a fixed number
1297
+ of logprobs as an alternative to having multiple pre-compiled graphs.
1298
+ Select the number of logprobs actually demanded by each request on CPU.
1299
+ """
1300
+ logprobs = self .sampler .compute_logprobs (logits )
1301
+ return self .sampler .gather_logprobs (
1302
+ logprobs ,
1303
+ self .model_config .max_logprobs ,
1304
+ token_ids = sampled_tokens .squeeze (- 1 ))
1280
1305
1281
1306
@torch .compile (backend = "openxla" , fullgraph = True , dynamic = False )
1282
1307
def structured_decode (self , require_struct_decoding : torch .Tensor ,
0 commit comments