Skip to content

Commit 42ffb1b

Browse files
llsj14HwwwwwwwH
authored andcommitted
[Bugfix][SpecDecode] Adjust Eagle model architecture to align with intended design (vllm-project#11672)
Signed-off-by: Sungjae Lee <[email protected]> Signed-off-by: hzh <[email protected]>
1 parent 95230b9 commit 42ffb1b

File tree

1 file changed

+24
-2
lines changed

1 file changed

+24
-2
lines changed

vllm/model_executor/models/eagle.py

Lines changed: 24 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -17,14 +17,30 @@
1717
from .utils import maybe_prefix
1818

1919

20+
class DummyInputLayerNorm(nn.Module):
21+
22+
def forward(self, x):
23+
return x
24+
25+
26+
class DummyOutputNorm(nn.Module):
27+
28+
def forward(self, x, residual):
29+
if residual is None:
30+
return x
31+
else:
32+
return x, residual
33+
34+
2035
class EAGLE(nn.Module):
2136
"""This class implements the EAGLE draft model from the paper: https://arxiv.org/pdf/2401.15077
2237
Reference implementation: https://github.com/SafeAILab/EAGLE
2338
2439
Differences from reference implementation:
2540
1. In reference, LlamaDecoderLayer implementation doesn't have
26-
input_layernorm for 1st decoder layer (https://github.com/SafeAILab/EAGLE/blob/7d065d084443fbfd386f88839efd7193c12be869/eagle/model/cnets.py#L427)
27-
but we do as HF implementation also does.
41+
input_layernorm for 1st decoder layer (https://github.com/SafeAILab/EAGLE/blob/7d065d084443fbfd386f88839efd7193c12be869/eagle/model/cnets.py#L427).
42+
Following this approach, our implementation also disables
43+
the input_layernorm for the first decoder layer.
2844
2. We allow any decoder layer to be used in EAGLE whereas in reference
2945
decoder layer is fixed to be LlamaDecoderLayer.
3046
3. We have an optional token_map which reduces draft vocab to most
@@ -46,10 +62,16 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
4662

4763
self.model = model_cls(vllm_config=vllm_config,
4864
prefix=maybe_prefix(prefix, "model"))
65+
4966
self.fc = nn.Linear(config.model.hidden_size * 2,
5067
config.model.hidden_size,
5168
bias=getattr(self.config, "eagle_fc_bias", False))
5269

70+
# Modify layer normalization and residual connections as suggested
71+
# in the EAGLE framework: https://github.com/SafeAILab/EAGLE
72+
self.model.model.layers[0].input_layernorm = DummyInputLayerNorm()
73+
self.model.model.norm = DummyOutputNorm()
74+
5375
self.orig_vocab_size = config.vocab_size
5476
self.truncated_vocab_size = config.truncated_vocab_size
5577
self.unpadded_vocab_size = self.truncated_vocab_size

0 commit comments

Comments
 (0)