Skip to content

Commit d6070fd

Browse files
wuisawesomeMu Huai
authored and
Mu Huai
committed
Support loading transformers models with named parameters (vllm-project#16868)
Signed-off-by: Alex <[email protected]> Signed-off-by: Mu Huai <[email protected]>
1 parent c5d3367 commit d6070fd

File tree

1 file changed

+23
-0
lines changed

1 file changed

+23
-0
lines changed

vllm/model_executor/models/transformers.py

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -166,6 +166,9 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
166166
# Initialize buffers (e.g. rotary embedding inverse frequency)
167167
self.init_buffers(self.model)
168168

169+
# Initialize parameters
170+
self.init_parameters(self.model)
171+
169172
# Move remaining meta tensors to device (should happen last)
170173
self.meta_to_empty(self.model)
171174

@@ -298,6 +301,25 @@ def init_buffers(self, module: nn.Module):
298301
for child in module.children():
299302
self.init_buffers(child)
300303

304+
def init_parameters(self, module: nn.Module):
305+
"""
306+
If a `parameter` is on the `meta` device, then its parent
307+
`module` is the original module created by:
308+
309+
```python
310+
with torch.device("meta"):
311+
self.model: PreTrainedModel = AutoModel.from_config(...)
312+
```
313+
"""
314+
for name, param in module.named_parameters(recurse=False):
315+
if param.device == torch.device("meta"):
316+
new_param = nn.Parameter(
317+
torch.empty_like(param.data,
318+
device=self.device_config.device))
319+
setattr(module, name, new_param)
320+
for child in module.children():
321+
self.init_parameters(child)
322+
301323
def meta_to_empty(self, module: nn.Module):
302324
tensors = list(chain(module.buffers(), module.parameters()))
303325
if tensors and all(t.device == torch.device("meta") for t in tensors):
@@ -342,6 +364,7 @@ def forward(
342364
def load_weights(self, weights: Iterable[tuple[str,
343365
torch.Tensor]]) -> set[str]:
344366
params_dict = dict(self.named_parameters())
367+
345368
loaded_params = set[str]()
346369
for name, loaded_weight in weights:
347370
# Use "model" instead of base_model_prefix because

0 commit comments

Comments
 (0)