|
36 | 36 | from vllm.v1.outputs import (EMPTY_MODEL_RUNNER_OUTPUT, LogprobsTensors,
|
37 | 37 | ModelRunnerOutput)
|
38 | 38 | from vllm.v1.sample.tpu.metadata import TPUSupportedSamplingMetadata
|
39 |
| -from vllm.v1.sample.tpu.sampler import MAX_TOP_LOGPROBS_TO_GATHER, Sampler as TPUSampler |
| 39 | +from vllm.v1.sample.tpu.sampler import Sampler as TPUSampler |
40 | 40 | from vllm.v1.utils import bind_kv_cache
|
41 | 41 | from vllm.v1.worker.gpu_input_batch import CachedRequestState, InputBatch
|
42 | 42 |
|
@@ -790,7 +790,7 @@ def execute_model(
|
790 | 790 | # Remove padding on cpu and keep dynamic op outside of xla graph.
|
791 | 791 | selected_token_ids = selected_token_ids.cpu()[:num_reqs]
|
792 | 792 | logprobs_lists = logprobs.tolists() \
|
793 |
| - if tpu_sampling_metadata.max_num_logprobs else None |
| 793 | + if tpu_sampling_metadata.logprobs else None |
794 | 794 |
|
795 | 795 | # Update the cache state concurrently. Code above will not block until
|
796 | 796 | # we use `selected_token_ids`. Add mark_step if post-processing changes
|
@@ -893,7 +893,7 @@ def load_model(self) -> None:
|
893 | 893 | xm.mark_step()
|
894 | 894 | xm.wait_device_ops()
|
895 | 895 | self.model = model
|
896 |
| - self.sampler = TPUSampler() |
| 896 | + self.sampler = TPUSampler(self.model_config.max_logprobs) |
897 | 897 |
|
898 | 898 | @torch.no_grad()
|
899 | 899 | def _dummy_run(self, num_tokens: int) -> None:
|
@@ -1104,16 +1104,17 @@ def _precompile_sample_from_logits(self) -> None:
|
1104 | 1104 | # because some operations in the sampler require it to be static.
|
1105 | 1105 | for all_greedy in [False, True]:
|
1106 | 1106 | generate_params_if_all_greedy = not all_greedy
|
1107 |
| - sampling_metadata = ( |
1108 |
| - TPUSupportedSamplingMetadata.from_input_batch( |
1109 |
| - self.input_batch, |
1110 |
| - num_reqs, |
1111 |
| - self.device, |
1112 |
| - generate_params_if_all_greedy, |
1113 |
| - )) |
1114 |
| - print("COMPILING", sampling_metadata.max_num_logprobs) |
1115 |
| - sampling_metadata.all_greedy = all_greedy |
1116 |
| - self.sample_from_logits(dummy_logits, sampling_metadata) |
| 1107 | + for top_logprobs in [False, True]: |
| 1108 | + sampling_metadata = ( |
| 1109 | + TPUSupportedSamplingMetadata.from_input_batch( |
| 1110 | + self.input_batch, |
| 1111 | + num_reqs, |
| 1112 | + self.device, |
| 1113 | + generate_params_if_all_greedy, |
| 1114 | + )) |
| 1115 | + sampling_metadata.logprobs = top_logprobs |
| 1116 | + sampling_metadata.all_greedy = all_greedy |
| 1117 | + self.sample_from_logits(dummy_logits, sampling_metadata) |
1117 | 1118 | logger.info(" -- num_seqs: %d", num_reqs)
|
1118 | 1119 | xm.wait_device_ops()
|
1119 | 1120 | end = time.perf_counter()
|
@@ -1256,14 +1257,19 @@ def sample_from_logits(
|
1256 | 1257 | """
|
1257 | 1258 | Sample with xla-friendly function. This function is to be traced
|
1258 | 1259 | separately from `forward` for lighter compilation overhead.
|
| 1260 | + Optionally (in a separate graph) returns top-logprobs too, by gathering |
| 1261 | + a fixed maximum number of logprobs for the whole batch, 20 by default. |
1259 | 1262 | """
|
1260 | 1263 | if sampling_metadata.all_greedy:
|
1261 | 1264 | out_tokens = torch.argmax(logits, dim=-1, keepdim=True)
|
1262 |
| - # TODO skip if not needed and compile for it |
1263 |
| - logprobs = self.sampler.compute_logprobs(logits) |
1264 |
| - logprobs_tensors = self.sampler.gather_logprobs(logprobs, |
1265 |
| - MAX_TOP_LOGPROBS_TO_GATHER, |
1266 |
| - token_ids=out_tokens.squeeze(-1)) |
| 1265 | + if sampling_metadata.logprobs: |
| 1266 | + logprobs = self.sampler.compute_logprobs(logits) |
| 1267 | + logprobs_tensors = self.sampler.gather_logprobs( |
| 1268 | + logprobs, |
| 1269 | + self.model_config.max_logprobs, |
| 1270 | + token_ids=out_tokens.squeeze(-1)) |
| 1271 | + else: |
| 1272 | + logprobs_tensors = None |
1267 | 1273 | else:
|
1268 | 1274 | sampler_out = self.sampler(logits, sampling_metadata)
|
1269 | 1275 | out_tokens = sampler_out.sampled_token_ids
|
|
0 commit comments