Skip to content

[V1][Spec Decode] Non greedy sample with EAGLE / Reduce memory allocation for Rejection Sampler #16077

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 13 commits into
base: main
Choose a base branch
from
Open
17 changes: 17 additions & 0 deletions examples/offline_inference/eagle.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,7 @@ def main():
parser.add_argument("--enable_chunked_prefill", action='store_true')
parser.add_argument("--max_num_batched_tokens", type=int, default=2048)
parser.add_argument("--temp", type=float, default=0)
parser.add_argument("--log_output_filename", type=str, default="eagle_output.txt")
args = parser.parse_args()

model_dir = "meta-llama/Meta-Llama-3-8B-Instruct"
Expand Down Expand Up @@ -90,6 +91,22 @@ def main():
outputs = llm.generate(prompt_token_ids=prompt_ids,
sampling_params=sampling_params)

# save output text in eagle_output.txt file for quality check
log_output_data = []
for i, output in enumerate(outputs):
input_text = tokenizer.decode(output.prompt_token_ids)
log_output_data.append({
"input": input_text,
"output": output.outputs[0].text
})

with open("eagle_output.txt", "w") as f:
f.write(
json.dumps(log_output_data, indent=4, ensure_ascii=False))
print("-" * 50)
print(f"Output texts saved to {args.log_output_filename}")
print("-" * 50)

# calculate the average number of accepted tokens per forward pass, +1 is
# to account for the token from the target model that's always going to be
# accepted
Expand Down
81 changes: 49 additions & 32 deletions tests/v1/e2e/test_spec_decode.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,8 +44,13 @@

@pytest.fixture
def sampling_config():
# Only support greedy for now
return SamplingParams(temperature=0, max_tokens=10, ignore_eos=False)
# return SamplingParams(temperature=0, max_tokens=10, ignore_eos=False)
return [ SamplingParams(temperature=0, max_tokens=10, ignore_eos=False),
SamplingParams(temperature=0.1, max_tokens=10, ignore_eos=False),
SamplingParams(temperature=0.2, max_tokens=10, ignore_eos=False),
SamplingParams(temperature=0.3, max_tokens=10, ignore_eos=False),
# SamplingParams(temperature=1, top_p=0.75, max_tokens=10, ignore_eos=False),
]

Check failure on line 53 in tests/v1/e2e/test_spec_decode.py

View workflow job for this annotation

GitHub Actions / pre-commit

Ruff (E501)

tests/v1/e2e/test_spec_decode.py:53:81: E501 Line too long (86 > 80)


@pytest.fixture
Expand All @@ -72,7 +77,9 @@
m.setenv("VLLM_USE_V1", "1")

ref_llm = LLM(model=model_name, max_model_len=1024)
ref_outputs = ref_llm.chat(test_prompts, sampling_config)
ref_outputs = []
for sampling_param in sampling_config:
ref_outputs.append(ref_llm.chat(test_prompts, sampling_param))
del ref_llm

spec_llm = LLM(
Expand All @@ -85,20 +92,22 @@
},
max_model_len=1024,
)
spec_outputs = spec_llm.chat(test_prompts, sampling_config)
matches = 0
misses = 0
for ref_output, spec_output in zip(ref_outputs, spec_outputs):
if ref_output.outputs[0].text == spec_output.outputs[0].text:
matches += 1
else:
misses += 1
print(f"ref_output: {ref_output.outputs[0].text}")
print(f"spec_output: {spec_output.outputs[0].text}")

# Heuristic: expect at least 70% of the prompts to match exactly
# Upon failure, inspect the outputs to check for inaccuracy.
assert matches > int(0.7 * len(ref_outputs))

for i, sampling_param in enumerate(sampling_config):
spec_output = spec_llm.chat(test_prompts, sampling_param)
matches = 0
misses = 0
for ref_output, spec_output in zip(ref_outputs[i], spec_output):
if ref_output.outputs[0].text == spec_output.outputs[0].text:

Check failure on line 101 in tests/v1/e2e/test_spec_decode.py

View workflow job for this annotation

GitHub Actions / pre-commit

Ruff (B020)

tests/v1/e2e/test_spec_decode.py:101:29: B020 Loop control variable `spec_output` overrides iterable it iterates
matches += 1
else:
misses += 1
print(f"ref_output: {ref_output.outputs[0].text}")
print(f"spec_output: {spec_output.outputs[0].text}")

# Heuristic: expect at least 70% of the prompts to match exactly
# Upon failure, inspect the outputs to check for inaccuracy.
assert matches > int(0.7 * len(ref_outputs[i]))
del spec_llm


Expand All @@ -115,9 +124,13 @@
'''
with monkeypatch.context() as m:
m.setenv("VLLM_USE_V1", "1")
ref_outputs = []
spec_outputs = []

ref_llm = LLM(model=model_name, max_model_len=1024)
ref_outputs = ref_llm.chat(test_prompts, sampling_config)

for sampling_param in sampling_config:
ref_outputs.append(ref_llm.chat(test_prompts, sampling_param))
del ref_llm

spec_llm = LLM(
Expand All @@ -129,18 +142,22 @@
},
max_model_len=1024,
)
spec_outputs = spec_llm.chat(test_prompts, sampling_config)
matches = 0
misses = 0
for ref_output, spec_output in zip(ref_outputs, spec_outputs):
if ref_output.outputs[0].text == spec_output.outputs[0].text:
matches += 1
else:
misses += 1
print(f"ref_output: {ref_output.outputs[0].text}")
print(f"spec_output: {spec_output.outputs[0].text}")

# Heuristic: expect at least 70% of the prompts to match exactly
# Upon failure, inspect the outputs to check for inaccuracy.
assert matches > int(0.7 * len(ref_outputs))

for sampling_param in sampling_config:
spec_outputs.append(spec_llm.chat(test_prompts, sampling_param))
del spec_llm

for i in range(len(sampling_config)):
matches = 0
misses = 0
for ref_output, spec_output in zip(ref_outputs[i], spec_outputs[i]):
if ref_output.outputs[0].text == spec_output.outputs[0].text:
matches += 1
else:
misses += 1
print(f"ref_output: {ref_output.outputs[0].text}")
print(f"spec_output: {spec_output.outputs[0].text}")

# Heuristic: expect at least 70% of the prompts to match exactly
# Upon failure, inspect the outputs to check for inaccuracy.
assert matches > int(0.7 * len(ref_outputs[i])), "Failed for sampling_param: " + str(sampling_config[i])
145 changes: 114 additions & 31 deletions vllm/v1/spec_decode/eagle.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,39 @@
device=device,
dtype=torch.int32)

max_batch_size = vllm_config.scheduler_config.max_num_seqs
vocab_size = vllm_config.model_config.get_vocab_size()

# setup buffers for draft token ids and probs to be
# reused across steps for Rejection Sampling
self.curr_num_tokens = -1
self.curr_batch_size = -1

# packed tensor for [bs, num_speculative_tokens]
self._draft_token_ids_buffer = torch.zeros(max_batch_size,
self.num_speculative_tokens,
dtype=torch.long,
device=device)

# packed tensor for [num_tokens, vocab_size]
self._draft_token_ids_buffer_shape = self._draft_token_ids_buffer.shape
self._draft_probs_buffer = torch.zeros(max_batch_size * self.num_speculative_tokens,
vocab_size,
# TODO(ekagra): pass dtype
dtype=torch.float32,
device=device)
self._draft_probs_buffer_shape = self._draft_probs_buffer.shape
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Just want to point this out, we might need to fix it. For the draft_probs_buffer, it has size (plug in numbers of llama3-8B):
256 * 10 * 128256 * 4 / 1024 / 1024 = 1.3G
It has a low probability that this might trigger OOM if we do this after vLLM preallocates all memory for kv cache. But it should not be a big problem.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is there a way to allocate this before vLLM preallocates memory for KVC?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The preallocation of RS buffer is happening here which is when gpuModelRunner is created. Could you point me to which line of code computes the available GPU memory and allocated the KVC on that?


def get_draft_token_ids(self) -> torch.Tensor:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Since we have the get_draft_token_ids API, I'm wondering if it might be cleaner to move all proposing logic (https://github.com/vllm-project/vllm/blob/660a6b0ed756bb7ca0459786fd8302b9ede2c280/vllm/v1/worker/gpu_model_runner.py#L1171C8-L1229C14) under this function?

Copy link
Contributor Author

@ekagra-ranjan ekagra-ranjan Apr 10, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

the get_draft_token_ids is API is more like a getter for the right slice of the preallocated buffer so repeated calls will just give the handle to the buffer. If we move the proposer logic here then repeated calls will propose again. We could refactor the code and add the section under a new API if that makes sense.

# [batch_size, num_speculative_tokens]
assert self.curr_batch_size != -1, "EagleProposer hasn't proposed yet."
return self._draft_token_ids_buffer[:self.curr_batch_size]

def get_draft_probs(self) -> torch.Tensor:
# [batch_size, num_speculative_tokens, vocab_size]
assert self.curr_num_tokens != -1, "EagleProposer hasn't proposed yet."
return self._draft_probs_buffer[:self.curr_num_tokens]

def propose(
self,
# [num_tokens]
Expand All @@ -48,9 +81,33 @@
# [batch_size, max_num_blocks_per_req]
block_table: torch.Tensor,
sampling_metadata: SamplingMetadata,
) -> tuple[torch.Tensor, torch.Tensor]:
):
# make sure that the buffers size has not changed
# by any future operation
assert self._draft_probs_buffer_shape.numel(
) == self._draft_probs_buffer.numel(
), "Size of self._draft_probs_buffer has been changed. "
"Make sure it remains the same."

assert self._draft_token_ids_buffer_shape.numel(
) == self._draft_token_ids_buffer.numel(
), "Size of self._draft_token_ids_buffer has been changed. "
"Make sure it remaiins the same."

# restore shape of buffers if it has been
# changed by any future operation
if (self._draft_probs_buffer.shape != self._draft_probs_buffer_shape):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why might this buffer be reshaped by any operation?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Since we are passing the buffer outside of the function, the caller gets the handle of this buffer and it might accidentally do a reshape. I am assuming that someone might in future do it since its not obvious that they shouldnt do it. The check will help in those cases. Let me know if this check should be removed

self._draft_probs_buffer.reshape(self._draft_probs_buffer_shape)

if (self._draft_token_ids_buffer.shape
!= self._draft_token_ids_buffer_shape):
self._draft_token_ids_buffer.reshape(
self._draft_token_ids_buffer_shape)

num_tokens = target_token_ids.shape[0]
batch_size = next_token_ids.shape[0]
self.curr_batch_size = batch_size
self.curr_num_tokens = batch_size * self.num_speculative_tokens
last_token_indices = cu_num_tokens[1:] - 1

input_ids = torch.empty_like(target_token_ids)
Expand Down Expand Up @@ -91,26 +148,27 @@
)
sample_hidden_states = hidden_states[last_token_indices]
logits = self.model.compute_logits(sample_hidden_states, None)
draft_token_ids, draft_probs = compute_probs_and_sample_next_token(
logits, sampling_metadata)
compute_probs_and_sample_next_token(logits, sampling_metadata,
0,
self.num_speculative_tokens,
batch_size,
self.arange,
self._draft_token_ids_buffer,
self._draft_probs_buffer)

# Early exit if there is only one draft token to be generated.
if self.num_speculative_tokens == 1:
# [batch_size, 1] and [batch_size, 1, vocab_size]
return draft_token_ids.view(-1, 1), draft_probs.unsqueeze(dim=1)
return

# Generate the remaining draft tokens.
draft_token_ids_list = [draft_token_ids]
draft_probs_list = [draft_probs]

positions = target_positions[last_token_indices]
hidden_states = sample_hidden_states
attn_metadata.num_actual_tokens = batch_size
attn_metadata.max_query_len = 1
attn_metadata.query_start_loc = self.arange[:batch_size + 1]
for _ in range(self.num_speculative_tokens - 1):
for speculative_token_idx in range(self.num_speculative_tokens - 1):
# Update the inputs.
input_ids = draft_token_ids_list[-1]
input_ids = self._draft_token_ids_buffer[:batch_size, speculative_token_idx]
positions += 1
attn_metadata.max_seq_len += 1
attn_metadata.seq_lens += 1
Expand All @@ -130,16 +188,13 @@
positions=positions,
)
logits = self.model.compute_logits(hidden_states, None)
draft_token_ids, probs = compute_probs_and_sample_next_token(
logits, sampling_metadata)
draft_token_ids_list.append(draft_token_ids)
draft_probs_list.append(probs)

# [batch_size, num_speculative_tokens]
draft_token_ids = torch.stack(draft_token_ids_list, dim=1)
# [batch_size, num_speculative_tokens, vocab_size]
draft_probs = torch.stack(draft_probs_list, dim=1)
return draft_token_ids, draft_probs
compute_probs_and_sample_next_token(logits, sampling_metadata,
speculative_token_idx + 1,
self.num_speculative_tokens,
batch_size,
self.arange,
self._draft_token_ids_buffer,
self._draft_probs_buffer)

@staticmethod
def prepare_inputs(
Expand Down Expand Up @@ -214,36 +269,64 @@
def compute_probs_and_sample_next_token(
logits: torch.Tensor,
sampling_metadata: SamplingMetadata,
) -> tuple[torch.Tensor, torch.Tensor]:
# index of the speculative token among num_speculative_tokens
speculative_token_idx: int,
# max number of speculative tokens
num_speculative_tokens: int,
# current batch size
batch_size: int,
# [batch_size + 1]
arange: torch.Tensor,
# [batch_size, num_speculative_tokens]
draft_token_ids_buffer: torch.Tensor,
# [batch_size, num_speculative_tokens, vocab_size]
draft_probs_buffer: torch.Tensor,
):
# We pass in the entire preallocated buffers draft_token_ids_buffer
# and draft_probs_buffer and select the portion of the buffer that
# we need to fill in using batch_size and speculative_token_idx.
# This allows us to write in-place. If we passed in the specific
# tensors slices directly to func, i.e.,
# draft_token_ids_buffer[:batch_size, speculative_token_idx]
# as draft_token_ids, then draft_token_ids = logits.argmax(dim=-1)
# would create a new tensor and not allow in-place writes.

Check failure on line 293 in vllm/v1/spec_decode/eagle.py

View workflow job for this annotation

GitHub Actions / pre-commit

Ruff (E501)

vllm/v1/spec_decode/eagle.py:293:81: E501 Line too long (101 > 80)
draft_probs_buffer_indices = arange[:batch_size] * num_speculative_tokens + speculative_token_idx

if sampling_metadata.all_greedy:
# For greedy requests, draft_probs is not used in rejection sampling.
# Therefore, we can just return the logits.
probs = logits
next_token_ids = logits.argmax(dim=-1)
return next_token_ids, probs
draft_probs_buffer[draft_probs_buffer_indices] = logits.to(dtype=torch.float32)
draft_token_ids_buffer[:batch_size, speculative_token_idx] = logits.argmax(dim=-1)
return

is_greedy = sampling_metadata.temperature == -1
temperature = torch.where(is_greedy, 1.0, sampling_metadata.temperature)
logits.div_(temperature.view(-1, 1))
probs = logits.softmax(dim=-1, dtype=torch.float32)
draft_probs_buffer[draft_probs_buffer_indices] = logits.softmax(
dim=-1, dtype=torch.float32)

# NOTE(woosuk): Currently, we ignore most of the sampling parameters in
# generating the draft tokens. We only use the temperature. While this
# could degrade the acceptance rate, it does not affect the distribution
# of the generated tokens after rejection sampling.

# TODO(woosuk): Consider seeds.
q = torch.empty_like(probs)
q = torch.empty_like(draft_probs_buffer[draft_probs_buffer_indices])
q.exponential_()
next_token_ids = probs.div_(q).argmax(dim=-1).view(-1)
draft_token_ids_buffer[:batch_size, speculative_token_idx] = \
draft_probs_buffer[draft_probs_buffer_indices] \
.div_(q) \
.argmax(dim=-1) \
.view(-1)
if not sampling_metadata.all_random:
greedy_token_ids = probs.argmax(dim=-1)
next_token_ids = torch.where(
greedy_token_ids = draft_probs_buffer[draft_probs_buffer_indices].argmax(dim=-1)
draft_token_ids_buffer[:batch_size, speculative_token_idx] = \
torch.where(
is_greedy,
greedy_token_ids,
next_token_ids,
draft_token_ids_buffer[:batch_size, speculative_token_idx],
)
return next_token_ids, probs


@triton.jit
Expand Down
14 changes: 11 additions & 3 deletions vllm/v1/spec_decode/ngram_proposer.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
class NgramProposer:

def __init__(self, vllm_config: VllmConfig):
self._draft_token_ids = None
# Minimum length of the n-gram to match.
self.min_n = vllm_config.speculative_config.prompt_lookup_min
# Maximum length of the n-gram to match.
Expand All @@ -22,10 +23,16 @@ def __init__(self, vllm_config: VllmConfig):
# This usually takes less than 1 second.
self.propose(np.zeros(1024, dtype=np.int32))

def get_draft_token_ids(self):
return self._draft_token_ids

def get_draft_probs(self):
return None

def propose(
self,
context_token_ids: np.ndarray,
) -> Optional[np.ndarray]:
):
"""Proposes the next sequence of tokens based on n-gram pattern
matching in the context. The function finds matches of the last n
tokens in the previous context, and returns k tokens that followed
Expand Down Expand Up @@ -54,8 +61,9 @@ def propose(
for n in range(self.max_n, self.min_n - 1, -1):
result = _find_subarray_kmp(context_token_ids, n, self.k)
if result is not None:
return result
return None
self._draft_token_ids = result
return
return

def load_model(self, *args, **kwargs):
# No model to load.
Expand Down
Loading
Loading