Skip to content

Commit b610861

Browse files
DarkLight1337Mu Huai
authored and
Mu Huai
committed
[Bugfix] Fix hybrid model tests (vllm-project#17182)
Signed-off-by: DarkLight1337 <[email protected]> Signed-off-by: Mu Huai <[email protected]>
1 parent c934aa4 commit b610861

File tree

3 files changed

+158
-534
lines changed

3 files changed

+158
-534
lines changed

tests/conftest.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -531,7 +531,10 @@ def _hidden_states_to_seq_logprobs(
531531
for _, hidden_state in enumerate(hidden_states):
532532
last_hidden_states = hidden_state[-1][0]
533533
logits = torch.matmul(
534-
last_hidden_states.to(output_embeddings.weight.device),
534+
last_hidden_states.to(
535+
device=output_embeddings.weight.device,
536+
dtype=output_embeddings.weight.dtype,
537+
),
535538
output_embeddings.weight.t(),
536539
)
537540
if getattr(output_embeddings, "bias", None) is not None:

0 commit comments

Comments
 (0)