Skip to content

Commit b152831

Browse files
DarkLight1337lulmer
authored andcommitted
[Bugfix] Fix 2 Node and Spec Decode tests (vllm-project#13341)
Signed-off-by: DarkLight1337 <[email protected]> Signed-off-by: Louis Ulmer <[email protected]>
1 parent fc92cb4 commit b152831

File tree

2 files changed

+17
-9
lines changed

2 files changed

+17
-9
lines changed

tests/distributed/test_pipeline_parallel.py

+5-5
Original file line numberDiff line numberDiff line change
@@ -275,11 +275,11 @@ def _compare_tp(
275275
if load_format == "dummy":
276276
# Avoid OOM
277277
text_overrides = {
278-
"num_layers": 1,
279-
"num_hidden_layers": 1,
280-
"num_experts": 2,
281-
"num_experts_per_tok": 2,
282-
"num_local_experts": 2,
278+
"num_hidden_layers": 4,
279+
"hidden_size": 512,
280+
"intermediate_size": 800,
281+
"num_attention_heads": 4,
282+
"num_key_value_heads": 1,
283283
}
284284

285285
if is_multimodal:

vllm/spec_decode/ngram_worker.py

+12-4
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66
import torch
77
import torch.nn as nn
88

9+
from vllm.config import VllmConfig
910
from vllm.model_executor.layers.sampler import SamplerOutput
1011
from vllm.sequence import ExecuteModelRequest
1112
from vllm.spec_decode.interfaces import SpeculativeProposals
@@ -25,11 +26,18 @@ class NGramWorker(NonLLMProposerWorkerBase):
2526
which don't rely on LLM model to give proposals.
2627
"""
2728

28-
def __init__(self, *args, **kwargs):
29+
def __init__(
30+
self,
31+
vllm_config: VllmConfig,
32+
local_rank: int,
33+
device_type: str = "cuda",
34+
**kwargs,
35+
):
36+
super().__init__(vllm_config)
37+
2938
# Get local_rank/vocab_size from kwargs attribute
30-
self.local_rank = kwargs["local_rank"]
31-
self.vocab_size = kwargs["vllm_config"].model_config.get_vocab_size()
32-
self.device_type = kwargs.get("device_type", "cuda")
39+
self.local_rank = local_rank
40+
self.device_type = device_type
3341

3442
# Lazy initialization list.
3543
self._proposer: Top1Proposer

0 commit comments

Comments
 (0)