Skip to content

Commit 104dba4

Browse files
committed
add request level, per-step acceptance counts tracking for spec dec
Signed-off-by: Bryan Lu <[email protected]>
1 parent eb07c8c commit 104dba4

File tree

8 files changed

+171
-117
lines changed

8 files changed

+171
-117
lines changed

examples/offline_inference/eagle.py

Lines changed: 110 additions & 86 deletions
Original file line numberDiff line numberDiff line change
@@ -7,89 +7,113 @@
77

88
from vllm import LLM, SamplingParams
99

10-
parser = argparse.ArgumentParser()
11-
12-
parser.add_argument(
13-
"--dataset",
14-
type=str,
15-
default="./examples/data/gsm8k.jsonl",
16-
help="downloaded from the eagle repo " \
17-
"https://github.com/SafeAILab/EAGLE/blob/main/eagle/data/"
18-
)
19-
parser.add_argument("--max_num_seqs", type=int, default=8)
20-
parser.add_argument("--num_prompts", type=int, default=80)
21-
parser.add_argument("--num_spec_tokens", type=int, default=2)
22-
parser.add_argument("--tp", type=int, default=1)
23-
parser.add_argument("--draft_tp", type=int, default=1)
24-
parser.add_argument("--enforce_eager", action='store_true')
25-
parser.add_argument("--enable_chunked_prefill", action='store_true')
26-
parser.add_argument("--max_num_batched_tokens", type=int, default=2048)
27-
parser.add_argument("--temp", type=float, default=0)
28-
29-
args = parser.parse_args()
30-
31-
print(args)
32-
33-
model_dir = "meta-llama/Meta-Llama-3-8B-Instruct"
34-
eagle_dir = "abhigoyal/EAGLE-LLaMA3-Instruct-8B-vllm"
35-
36-
max_model_len = 2048
37-
38-
tokenizer = AutoTokenizer.from_pretrained(model_dir)
39-
40-
if os.path.exists(args.dataset):
41-
prompts = []
42-
num_prompts = args.num_prompts
43-
with open(args.dataset) as f:
44-
for line in f:
45-
data = json.loads(line)
46-
prompts.append(data["turns"][0])
47-
else:
48-
prompts = ["The future of AI is", "The president of the United States is"]
49-
50-
prompts = prompts[:args.num_prompts]
51-
num_prompts = len(prompts)
52-
53-
prompt_ids = [
54-
tokenizer.apply_chat_template([{
55-
"role": "user",
56-
"content": prompt
57-
}],
58-
add_generation_prompt=True)
59-
for prompt in prompts
60-
]
61-
62-
llm = LLM(
63-
model=model_dir,
64-
trust_remote_code=True,
65-
tensor_parallel_size=args.tp,
66-
enable_chunked_prefill=args.enable_chunked_prefill,
67-
max_num_batched_tokens=args.max_num_batched_tokens,
68-
enforce_eager=args.enforce_eager,
69-
max_model_len=max_model_len,
70-
max_num_seqs=args.max_num_seqs,
71-
gpu_memory_utilization=0.8,
72-
speculative_config={
73-
"model": eagle_dir,
74-
"num_speculative_tokens": args.num_spec_tokens,
75-
"draft_tensor_parallel_size": args.draft_tp,
76-
"max_model_len": max_model_len,
77-
},
78-
disable_log_stats=False,
79-
)
80-
81-
sampling_params = SamplingParams(temperature=args.temp, max_tokens=256)
82-
83-
outputs = llm.generate(prompt_token_ids=prompt_ids,
84-
sampling_params=sampling_params)
85-
86-
# calculate the average number of accepted tokens per forward pass, +1 is
87-
# to account for the token from the target model that's always going to be
88-
# accepted
89-
acceptance_counts = [0] * (args.num_spec_tokens + 1)
90-
for output in outputs:
91-
for step, count in enumerate(output.metrics.spec_token_acceptance_counts):
92-
acceptance_counts[step] += count
93-
94-
print(f"mean acceptance length: \
95-
{sum(acceptance_counts) / acceptance_counts[0]:.2f}")
10+
11+
def load_prompts(dataset_path, num_prompts):
12+
if os.path.exists(dataset_path):
13+
prompts = []
14+
try:
15+
with open(dataset_path) as f:
16+
for line in f:
17+
data = json.loads(line)
18+
prompts.append(data["turns"][0])
19+
except Exception as e:
20+
print(f"Error reading dataset: {e}")
21+
return []
22+
else:
23+
prompts = [
24+
"The future of AI is", "The president of the United States is"
25+
]
26+
27+
return prompts[:num_prompts]
28+
29+
30+
def main():
31+
parser = argparse.ArgumentParser()
32+
parser.add_argument(
33+
"--dataset",
34+
type=str,
35+
default="./examples/data/gsm8k.jsonl",
36+
help="downloaded from the eagle repo " \
37+
"https://github.com/SafeAILab/EAGLE/blob/main/eagle/data/"
38+
)
39+
parser.add_argument("--max_num_seqs", type=int, default=8)
40+
parser.add_argument("--num_prompts", type=int, default=80)
41+
parser.add_argument("--num_spec_tokens", type=int, default=2)
42+
parser.add_argument("--tp", type=int, default=1)
43+
parser.add_argument("--draft_tp", type=int, default=1)
44+
parser.add_argument("--enforce_eager", action='store_true')
45+
parser.add_argument("--enable_chunked_prefill", action='store_true')
46+
parser.add_argument("--max_num_batched_tokens", type=int, default=2048)
47+
parser.add_argument("--temp", type=float, default=0)
48+
parser.add_argument("--use_v1", type=str, default="1", help='1 or 0')
49+
args = parser.parse_args()
50+
51+
# TODO: remove this option once EAGLE in v1 is ready.
52+
os.environ["VLLM_USE_V1"] = args.use_v1
53+
54+
model_dir = "meta-llama/Meta-Llama-3-8B-Instruct"
55+
eagle_dir = "abhigoyal/EAGLE-LLaMA3-Instruct-8B-vllm"
56+
57+
max_model_len = 2048
58+
59+
tokenizer = AutoTokenizer.from_pretrained(model_dir)
60+
61+
prompts = load_prompts(args.dataset, args.num_prompts)
62+
63+
prompt_ids = [
64+
tokenizer.apply_chat_template([{
65+
"role": "user",
66+
"content": prompt
67+
}],
68+
add_generation_prompt=True)
69+
for prompt in prompts
70+
]
71+
72+
llm = LLM(
73+
model=model_dir,
74+
trust_remote_code=True,
75+
tensor_parallel_size=args.tp,
76+
enable_chunked_prefill=args.enable_chunked_prefill,
77+
max_num_batched_tokens=args.max_num_batched_tokens,
78+
enforce_eager=args.enforce_eager,
79+
max_model_len=max_model_len,
80+
max_num_seqs=args.max_num_seqs,
81+
gpu_memory_utilization=0.8,
82+
speculative_config={
83+
"method": "eagle",
84+
"model": eagle_dir,
85+
"num_speculative_tokens": args.num_spec_tokens,
86+
"draft_tensor_parallel_size": args.draft_tp,
87+
"max_model_len": max_model_len,
88+
},
89+
disable_log_stats=False,
90+
)
91+
92+
sampling_params = SamplingParams(temperature=args.temp, max_tokens=256)
93+
94+
outputs = llm.generate(prompt_token_ids=prompt_ids,
95+
sampling_params=sampling_params)
96+
97+
# calculate the average number of accepted tokens per forward pass, +1 is
98+
# to account for the token from the target model that's always going to be
99+
# accepted
100+
acceptance_counts = [0] * (args.num_spec_tokens + 1)
101+
if args.use_v1 == '1':
102+
for output in outputs:
103+
for step, count in enumerate(
104+
output.spec_token_acceptance_counts[0]):
105+
acceptance_counts[step] += count
106+
else:
107+
for output in outputs:
108+
for step, count in enumerate(
109+
output.metrics.spec_token_acceptance_counts):
110+
acceptance_counts[step] += count
111+
112+
print("-" * 50)
113+
print(f"mean acceptance length: \
114+
{sum(acceptance_counts) / acceptance_counts[0]:.2f}")
115+
print("-" * 50)
116+
117+
118+
if __name__ == "__main__":
119+
main()

vllm/outputs.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -43,6 +43,7 @@ class CompletionOutput:
4343
finish_reason: Optional[str] = None
4444
stop_reason: Union[int, str, None] = None
4545
lora_request: Optional[LoRARequest] = None
46+
spec_token_acceptance_counts: Optional[list[int]] = None
4647

4748
def finished(self) -> bool:
4849
return self.finish_reason is not None
@@ -133,6 +134,9 @@ def __init__(
133134
self.encoder_prompt = encoder_prompt
134135
self.encoder_prompt_token_ids = encoder_prompt_token_ids
135136
self.num_cached_tokens = num_cached_tokens
137+
self.spec_token_acceptance_counts = [
138+
o.spec_token_acceptance_counts for o in outputs
139+
]
136140

137141
def add(self, next_output: "RequestOutput") -> None:
138142
"""Merge subsequent RequestOutput into this one"""

vllm/v1/core/sched/scheduler.py

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -590,6 +590,10 @@ def update_from_output(
590590
num_draft_tokens=len(scheduled_spec_token_ids),
591591
num_accepted_tokens=len(generated_token_ids) - 1)
592592

593+
for i in range(len(generated_token_ids)):
594+
if request.spec_token_acceptance_counts is not None:
595+
request.spec_token_acceptance_counts[i] += 1
596+
593597
cached_encoder_input_ids = (
594598
self.encoder_cache_manager.get_cached_input_ids(request))
595599
# OPTIMIZATION: Avoid list(set) if the set is empty.
@@ -651,7 +655,9 @@ def update_from_output(
651655
new_logprobs=new_logprobs,
652656
new_prompt_logprobs_tensors=prompt_logprobs_tensors,
653657
stop_reason=request.stop_reason,
654-
events=request.take_events()))
658+
events=request.take_events(),
659+
spec_token_acceptance_counts=request.
660+
spec_token_acceptance_counts))
655661
else:
656662
# Invariant: EngineCore returns no partial prefill outputs.
657663
assert not prompt_logprobs_tensors

vllm/v1/engine/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -59,6 +59,7 @@ class EngineCoreRequest(
5959
eos_token_id: Optional[int]
6060
arrival_time: float
6161
lora_request: Optional[LoRARequest]
62+
spec_token_acceptance_counts: Optional[list[int]]
6263

6364

6465
class EngineCoreEventType(enum.IntEnum):
@@ -101,6 +102,7 @@ class EngineCoreOutput(
101102
finish_reason: Optional[FinishReason] = None
102103
stop_reason: Union[int, str, None] = None
103104
events: Optional[list[EngineCoreEvent]] = None
105+
spec_token_acceptance_counts: Optional[list[int]] = None
104106

105107
@property
106108
def finished(self) -> bool:

vllm/v1/engine/llm_engine.py

Lines changed: 14 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -183,11 +183,20 @@ def add_request(
183183
priority: int = 0,
184184
) -> None:
185185
# Process raw inputs into the request.
186-
request = self.processor.process_inputs(request_id, prompt, params,
187-
arrival_time, lora_request,
188-
trace_headers,
189-
prompt_adapter_request,
190-
priority)
186+
num_spec_tokens = 0
187+
if self.vllm_config.speculative_config is not None:
188+
num_spec_tokens = (
189+
self.vllm_config.speculative_config.num_speculative_tokens)
190+
request = self.processor.process_inputs(
191+
request_id,
192+
prompt,
193+
params,
194+
arrival_time,
195+
lora_request,
196+
trace_headers,
197+
prompt_adapter_request,
198+
priority,
199+
num_spec_tokens=num_spec_tokens)
191200

192201
n = params.n if isinstance(params, SamplingParams) else 1
193202

vllm/v1/engine/output_processor.py

Lines changed: 16 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -136,10 +136,9 @@ def from_new_request(
136136
)
137137

138138
def make_request_output(
139-
self,
140-
new_token_ids: list[int],
141-
finish_reason: Optional[FinishReason],
139+
self, new_token_ids: list[int], finish_reason: Optional[FinishReason],
142140
stop_reason: Union[int, str, None],
141+
spec_token_acceptance_counts: Optional[list[int]]
143142
) -> Optional[RequestOutput]:
144143

145144
finished = finish_reason is not None
@@ -150,7 +149,10 @@ def make_request_output(
150149
return None
151150

152151
completion_output = self._new_completion_output(
153-
new_token_ids, finish_reason, stop_reason)
152+
new_token_ids,
153+
finish_reason,
154+
stop_reason,
155+
spec_token_acceptance_counts=spec_token_acceptance_counts)
154156

155157
request_id = self.request_id
156158
if self.parent_req is None:
@@ -186,10 +188,9 @@ def _new_request_output(
186188
)
187189

188190
def _new_completion_output(
189-
self,
190-
token_ids: list[int],
191-
finish_reason: Optional[FinishReason],
192-
stop_reason: Union[int, str, None],
191+
self, token_ids: list[int], finish_reason: Optional[FinishReason],
192+
stop_reason: Union[int, str, None],
193+
spec_token_acceptance_counts: Optional[list[int]]
193194
) -> CompletionOutput:
194195

195196
finished = finish_reason is not None
@@ -212,7 +213,8 @@ def _new_completion_output(
212213
logprobs=logprobs,
213214
cumulative_logprob=self.logprobs_processor.cumulative_logprob,
214215
finish_reason=str(finish_reason) if finished else None,
215-
stop_reason=stop_reason if finished else None)
216+
stop_reason=stop_reason if finished else None,
217+
spec_token_acceptance_counts=spec_token_acceptance_counts)
216218

217219

218220
class OutputProcessor:
@@ -337,7 +339,11 @@ def process_outputs(
337339

338340
# 4) Create and handle RequestOutput objects.
339341
if request_output := req_state.make_request_output(
340-
new_token_ids, finish_reason, stop_reason):
342+
new_token_ids,
343+
finish_reason,
344+
stop_reason,
345+
spec_token_acceptance_counts=engine_core_output.
346+
spec_token_acceptance_counts):
341347
if req_state.queue is not None:
342348
# AsyncLLM: put into queue for handling by generate().
343349
req_state.queue.put(request_output)

vllm/v1/engine/processor.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -176,6 +176,7 @@ def process_inputs(
176176
trace_headers: Optional[Mapping[str, str]] = None,
177177
prompt_adapter_request: Optional[PromptAdapterRequest] = None,
178178
priority: int = 0,
179+
num_spec_tokens: int = 0,
179180
) -> EngineCoreRequest:
180181

181182
# TODO(woosuk): Support pooling models.
@@ -278,7 +279,7 @@ def process_inputs(
278279
eos_token_id=eos_token_id,
279280
arrival_time=arrival_time,
280281
lora_request=lora_request,
281-
)
282+
spec_token_acceptance_counts=[0] * (num_spec_tokens + 1))
282283

283284
def _validate_model_inputs(self,
284285
inputs: ProcessorInputs,

0 commit comments

Comments
 (0)