Skip to content

Commit e7d1414

Browse files
committed
refactor logprobs into separate torch.compiled graph
Signed-off-by: NickLucche <[email protected]>
1 parent d14da5d commit e7d1414

File tree

2 files changed

+55
-54
lines changed

2 files changed

+55
-54
lines changed

vllm/v1/sample/tpu/sampler.py

+2-26
Original file line numberDiff line numberDiff line change
@@ -13,58 +13,34 @@
1313

1414
class Sampler(nn.Module):
1515

16-
def __init__(self, max_logprobs: int):
16+
def __init__(self):
1717
super().__init__()
1818
self.topk_topp_sampler = TopKTopPSampler()
19-
# Gather a fixed amount of top logprobs. Defaults to 20.
20-
self.max_logprobs = max_logprobs
2119

2220
def forward(
2321
self,
2422
logits: torch.Tensor,
2523
sampling_metadata: TPUSupportedSamplingMetadata,
2624
) -> SamplerOutput:
27-
# NOTE(woosuk): Use the original logits (before any penalties or
28-
# temperature scaling) for the top-k logprobs.
29-
# This is different from the V0 sampler, which uses the logits that
30-
# is used for sampling (after penalties and temperature scaling).
31-
if sampling_metadata.logprobs:
32-
raw_logprobs = self.compute_logprobs(logits)
33-
3425
# Use float32 for the logits.
3526
logits = logits.to(torch.float32)
3627
# Sample the next token.
3728
sampled = self.sample(logits, sampling_metadata)
3829

39-
# Gather the top_logprobs with corresponding tokens. Use a fixed number
40-
# of logprobs as an alternative to having multiple pre-compiled graphs.
41-
# Select the logprobs actually demanded by each request on CPU.
42-
if sampling_metadata.logprobs:
43-
logprobs_tensors = self.gather_logprobs(raw_logprobs,
44-
self.max_logprobs,
45-
token_ids=sampled)
46-
else:
47-
logprobs_tensors = None
48-
49-
# Use int32 to reduce the tensor size.
50-
sampled = sampled.to(torch.int32)
51-
5230
# These are TPU tensors.
5331
sampler_output = SamplerOutput(
5432
# The sampled tokens are expanded to 2D tensor with shape
5533
# [num_requests, 1], where each row represents one generated
5634
# token per request.
5735
sampled_token_ids=sampled.unsqueeze(-1),
58-
logprobs_tensors=logprobs_tensors,
59-
)
36+
logprobs_tensors=None)
6037
return sampler_output
6138

6239
def apply_temperature(
6340
self,
6441
logits: torch.Tensor,
6542
temp: torch.Tensor,
6643
) -> torch.Tensor:
67-
# Use in-place division to avoid creating a new tensor.
6844
return logits.div_(temp.unsqueeze(dim=1))
6945

7046
def greedy_sample(self, logits: torch.Tensor) -> torch.Tensor:

vllm/v1/worker/tpu_model_runner.py

+53-28
Original file line numberDiff line numberDiff line change
@@ -786,8 +786,16 @@ def execute_model(
786786
logits = self.structured_decode(require_struct_decoding,
787787
grammar_bitmask_padded, logits,
788788
arange)
789-
selected_token_ids, logprobs = self.sample_from_logits(
790-
logits, tpu_sampling_metadata)
789+
selected_token_ids = self.sample_from_logits(logits,
790+
tpu_sampling_metadata)
791+
792+
# NOTE (NickLucche) Use the original logits (before any penalties or
793+
# temperature scaling) for the top-k logprobs. We can't enforce it due
794+
# to recompilations outside torch.compiled code, so just make sure
795+
# `sample_from_logits` does not modify the logits in-place.
796+
logprobs = self.gather_logprobs(logits, selected_token_ids) \
797+
if tpu_sampling_metadata.logprobs else None
798+
791799
# Remove padding on cpu and keep dynamic op outside of xla graph.
792800
selected_token_ids = selected_token_ids.cpu()[:num_reqs]
793801
logprobs_lists = logprobs.tolists() \
@@ -894,7 +902,7 @@ def load_model(self) -> None:
894902
xm.mark_step()
895903
xm.wait_device_ops()
896904
self.model = model
897-
self.sampler = TPUSampler(self.model_config.max_logprobs)
905+
self.sampler = TPUSampler()
898906

899907
@torch.no_grad()
900908
def _dummy_run(self, num_tokens: int) -> None:
@@ -1105,23 +1113,37 @@ def _precompile_sample_from_logits(self) -> None:
11051113
# because some operations in the sampler require it to be static.
11061114
for all_greedy in [False, True]:
11071115
generate_params_if_all_greedy = not all_greedy
1108-
for top_logprobs in [False, True]:
1109-
sampling_metadata = (
1110-
TPUSupportedSamplingMetadata.from_input_batch(
1111-
self.input_batch,
1112-
num_reqs,
1113-
self.device,
1114-
generate_params_if_all_greedy,
1115-
))
1116-
sampling_metadata.logprobs = top_logprobs
1117-
sampling_metadata.all_greedy = all_greedy
1118-
self.sample_from_logits(dummy_logits, sampling_metadata)
1116+
sampling_metadata = (
1117+
TPUSupportedSamplingMetadata.from_input_batch(
1118+
self.input_batch,
1119+
num_reqs,
1120+
self.device,
1121+
generate_params_if_all_greedy,
1122+
))
1123+
sampling_metadata.all_greedy = all_greedy
1124+
self.sample_from_logits(dummy_logits, sampling_metadata)
11191125
logger.info(" -- num_seqs: %d", num_reqs)
11201126
xm.wait_device_ops()
11211127
end = time.perf_counter()
11221128
logger.info("Compilation finished in %.2f [secs].", end - start)
11231129
self._update_num_xla_graphs("sample_from_logits")
11241130

1131+
def _precompile_gather_logprobs(self) -> None:
1132+
logger.info("Compiling gather_logprobs with different input shapes.")
1133+
start = time.perf_counter()
1134+
for num_reqs in self.num_reqs_paddings:
1135+
dummy_logits = torch.zeros((num_reqs, self.vocab_size),
1136+
device=self.device,
1137+
dtype=self._hidden_states_dtype)
1138+
dummy_tokens = torch.zeros((num_reqs, 1),
1139+
dtype=torch.int64).to(self.device)
1140+
self.gather_logprobs(dummy_logits, dummy_tokens)
1141+
logger.info(" -- num_seqs: %d", num_reqs)
1142+
xm.wait_device_ops()
1143+
end = time.perf_counter()
1144+
logger.info("Compilation finished in %.2f [secs].", end - start)
1145+
self._update_num_xla_graphs("gather_logprobs")
1146+
11251147
def capture_model(self) -> None:
11261148
"""
11271149
Precompile all the subgraphs with possible input shapes.
@@ -1132,6 +1154,7 @@ def capture_model(self) -> None:
11321154
self._precompile_compute_logits()
11331155
self._precompile_structured_decoding()
11341156
self._precompile_sample_from_logits()
1157+
self._precompile_gather_logprobs()
11351158

11361159
def profile_run(
11371160
self,
@@ -1254,29 +1277,31 @@ def compute_logits(self,
12541277
@torch.compile(backend="openxla", fullgraph=True, dynamic=False)
12551278
def sample_from_logits(
12561279
self, logits: torch.Tensor,
1257-
sampling_metadata: TPUSupportedSamplingMetadata) -> \
1258-
tuple[torch.Tensor, Optional[LogprobsTensors]]:
1280+
sampling_metadata: TPUSupportedSamplingMetadata) -> torch.Tensor:
12591281
"""
12601282
Sample with xla-friendly function. This function is to be traced
12611283
separately from `forward` for lighter compilation overhead.
1262-
Optionally (in a separate graph) returns top-logprobs too, by gathering
1263-
a fixed maximum number of logprobs for the whole batch, 20 by default.
12641284
"""
12651285
if sampling_metadata.all_greedy:
12661286
out_tokens = torch.argmax(logits, dim=-1, keepdim=True)
1267-
if sampling_metadata.logprobs:
1268-
logprobs = self.sampler.compute_logprobs(logits)
1269-
logprobs_tensors = self.sampler.gather_logprobs(
1270-
logprobs,
1271-
self.model_config.max_logprobs,
1272-
token_ids=out_tokens.squeeze(-1))
1273-
else:
1274-
logprobs_tensors = None
12751287
else:
12761288
sampler_out = self.sampler(logits, sampling_metadata)
12771289
out_tokens = sampler_out.sampled_token_ids
1278-
logprobs_tensors = sampler_out.logprobs_tensors
1279-
return out_tokens, logprobs_tensors
1290+
return out_tokens
1291+
1292+
@torch.compile(backend="openxla", fullgraph=True, dynamic=False)
1293+
def gather_logprobs(self, logits: torch.Tensor,
1294+
sampled_tokens: torch.Tensor) -> LogprobsTensors:
1295+
"""
1296+
Gather the top_logprobs with corresponding tokens. Use a fixed number
1297+
of logprobs as an alternative to having multiple pre-compiled graphs.
1298+
Select the number of logprobs actually demanded by each request on CPU.
1299+
"""
1300+
logprobs = self.sampler.compute_logprobs(logits)
1301+
return self.sampler.gather_logprobs(
1302+
logprobs,
1303+
self.model_config.max_logprobs,
1304+
token_ids=sampled_tokens.squeeze(-1))
12801305

12811306
@torch.compile(backend="openxla", fullgraph=True, dynamic=False)
12821307
def structured_decode(self, require_struct_decoding: torch.Tensor,

0 commit comments

Comments
 (0)