17
17
from .utils import maybe_prefix
18
18
19
19
20
+ class DummyInputLayerNorm (nn .Module ):
21
+
22
+ def forward (self , x ):
23
+ return x
24
+
25
+
26
+ class DummyOutputNorm (nn .Module ):
27
+
28
+ def forward (self , x , residual ):
29
+ if residual is None :
30
+ return x
31
+ else :
32
+ return x , residual
33
+
34
+
20
35
class EAGLE (nn .Module ):
21
36
"""This class implements the EAGLE draft model from the paper: https://arxiv.org/pdf/2401.15077
22
37
Reference implementation: https://github.com/SafeAILab/EAGLE
23
38
24
39
Differences from reference implementation:
25
40
1. In reference, LlamaDecoderLayer implementation doesn't have
26
- input_layernorm for 1st decoder layer (https://github.com/SafeAILab/EAGLE/blob/7d065d084443fbfd386f88839efd7193c12be869/eagle/model/cnets.py#L427)
27
- but we do as HF implementation also does.
41
+ input_layernorm for 1st decoder layer (https://github.com/SafeAILab/EAGLE/blob/7d065d084443fbfd386f88839efd7193c12be869/eagle/model/cnets.py#L427).
42
+ Following this approach, our implementation also disables
43
+ the input_layernorm for the first decoder layer.
28
44
2. We allow any decoder layer to be used in EAGLE whereas in reference
29
45
decoder layer is fixed to be LlamaDecoderLayer.
30
46
3. We have an optional token_map which reduces draft vocab to most
@@ -46,10 +62,16 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
46
62
47
63
self .model = model_cls (vllm_config = vllm_config ,
48
64
prefix = maybe_prefix (prefix , "model" ))
65
+
49
66
self .fc = nn .Linear (config .model .hidden_size * 2 ,
50
67
config .model .hidden_size ,
51
68
bias = getattr (self .config , "eagle_fc_bias" , False ))
52
69
70
+ # Modify layer normalization and residual connections as suggested
71
+ # in the EAGLE framework: https://github.com/SafeAILab/EAGLE
72
+ self .model .model .layers [0 ].input_layernorm = DummyInputLayerNorm ()
73
+ self .model .model .norm = DummyOutputNorm ()
74
+
53
75
self .orig_vocab_size = config .vocab_size
54
76
self .truncated_vocab_size = config .truncated_vocab_size
55
77
self .unpadded_vocab_size = self .truncated_vocab_size
0 commit comments