Skip to content

Commit c5b4b11

Browse files
authored
[Bugfix] Fix k_proj's bias for whisper self attention (#12342)
Signed-off-by: Isotr0py <[email protected]>
1 parent 8ae5ff2 commit c5b4b11

File tree

1 file changed

+18
-3
lines changed

1 file changed

+18
-3
lines changed

vllm/model_executor/models/whisper.py

Lines changed: 18 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -729,7 +729,22 @@ def sample(
729729
def load_weights(self, weights: Iterable[Tuple[str,
730730
torch.Tensor]]) -> Set[str]:
731731
loader = AutoWeightsLoader(self, skip_prefixes=["proj_out."])
732-
loaded_weights = [(name, loaded_weight)
733-
for name, loaded_weight in weights]
734732
mapper = WeightsMapper({".fc1.": ".mlp.fc1.", ".fc2.": ".mlp.fc2."})
735-
return loader.load_weights(loaded_weights, mapper=mapper)
733+
# add fake zeros bias for k_proj to state_dict
734+
weights = _create_fake_bias_for_k_proj(weights)
735+
return loader.load_weights(weights, mapper=mapper)
736+
737+
738+
def _create_fake_bias_for_k_proj(
739+
weights: Iterable[Tuple[str, torch.Tensor]]
740+
) -> Iterable[Tuple[str, torch.Tensor]]:
741+
"""
742+
Create full zeros bias for k_proj weight in self-attention layers.
743+
So that the bias for k_proj in qkv_proj can be initialized with zeros.
744+
"""
745+
for name, weight in weights:
746+
if ".self_attn.k_proj.weight" in name:
747+
bias = torch.zeros(weight.size(0))
748+
bias_name = name.replace("weight", "bias")
749+
yield from [(name, weight), (bias_name, bias)]
750+
yield name, weight

0 commit comments

Comments
 (0)