From d9420dd1f93bf8cf06148666cb538ab4df6a35fd Mon Sep 17 00:00:00 2001 From: Sungjae Lee <33976427+llsj14@users.noreply.github.com> Date: Wed, 1 Jan 2025 16:07:03 +0000 Subject: [PATCH 1/9] fix: disable input layernorm and output norm Signed-off-by: Sungjae Lee <33976427+llsj14@users.noreply.github.com> --- vllm/model_executor/models/eagle.py | 12 ++++++++++++ 1 file changed, 12 insertions(+) diff --git a/vllm/model_executor/models/eagle.py b/vllm/model_executor/models/eagle.py index f138d136302..f625c91887f 100644 --- a/vllm/model_executor/models/eagle.py +++ b/vllm/model_executor/models/eagle.py @@ -17,6 +17,14 @@ from .utils import maybe_prefix +class DummyInputLayerNorm(nn.Module): + def forward(self, x): + return x + +class DummyNorm(nn.Module): + def forward(self, x, y): + return x, None + class EAGLE(nn.Module): """This class implements the EAGLE draft model from the paper: https://arxiv.org/pdf/2401.15077 Reference implementation: https://github.com/SafeAILab/EAGLE @@ -46,6 +54,10 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): self.model = model_cls(vllm_config=vllm_config, prefix=maybe_prefix(prefix, "model")) + for layer in self.model.model.layers: + layer.input_layernorm = DummyInputLayerNorm() + self.model.model.norm = DummyNorm() + self.fc = nn.Linear(config.model.hidden_size * 2, config.model.hidden_size, bias=getattr(self.config, "eagle_fc_bias", False)) From 60f863e4dbd665881c82ef61edf9d59ac04f6011 Mon Sep 17 00:00:00 2001 From: Sungjae Lee <33976427+llsj14@users.noreply.github.com> Date: Wed, 1 Jan 2025 16:07:23 +0000 Subject: [PATCH 2/9] fix: add residual path Signed-off-by: Sungjae Lee <33976427+llsj14@users.noreply.github.com> --- vllm/model_executor/models/llama.py | 8 ++++++++ 1 file changed, 8 insertions(+) diff --git a/vllm/model_executor/models/llama.py b/vllm/model_executor/models/llama.py index 2902e6999c2..dbb3620f001 100644 --- a/vllm/model_executor/models/llama.py +++ b/vllm/model_executor/models/llama.py @@ -315,6 +315,12 @@ def __init__(self, ) else: self.embed_tokens = PPMissingLayer() + + if config.num_hidden_layers==1: + self.eagle = True + else: + self.eagle = False + self.start_layer, self.end_layer, self.layers = make_layers( config.num_hidden_layers, lambda prefix: layer_type(config=config, @@ -368,6 +374,8 @@ def forward( }) hidden_states, _ = self.norm(hidden_states, residual) + if self.eagle: + hidden_states = residual + hidden_states return hidden_states def load_weights(self, weights: Iterable[Tuple[str, From c4358517fe78e60c35786e828fb2a022885afa66 Mon Sep 17 00:00:00 2001 From: Sungjae Lee <33976427+llsj14@users.noreply.github.com> Date: Wed, 1 Jan 2025 17:06:23 +0000 Subject: [PATCH 3/9] remove modification on llama model Signed-off-by: Sungjae Lee <33976427+llsj14@users.noreply.github.com> --- vllm/model_executor/models/eagle.py | 15 +++++++++------ vllm/model_executor/models/llama.py | 8 -------- 2 files changed, 9 insertions(+), 14 deletions(-) diff --git a/vllm/model_executor/models/eagle.py b/vllm/model_executor/models/eagle.py index f625c91887f..dc1dcde3953 100644 --- a/vllm/model_executor/models/eagle.py +++ b/vllm/model_executor/models/eagle.py @@ -21,9 +21,10 @@ class DummyInputLayerNorm(nn.Module): def forward(self, x): return x -class DummyNorm(nn.Module): - def forward(self, x, y): - return x, None +class DummOutputNorm(nn.Module): + def forward(self, x, residual): + x = x + residual + return x, residual class EAGLE(nn.Module): """This class implements the EAGLE draft model from the paper: https://arxiv.org/pdf/2401.15077 @@ -54,14 +55,16 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): self.model = model_cls(vllm_config=vllm_config, prefix=maybe_prefix(prefix, "model")) - for layer in self.model.model.layers: - layer.input_layernorm = DummyInputLayerNorm() - self.model.model.norm = DummyNorm() self.fc = nn.Linear(config.model.hidden_size * 2, config.model.hidden_size, bias=getattr(self.config, "eagle_fc_bias", False)) + # Modify layer normalization and residual connections as suggested + # in the EAGLE framework: https://github.com/SafeAILab/EAGLE + self.model.model.layers[0].input_layernorm = DummyInputLayerNorm() + self.model.model.norm = DummOutputNorm() + self.orig_vocab_size = config.vocab_size self.truncated_vocab_size = config.truncated_vocab_size self.unpadded_vocab_size = self.truncated_vocab_size diff --git a/vllm/model_executor/models/llama.py b/vllm/model_executor/models/llama.py index dbb3620f001..2902e6999c2 100644 --- a/vllm/model_executor/models/llama.py +++ b/vllm/model_executor/models/llama.py @@ -315,12 +315,6 @@ def __init__(self, ) else: self.embed_tokens = PPMissingLayer() - - if config.num_hidden_layers==1: - self.eagle = True - else: - self.eagle = False - self.start_layer, self.end_layer, self.layers = make_layers( config.num_hidden_layers, lambda prefix: layer_type(config=config, @@ -374,8 +368,6 @@ def forward( }) hidden_states, _ = self.norm(hidden_states, residual) - if self.eagle: - hidden_states = residual + hidden_states return hidden_states def load_weights(self, weights: Iterable[Tuple[str, From aa183ff010e354fb03ce27dc5700f51cb73831c3 Mon Sep 17 00:00:00 2001 From: Sungjae Lee <33976427+llsj14@users.noreply.github.com> Date: Wed, 1 Jan 2025 17:07:46 +0000 Subject: [PATCH 4/9] make format Signed-off-by: Sungjae Lee <33976427+llsj14@users.noreply.github.com> --- vllm/model_executor/models/eagle.py | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/vllm/model_executor/models/eagle.py b/vllm/model_executor/models/eagle.py index dc1dcde3953..11e38655900 100644 --- a/vllm/model_executor/models/eagle.py +++ b/vllm/model_executor/models/eagle.py @@ -18,14 +18,18 @@ class DummyInputLayerNorm(nn.Module): + def forward(self, x): return x + class DummOutputNorm(nn.Module): + def forward(self, x, residual): x = x + residual return x, residual + class EAGLE(nn.Module): """This class implements the EAGLE draft model from the paper: https://arxiv.org/pdf/2401.15077 Reference implementation: https://github.com/SafeAILab/EAGLE @@ -60,7 +64,7 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): config.model.hidden_size, bias=getattr(self.config, "eagle_fc_bias", False)) - # Modify layer normalization and residual connections as suggested + # Modify layer normalization and residual connections as suggested # in the EAGLE framework: https://github.com/SafeAILab/EAGLE self.model.model.layers[0].input_layernorm = DummyInputLayerNorm() self.model.model.norm = DummOutputNorm() From f2751c80c74f066977cb0169e9fd993eaaa0de26 Mon Sep 17 00:00:00 2001 From: Sungjae Lee <33976427+llsj14@users.noreply.github.com> Date: Fri, 3 Jan 2025 06:26:33 +0000 Subject: [PATCH 5/9] fix comment and typo Signed-off-by: Sungjae Lee <33976427+llsj14@users.noreply.github.com> --- vllm/model_executor/models/eagle.py | 9 +++++---- vllm/spec_decode/metrics.py | 2 +- 2 files changed, 6 insertions(+), 5 deletions(-) diff --git a/vllm/model_executor/models/eagle.py b/vllm/model_executor/models/eagle.py index 11e38655900..b7ce95152bb 100644 --- a/vllm/model_executor/models/eagle.py +++ b/vllm/model_executor/models/eagle.py @@ -23,7 +23,7 @@ def forward(self, x): return x -class DummOutputNorm(nn.Module): +class DummyOutputNorm(nn.Module): def forward(self, x, residual): x = x + residual @@ -36,8 +36,9 @@ class EAGLE(nn.Module): Differences from reference implementation: 1. In reference, LlamaDecoderLayer implementation doesn't have - input_layernorm for 1st decoder layer (https://github.com/SafeAILab/EAGLE/blob/7d065d084443fbfd386f88839efd7193c12be869/eagle/model/cnets.py#L427) - but we do as HF implementation also does. + input_layernorm for 1st decoder layer (https://github.com/SafeAILab/EAGLE/blob/7d065d084443fbfd386f88839efd7193c12be869/eagle/model/cnets.py#L427). + Following this approach, our implementation also disables + the input_layernormfor the first decoder layer. 2. We allow any decoder layer to be used in EAGLE whereas in reference decoder layer is fixed to be LlamaDecoderLayer. 3. We have an optional token_map which reduces draft vocab to most @@ -67,7 +68,7 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): # Modify layer normalization and residual connections as suggested # in the EAGLE framework: https://github.com/SafeAILab/EAGLE self.model.model.layers[0].input_layernorm = DummyInputLayerNorm() - self.model.model.norm = DummOutputNorm() + self.model.model.norm = DummyOutputNorm() self.orig_vocab_size = config.vocab_size self.truncated_vocab_size = config.truncated_vocab_size diff --git a/vllm/spec_decode/metrics.py b/vllm/spec_decode/metrics.py index d678f457849..37f372b4250 100644 --- a/vllm/spec_decode/metrics.py +++ b/vllm/spec_decode/metrics.py @@ -117,7 +117,7 @@ def _should_collect_rejsample_metrics(self, now: float) -> bool: if self._rank != 0: return False - return now - self._last_metrics_collect_time >= self._rejsample_metrics_collect_interval_s # noqa: E501 + return now - self._last_metrics_collect_time >= 0.1 # noqa: E501 def _copy_rejsample_metrics_async(self) -> torch.cuda.Event: """Copy rejection/typical-acceptance sampling metrics From 300f58cf84f0dc4d3b18b2af57d6dd2852608c67 Mon Sep 17 00:00:00 2001 From: Sungjae Lee <33976427+llsj14@users.noreply.github.com> Date: Fri, 3 Jan 2025 06:42:14 +0000 Subject: [PATCH 6/9] fix typo Signed-off-by: Sungjae Lee <33976427+llsj14@users.noreply.github.com> --- vllm/model_executor/models/eagle.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/vllm/model_executor/models/eagle.py b/vllm/model_executor/models/eagle.py index b7ce95152bb..a2a3c0e3345 100644 --- a/vllm/model_executor/models/eagle.py +++ b/vllm/model_executor/models/eagle.py @@ -38,7 +38,7 @@ class EAGLE(nn.Module): 1. In reference, LlamaDecoderLayer implementation doesn't have input_layernorm for 1st decoder layer (https://github.com/SafeAILab/EAGLE/blob/7d065d084443fbfd386f88839efd7193c12be869/eagle/model/cnets.py#L427). Following this approach, our implementation also disables - the input_layernormfor the first decoder layer. + the input_layernorm for the first decoder layer. 2. We allow any decoder layer to be used in EAGLE whereas in reference decoder layer is fixed to be LlamaDecoderLayer. 3. We have an optional token_map which reduces draft vocab to most From 0c8d357d2497830b82b735682312ba6d40f3f052 Mon Sep 17 00:00:00 2001 From: Sungjae Lee <33976427+llsj14@users.noreply.github.com> Date: Fri, 3 Jan 2025 06:44:25 +0000 Subject: [PATCH 7/9] revert updating metric part Signed-off-by: Sungjae Lee <33976427+llsj14@users.noreply.github.com> --- vllm/spec_decode/metrics.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/vllm/spec_decode/metrics.py b/vllm/spec_decode/metrics.py index 37f372b4250..d678f457849 100644 --- a/vllm/spec_decode/metrics.py +++ b/vllm/spec_decode/metrics.py @@ -117,7 +117,7 @@ def _should_collect_rejsample_metrics(self, now: float) -> bool: if self._rank != 0: return False - return now - self._last_metrics_collect_time >= 0.1 # noqa: E501 + return now - self._last_metrics_collect_time >= self._rejsample_metrics_collect_interval_s # noqa: E501 def _copy_rejsample_metrics_async(self) -> torch.cuda.Event: """Copy rejection/typical-acceptance sampling metrics From e25a37e4a4e8440f1c4ba183d23d184ff4bdcbaf Mon Sep 17 00:00:00 2001 From: Sungjae Lee <33976427+llsj14@users.noreply.github.com> Date: Wed, 8 Jan 2025 03:53:21 +0000 Subject: [PATCH 8/9] add condition before residual operation Signed-off-by: Sungjae Lee <33976427+llsj14@users.noreply.github.com> --- vllm/model_executor/models/eagle.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/vllm/model_executor/models/eagle.py b/vllm/model_executor/models/eagle.py index a2a3c0e3345..94b283d60b1 100644 --- a/vllm/model_executor/models/eagle.py +++ b/vllm/model_executor/models/eagle.py @@ -26,7 +26,8 @@ def forward(self, x): class DummyOutputNorm(nn.Module): def forward(self, x, residual): - x = x + residual + if residual is not None: + x = x + residual return x, residual From 21ef7ee1941ebef1ef7e0cbe549d85b5aa168553 Mon Sep 17 00:00:00 2001 From: Sungjae Lee <33976427+llsj14@users.noreply.github.com> Date: Sat, 11 Jan 2025 01:11:57 +0000 Subject: [PATCH 9/9] fix: DummyOutputNorm Signed-off-by: Sungjae Lee <33976427+llsj14@users.noreply.github.com> --- vllm/model_executor/models/eagle.py | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) diff --git a/vllm/model_executor/models/eagle.py b/vllm/model_executor/models/eagle.py index 94b283d60b1..eb7b5af19ae 100644 --- a/vllm/model_executor/models/eagle.py +++ b/vllm/model_executor/models/eagle.py @@ -26,9 +26,10 @@ def forward(self, x): class DummyOutputNorm(nn.Module): def forward(self, x, residual): - if residual is not None: - x = x + residual - return x, residual + if residual is None: + return x + else: + return x, residual class EAGLE(nn.Module):