diff --git a/vllm/sequence.py b/vllm/sequence.py index 98578ee04d5..45d0e5bc768 100644 --- a/vllm/sequence.py +++ b/vllm/sequence.py @@ -1137,6 +1137,9 @@ def __getitem__(self, key: Union[str, slice]): def __setitem__(self, key: str, value: torch.Tensor): self.tensors[key] = value + def items(self): + return self.tensors.items() + def __len__(self): return len(self.tensors) diff --git a/vllm/v1/worker/gpu_model_runner.py b/vllm/v1/worker/gpu_model_runner.py index f1212c3554b..6acad65e1c9 100644 --- a/vllm/v1/worker/gpu_model_runner.py +++ b/vllm/v1/worker/gpu_model_runner.py @@ -151,7 +151,8 @@ def __init__( self.positions = torch.zeros(self.max_num_tokens, dtype=torch.int64, device=self.device) - # self.intermediate_tensors # Set after load_model + # None in the first PP rank. The rest are set after load_model. + self.intermediate_tensors: Optional[IntermediateTensors] = None # Only relevant for models using M-RoPE (e.g, Qwen2-VL) if self.uses_mrope: @@ -925,6 +926,11 @@ def execute_model( if get_pp_group().is_first_rank: intermediate_tensors = None else: + assert intermediate_tensors is not None + assert self.intermediate_tensors is not None + for k, v in intermediate_tensors.items(): + self.intermediate_tensors[k][:num_input_tokens].copy_( + v[:num_input_tokens], non_blocking=True) intermediate_tensors = IntermediateTensors({ k: v[:num_input_tokens] for k, v in self.intermediate_tensors.items() @@ -1135,7 +1141,7 @@ def _dummy_run( if get_pp_group().is_first_rank: intermediate_tensors = None else: - if not hasattr(self, "intermediate_tensors"): + if self.intermediate_tensors is None: self.intermediate_tensors = ( self.model.make_empty_intermediate_tensors( batch_size=self.max_num_tokens,