Skip to content

Commit 3a24309

Browse files
authored
Optimize _get_ranks in Sampler (#3623)
1 parent 64172a9 commit 3a24309

File tree

1 file changed

+17
-10
lines changed

1 file changed

+17
-10
lines changed

vllm/model_executor/layers/sampler.py

Lines changed: 17 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -506,22 +506,23 @@ def _sample(
506506
# sampling_tensors)
507507

508508

509-
def _get_ranks(x: torch.Tensor, indices: List[int]) -> torch.Tensor:
509+
def _get_ranks(x: torch.Tensor, indices: torch.Tensor) -> torch.Tensor:
510510
"""
511511
This function calculates the ranks of the chosen tokens in a logprob tensor.
512512
513513
Args:
514514
x (torch.Tensor): 2D logprob tensor of shape (N, M)
515515
where N is the no. of tokens and M is the vocab dim.
516-
indices (List[int]): List of chosen token indices.
516+
indices (torch.Tensor): List of chosen token indices.
517517
518518
Returns:
519519
torch.Tensor: 1D tensor of shape (N,) where N is the no. of tokens.
520520
Each element in the returned tensor represents the rank
521521
of the chosen token in the input logprob tensor.
522522
"""
523-
vals = x[range(len(x)), indices]
524-
return (x > vals[:, None]).long().sum(1) + 1
523+
vals = x[torch.arange(0, len(x), device=x.device, dtype=indices.dtype),
524+
indices]
525+
return (x > vals[:, None]).long().sum(1).add_(1)
525526

526527

527528
def _get_logprobs(
@@ -561,12 +562,21 @@ def _get_logprobs(
561562
sample_idx += num_parent_seqs
562563
assert sample_idx == logprobs.size(0)
563564

565+
batched_logprobs_query_seq_indices_gpu = torch.tensor(
566+
batched_logprobs_query_seq_indices, device=logprobs.device)
567+
batched_logprobs_query_token_indices_gpu = torch.tensor(
568+
batched_logprobs_query_token_indices, device=logprobs.device)
569+
564570
# Batched query for logprobs of selected token
565571
batched_logprobs_query_result = logprobs[[
566-
batched_logprobs_query_seq_indices,
567-
batched_logprobs_query_token_indices
572+
batched_logprobs_query_seq_indices_gpu,
573+
batched_logprobs_query_token_indices_gpu
568574
]]
569575

576+
batched_ranks_query_result = _get_ranks(
577+
logprobs[batched_logprobs_query_seq_indices_gpu],
578+
batched_logprobs_query_token_indices_gpu)
579+
570580
# Batched query for logprobs of topk tokens
571581
if largest_num_logprobs > 0:
572582
top_logprobs, top_token_ids = torch.topk(logprobs,
@@ -578,10 +588,7 @@ def _get_logprobs(
578588
top_logprobs, top_token_ids = None, None
579589

580590
batched_logprobs_query_result = batched_logprobs_query_result.cpu()
581-
582-
batched_ranks_query_result = _get_ranks(
583-
logprobs[batched_logprobs_query_seq_indices],
584-
batched_logprobs_query_token_indices)
591+
batched_ranks_query_result = batched_ranks_query_result.cpu()
585592

586593
# Gather results
587594
result_prompt_logprobs: List[Optional[PromptLogprobs]] = []

0 commit comments

Comments
 (0)