Skip to content

[V1][BugFix] Clean up rejection sampler & Fix warning msg #13362

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Feb 16, 2025
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
109 changes: 69 additions & 40 deletions vllm/v1/sample/rejection_sampler.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,9 @@
import torch.nn as nn
from torch.nn.utils.rnn import pad_sequence

from vllm import envs
from vllm.logger import init_logger
from vllm.platforms import current_platform
from vllm.v1.outputs import SamplerOutput
from vllm.v1.sample.metadata import SamplingMetadata

Expand All @@ -19,27 +21,50 @@

class RejectionSampler(nn.Module):

def __init__(self):
super().__init__()
if current_platform.is_cuda:
if is_flashinfer_available:
if envs.VLLM_USE_FLASHINFER_SAMPLER is not False:
# NOTE(woosuk): The V0 sampler doesn't use FlashInfer for
# sampling unless VLLM_USE_FLASHINFER_SAMPLER=1 (i.e., by
# default it is unused). For backward compatibility, we set
# `VLLM_USE_FLASHINFER_SAMPLER` as None by default and
# interpret it differently in V0 and V1 samplers: In V0,
# None means False, while in V1, None means True. This is
# why we use the condition
# `envs.VLLM_USE_FLASHINFER_SAMPLER is not False` here.
logger.info("Using FlashInfer for rejection sampling.")
self.forward_method = self.flashinfer_sample
else:
logger.warning(
"FlashInfer is available, but it is not enabled. "
"Falling back to the PyTorch-native implementation of "
"rejection sampling. For the best performance, "
"please set VLLM_USE_FLASHINFER_SAMPLER=1.")
self.forward_method = self.forward_native
else:
logger.warning(
"FlashInfer is not available. Falling back to the PyTorch-"
"native implementation of rejection sampling. For the "
"best performance, please install FlashInfer.")
self.forward_method = self.forward_native
else:
self.forward_method = self.forward_native

def forward(self, logits: torch.Tensor,
sampling_metadata: SamplingMetadata) -> SamplerOutput:
if not sampling_metadata.all_greedy:
raise NotImplementedError(
"Only greedy sampling is supported by rejection sampler.")
"Currently, only greedy sampling is supported by "
"rejection sampler.")
return self.forward_method(logits, sampling_metadata)

if is_flashinfer_available:
logger.info("User FlashInfer for rejection sampling.")
return RejectionSampler.flashinfer_sample(logits,
sampling_metadata)
else:
logger.warning(
"FlashInfer is not available. Falling back to the PyTorch-"
"native implementation of rejection sampling.")
return RejectionSampler.greedy_sample_native(
logits, sampling_metadata)

@staticmethod
def flashinfer_sample(
logits: torch.Tensor,
sampling_metadata: SamplingMetadata) -> SamplerOutput:
self,
logits: torch.Tensor,
sampling_metadata: SamplingMetadata,
) -> SamplerOutput:
# NOTE: The following input preparationg can be moved
# to the model runner with a persistent manner for better
# performance.
Expand Down Expand Up @@ -71,10 +96,10 @@ def flashinfer_sample(
vocab_size = logits.size(-1)
# NOTE: CPU <-> GPU synchronization happens here.
draft_token_ids = draft_token_ids.to(logits.device)
draft_probs = RejectionSampler._create_greedy_token_probs(
draft_token_ids, vocab_size, logits.device)
target_probs = RejectionSampler._create_greedy_token_probs(
target_token_ids, vocab_size, logits.device)
draft_probs = _create_greedy_token_probs(draft_token_ids, vocab_size,
logits.device)
target_probs = _create_greedy_token_probs(target_token_ids, vocab_size,
logits.device)
uniform_samples = torch.zeros(batch_size,
max_spec_len + 1,
device=logits.device)
Expand All @@ -89,10 +114,11 @@ def flashinfer_sample(
logprobs_tensors=None)

# TODO: The following method can be optimized for better performance.
@staticmethod
def greedy_sample_native(
logits: torch.Tensor,
sampling_metadata: SamplingMetadata) -> SamplerOutput:
def forward_native(
self,
logits: torch.Tensor,
sampling_metadata: SamplingMetadata,
) -> SamplerOutput:
spec_lens = [len(x) for x in sampling_metadata.spec_token_ids]
# Add 1 to include the 'bonus' token.
sample_lens = [x + 1 for x in spec_lens]
Expand Down Expand Up @@ -137,24 +163,27 @@ def greedy_sample_native(
return SamplerOutput(sampled_token_ids=output_token_ids,
logprobs_tensors=None)

@staticmethod
def _create_greedy_token_probs(token_ids: torch.Tensor, vocab_size: int,
out_device: torch.device) -> torch.Tensor:
batch_size, num_tokens = token_ids.shape

token_probs = torch.zeros(batch_size,
num_tokens,
vocab_size,
dtype=torch.float,
device=out_device)
def _create_greedy_token_probs(
token_ids: torch.Tensor,
vocab_size: int,
out_device: torch.device,
) -> torch.Tensor:
batch_size, num_tokens = token_ids.shape

token_probs = torch.zeros(batch_size,
num_tokens,
vocab_size,
dtype=torch.float,
device=out_device)

# Ignore INVALID_TOKEN_ID.
valid_mask = (token_ids != INVALID_TOKEN_ID)
valid_indices = token_ids.clone()
valid_indices[~valid_mask] = 0
# Ignore INVALID_TOKEN_ID.
valid_mask = (token_ids != INVALID_TOKEN_ID)
valid_indices = token_ids.clone()
valid_indices[~valid_mask] = 0

token_probs.scatter_(dim=2,
index=valid_indices.unsqueeze(-1),
src=valid_mask.unsqueeze(-1).float())
token_probs.scatter_(dim=2,
index=valid_indices.unsqueeze(-1),
src=valid_mask.unsqueeze(-1).float())

return token_probs
return token_probs