-
-
Notifications
You must be signed in to change notification settings - Fork 7.7k
[V1][Spec Decode] Non greedy sample with EAGLE / Reduce memory allocation for Rejection Sampler #16077
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
[V1][Spec Decode] Non greedy sample with EAGLE / Reduce memory allocation for Rejection Sampler #16077
Changes from all commits
fbd1630
8aaf041
96bd40b
57aed3d
a901022
6fff2d0
660a6b0
e5fc75a
8e49750
f69ebc0
57bd14b
610cfaf
6f577a4
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -31,6 +31,39 @@ | |
device=device, | ||
dtype=torch.int32) | ||
|
||
max_batch_size = vllm_config.scheduler_config.max_num_seqs | ||
vocab_size = vllm_config.model_config.get_vocab_size() | ||
|
||
# setup buffers for draft token ids and probs to be | ||
# reused across steps for Rejection Sampling | ||
self.curr_num_tokens = -1 | ||
self.curr_batch_size = -1 | ||
|
||
# packed tensor for [bs, num_speculative_tokens] | ||
self._draft_token_ids_buffer = torch.zeros(max_batch_size, | ||
self.num_speculative_tokens, | ||
dtype=torch.long, | ||
device=device) | ||
|
||
# packed tensor for [num_tokens, vocab_size] | ||
self._draft_token_ids_buffer_shape = self._draft_token_ids_buffer.shape | ||
self._draft_probs_buffer = torch.zeros(max_batch_size * self.num_speculative_tokens, | ||
vocab_size, | ||
# TODO(ekagra): pass dtype | ||
dtype=torch.float32, | ||
device=device) | ||
self._draft_probs_buffer_shape = self._draft_probs_buffer.shape | ||
|
||
def get_draft_token_ids(self) -> torch.Tensor: | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Since we have the There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. the |
||
# [batch_size, num_speculative_tokens] | ||
assert self.curr_batch_size != -1, "EagleProposer hasn't proposed yet." | ||
return self._draft_token_ids_buffer[:self.curr_batch_size] | ||
|
||
def get_draft_probs(self) -> torch.Tensor: | ||
# [batch_size, num_speculative_tokens, vocab_size] | ||
assert self.curr_num_tokens != -1, "EagleProposer hasn't proposed yet." | ||
return self._draft_probs_buffer[:self.curr_num_tokens] | ||
|
||
def propose( | ||
self, | ||
# [num_tokens] | ||
|
@@ -48,9 +81,33 @@ | |
# [batch_size, max_num_blocks_per_req] | ||
block_table: torch.Tensor, | ||
sampling_metadata: SamplingMetadata, | ||
) -> tuple[torch.Tensor, torch.Tensor]: | ||
): | ||
# make sure that the buffers size has not changed | ||
# by any future operation | ||
assert self._draft_probs_buffer_shape.numel( | ||
) == self._draft_probs_buffer.numel( | ||
ekagra-ranjan marked this conversation as resolved.
Show resolved
Hide resolved
|
||
), "Size of self._draft_probs_buffer has been changed. " | ||
"Make sure it remains the same." | ||
|
||
assert self._draft_token_ids_buffer_shape.numel( | ||
) == self._draft_token_ids_buffer.numel( | ||
), "Size of self._draft_token_ids_buffer has been changed. " | ||
"Make sure it remaiins the same." | ||
|
||
# restore shape of buffers if it has been | ||
# changed by any future operation | ||
if (self._draft_probs_buffer.shape != self._draft_probs_buffer_shape): | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Why might this buffer be reshaped by any operation? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Since we are passing the buffer outside of the function, the caller gets the handle of this buffer and it might accidentally do a reshape. I am assuming that someone might in future do it since its not obvious that they shouldnt do it. The check will help in those cases. Let me know if this check should be removed |
||
self._draft_probs_buffer.reshape(self._draft_probs_buffer_shape) | ||
|
||
if (self._draft_token_ids_buffer.shape | ||
!= self._draft_token_ids_buffer_shape): | ||
self._draft_token_ids_buffer.reshape( | ||
self._draft_token_ids_buffer_shape) | ||
|
||
num_tokens = target_token_ids.shape[0] | ||
batch_size = next_token_ids.shape[0] | ||
self.curr_batch_size = batch_size | ||
self.curr_num_tokens = batch_size * self.num_speculative_tokens | ||
last_token_indices = cu_num_tokens[1:] - 1 | ||
|
||
input_ids = torch.empty_like(target_token_ids) | ||
|
@@ -91,26 +148,27 @@ | |
) | ||
sample_hidden_states = hidden_states[last_token_indices] | ||
logits = self.model.compute_logits(sample_hidden_states, None) | ||
draft_token_ids, draft_probs = compute_probs_and_sample_next_token( | ||
logits, sampling_metadata) | ||
compute_probs_and_sample_next_token(logits, sampling_metadata, | ||
0, | ||
self.num_speculative_tokens, | ||
batch_size, | ||
self.arange, | ||
self._draft_token_ids_buffer, | ||
self._draft_probs_buffer) | ||
|
||
# Early exit if there is only one draft token to be generated. | ||
if self.num_speculative_tokens == 1: | ||
# [batch_size, 1] and [batch_size, 1, vocab_size] | ||
return draft_token_ids.view(-1, 1), draft_probs.unsqueeze(dim=1) | ||
return | ||
|
||
# Generate the remaining draft tokens. | ||
draft_token_ids_list = [draft_token_ids] | ||
draft_probs_list = [draft_probs] | ||
|
||
positions = target_positions[last_token_indices] | ||
hidden_states = sample_hidden_states | ||
attn_metadata.num_actual_tokens = batch_size | ||
attn_metadata.max_query_len = 1 | ||
attn_metadata.query_start_loc = self.arange[:batch_size + 1] | ||
for _ in range(self.num_speculative_tokens - 1): | ||
for speculative_token_idx in range(self.num_speculative_tokens - 1): | ||
# Update the inputs. | ||
input_ids = draft_token_ids_list[-1] | ||
input_ids = self._draft_token_ids_buffer[:batch_size, speculative_token_idx] | ||
positions += 1 | ||
attn_metadata.max_seq_len += 1 | ||
attn_metadata.seq_lens += 1 | ||
|
@@ -130,16 +188,13 @@ | |
positions=positions, | ||
) | ||
logits = self.model.compute_logits(hidden_states, None) | ||
draft_token_ids, probs = compute_probs_and_sample_next_token( | ||
logits, sampling_metadata) | ||
draft_token_ids_list.append(draft_token_ids) | ||
draft_probs_list.append(probs) | ||
|
||
# [batch_size, num_speculative_tokens] | ||
draft_token_ids = torch.stack(draft_token_ids_list, dim=1) | ||
# [batch_size, num_speculative_tokens, vocab_size] | ||
draft_probs = torch.stack(draft_probs_list, dim=1) | ||
return draft_token_ids, draft_probs | ||
compute_probs_and_sample_next_token(logits, sampling_metadata, | ||
speculative_token_idx + 1, | ||
self.num_speculative_tokens, | ||
batch_size, | ||
self.arange, | ||
self._draft_token_ids_buffer, | ||
self._draft_probs_buffer) | ||
|
||
@staticmethod | ||
def prepare_inputs( | ||
|
@@ -214,36 +269,64 @@ | |
def compute_probs_and_sample_next_token( | ||
logits: torch.Tensor, | ||
sampling_metadata: SamplingMetadata, | ||
) -> tuple[torch.Tensor, torch.Tensor]: | ||
# index of the speculative token among num_speculative_tokens | ||
speculative_token_idx: int, | ||
# max number of speculative tokens | ||
num_speculative_tokens: int, | ||
# current batch size | ||
batch_size: int, | ||
# [batch_size + 1] | ||
arange: torch.Tensor, | ||
# [batch_size, num_speculative_tokens] | ||
draft_token_ids_buffer: torch.Tensor, | ||
# [batch_size, num_speculative_tokens, vocab_size] | ||
draft_probs_buffer: torch.Tensor, | ||
): | ||
# We pass in the entire preallocated buffers draft_token_ids_buffer | ||
# and draft_probs_buffer and select the portion of the buffer that | ||
# we need to fill in using batch_size and speculative_token_idx. | ||
# This allows us to write in-place. If we passed in the specific | ||
# tensors slices directly to func, i.e., | ||
# draft_token_ids_buffer[:batch_size, speculative_token_idx] | ||
# as draft_token_ids, then draft_token_ids = logits.argmax(dim=-1) | ||
# would create a new tensor and not allow in-place writes. | ||
|
||
draft_probs_buffer_indices = arange[:batch_size] * num_speculative_tokens + speculative_token_idx | ||
|
||
if sampling_metadata.all_greedy: | ||
# For greedy requests, draft_probs is not used in rejection sampling. | ||
# Therefore, we can just return the logits. | ||
probs = logits | ||
next_token_ids = logits.argmax(dim=-1) | ||
return next_token_ids, probs | ||
draft_probs_buffer[draft_probs_buffer_indices] = logits.to(dtype=torch.float32) | ||
draft_token_ids_buffer[:batch_size, speculative_token_idx] = logits.argmax(dim=-1) | ||
return | ||
|
||
is_greedy = sampling_metadata.temperature == -1 | ||
temperature = torch.where(is_greedy, 1.0, sampling_metadata.temperature) | ||
logits.div_(temperature.view(-1, 1)) | ||
probs = logits.softmax(dim=-1, dtype=torch.float32) | ||
draft_probs_buffer[draft_probs_buffer_indices] = logits.softmax( | ||
dim=-1, dtype=torch.float32) | ||
|
||
# NOTE(woosuk): Currently, we ignore most of the sampling parameters in | ||
# generating the draft tokens. We only use the temperature. While this | ||
# could degrade the acceptance rate, it does not affect the distribution | ||
# of the generated tokens after rejection sampling. | ||
|
||
# TODO(woosuk): Consider seeds. | ||
q = torch.empty_like(probs) | ||
q = torch.empty_like(draft_probs_buffer[draft_probs_buffer_indices]) | ||
q.exponential_() | ||
next_token_ids = probs.div_(q).argmax(dim=-1).view(-1) | ||
draft_token_ids_buffer[:batch_size, speculative_token_idx] = \ | ||
draft_probs_buffer[draft_probs_buffer_indices] \ | ||
.div_(q) \ | ||
.argmax(dim=-1) \ | ||
.view(-1) | ||
if not sampling_metadata.all_random: | ||
greedy_token_ids = probs.argmax(dim=-1) | ||
next_token_ids = torch.where( | ||
greedy_token_ids = draft_probs_buffer[draft_probs_buffer_indices].argmax(dim=-1) | ||
draft_token_ids_buffer[:batch_size, speculative_token_idx] = \ | ||
torch.where( | ||
is_greedy, | ||
greedy_token_ids, | ||
next_token_ids, | ||
draft_token_ids_buffer[:batch_size, speculative_token_idx], | ||
) | ||
return next_token_ids, probs | ||
|
||
|
||
@triton.jit | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Just want to point this out, we might need to fix it. For the
draft_probs_buffer
, it has size (plug in numbers of llama3-8B):256 * 10 * 128256 * 4 / 1024 / 1024 = 1.3G
It has a low probability that this might trigger OOM if we do this after vLLM preallocates all memory for kv cache. But it should not be a big problem.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Is there a way to allocate this before vLLM preallocates memory for KVC?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The preallocation of RS buffer is happening here which is when gpuModelRunner is created. Could you point me to which line of code computes the available GPU memory and allocated the KVC on that?