Skip to content

Commit c407e22

Browse files
llsj14mzusman
authored andcommitted
[CI][Spec Decode] fix: broken test for EAGLE model (vllm-project#11972)
Signed-off-by: Sungjae Lee <[email protected]>
1 parent 5d23a81 commit c407e22

File tree

2 files changed

+13
-2
lines changed

2 files changed

+13
-2
lines changed

.buildkite/test-pipeline.yaml

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -231,13 +231,15 @@ steps:
231231
- pytest -v -s test_logits_processor.py
232232
- pytest -v -s model_executor/test_guided_processors.py
233233

234-
- label: Speculative decoding tests # 30min
234+
- label: Speculative decoding tests # 40min
235235
source_file_dependencies:
236236
- vllm/spec_decode
237237
- tests/spec_decode
238+
- vllm/model_executor/models/eagle.py
238239
commands:
239240
- pytest -v -s spec_decode/e2e/test_multistep_correctness.py
240241
- VLLM_ATTENTION_BACKEND=FLASH_ATTN pytest -v -s spec_decode --ignore=spec_decode/e2e/test_multistep_correctness.py
242+
- pytest -v -s spec_decode/e2e/test_eagle_correctness.py
241243

242244
- label: LoRA Test %N # 15min each
243245
mirror_hardwares: [amd]

vllm/model_executor/models/eagle.py

Lines changed: 10 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,11 @@
1919

2020
class DummyInputLayerNorm(nn.Module):
2121

22+
def __init__(self, weight=None, bias=None):
23+
super().__init__()
24+
self.weight = nn.Parameter(weight) if weight is not None else None
25+
self.bias = nn.Parameter(bias) if bias is not None else None
26+
2227
def forward(self, x):
2328
return x
2429

@@ -69,7 +74,11 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
6974

7075
# Modify layer normalization and residual connections as suggested
7176
# in the EAGLE framework: https://github.com/SafeAILab/EAGLE
72-
self.model.model.layers[0].input_layernorm = DummyInputLayerNorm()
77+
# While weights and biases are generally not needed,
78+
# they are retained here to support certain unit tests
79+
# (e.g., spec_decode/e2e/test_eagle_correctness.py).
80+
self.model.model.layers[0].input_layernorm = DummyInputLayerNorm(
81+
weight=self.model.model.layers[0].input_layernorm.weight)
7382
self.model.model.norm = DummyOutputNorm()
7483

7584
self.orig_vocab_size = config.vocab_size

0 commit comments

Comments
 (0)