File tree Expand file tree Collapse file tree 1 file changed +7
-2
lines changed Expand file tree Collapse file tree 1 file changed +7
-2
lines changed Original file line number Diff line number Diff line change 12
12
from vllm .distributed .communication_op import (broadcast_tensor_dict ,
13
13
get_tp_group ,
14
14
tensor_model_parallel_gather )
15
+ from vllm .distributed .parallel_state import model_parallel_is_initialized
15
16
from vllm .logger import init_logger
16
17
from vllm .model_executor .layers .rejection_sampler import RejectionSampler
17
18
from vllm .model_executor .layers .sampler import SamplerOutput
@@ -366,8 +367,12 @@ def init_device(self) -> None:
366
367
target_lm_head_weight )
367
368
368
369
self ._metrics .init_tensors (self .rank , device_type = self .device )
369
- self .spec_decode_sampler .init_tensors (get_tp_group ().local_rank ,
370
- device_type = self .device )
370
+ if model_parallel_is_initialized ():
371
+ self .spec_decode_sampler .init_tensors (get_tp_group ().local_rank ,
372
+ device_type = self .device )
373
+ else :
374
+ self .spec_decode_sampler .init_tensors (self .rank ,
375
+ device_type = self .device )
371
376
372
377
scorer_cls : Type [SpeculativeScorer ]
373
378
if self .disable_mqa_scorer :
You can’t perform that action at this time.
0 commit comments