From fbd1630bc476f54d929d1893970982768047c8d7 Mon Sep 17 00:00:00 2001 From: ekagra <3116519+ekagra-ranjan@users.noreply.github.com> Date: Fri, 4 Apr 2025 22:31:00 +0000 Subject: [PATCH 01/11] use data buffer and avoid data copy via torch stack --- vllm/v1/spec_decode/eagle.py | 53 ++++++++++++++++++++++-------- vllm/v1/worker/gpu_model_runner.py | 6 +++- 2 files changed, 44 insertions(+), 15 deletions(-) diff --git a/vllm/v1/spec_decode/eagle.py b/vllm/v1/spec_decode/eagle.py index 3aaaf34bc79..32061e76c3b 100644 --- a/vllm/v1/spec_decode/eagle.py +++ b/vllm/v1/spec_decode/eagle.py @@ -24,6 +24,22 @@ def __init__( self.arange = torch.arange(vllm_config.scheduler_config.max_num_seqs, device=device) + max_batch_size = vllm_config.scheduler_config.max_num_seqs + vocab_size = vllm_config.model_config.get_vocab_size() + + # setup buffers for draft token ids and probs to be reused across steps for Rejection Sampling + self.draft_token_ids_buffer = torch.zeros(max_batch_size, + self.num_speculative_tokens, + dtype=torch.int32, + device=device + ) + self.draft_probs_buffer = torch.zeros(max_batch_size, + self.num_speculative_tokens, + vocab_size, + dtype=torch.float32, + device=device + ) + def propose( self, # [num_tokens] @@ -83,7 +99,10 @@ def propose( sample_hidden_states = hidden_states[last_token_indices] logits = self.model.compute_logits(sample_hidden_states, None) draft_token_ids, draft_probs = compute_probs_and_sample_next_token( - logits, sampling_metadata) + logits, sampling_metadata, + 0, + self.draft_token_ids_buffer, + self.draft_probs_buffer) # Early exit if there is only one draft token to be generated. if self.num_speculative_tokens == 1: @@ -99,7 +118,7 @@ def propose( attn_metadata.num_actual_tokens = batch_size attn_metadata.max_query_len = 1 attn_metadata.query_start_loc = self.arange[:batch_size] - for _ in range(self.num_speculative_tokens - 1): + for speculative_token_idx in range(self.num_speculative_tokens - 1): # Update the inputs. input_ids = draft_token_ids_list[-1] positions += 1 @@ -121,10 +140,13 @@ def propose( positions=positions, ) logits = self.model.compute_logits(hidden_states, None) - draft_token_ids, probs = compute_probs_and_sample_next_token( - logits, sampling_metadata) - draft_token_ids_list.append(draft_token_ids) - draft_probs_list.append(probs) + compute_probs_and_sample_next_token( + logits, sampling_metadata, + speculative_token_idx + 1, + self.draft_token_ids_buffer, + self.draft_probs_buffer) + # draft_token_ids_list.append(draft_token_ids) + # draft_probs_list.append(probs) # [batch_size, num_speculative_tokens] draft_token_ids = torch.stack(draft_token_ids_list, dim=1) @@ -203,18 +225,22 @@ def forward( def compute_probs_and_sample_next_token( logits: torch.Tensor, sampling_metadata: SamplingMetadata, + speculative_token_idx: int, + draft_token_ids_buffer: torch.Tensor, # [batch_size, num_speculative_tokens] + draft_probs_buffer: torch.Tensor, # [batch_size, num_speculative_tokens, vocab_size] ) -> tuple[torch.Tensor, torch.Tensor]: if sampling_metadata.all_greedy: # For greedy requests, draft_probs is not used in rejection sampling. # Therefore, we can just return the logits. - probs = logits - next_token_ids = logits.argmax(dim=-1) - return next_token_ids, probs + draft_probs_buffer[:, speculative_token_idx, :] = logits + draft_token_ids_buffer[:, speculative_token_idx] = logits.argmax(dim=-1) + # return next_token_ids, draft_probs is_greedy = sampling_metadata.temperature == -1 temperature = torch.where(is_greedy, 1.0, sampling_metadata.temperature) logits.div_(temperature.view(-1, 1)) - probs = logits.softmax(dim=-1, dtype=torch.float32) + # probs = logits.softmax(dim=-1, dtype=torch.float32) # REMOVE + draft_probs_buffer[:, speculative_token_idx, :] = logits.softmax(dim=-1, dtype=torch.float32) # NOTE(woosuk): Currently, we ignore most of the sampling parameters in # generating the draft tokens. We only use the temperature. While this @@ -222,17 +248,16 @@ def compute_probs_and_sample_next_token( # of the generated tokens after rejection sampling. # TODO(woosuk): Consider seeds. - q = torch.empty_like(probs) + q = torch.empty_like(draft_probs_buffer[:, speculative_token_idx, :]) q.exponential_() - next_token_ids = probs.div_(q).argmax(dim=-1).view(-1) + draft_token_ids_buffer[:, speculative_token_idx] = draft_probs_buffer[:, speculative_token_idx, :].div_(q).argmax(dim=-1).view(-1) if not sampling_metadata.all_random: - greedy_token_ids = probs.argmax(dim=-1) + greedy_token_ids = draft_probs_buffer[:, speculative_token_idx, :].argmax(dim=-1) next_token_ids = torch.where( is_greedy, greedy_token_ids, next_token_ids, ) - return next_token_ids, probs @triton.jit diff --git a/vllm/v1/worker/gpu_model_runner.py b/vllm/v1/worker/gpu_model_runner.py index 513806332ef..a0278c64d70 100644 --- a/vllm/v1/worker/gpu_model_runner.py +++ b/vllm/v1/worker/gpu_model_runner.py @@ -102,6 +102,10 @@ def __init__( self.block_size = cache_config.block_size self.max_model_len = model_config.max_model_len self.max_num_blocks_per_req = cdiv(self.max_model_len, self.block_size) + + # REMOVE + print(f"self.block_size: {self.block_size}, self.max_model_len: {self.max_model_len}, self.max_num_blocks_per_req: {self.max_num_blocks_per_req}") + self.max_num_tokens = scheduler_config.max_num_batched_tokens self.max_num_reqs = scheduler_config.max_num_seqs @@ -1197,7 +1201,7 @@ def execute_model( target_hidden_states = hidden_states[token_indices] target_slot_mapping = attn_metadata.slot_mapping[token_indices] - draft_token_ids, draft_probs = self.drafter.propose( + self.drafter.propose( target_token_ids=target_token_ids, target_positions=target_positions, target_hidden_states=target_hidden_states, From 8aaf0410a9c5ef78cea80d6cb8ede823577509bb Mon Sep 17 00:00:00 2001 From: ekagra <3116519+ekagra-ranjan@users.noreply.github.com> Date: Mon, 7 Apr 2025 21:24:09 +0000 Subject: [PATCH 02/11] add getter for buffer and reuse it across step in RS --- vllm/v1/spec_decode/eagle.py | 51 +++++++++++++++++++-------- vllm/v1/spec_decode/ngram_proposer.py | 6 ++++ vllm/v1/worker/gpu_model_runner.py | 10 ++---- 3 files changed, 44 insertions(+), 23 deletions(-) diff --git a/vllm/v1/spec_decode/eagle.py b/vllm/v1/spec_decode/eagle.py index 32061e76c3b..32023eba568 100644 --- a/vllm/v1/spec_decode/eagle.py +++ b/vllm/v1/spec_decode/eagle.py @@ -28,18 +28,29 @@ def __init__( vocab_size = vllm_config.model_config.get_vocab_size() # setup buffers for draft token ids and probs to be reused across steps for Rejection Sampling - self.draft_token_ids_buffer = torch.zeros(max_batch_size, + self.curr_batch_size = -1 + self._draft_token_ids_buffer = torch.zeros(max_batch_size, self.num_speculative_tokens, dtype=torch.int32, device=device ) - self.draft_probs_buffer = torch.zeros(max_batch_size, + self._draft_probs_buffer = torch.zeros(max_batch_size, self.num_speculative_tokens, vocab_size, dtype=torch.float32, device=device ) + def get_draft_token_ids(self) -> torch.Tensor: + # [batch_size, num_speculative_tokens] + assert self.curr_batch_size != -1, "EagleProposer hasn't proposed yet." + return self._draft_token_ids_buffer[:self.curr_batch_size, :] + + def get_draft_probs(self) -> torch.Tensor: + # [batch_size, num_speculative_tokens, vocab_size] + assert self.curr_batch_size != -1, "EagleProposer hasn't proposed yet." + return self._draft_probs_buffer[:self.curr_batch_size, :, :] + def propose( self, # [num_tokens] @@ -60,6 +71,7 @@ def propose( ) -> tuple[torch.Tensor, torch.Tensor]: num_tokens = target_token_ids.shape[0] batch_size = next_token_ids.shape[0] + self.curr_batch_size = batch_size last_token_indices = cu_num_tokens[1:] - 1 input_ids = torch.empty_like(target_token_ids) @@ -101,8 +113,9 @@ def propose( draft_token_ids, draft_probs = compute_probs_and_sample_next_token( logits, sampling_metadata, 0, - self.draft_token_ids_buffer, - self.draft_probs_buffer) + batch_size, + self._draft_token_ids_buffer, + self._draft_probs_buffer) # Early exit if there is only one draft token to be generated. if self.num_speculative_tokens == 1: @@ -143,16 +156,17 @@ def propose( compute_probs_and_sample_next_token( logits, sampling_metadata, speculative_token_idx + 1, - self.draft_token_ids_buffer, - self.draft_probs_buffer) + batch_size, + self._draft_token_ids_buffer, + self._draft_probs_buffer) # draft_token_ids_list.append(draft_token_ids) # draft_probs_list.append(probs) # [batch_size, num_speculative_tokens] - draft_token_ids = torch.stack(draft_token_ids_list, dim=1) + # draft_token_ids = torch.stack(draft_token_ids_list, dim=1) # [batch_size, num_speculative_tokens, vocab_size] - draft_probs = torch.stack(draft_probs_list, dim=1) - return draft_token_ids, draft_probs + # draft_probs = torch.stack(draft_probs_list, dim=1) + # return draft_token_ids, draft_probs @staticmethod def prepare_inputs( @@ -226,21 +240,28 @@ def compute_probs_and_sample_next_token( logits: torch.Tensor, sampling_metadata: SamplingMetadata, speculative_token_idx: int, + batch_size: int, draft_token_ids_buffer: torch.Tensor, # [batch_size, num_speculative_tokens] draft_probs_buffer: torch.Tensor, # [batch_size, num_speculative_tokens, vocab_size] ) -> tuple[torch.Tensor, torch.Tensor]: + # We pass in the entire preallocated buffers draft_token_ids_buffer and draft_probs_buffer + # and select the portion of the buffer that we need to fill in using batch_size and speculative_token_idx. + # This allows us to write in-place. If we passed in the specific tensors slices directly to func, i.e., + # draft_token_ids_buffer[:batch_size, speculative_token_idx] as draft_token_ids, then + # draft_token_ids = logits.argmax(dim=-1) would create a new tensor and not allow in-place writes. + if sampling_metadata.all_greedy: # For greedy requests, draft_probs is not used in rejection sampling. # Therefore, we can just return the logits. - draft_probs_buffer[:, speculative_token_idx, :] = logits - draft_token_ids_buffer[:, speculative_token_idx] = logits.argmax(dim=-1) + draft_probs_buffer[:batch_size, speculative_token_idx, :] = logits + draft_token_ids_buffer[:batch_size, speculative_token_idx] = logits.argmax(dim=-1) # return next_token_ids, draft_probs is_greedy = sampling_metadata.temperature == -1 temperature = torch.where(is_greedy, 1.0, sampling_metadata.temperature) logits.div_(temperature.view(-1, 1)) # probs = logits.softmax(dim=-1, dtype=torch.float32) # REMOVE - draft_probs_buffer[:, speculative_token_idx, :] = logits.softmax(dim=-1, dtype=torch.float32) + draft_probs_buffer[:batch_size, speculative_token_idx, :] = logits.softmax(dim=-1, dtype=torch.float32) # NOTE(woosuk): Currently, we ignore most of the sampling parameters in # generating the draft tokens. We only use the temperature. While this @@ -248,11 +269,11 @@ def compute_probs_and_sample_next_token( # of the generated tokens after rejection sampling. # TODO(woosuk): Consider seeds. - q = torch.empty_like(draft_probs_buffer[:, speculative_token_idx, :]) + q = torch.empty_like(draft_probs_buffer[:batch_size, speculative_token_idx, :]) q.exponential_() - draft_token_ids_buffer[:, speculative_token_idx] = draft_probs_buffer[:, speculative_token_idx, :].div_(q).argmax(dim=-1).view(-1) + draft_token_ids_buffer[:batch_size, speculative_token_idx] = draft_probs_buffer[:batch_size, speculative_token_idx, :].div_(q).argmax(dim=-1).view(-1) if not sampling_metadata.all_random: - greedy_token_ids = draft_probs_buffer[:, speculative_token_idx, :].argmax(dim=-1) + greedy_token_ids = draft_probs_buffer[:batch_size, speculative_token_idx, :].argmax(dim=-1) next_token_ids = torch.where( is_greedy, greedy_token_ids, diff --git a/vllm/v1/spec_decode/ngram_proposer.py b/vllm/v1/spec_decode/ngram_proposer.py index 8f6d20d11ff..7304a649fdb 100644 --- a/vllm/v1/spec_decode/ngram_proposer.py +++ b/vllm/v1/spec_decode/ngram_proposer.py @@ -12,6 +12,12 @@ class NgramProposer: def __init__(self, vllm_config: VllmConfig): self.vllm_config = vllm_config + def get_draft_token_ids(self): + return None + + def get_draft_probs(self): + return None + def propose( self, context_token_ids: np.ndarray, diff --git a/vllm/v1/worker/gpu_model_runner.py b/vllm/v1/worker/gpu_model_runner.py index a0278c64d70..b8265e26089 100644 --- a/vllm/v1/worker/gpu_model_runner.py +++ b/vllm/v1/worker/gpu_model_runner.py @@ -103,9 +103,6 @@ def __init__( self.max_model_len = model_config.max_model_len self.max_num_blocks_per_req = cdiv(self.max_model_len, self.block_size) - # REMOVE - print(f"self.block_size: {self.block_size}, self.max_model_len: {self.max_model_len}, self.max_num_blocks_per_req: {self.max_num_blocks_per_req}") - self.max_num_tokens = scheduler_config.max_num_batched_tokens self.max_num_reqs = scheduler_config.max_num_seqs @@ -1092,7 +1089,7 @@ def execute_model( target_logits = logits[spec_decode_metadata.target_logits_indices] output_token_ids = self.rejection_sampler( spec_decode_metadata, - None, # draft_probs + self.drafter.get_draft_probs(), target_logits, bonus_token_ids, sampling_metadata, @@ -1211,10 +1208,7 @@ def execute_model( block_table=attn_metadata.block_table, sampling_metadata=sampling_metadata, ) - spec_token_ids = draft_token_ids.tolist() - # TODO(woosuk): Cache draft_probs and use it for rejection sampling - # in the next step. - del draft_probs + spec_token_ids = self.drafter.get_draft_token_ids().tolist() return ModelRunnerOutput( req_ids=self.input_batch.req_ids, From 96bd40bd6601f1044ca4e47756198e8b4ad4e6ce Mon Sep 17 00:00:00 2001 From: ekagra <3116519+ekagra-ranjan@users.noreply.github.com> Date: Mon, 7 Apr 2025 21:29:15 +0000 Subject: [PATCH 03/11] cleanup comment --- vllm/v1/spec_decode/eagle.py | 10 ---------- vllm/v1/worker/gpu_model_runner.py | 1 - 2 files changed, 11 deletions(-) diff --git a/vllm/v1/spec_decode/eagle.py b/vllm/v1/spec_decode/eagle.py index 32023eba568..15d19bbd305 100644 --- a/vllm/v1/spec_decode/eagle.py +++ b/vllm/v1/spec_decode/eagle.py @@ -159,14 +159,6 @@ def propose( batch_size, self._draft_token_ids_buffer, self._draft_probs_buffer) - # draft_token_ids_list.append(draft_token_ids) - # draft_probs_list.append(probs) - - # [batch_size, num_speculative_tokens] - # draft_token_ids = torch.stack(draft_token_ids_list, dim=1) - # [batch_size, num_speculative_tokens, vocab_size] - # draft_probs = torch.stack(draft_probs_list, dim=1) - # return draft_token_ids, draft_probs @staticmethod def prepare_inputs( @@ -255,12 +247,10 @@ def compute_probs_and_sample_next_token( # Therefore, we can just return the logits. draft_probs_buffer[:batch_size, speculative_token_idx, :] = logits draft_token_ids_buffer[:batch_size, speculative_token_idx] = logits.argmax(dim=-1) - # return next_token_ids, draft_probs is_greedy = sampling_metadata.temperature == -1 temperature = torch.where(is_greedy, 1.0, sampling_metadata.temperature) logits.div_(temperature.view(-1, 1)) - # probs = logits.softmax(dim=-1, dtype=torch.float32) # REMOVE draft_probs_buffer[:batch_size, speculative_token_idx, :] = logits.softmax(dim=-1, dtype=torch.float32) # NOTE(woosuk): Currently, we ignore most of the sampling parameters in diff --git a/vllm/v1/worker/gpu_model_runner.py b/vllm/v1/worker/gpu_model_runner.py index b8265e26089..97ff2cdf354 100644 --- a/vllm/v1/worker/gpu_model_runner.py +++ b/vllm/v1/worker/gpu_model_runner.py @@ -102,7 +102,6 @@ def __init__( self.block_size = cache_config.block_size self.max_model_len = model_config.max_model_len self.max_num_blocks_per_req = cdiv(self.max_model_len, self.block_size) - self.max_num_tokens = scheduler_config.max_num_batched_tokens self.max_num_reqs = scheduler_config.max_num_seqs From 57aed3d134e9019a32633079d3a4a6795e05213d Mon Sep 17 00:00:00 2001 From: ekagra <3116519+ekagra-ranjan@users.noreply.github.com> Date: Mon, 7 Apr 2025 22:38:42 +0000 Subject: [PATCH 04/11] linting --- vllm/v1/spec_decode/eagle.py | 109 ++++++++++++++++++++--------------- 1 file changed, 62 insertions(+), 47 deletions(-) diff --git a/vllm/v1/spec_decode/eagle.py b/vllm/v1/spec_decode/eagle.py index 15d19bbd305..f023e34e94c 100644 --- a/vllm/v1/spec_decode/eagle.py +++ b/vllm/v1/spec_decode/eagle.py @@ -27,19 +27,20 @@ def __init__( max_batch_size = vllm_config.scheduler_config.max_num_seqs vocab_size = vllm_config.model_config.get_vocab_size() - # setup buffers for draft token ids and probs to be reused across steps for Rejection Sampling + # setup buffers for draft token ids and probs to be + # reused across steps for Rejection Sampling self.curr_batch_size = -1 - self._draft_token_ids_buffer = torch.zeros(max_batch_size, - self.num_speculative_tokens, - dtype=torch.int32, - device=device - ) - self._draft_probs_buffer = torch.zeros(max_batch_size, - self.num_speculative_tokens, - vocab_size, - dtype=torch.float32, - device=device - ) + self._draft_token_ids_buffer = torch.zeros(max_batch_size, + self.num_speculative_tokens, + dtype=torch.int32, + device=device) + self._draft_token_ids_buffer_shape = self._draft_token_ids_buffer.shape + self._draft_probs_buffer = torch.zeros(max_batch_size, + self.num_speculative_tokens, + vocab_size, + dtype=torch.float32, + device=device) + self._draft_probs_buffer_shape = self._draft_probs_buffer.shape def get_draft_token_ids(self) -> torch.Tensor: # [batch_size, num_speculative_tokens] @@ -68,7 +69,17 @@ def propose( # [batch_size, max_num_blocks_per_req] block_table: torch.Tensor, sampling_metadata: SamplingMetadata, - ) -> tuple[torch.Tensor, torch.Tensor]: + ): + # make sure that the buffers size has not changed by any future operation + assert self._draft_probs_buffer_shape.numel() == self._draft_probs_buffer.numel(), "Size of self._draft_probs_buffer has been changed. Make sure it remaiins the same." + assert self._draft_token_ids_buffer_shape.numel() == self._draft_token_ids_buffer.numel(), "Size of self._draft_token_ids_buffer has been changed. Make sure it remaiins the same." + # restore shape of buffers if it has been changed by any future operation + if (self._draft_probs_buffer.shape != self._draft_probs_buffer_shape): + self._draft_probs_buffer.reshape(self._draft_probs_buffer_shape) + + if (self._draft_token_ids_buffer.shape != self._draft_token_ids_buffer_shape): + self._draft_token_ids_buffer.reshape(self._draft_token_ids_buffer_shape) + num_tokens = target_token_ids.shape[0] batch_size = next_token_ids.shape[0] self.curr_batch_size = batch_size @@ -110,22 +121,15 @@ def propose( ) sample_hidden_states = hidden_states[last_token_indices] logits = self.model.compute_logits(sample_hidden_states, None) - draft_token_ids, draft_probs = compute_probs_and_sample_next_token( - logits, sampling_metadata, - 0, - batch_size, - self._draft_token_ids_buffer, - self._draft_probs_buffer) + compute_probs_and_sample_next_token( + logits, sampling_metadata, 0, batch_size, + self._draft_token_ids_buffer, self._draft_probs_buffer) # Early exit if there is only one draft token to be generated. if self.num_speculative_tokens == 1: - # [batch_size, 1] and [batch_size, 1, vocab_size] - return draft_token_ids.view(-1, 1), draft_probs.unsqueeze(dim=1) + return # Generate the remaining draft tokens. - draft_token_ids_list = [draft_token_ids] - draft_probs_list = [draft_probs] - positions = target_positions[last_token_indices] hidden_states = sample_hidden_states attn_metadata.num_actual_tokens = batch_size @@ -133,7 +137,8 @@ def propose( attn_metadata.query_start_loc = self.arange[:batch_size] for speculative_token_idx in range(self.num_speculative_tokens - 1): # Update the inputs. - input_ids = draft_token_ids_list[-1] + input_ids = self._draft_token_ids_buffer[:batch_size, + speculative_token_idx] positions += 1 attn_metadata.max_seq_len += 1 attn_metadata.seq_lens += 1 @@ -153,12 +158,11 @@ def propose( positions=positions, ) logits = self.model.compute_logits(hidden_states, None) - compute_probs_and_sample_next_token( - logits, sampling_metadata, - speculative_token_idx + 1, - batch_size, - self._draft_token_ids_buffer, - self._draft_probs_buffer) + compute_probs_and_sample_next_token(logits, sampling_metadata, + speculative_token_idx + 1, + batch_size, + self._draft_token_ids_buffer, + self._draft_probs_buffer) @staticmethod def prepare_inputs( @@ -229,29 +233,33 @@ def forward( # FIXME(woosuk): The logic here is duplicated with the main sampling code. # We should refactor this to reuse the same sampling implementation. def compute_probs_and_sample_next_token( - logits: torch.Tensor, - sampling_metadata: SamplingMetadata, - speculative_token_idx: int, - batch_size: int, - draft_token_ids_buffer: torch.Tensor, # [batch_size, num_speculative_tokens] - draft_probs_buffer: torch.Tensor, # [batch_size, num_speculative_tokens, vocab_size] -) -> tuple[torch.Tensor, torch.Tensor]: - # We pass in the entire preallocated buffers draft_token_ids_buffer and draft_probs_buffer + logits: torch.Tensor, + sampling_metadata: SamplingMetadata, + speculative_token_idx: int, + batch_size: int, + # [batch_size, num_speculative_tokens] + draft_token_ids_buffer: torch.Tensor, + # [batch_size, num_speculative_tokens, vocab_size] + draft_probs_buffer: torch.Tensor, +): + # We pass in the entire preallocated buffers draft_token_ids_buffer and draft_probs_buffer # and select the portion of the buffer that we need to fill in using batch_size and speculative_token_idx. # This allows us to write in-place. If we passed in the specific tensors slices directly to func, i.e., - # draft_token_ids_buffer[:batch_size, speculative_token_idx] as draft_token_ids, then + # draft_token_ids_buffer[:batch_size, speculative_token_idx] as draft_token_ids, then # draft_token_ids = logits.argmax(dim=-1) would create a new tensor and not allow in-place writes. if sampling_metadata.all_greedy: # For greedy requests, draft_probs is not used in rejection sampling. # Therefore, we can just return the logits. draft_probs_buffer[:batch_size, speculative_token_idx, :] = logits - draft_token_ids_buffer[:batch_size, speculative_token_idx] = logits.argmax(dim=-1) + draft_token_ids_buffer[:batch_size, + speculative_token_idx] = logits.argmax(dim=-1) is_greedy = sampling_metadata.temperature == -1 temperature = torch.where(is_greedy, 1.0, sampling_metadata.temperature) logits.div_(temperature.view(-1, 1)) - draft_probs_buffer[:batch_size, speculative_token_idx, :] = logits.softmax(dim=-1, dtype=torch.float32) + draft_probs_buffer[:batch_size, speculative_token_idx, :] = logits.softmax( + dim=-1, dtype=torch.float32) # NOTE(woosuk): Currently, we ignore most of the sampling parameters in # generating the draft tokens. We only use the temperature. While this @@ -259,15 +267,22 @@ def compute_probs_and_sample_next_token( # of the generated tokens after rejection sampling. # TODO(woosuk): Consider seeds. - q = torch.empty_like(draft_probs_buffer[:batch_size, speculative_token_idx, :]) + q = torch.empty_like(draft_probs_buffer[:batch_size, + speculative_token_idx, :]) q.exponential_() - draft_token_ids_buffer[:batch_size, speculative_token_idx] = draft_probs_buffer[:batch_size, speculative_token_idx, :].div_(q).argmax(dim=-1).view(-1) + draft_token_ids_buffer[:batch_size, speculative_token_idx] = \ + draft_probs_buffer[:batch_size, speculative_token_idx, :] \ + .div_(q) \ + .argmax(dim=-1) \ + .view(-1) if not sampling_metadata.all_random: - greedy_token_ids = draft_probs_buffer[:batch_size, speculative_token_idx, :].argmax(dim=-1) - next_token_ids = torch.where( + greedy_token_ids = draft_probs_buffer[:batch_size, + speculative_token_idx, :].argmax( + dim=-1) + draft_token_ids_buffer[:batch_size, speculative_token_idx] = torch.where( is_greedy, greedy_token_ids, - next_token_ids, + draft_token_ids_buffer[:batch_size, speculative_token_idx], ) From a90102210d4f56ff43028c3f7e5433942f9e130b Mon Sep 17 00:00:00 2001 From: ekagra <3116519+ekagra-ranjan@users.noreply.github.com> Date: Mon, 7 Apr 2025 22:44:26 +0000 Subject: [PATCH 05/11] linting Signed-off-by: ekagra <3116519+ekagra-ranjan@users.noreply.github.com> --- vllm/v1/spec_decode/eagle.py | 67 ++++++++++++++++++++++-------------- 1 file changed, 42 insertions(+), 25 deletions(-) diff --git a/vllm/v1/spec_decode/eagle.py b/vllm/v1/spec_decode/eagle.py index f023e34e94c..ae8f87607e2 100644 --- a/vllm/v1/spec_decode/eagle.py +++ b/vllm/v1/spec_decode/eagle.py @@ -27,7 +27,7 @@ def __init__( max_batch_size = vllm_config.scheduler_config.max_num_seqs vocab_size = vllm_config.model_config.get_vocab_size() - # setup buffers for draft token ids and probs to be + # setup buffers for draft token ids and probs to be # reused across steps for Rejection Sampling self.curr_batch_size = -1 self._draft_token_ids_buffer = torch.zeros(max_batch_size, @@ -70,15 +70,27 @@ def propose( block_table: torch.Tensor, sampling_metadata: SamplingMetadata, ): - # make sure that the buffers size has not changed by any future operation - assert self._draft_probs_buffer_shape.numel() == self._draft_probs_buffer.numel(), "Size of self._draft_probs_buffer has been changed. Make sure it remaiins the same." - assert self._draft_token_ids_buffer_shape.numel() == self._draft_token_ids_buffer.numel(), "Size of self._draft_token_ids_buffer has been changed. Make sure it remaiins the same." - # restore shape of buffers if it has been changed by any future operation + # make sure that the buffers size has not changed + # by any future operation + assert self._draft_probs_buffer_shape.numel( + ) == self._draft_probs_buffer.numel( + ), "Size of self._draft_probs_buffer has been changed. " + "Make sure it remaiins the same." + + assert self._draft_token_ids_buffer_shape.numel( + ) == self._draft_token_ids_buffer.numel( + ), "Size of self._draft_token_ids_buffer has been changed. " + "Make sure it remaiins the same." + + # restore shape of buffers if it has been + # changed by any future operation if (self._draft_probs_buffer.shape != self._draft_probs_buffer_shape): self._draft_probs_buffer.reshape(self._draft_probs_buffer_shape) - if (self._draft_token_ids_buffer.shape != self._draft_token_ids_buffer_shape): - self._draft_token_ids_buffer.reshape(self._draft_token_ids_buffer_shape) + if (self._draft_token_ids_buffer.shape + != self._draft_token_ids_buffer_shape): + self._draft_token_ids_buffer.reshape( + self._draft_token_ids_buffer_shape) num_tokens = target_token_ids.shape[0] batch_size = next_token_ids.shape[0] @@ -121,13 +133,14 @@ def propose( ) sample_hidden_states = hidden_states[last_token_indices] logits = self.model.compute_logits(sample_hidden_states, None) - compute_probs_and_sample_next_token( - logits, sampling_metadata, 0, batch_size, - self._draft_token_ids_buffer, self._draft_probs_buffer) + compute_probs_and_sample_next_token(logits, sampling_metadata, 0, + batch_size, + self._draft_token_ids_buffer, + self._draft_probs_buffer) # Early exit if there is only one draft token to be generated. if self.num_speculative_tokens == 1: - return + return # Generate the remaining draft tokens. positions = target_positions[last_token_indices] @@ -233,20 +246,23 @@ def forward( # FIXME(woosuk): The logic here is duplicated with the main sampling code. # We should refactor this to reuse the same sampling implementation. def compute_probs_and_sample_next_token( - logits: torch.Tensor, - sampling_metadata: SamplingMetadata, - speculative_token_idx: int, - batch_size: int, - # [batch_size, num_speculative_tokens] - draft_token_ids_buffer: torch.Tensor, - # [batch_size, num_speculative_tokens, vocab_size] - draft_probs_buffer: torch.Tensor, + logits: torch.Tensor, + sampling_metadata: SamplingMetadata, + speculative_token_idx: int, + batch_size: int, + # [batch_size, num_speculative_tokens] + draft_token_ids_buffer: torch.Tensor, + # [batch_size, num_speculative_tokens, vocab_size] + draft_probs_buffer: torch.Tensor, ): - # We pass in the entire preallocated buffers draft_token_ids_buffer and draft_probs_buffer - # and select the portion of the buffer that we need to fill in using batch_size and speculative_token_idx. - # This allows us to write in-place. If we passed in the specific tensors slices directly to func, i.e., - # draft_token_ids_buffer[:batch_size, speculative_token_idx] as draft_token_ids, then - # draft_token_ids = logits.argmax(dim=-1) would create a new tensor and not allow in-place writes. + # We pass in the entire preallocated buffers draft_token_ids_buffer + # and draft_probs_buffer and select the portion of the buffer that + # we need to fill in using batch_size and speculative_token_idx. + # This allows us to write in-place. If we passed in the specific + # tensors slices directly to func, i.e., + # draft_token_ids_buffer[:batch_size, speculative_token_idx] + # as draft_token_ids, then draft_token_ids = logits.argmax(dim=-1) + # would create a new tensor and not allow in-place writes. if sampling_metadata.all_greedy: # For greedy requests, draft_probs is not used in rejection sampling. @@ -279,7 +295,8 @@ def compute_probs_and_sample_next_token( greedy_token_ids = draft_probs_buffer[:batch_size, speculative_token_idx, :].argmax( dim=-1) - draft_token_ids_buffer[:batch_size, speculative_token_idx] = torch.where( + draft_token_ids_buffer[:batch_size, speculative_token_idx] = \ + torch.where( is_greedy, greedy_token_ids, draft_token_ids_buffer[:batch_size, speculative_token_idx], From 6fff2d0f7db56f572aa3b4df61109fb701f4f50f Mon Sep 17 00:00:00 2001 From: ekagra <3116519+ekagra-ranjan@users.noreply.github.com> Date: Tue, 8 Apr 2025 17:39:08 +0000 Subject: [PATCH 06/11] make ngram have same interface as eagle Signed-off-by: ekagra <3116519+ekagra-ranjan@users.noreply.github.com> --- vllm/v1/spec_decode/ngram_proposer.py | 10 ++++++---- vllm/v1/worker/gpu_model_runner.py | 3 ++- 2 files changed, 8 insertions(+), 5 deletions(-) diff --git a/vllm/v1/spec_decode/ngram_proposer.py b/vllm/v1/spec_decode/ngram_proposer.py index 7304a649fdb..72a380ccebe 100644 --- a/vllm/v1/spec_decode/ngram_proposer.py +++ b/vllm/v1/spec_decode/ngram_proposer.py @@ -11,9 +11,10 @@ class NgramProposer: def __init__(self, vllm_config: VllmConfig): self.vllm_config = vllm_config + self._draft_token_ids = None def get_draft_token_ids(self): - return None + return self._draft_token_ids def get_draft_probs(self): return None @@ -24,7 +25,7 @@ def propose( min_n: int, max_n: int, k: int, - ) -> Optional[np.ndarray]: + ): """Proposes the next sequence of tokens based on n-gram pattern matching in the context. The function finds matches of the last n tokens in the previous context, and returns k tokens that followed @@ -58,8 +59,9 @@ def propose( for n in range(max_n, min_n - 1, -1): result = _find_subarray_kmp(context_token_ids, n, k) if result is not None: - return result - return None + self._draft_token_ids = result + return + return def load_model(self, *args, **kwargs): # No model to load. diff --git a/vllm/v1/worker/gpu_model_runner.py b/vllm/v1/worker/gpu_model_runner.py index 97ff2cdf354..92098d0e6b4 100644 --- a/vllm/v1/worker/gpu_model_runner.py +++ b/vllm/v1/worker/gpu_model_runner.py @@ -1242,12 +1242,13 @@ def generate_draft_token_ids( start_idx = self.input_batch.num_tokens_no_spec[i] end_idx = start_idx + num_sampled_ids self.input_batch.token_ids_cpu[i, start_idx:end_idx] = sampled_ids - drafter_output = self.drafter.propose( + self.drafter.propose( self.input_batch.token_ids_cpu[i, :end_idx], self.speculative_config.prompt_lookup_min, self.speculative_config.prompt_lookup_max, self.speculative_config.num_speculative_tokens, ) + drafter_output = self.drafter.get_draft_token_ids() if drafter_output is None or len(drafter_output) == 0: draft_token_ids.append([]) else: From 8e497507fff8922626fbcb9c93702669103d7bca Mon Sep 17 00:00:00 2001 From: Ekagra Ranjan <3116519+ekagra-ranjan@users.noreply.github.com> Date: Fri, 11 Apr 2025 11:40:09 -0400 Subject: [PATCH 07/11] Update vllm/v1/spec_decode/eagle.py Co-authored-by: Lily Liu --- vllm/v1/spec_decode/eagle.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/vllm/v1/spec_decode/eagle.py b/vllm/v1/spec_decode/eagle.py index 54638247348..631c3df6510 100644 --- a/vllm/v1/spec_decode/eagle.py +++ b/vllm/v1/spec_decode/eagle.py @@ -82,7 +82,7 @@ def propose( assert self._draft_probs_buffer_shape.numel( ) == self._draft_probs_buffer.numel( ), "Size of self._draft_probs_buffer has been changed. " - "Make sure it remaiins the same." + "Make sure it remains the same." assert self._draft_token_ids_buffer_shape.numel( ) == self._draft_token_ids_buffer.numel( From f69ebc050b41ff4aa94b345e852b69df6c72a1dd Mon Sep 17 00:00:00 2001 From: ekagra <3116519+ekagra-ranjan@users.noreply.github.com> Date: Fri, 11 Apr 2025 16:13:38 +0000 Subject: [PATCH 08/11] fix return when greedy --- vllm/v1/spec_decode/eagle.py | 1 + 1 file changed, 1 insertion(+) diff --git a/vllm/v1/spec_decode/eagle.py b/vllm/v1/spec_decode/eagle.py index 631c3df6510..5b07be7b43d 100644 --- a/vllm/v1/spec_decode/eagle.py +++ b/vllm/v1/spec_decode/eagle.py @@ -281,6 +281,7 @@ def compute_probs_and_sample_next_token( draft_probs_buffer[:batch_size, speculative_token_idx, :] = logits draft_token_ids_buffer[:batch_size, speculative_token_idx] = logits.argmax(dim=-1) + return is_greedy = sampling_metadata.temperature == -1 temperature = torch.where(is_greedy, 1.0, sampling_metadata.temperature) From 57bd14bf0d8eac168b4946d89109f9867d019d3e Mon Sep 17 00:00:00 2001 From: ekagra <3116519+ekagra-ranjan@users.noreply.github.com> Date: Fri, 11 Apr 2025 21:05:53 +0000 Subject: [PATCH 09/11] pack draft probs --- vllm/v1/spec_decode/eagle.py | 54 +++++++++++++++++++++++------------- 1 file changed, 34 insertions(+), 20 deletions(-) diff --git a/vllm/v1/spec_decode/eagle.py b/vllm/v1/spec_decode/eagle.py index 5b07be7b43d..2cd0d230d0b 100644 --- a/vllm/v1/spec_decode/eagle.py +++ b/vllm/v1/spec_decode/eagle.py @@ -36,15 +36,20 @@ def __init__( # setup buffers for draft token ids and probs to be # reused across steps for Rejection Sampling + self.curr_num_tokens = -1 self.curr_batch_size = -1 - self._draft_token_ids_buffer = torch.zeros(max_batch_size, + + # packed tensor for [bs, num_speculative_tokens] + self._draft_token_ids_buffer = torch.zeros(max_batch_size, self.num_speculative_tokens, - dtype=torch.int32, + dtype=torch.long, device=device) + + # packed tensor for [num_tokens, vocab_size] self._draft_token_ids_buffer_shape = self._draft_token_ids_buffer.shape - self._draft_probs_buffer = torch.zeros(max_batch_size, - self.num_speculative_tokens, + self._draft_probs_buffer = torch.zeros(max_batch_size * self.num_speculative_tokens, vocab_size, + # TODO(ekagra): pass dtype dtype=torch.float32, device=device) self._draft_probs_buffer_shape = self._draft_probs_buffer.shape @@ -52,12 +57,12 @@ def __init__( def get_draft_token_ids(self) -> torch.Tensor: # [batch_size, num_speculative_tokens] assert self.curr_batch_size != -1, "EagleProposer hasn't proposed yet." - return self._draft_token_ids_buffer[:self.curr_batch_size, :] + return self._draft_token_ids_buffer[:self.curr_batch_size] def get_draft_probs(self) -> torch.Tensor: # [batch_size, num_speculative_tokens, vocab_size] - assert self.curr_batch_size != -1, "EagleProposer hasn't proposed yet." - return self._draft_probs_buffer[:self.curr_batch_size, :, :] + assert self.curr_num_tokens != -1, "EagleProposer hasn't proposed yet." + return self._draft_probs_buffer[:self.curr_num_tokens] def propose( self, @@ -102,6 +107,7 @@ def propose( num_tokens = target_token_ids.shape[0] batch_size = next_token_ids.shape[0] self.curr_batch_size = batch_size + self.curr_num_tokens = batch_size * self.num_speculative_tokens last_token_indices = cu_num_tokens[1:] - 1 input_ids = torch.empty_like(target_token_ids) @@ -142,8 +148,11 @@ def propose( ) sample_hidden_states = hidden_states[last_token_indices] logits = self.model.compute_logits(sample_hidden_states, None) - compute_probs_and_sample_next_token(logits, sampling_metadata, 0, + compute_probs_and_sample_next_token(logits, sampling_metadata, + 0, + self.num_speculative_tokens, batch_size, + self.arange, self._draft_token_ids_buffer, self._draft_probs_buffer) @@ -159,8 +168,7 @@ def propose( attn_metadata.query_start_loc = self.arange[:batch_size + 1] for speculative_token_idx in range(self.num_speculative_tokens - 1): # Update the inputs. - input_ids = self._draft_token_ids_buffer[:batch_size, - speculative_token_idx] + input_ids = self._draft_token_ids_buffer[:batch_size, speculative_token_idx] positions += 1 attn_metadata.max_seq_len += 1 attn_metadata.seq_lens += 1 @@ -182,7 +190,9 @@ def propose( logits = self.model.compute_logits(hidden_states, None) compute_probs_and_sample_next_token(logits, sampling_metadata, speculative_token_idx + 1, + self.num_speculative_tokens, batch_size, + self.arange, self._draft_token_ids_buffer, self._draft_probs_buffer) @@ -259,8 +269,14 @@ def load_model(self, target_model: nn.Module) -> None: def compute_probs_and_sample_next_token( logits: torch.Tensor, sampling_metadata: SamplingMetadata, + # index of the speculative token among num_speculative_tokens speculative_token_idx: int, + # max number of speculative tokens + num_speculative_tokens: int, + # current batch size batch_size: int, + # [batch_size + 1] + arange: torch.Tensor, # [batch_size, num_speculative_tokens] draft_token_ids_buffer: torch.Tensor, # [batch_size, num_speculative_tokens, vocab_size] @@ -275,18 +291,19 @@ def compute_probs_and_sample_next_token( # as draft_token_ids, then draft_token_ids = logits.argmax(dim=-1) # would create a new tensor and not allow in-place writes. + draft_probs_buffer_indices = arange[:batch_size] * num_speculative_tokens + speculative_token_idx + if sampling_metadata.all_greedy: # For greedy requests, draft_probs is not used in rejection sampling. # Therefore, we can just return the logits. - draft_probs_buffer[:batch_size, speculative_token_idx, :] = logits - draft_token_ids_buffer[:batch_size, - speculative_token_idx] = logits.argmax(dim=-1) + draft_probs_buffer[draft_probs_buffer_indices] = logits.to(dtype=torch.float32) + draft_token_ids_buffer[:batch_size, speculative_token_idx] = logits.argmax(dim=-1) return is_greedy = sampling_metadata.temperature == -1 temperature = torch.where(is_greedy, 1.0, sampling_metadata.temperature) logits.div_(temperature.view(-1, 1)) - draft_probs_buffer[:batch_size, speculative_token_idx, :] = logits.softmax( + draft_probs_buffer[draft_probs_buffer_indices] = logits.softmax( dim=-1, dtype=torch.float32) # NOTE(woosuk): Currently, we ignore most of the sampling parameters in @@ -295,18 +312,15 @@ def compute_probs_and_sample_next_token( # of the generated tokens after rejection sampling. # TODO(woosuk): Consider seeds. - q = torch.empty_like(draft_probs_buffer[:batch_size, - speculative_token_idx, :]) + q = torch.empty_like(draft_probs_buffer[draft_probs_buffer_indices]) q.exponential_() draft_token_ids_buffer[:batch_size, speculative_token_idx] = \ - draft_probs_buffer[:batch_size, speculative_token_idx, :] \ + draft_probs_buffer[draft_probs_buffer_indices] \ .div_(q) \ .argmax(dim=-1) \ .view(-1) if not sampling_metadata.all_random: - greedy_token_ids = draft_probs_buffer[:batch_size, - speculative_token_idx, :].argmax( - dim=-1) + greedy_token_ids = draft_probs_buffer[draft_probs_buffer_indices].argmax(dim=-1) draft_token_ids_buffer[:batch_size, speculative_token_idx] = \ torch.where( is_greedy, From 610cfaf88be9c64f2baab0400694d4739acbcb7a Mon Sep 17 00:00:00 2001 From: ekagra <3116519+ekagra-ranjan@users.noreply.github.com> Date: Fri, 11 Apr 2025 23:05:13 +0000 Subject: [PATCH 10/11] add log to eagle example --- examples/offline_inference/eagle.py | 17 +++++++++++++++++ 1 file changed, 17 insertions(+) diff --git a/examples/offline_inference/eagle.py b/examples/offline_inference/eagle.py index 453ae7b6f56..f15b37d5b66 100644 --- a/examples/offline_inference/eagle.py +++ b/examples/offline_inference/eagle.py @@ -45,6 +45,7 @@ def main(): parser.add_argument("--enable_chunked_prefill", action='store_true') parser.add_argument("--max_num_batched_tokens", type=int, default=2048) parser.add_argument("--temp", type=float, default=0) + parser.add_argument("--log_output_filename", type=str, default="eagle_output.txt") args = parser.parse_args() model_dir = "meta-llama/Meta-Llama-3-8B-Instruct" @@ -90,6 +91,22 @@ def main(): outputs = llm.generate(prompt_token_ids=prompt_ids, sampling_params=sampling_params) + # save output text in eagle_output.txt file for quality check + log_output_data = [] + for i, output in enumerate(outputs): + input_text = tokenizer.decode(output.prompt_token_ids) + log_output_data.append({ + "input": input_text, + "output": output.outputs[0].text + }) + + with open("eagle_output.txt", "w") as f: + f.write( + json.dumps(log_output_data, indent=4, ensure_ascii=False)) + print("-" * 50) + print(f"Output texts saved to {args.log_output_filename}") + print("-" * 50) + # calculate the average number of accepted tokens per forward pass, +1 is # to account for the token from the target model that's always going to be # accepted From 6f577a4fc0717ad5c64b526069f4fb7ba291b22e Mon Sep 17 00:00:00 2001 From: ekagra <3116519+ekagra-ranjan@users.noreply.github.com> Date: Fri, 11 Apr 2025 23:06:06 +0000 Subject: [PATCH 11/11] add non greedy to test --- tests/v1/e2e/test_spec_decode.py | 81 +++++++++++++++++++------------- 1 file changed, 49 insertions(+), 32 deletions(-) diff --git a/tests/v1/e2e/test_spec_decode.py b/tests/v1/e2e/test_spec_decode.py index 67371498059..409b31f4797 100644 --- a/tests/v1/e2e/test_spec_decode.py +++ b/tests/v1/e2e/test_spec_decode.py @@ -44,8 +44,13 @@ def test_prompts(): @pytest.fixture def sampling_config(): - # Only support greedy for now - return SamplingParams(temperature=0, max_tokens=10, ignore_eos=False) + # return SamplingParams(temperature=0, max_tokens=10, ignore_eos=False) + return [ SamplingParams(temperature=0, max_tokens=10, ignore_eos=False), + SamplingParams(temperature=0.1, max_tokens=10, ignore_eos=False), + SamplingParams(temperature=0.2, max_tokens=10, ignore_eos=False), + SamplingParams(temperature=0.3, max_tokens=10, ignore_eos=False), + # SamplingParams(temperature=1, top_p=0.75, max_tokens=10, ignore_eos=False), + ] @pytest.fixture @@ -72,7 +77,9 @@ def test_ngram_correctness( m.setenv("VLLM_USE_V1", "1") ref_llm = LLM(model=model_name, max_model_len=1024) - ref_outputs = ref_llm.chat(test_prompts, sampling_config) + ref_outputs = [] + for sampling_param in sampling_config: + ref_outputs.append(ref_llm.chat(test_prompts, sampling_param)) del ref_llm spec_llm = LLM( @@ -85,20 +92,22 @@ def test_ngram_correctness( }, max_model_len=1024, ) - spec_outputs = spec_llm.chat(test_prompts, sampling_config) - matches = 0 - misses = 0 - for ref_output, spec_output in zip(ref_outputs, spec_outputs): - if ref_output.outputs[0].text == spec_output.outputs[0].text: - matches += 1 - else: - misses += 1 - print(f"ref_output: {ref_output.outputs[0].text}") - print(f"spec_output: {spec_output.outputs[0].text}") - - # Heuristic: expect at least 70% of the prompts to match exactly - # Upon failure, inspect the outputs to check for inaccuracy. - assert matches > int(0.7 * len(ref_outputs)) + + for i, sampling_param in enumerate(sampling_config): + spec_output = spec_llm.chat(test_prompts, sampling_param) + matches = 0 + misses = 0 + for ref_output, spec_output in zip(ref_outputs[i], spec_output): + if ref_output.outputs[0].text == spec_output.outputs[0].text: + matches += 1 + else: + misses += 1 + print(f"ref_output: {ref_output.outputs[0].text}") + print(f"spec_output: {spec_output.outputs[0].text}") + + # Heuristic: expect at least 70% of the prompts to match exactly + # Upon failure, inspect the outputs to check for inaccuracy. + assert matches > int(0.7 * len(ref_outputs[i])) del spec_llm @@ -115,9 +124,13 @@ def test_eagle_correctness( ''' with monkeypatch.context() as m: m.setenv("VLLM_USE_V1", "1") + ref_outputs = [] + spec_outputs = [] ref_llm = LLM(model=model_name, max_model_len=1024) - ref_outputs = ref_llm.chat(test_prompts, sampling_config) + + for sampling_param in sampling_config: + ref_outputs.append(ref_llm.chat(test_prompts, sampling_param)) del ref_llm spec_llm = LLM( @@ -129,18 +142,22 @@ def test_eagle_correctness( }, max_model_len=1024, ) - spec_outputs = spec_llm.chat(test_prompts, sampling_config) - matches = 0 - misses = 0 - for ref_output, spec_output in zip(ref_outputs, spec_outputs): - if ref_output.outputs[0].text == spec_output.outputs[0].text: - matches += 1 - else: - misses += 1 - print(f"ref_output: {ref_output.outputs[0].text}") - print(f"spec_output: {spec_output.outputs[0].text}") - - # Heuristic: expect at least 70% of the prompts to match exactly - # Upon failure, inspect the outputs to check for inaccuracy. - assert matches > int(0.7 * len(ref_outputs)) + + for sampling_param in sampling_config: + spec_outputs.append(spec_llm.chat(test_prompts, sampling_param)) del spec_llm + + for i in range(len(sampling_config)): + matches = 0 + misses = 0 + for ref_output, spec_output in zip(ref_outputs[i], spec_outputs[i]): + if ref_output.outputs[0].text == spec_output.outputs[0].text: + matches += 1 + else: + misses += 1 + print(f"ref_output: {ref_output.outputs[0].text}") + print(f"spec_output: {spec_output.outputs[0].text}") + + # Heuristic: expect at least 70% of the prompts to match exactly + # Upon failure, inspect the outputs to check for inaccuracy. + assert matches > int(0.7 * len(ref_outputs[i])), "Failed for sampling_param: " + str(sampling_config[i])