@@ -153,6 +153,30 @@ def _initialize_model(
153
153
return model_class (** kwargs )
154
154
155
155
156
+ def _process_weights_after_loading (model : nn .Module , model_config : ModelConfig ,
157
+ target_device : torch .device ) -> None :
158
+ for _ , module in model .named_modules ():
159
+ quant_method = getattr (module , "quant_method" , None )
160
+ if isinstance (quant_method , QuantizeMethodBase ):
161
+ # When quant methods need to process weights after loading
162
+ # (for repacking, quantizing, etc), they expect parameters
163
+ # to be on the global target device. This scope is for the
164
+ # case where cpu offloading is used, where we will move the
165
+ # parameters onto device for processing and back off after.
166
+ with device_loading_context (module , target_device ):
167
+ quant_method .process_weights_after_loading (module )
168
+
169
+ # Currently only used by MLA.
170
+ # NOTE: This intentionally happens after other modules so we can easily
171
+ # decompress the weights for MLA.
172
+ for _ , module in model .named_modules ():
173
+ if isinstance (module , Attention ) and \
174
+ hasattr (module , "process_weights_after_loading" ):
175
+ # TODO(lucas): see if there is a way to unify the signatures
176
+ # of process_weights_after_loading
177
+ module .process_weights_after_loading (model_config .dtype )
178
+
179
+
156
180
class BaseModelLoader (ABC ):
157
181
"""Base class for model loaders."""
158
182
@@ -376,7 +400,6 @@ def download_model(self, model_config: ModelConfig) -> None:
376
400
def load_model (self , vllm_config : VllmConfig ) -> nn .Module :
377
401
device_config = vllm_config .device_config
378
402
model_config = vllm_config .model_config
379
-
380
403
target_device = torch .device (device_config .device )
381
404
with set_default_torch_dtype (model_config .dtype ):
382
405
with target_device :
@@ -394,23 +417,8 @@ def load_model(self, vllm_config: VllmConfig) -> nn.Module:
394
417
"Following weights were not initialized from "
395
418
f"checkpoint: { weights_not_loaded } " )
396
419
397
- for _ , module in model .named_modules ():
398
- quant_method = getattr (module , "quant_method" , None )
399
- if isinstance (quant_method , QuantizeMethodBase ):
400
- # When quant methods need to process weights after loading
401
- # (for repacking, quantizing, etc), they expect parameters
402
- # to be on the global target device. This scope is for the
403
- # case where cpu offloading is used, where we will move the
404
- # parameters onto device for processing and back off after.
405
- with device_loading_context (module , target_device ):
406
- quant_method .process_weights_after_loading (module )
407
- if isinstance (module , Attention ) and \
408
- hasattr (module , "process_weights_after_loading" ):
409
- # When attention modules need to process weights after
410
- # currently only used by MLA
411
- # TODO(lucas): see if there is a way to unify the signatures
412
- # of process_weights_after_loading
413
- module .process_weights_after_loading (model_config .dtype )
420
+ _process_weights_after_loading (model , model_config , target_device )
421
+
414
422
return model .eval ()
415
423
416
424
@@ -429,29 +437,15 @@ def download_model(self, model_config: ModelConfig) -> None:
429
437
def load_model (self , vllm_config : VllmConfig ) -> nn .Module :
430
438
device_config = vllm_config .device_config
431
439
model_config = vllm_config .model_config
440
+ target_device = torch .device (device_config .device )
432
441
with set_default_torch_dtype (model_config .dtype ):
433
- with torch . device ( device_config . device ) :
442
+ with target_device :
434
443
model = _initialize_model (vllm_config = vllm_config )
435
444
# NOTE(woosuk): For accurate performance evaluation, we assign
436
445
# random values to the weights.
437
446
initialize_dummy_weights (model )
438
447
439
- for _ , module in model .named_modules ():
440
- quant_method = getattr (module , "quant_method" , None )
441
- if quant_method is not None :
442
- # When quant methods need to process weights after loading
443
- # (for repacking, quantizing, etc), they expect parameters
444
- # to be on the global target device. This scope is for the
445
- # case where cpu offloading is used, where we will move the
446
- # parameters onto device for processing and back off after.
447
- with device_loading_context (
448
- module , torch .device (device_config .device )):
449
- quant_method .process_weights_after_loading (module )
450
- if isinstance (module , Attention ) and \
451
- hasattr (module , "process_weights_after_loading" ):
452
- # When attention modules need to process weights after
453
- # currently only used by MLA
454
- module .process_weights_after_loading (model_config .dtype )
448
+ _process_weights_after_loading (model , model_config , target_device )
455
449
return model .eval ()
456
450
457
451
@@ -632,6 +626,7 @@ def download_model(self, model_config: ModelConfig) -> None:
632
626
def load_model (self , vllm_config : VllmConfig ) -> nn .Module :
633
627
device_config = vllm_config .device_config
634
628
model_config = vllm_config .model_config
629
+ target_device = torch .device (device_config .device )
635
630
from safetensors .torch import safe_open
636
631
637
632
from vllm .distributed import get_tensor_model_parallel_rank
@@ -640,18 +635,10 @@ def load_model(self, vllm_config: VllmConfig) -> nn.Module:
640
635
model_config .revision )
641
636
642
637
with set_default_torch_dtype (model_config .dtype ):
643
- with torch . device ( device_config . device ) :
638
+ with target_device :
644
639
model = _initialize_model (vllm_config = vllm_config )
645
- for _ , module in model .named_modules ():
646
- quant_method = getattr (module , "quant_method" , None )
647
- if quant_method is not None :
648
- quant_method .process_weights_after_loading (module )
649
- if isinstance (module , Attention ) and \
650
- hasattr (module , "process_weights_after_loading" ):
651
- # When attention modules need to process weights after
652
- # currently only used by MLA
653
- module .process_weights_after_loading (
654
- model_config .dtype )
640
+ _process_weights_after_loading (model , model_config ,
641
+ target_device )
655
642
rank = get_tensor_model_parallel_rank ()
656
643
pattern = os .path .join (
657
644
local_model_path ,
@@ -1401,16 +1388,7 @@ def load_model(self, vllm_config: VllmConfig) -> nn.Module:
1401
1388
self ._get_weights_iterator (model_weights ,
1402
1389
model_config .revision ))
1403
1390
1404
- for _ , module in model .named_modules ():
1405
- quant_method = getattr (module , "quant_method" , None )
1406
- if quant_method is not None :
1407
- with device_loading_context (module , target_device ):
1408
- quant_method .process_weights_after_loading (module )
1409
- if isinstance (module , Attention ) and \
1410
- hasattr (module , "process_weights_after_loading" ):
1411
- # When attention modules need to process weights after
1412
- # currently only used by MLA
1413
- module .process_weights_after_loading (model_config .dtype )
1391
+ _process_weights_after_loading (model , model_config , target_device )
1414
1392
return model .eval ()
1415
1393
1416
1394
0 commit comments