Skip to content

Commit 632c82d

Browse files
committed
separate logprobs graph
Signed-off-by: NickLucche <[email protected]>
1 parent ee18eb7 commit 632c82d

File tree

3 files changed

+39
-27
lines changed

3 files changed

+39
-27
lines changed

vllm/v1/sample/tpu/metadata.py

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -32,8 +32,8 @@ class TPUSupportedSamplingMetadata:
3232
all_greedy: bool = True
3333

3434
# Maximum number of top logprobs requested in current batch.
35-
# TODO use constant from sampler.py OR a bool
36-
max_num_logprobs: Optional[int] = 24
35+
# TODO specify why bool
36+
logprobs: bool = False
3737

3838
# TODO No penalties for now
3939
no_penalties: bool = True
@@ -85,10 +85,12 @@ def from_input_batch(
8585
we want to pre-compile a graph with sampling parameters, even if
8686
they are not strictly needed for greedy decoding.
8787
"""
88+
needs_logprobs = input_batch.max_num_logprobs>0 if \
89+
input_batch.max_num_logprobs else False
8890
# Early return to avoid unnecessary cpu to tpu copy
8991
if (input_batch.all_greedy is True
9092
and generate_params_if_all_greedy is False):
91-
return cls(all_greedy=True)
93+
return cls(all_greedy=True, logprobs=needs_logprobs)
9294

9395
num_reqs = input_batch.num_reqs
9496

@@ -117,4 +119,4 @@ def fill_slice(cpu_tensor: torch.Tensor, fill_val) -> torch.Tensor:
117119
xla_device),
118120
min_p=input_batch.min_p_cpu_tensor[:padded_num_reqs].to(
119121
xla_device),
120-
max_num_logprobs=input_batch.max_num_logprobs)
122+
logprobs=needs_logprobs)

vllm/v1/sample/tpu/sampler.py

Lines changed: 9 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -9,14 +9,15 @@
99
from vllm.v1.sample.tpu.metadata import TPUSupportedSamplingMetadata
1010

1111
_SAMPLING_EPS = 1e-5
12-
MAX_TOP_LOGPROBS_TO_GATHER = 24
1312

1413

1514
class Sampler(nn.Module):
1615

17-
def __init__(self):
16+
def __init__(self, max_logprobs: int):
1817
super().__init__()
1918
self.topk_topp_sampler = TopKTopPSampler()
19+
# Gather a fixed amount of top logprobs. Defaults to 20.
20+
self.max_logprobs = max_logprobs
2021

2122
def forward(
2223
self,
@@ -37,9 +38,12 @@ def forward(
3738
# Gather the top_logprobs with corresponding tokens. Use a fixed number
3839
# of logprobs as an alternative to having multiple pre-compiled graphs.
3940
# Select the logprobs actually demanded by each request on CPU.
40-
logprobs_tensors = self.gather_logprobs(raw_logprobs,
41-
MAX_TOP_LOGPROBS_TO_GATHER,
42-
token_ids=sampled)
41+
if sampling_metadata.logprobs:
42+
logprobs_tensors = self.gather_logprobs(raw_logprobs,
43+
self.max_logprobs,
44+
token_ids=sampled)
45+
else:
46+
logprobs_tensors = None
4347

4448
# Use int32 to reduce the tensor size.
4549
sampled = sampled.to(torch.int32)

vllm/v1/worker/tpu_model_runner.py

Lines changed: 24 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -36,7 +36,7 @@
3636
from vllm.v1.outputs import (EMPTY_MODEL_RUNNER_OUTPUT, LogprobsTensors,
3737
ModelRunnerOutput)
3838
from vllm.v1.sample.tpu.metadata import TPUSupportedSamplingMetadata
39-
from vllm.v1.sample.tpu.sampler import MAX_TOP_LOGPROBS_TO_GATHER, Sampler as TPUSampler
39+
from vllm.v1.sample.tpu.sampler import Sampler as TPUSampler
4040
from vllm.v1.utils import bind_kv_cache
4141
from vllm.v1.worker.gpu_input_batch import CachedRequestState, InputBatch
4242

@@ -790,7 +790,7 @@ def execute_model(
790790
# Remove padding on cpu and keep dynamic op outside of xla graph.
791791
selected_token_ids = selected_token_ids.cpu()[:num_reqs]
792792
logprobs_lists = logprobs.tolists() \
793-
if tpu_sampling_metadata.max_num_logprobs else None
793+
if tpu_sampling_metadata.logprobs else None
794794

795795
# Update the cache state concurrently. Code above will not block until
796796
# we use `selected_token_ids`. Add mark_step if post-processing changes
@@ -893,7 +893,7 @@ def load_model(self) -> None:
893893
xm.mark_step()
894894
xm.wait_device_ops()
895895
self.model = model
896-
self.sampler = TPUSampler()
896+
self.sampler = TPUSampler(self.model_config.max_logprobs)
897897

898898
@torch.no_grad()
899899
def _dummy_run(self, num_tokens: int) -> None:
@@ -1104,16 +1104,17 @@ def _precompile_sample_from_logits(self) -> None:
11041104
# because some operations in the sampler require it to be static.
11051105
for all_greedy in [False, True]:
11061106
generate_params_if_all_greedy = not all_greedy
1107-
sampling_metadata = (
1108-
TPUSupportedSamplingMetadata.from_input_batch(
1109-
self.input_batch,
1110-
num_reqs,
1111-
self.device,
1112-
generate_params_if_all_greedy,
1113-
))
1114-
print("COMPILING", sampling_metadata.max_num_logprobs)
1115-
sampling_metadata.all_greedy = all_greedy
1116-
self.sample_from_logits(dummy_logits, sampling_metadata)
1107+
for top_logprobs in [False, True]:
1108+
sampling_metadata = (
1109+
TPUSupportedSamplingMetadata.from_input_batch(
1110+
self.input_batch,
1111+
num_reqs,
1112+
self.device,
1113+
generate_params_if_all_greedy,
1114+
))
1115+
sampling_metadata.logprobs = top_logprobs
1116+
sampling_metadata.all_greedy = all_greedy
1117+
self.sample_from_logits(dummy_logits, sampling_metadata)
11171118
logger.info(" -- num_seqs: %d", num_reqs)
11181119
xm.wait_device_ops()
11191120
end = time.perf_counter()
@@ -1256,14 +1257,19 @@ def sample_from_logits(
12561257
"""
12571258
Sample with xla-friendly function. This function is to be traced
12581259
separately from `forward` for lighter compilation overhead.
1260+
Optionally (in a separate graph) returns top-logprobs too, by gathering
1261+
a fixed maximum number of logprobs for the whole batch, 20 by default.
12591262
"""
12601263
if sampling_metadata.all_greedy:
12611264
out_tokens = torch.argmax(logits, dim=-1, keepdim=True)
1262-
# TODO skip if not needed and compile for it
1263-
logprobs = self.sampler.compute_logprobs(logits)
1264-
logprobs_tensors = self.sampler.gather_logprobs(logprobs,
1265-
MAX_TOP_LOGPROBS_TO_GATHER,
1266-
token_ids=out_tokens.squeeze(-1))
1265+
if sampling_metadata.logprobs:
1266+
logprobs = self.sampler.compute_logprobs(logits)
1267+
logprobs_tensors = self.sampler.gather_logprobs(
1268+
logprobs,
1269+
self.model_config.max_logprobs,
1270+
token_ids=out_tokens.squeeze(-1))
1271+
else:
1272+
logprobs_tensors = None
12671273
else:
12681274
sampler_out = self.sampler(logits, sampling_metadata)
12691275
out_tokens = sampler_out.sampled_token_ids

0 commit comments

Comments
 (0)