Skip to content

Commit e0d0c5e

Browse files
2015arorasYuqi Zhang
authored and
Yuqi Zhang
committed
[Bugfix][Model] Make Olmo2Model weight loading return loaded weights (vllm-project#18504)
Signed-off-by: Shane A <[email protected]> Signed-off-by: Yuqi Zhang <[email protected]>
1 parent 3770c65 commit e0d0c5e

File tree

1 file changed

+5
-1
lines changed

1 file changed

+5
-1
lines changed

vllm/model_executor/models/olmo2.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -314,7 +314,8 @@ def forward(
314314
hidden_states = self.norm(hidden_states)
315315
return hidden_states
316316

317-
def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]):
317+
def load_weights(self, weights: Iterable[tuple[str,
318+
torch.Tensor]]) -> set[str]:
318319
stacked_params_mapping = [
319320
# (param_name, shard_name, shard_id)
320321
("qkv_proj", "q_proj", "q"),
@@ -325,6 +326,7 @@ def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]):
325326
]
326327

327328
params_dict = dict(self.named_parameters(remove_duplicate=False))
329+
loaded_params: set[str] = set()
328330
for name, loaded_weight in weights:
329331
if is_pp_missing_parameter(name, self):
330332
continue
@@ -347,6 +349,8 @@ def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]):
347349
weight_loader = getattr(param, "weight_loader",
348350
default_weight_loader)
349351
weight_loader(param, loaded_weight)
352+
loaded_params.add(name)
353+
return loaded_params
350354

351355

352356
class Olmo2ForCausalLM(nn.Module, SupportsPP):

0 commit comments

Comments
 (0)