Skip to content

Commit fdb5374

Browse files
WoosukKwonwuisawesome
authored andcommitted
[Minor][Models] Fix Return Types of Llama & Eagle (vllm-project#17220)
Signed-off-by: Woosuk Kwon <[email protected]>
1 parent 0ed4e02 commit fdb5374

File tree

3 files changed

+6
-5
lines changed

3 files changed

+6
-5
lines changed

vllm/model_executor/models/llama.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -345,7 +345,8 @@ def forward(
345345
positions: torch.Tensor,
346346
intermediate_tensors: Optional[IntermediateTensors],
347347
inputs_embeds: Optional[torch.Tensor] = None,
348-
) -> Union[torch.Tensor, IntermediateTensors]:
348+
) -> Union[torch.Tensor, IntermediateTensors, tuple[torch.Tensor,
349+
list[torch.Tensor]]]:
349350
if get_pp_group().is_first_rank:
350351
if inputs_embeds is not None:
351352
hidden_states = inputs_embeds

vllm/model_executor/models/llama_eagle.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -70,7 +70,7 @@ def forward(
7070
input_ids: torch.Tensor,
7171
positions: torch.Tensor,
7272
hidden_states: torch.Tensor,
73-
) -> torch.Tensor:
73+
) -> tuple[torch.Tensor, torch.Tensor]:
7474
input_embeds = self.embed_tokens(input_ids)
7575
hidden_states = self.fc(
7676
torch.cat((input_embeds, hidden_states), dim=-1))
@@ -133,7 +133,7 @@ def forward(
133133
input_ids: torch.Tensor,
134134
positions: torch.Tensor,
135135
hidden_states: torch.Tensor,
136-
) -> torch.Tensor:
136+
) -> tuple[torch.Tensor, torch.Tensor]:
137137
return self.model(input_ids, positions, hidden_states)
138138

139139
def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):

vllm/model_executor/models/llama_eagle3.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -117,7 +117,7 @@ def forward(
117117
input_ids: torch.Tensor,
118118
positions: torch.Tensor,
119119
hidden_states: torch.Tensor,
120-
) -> torch.Tensor:
120+
) -> tuple[torch.Tensor, torch.Tensor]:
121121
input_embeds = self.embed_tokens(input_ids)
122122
if (hidden_states.shape[-1] != input_embeds.shape[-1]):
123123
hidden_states = self.fc(hidden_states)
@@ -194,7 +194,7 @@ def forward(
194194
input_ids: torch.Tensor,
195195
positions: torch.Tensor,
196196
hidden_states: torch.Tensor,
197-
) -> torch.Tensor:
197+
) -> tuple[torch.Tensor, torch.Tensor]:
198198
return self.model(input_ids, positions, hidden_states)
199199

200200
def compute_logits(

0 commit comments

Comments
 (0)