Skip to content

Commit 57e87e4

Browse files
simon-molulmer
authored andcommitted
[bugfix] spec decode worker get tp group only when initialized (vllm-project#13578)
Signed-off-by: Louis Ulmer <[email protected]>
1 parent 8d2dca1 commit 57e87e4

File tree

1 file changed

+7
-2
lines changed

1 file changed

+7
-2
lines changed

vllm/spec_decode/spec_decode_worker.py

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@
1212
from vllm.distributed.communication_op import (broadcast_tensor_dict,
1313
get_tp_group,
1414
tensor_model_parallel_gather)
15+
from vllm.distributed.parallel_state import model_parallel_is_initialized
1516
from vllm.logger import init_logger
1617
from vllm.model_executor.layers.rejection_sampler import RejectionSampler
1718
from vllm.model_executor.layers.sampler import SamplerOutput
@@ -366,8 +367,12 @@ def init_device(self) -> None:
366367
target_lm_head_weight)
367368

368369
self._metrics.init_tensors(self.rank, device_type=self.device)
369-
self.spec_decode_sampler.init_tensors(get_tp_group().local_rank,
370-
device_type=self.device)
370+
if model_parallel_is_initialized():
371+
self.spec_decode_sampler.init_tensors(get_tp_group().local_rank,
372+
device_type=self.device)
373+
else:
374+
self.spec_decode_sampler.init_tensors(self.rank,
375+
device_type=self.device)
371376

372377
scorer_cls: Type[SpeculativeScorer]
373378
if self.disable_mqa_scorer:

0 commit comments

Comments
 (0)