Skip to content

Commit ed5d408

Browse files
authored
[Neuron] Remove bypass on EAGLEConfig and add a test (#18514)
Signed-off-by: Elaine Zhao <[email protected]>
1 parent 583507d commit ed5d408

File tree

4 files changed

+95
-5
lines changed

4 files changed

+95
-5
lines changed

.buildkite/scripts/hardware_ci/run-neuron-test.sh

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -53,4 +53,11 @@ docker run --rm -it --device=/dev/neuron0 --network bridge \
5353
-e "NEURON_COMPILE_CACHE_URL=${NEURON_COMPILE_CACHE_MOUNT}" \
5454
--name "${container_name}" \
5555
${image_name} \
56-
/bin/bash -c "python3 /workspace/vllm/examples/offline_inference/neuron.py && python3 -m pytest /workspace/vllm/tests/neuron/1_core/ -v --capture=tee-sys && python3 -m pytest /workspace/vllm/tests/neuron/2_core/ -v --capture=tee-sys"
56+
/bin/bash -c "
57+
python3 /workspace/vllm/examples/offline_inference/neuron.py;
58+
python3 -m pytest /workspace/vllm/tests/neuron/1_core/ -v --capture=tee-sys;
59+
for f in /workspace/vllm/tests/neuron/2_core/*.py; do
60+
echo 'Running test file: '$f;
61+
python3 -m pytest \$f -v --capture=tee-sys;
62+
done
63+
"

tests/neuron/2_core/test_eagle.py

Lines changed: 82 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,82 @@
1+
# SPDX-License-Identifier: Apache-2.0
2+
3+
import json
4+
import os
5+
import shutil
6+
import tempfile
7+
8+
import torch
9+
from huggingface_hub import snapshot_download
10+
from safetensors import safe_open
11+
12+
from vllm import LLM, SamplingParams
13+
14+
15+
def patch_eagle_draft_with_lm_head(target_model_id: str,
16+
draft_model_id: str) -> str:
17+
# In NxDI, draft model checkpoint must include lm_head weights from target
18+
# model. For more details see https://awsdocs-neuron.readthedocs-hosted.com
19+
# /en/latest/libraries/nxd-inference/developer_guides/feature-guide.html
20+
# #eagle-checkpoint-compatibility
21+
final_draft_dir = "/tmp/patched_eagle_draft"
22+
23+
with tempfile.TemporaryDirectory() as tmp_dir:
24+
target_dir = snapshot_download(repo_id=target_model_id,
25+
local_dir=os.path.join(
26+
tmp_dir, "target"))
27+
draft_dir = snapshot_download(repo_id=draft_model_id,
28+
local_dir=os.path.join(tmp_dir, "draft"))
29+
30+
lm_head_key = "lm_head.weight"
31+
index_path = os.path.join(target_dir, "model.safetensors.index.json")
32+
with open(index_path) as f:
33+
index = json.load(f)
34+
shard_name = index["weight_map"][lm_head_key]
35+
target_safetensor_path = os.path.join(target_dir, shard_name)
36+
37+
with safe_open(target_safetensor_path, framework="pt") as f:
38+
target_lm_head = f.get_tensor(lm_head_key)
39+
40+
draft_path = os.path.join(draft_dir, "pytorch_model.bin")
41+
draft_state_dict = torch.load(draft_path, map_location="cpu")
42+
draft_state_dict[lm_head_key] = target_lm_head.to(torch.float16)
43+
torch.save(draft_state_dict, draft_path)
44+
45+
shutil.copytree(draft_dir, final_draft_dir, dirs_exist_ok=True)
46+
47+
return final_draft_dir
48+
49+
50+
def test_eagle():
51+
patched_draft_path = patch_eagle_draft_with_lm_head(
52+
target_model_id="meta-llama/Llama-2-7b-hf",
53+
draft_model_id="yuhuili/EAGLE-llama2-chat-7B")
54+
llm = LLM(
55+
model="meta-llama/Llama-2-7b-hf",
56+
speculative_config={
57+
"model": patched_draft_path,
58+
"num_speculative_tokens": 5,
59+
"max_model_len": 128
60+
},
61+
max_num_seqs=1,
62+
max_model_len=128,
63+
tensor_parallel_size=2,
64+
override_neuron_config={
65+
"enable_eagle_speculation": True,
66+
"enable_fused_speculation": True,
67+
"fused_qkv": True
68+
},
69+
)
70+
prompts = [
71+
"The president of the United States is",
72+
]
73+
outputs = llm.generate(prompts, SamplingParams(top_k=1))
74+
expected_output = " the head of state and head of government of " \
75+
"the United States. The president direct"
76+
77+
for output in outputs:
78+
generated_text = output.outputs[0].text
79+
print(f"Prompt: {output.prompt!r}, Generated text: {generated_text!r}")
80+
assert (expected_output == generated_text)
81+
82+
print("Neuron Eagle speculation test passed.")

tests/neuron/2_core/test_mistral.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -12,8 +12,7 @@ def test_mistral():
1212
override_neuron_config={
1313
"sequence_parallel_enabled": False,
1414
"skip_warmup": True
15-
},
16-
device="neuron")
15+
})
1716

1817
# Send more prompts than the compiled batch size (4) and request
1918
# varying generation lengths to test accuracy related to Neuron
@@ -59,4 +58,7 @@ def test_mistral():
5958

6059
for expected_output, output in zip(expected_outputs, outputs):
6160
generated_text = output.outputs[0].text
61+
print(f"Prompt: {output.prompt!r}, Generated text: {generated_text!r}")
6262
assert (expected_output == generated_text)
63+
64+
print("Neuron Mistral test passed.")

vllm/config.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2529,11 +2529,10 @@ def __post_init__(self):
25292529
"Chunked prefill and EAGLE are not compatible "
25302530
"when using V0.")
25312531

2532-
from vllm.platforms import current_platform
25332532
from vllm.transformers_utils.configs.eagle import (
25342533
EAGLEConfig)
25352534
if isinstance(self.draft_model_config.hf_config,
2536-
EAGLEConfig) or current_platform.is_neuron():
2535+
EAGLEConfig):
25372536
pass
25382537
else:
25392538
eagle_config = EAGLEConfig(

0 commit comments

Comments
 (0)