@@ -166,6 +166,9 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
166
166
# Initialize buffers (e.g. rotary embedding inverse frequency)
167
167
self .init_buffers (self .model )
168
168
169
+ # Initialize parameters
170
+ self .init_parameters (self .model )
171
+
169
172
# Move remaining meta tensors to device (should happen last)
170
173
self .meta_to_empty (self .model )
171
174
@@ -298,6 +301,25 @@ def init_buffers(self, module: nn.Module):
298
301
for child in module .children ():
299
302
self .init_buffers (child )
300
303
304
+ def init_parameters (self , module : nn .Module ):
305
+ """
306
+ If a `parameter` is on the `meta` device, then its parent
307
+ `module` is the original module created by:
308
+
309
+ ```python
310
+ with torch.device("meta"):
311
+ self.model: PreTrainedModel = AutoModel.from_config(...)
312
+ ```
313
+ """
314
+ for name , param in module .named_parameters (recurse = False ):
315
+ if param .device == torch .device ("meta" ):
316
+ new_param = nn .Parameter (
317
+ torch .empty_like (param .data ,
318
+ device = self .device_config .device ))
319
+ setattr (module , name , new_param )
320
+ for child in module .children ():
321
+ self .init_parameters (child )
322
+
301
323
def meta_to_empty (self , module : nn .Module ):
302
324
tensors = list (chain (module .buffers (), module .parameters ()))
303
325
if tensors and all (t .device == torch .device ("meta" ) for t in tensors ):
@@ -342,6 +364,7 @@ def forward(
342
364
def load_weights (self , weights : Iterable [tuple [str ,
343
365
torch .Tensor ]]) -> set [str ]:
344
366
params_dict = dict (self .named_parameters ())
367
+
345
368
loaded_params = set [str ]()
346
369
for name , loaded_weight in weights :
347
370
# Use "model" instead of base_model_prefix because
0 commit comments