@@ -600,8 +600,10 @@ class ShardedStateLoader(BaseModelLoader):
600
600
601
601
DEFAULT_PATTERN = "model-rank-{rank}-part-{part}.safetensors"
602
602
603
- def __init__ (self , load_config : LoadConfig ):
603
+ def __init__ (self , load_config : LoadConfig , runai_model_streamer : bool = False ):
604
604
super ().__init__ (load_config )
605
+
606
+ self .runai_model_streamer = runai_model_streamer
605
607
extra_config = ({} if load_config .model_loader_extra_config is None
606
608
else load_config .model_loader_extra_config .copy ())
607
609
self .pattern = extra_config .pop ("pattern" , self .DEFAULT_PATTERN )
@@ -648,7 +650,7 @@ def get_end_ptr(tensor: torch.Tensor) -> int:
648
650
649
651
def _prepare_weights (self , model_name_or_path : str ,
650
652
revision : Optional [str ]):
651
- if os .path .isdir (model_name_or_path ):
653
+ if is_s3 ( model_name_or_path ) or os .path .isdir (model_name_or_path ):
652
654
return model_name_or_path
653
655
else :
654
656
allow_patterns = ["*.safetensors" ]
@@ -667,12 +669,14 @@ def load_model(self, vllm_config: VllmConfig) -> nn.Module:
667
669
device_config = vllm_config .device_config
668
670
model_config = vllm_config .model_config
669
671
target_device = torch .device (device_config .device )
670
- from safetensors .torch import safe_open
671
672
672
673
from vllm .distributed import get_tensor_model_parallel_rank
673
674
674
- local_model_path = self ._prepare_weights (model_config .model ,
675
- model_config .revision )
675
+
676
+ model_weights = model_config .model
677
+ if hasattr (model_config , "model_weights" ):
678
+ model_weights = model_config .model_weights
679
+ local_model_path = model_weights
676
680
677
681
with set_default_torch_dtype (model_config .dtype ):
678
682
with target_device :
@@ -684,40 +688,54 @@ def load_model(self, vllm_config: VllmConfig) -> nn.Module:
684
688
local_model_path ,
685
689
self .pattern .format (rank = rank , part = "*" ),
686
690
)
687
- filepaths = glob .glob (pattern )
691
+
692
+ filepaths = []
693
+ if is_s3 (local_model_path ):
694
+ file_pattern = f"*{ self .pattern .format (rank = rank , part = "*" )} "
695
+ filepaths = s3_glob (path = local_model_path , allow_pattern = [file_pattern ])
696
+ else :
697
+ filepaths = glob .glob (pattern )
688
698
if not filepaths :
689
699
# TODO: support un-sharded checkpoints too
690
700
raise ValueError (
691
701
f"Could not find checkpoint files '{ pattern } ', only "
692
702
f"pre-sharded checkpoints are currently supported!" )
693
703
state_dict = self ._filter_subtensors (model .state_dict ())
694
- for path in filepaths :
695
- with safe_open (path , framework = "pt" ) as f :
696
- for key in f .keys (): # noqa: SIM118
697
- tensor = f .get_tensor (key )
698
- # If loading with LoRA enabled, additional padding may
699
- # be added to certain parameters. We only load into a
700
- # narrowed view of the parameter data.
701
- param_data = state_dict [key ].data
702
- param_shape = state_dict [key ].shape
703
- for dim , size in enumerate (tensor .shape ):
704
- if size < param_shape [dim ]:
705
- param_data = param_data .narrow (dim , 0 , size )
706
- if tensor .shape != param_shape :
707
- logger .warning (
708
- "loading tensor of shape %s into "
709
- "parameter '%s' of shape %s" ,
710
- tensor .shape ,
711
- key ,
712
- param_shape ,
713
- )
714
- param_data .copy_ (tensor )
715
- state_dict .pop (key )
704
+ for key , tensor in self .iterate_over_files (filepaths ):
705
+ # If loading with LoRA enabled, additional padding may
706
+ # be added to certain parameters. We only load into a
707
+ # narrowed view of the parameter data.
708
+ param_data = state_dict [key ].data
709
+ param_shape = state_dict [key ].shape
710
+ for dim , size in enumerate (tensor .shape ):
711
+ if size < param_shape [dim ]:
712
+ param_data = param_data .narrow (dim , 0 , size )
713
+ if tensor .shape != param_shape :
714
+ logger .warning (
715
+ "loading tensor of shape %s into "
716
+ "parameter '%s' of shape %s" ,
717
+ tensor .shape ,
718
+ key ,
719
+ param_shape ,
720
+ )
721
+ param_data .copy_ (tensor )
722
+ state_dict .pop (key )
716
723
if state_dict :
717
724
raise ValueError (
718
725
f"Missing keys { tuple (state_dict )} in loaded state!" )
719
726
return model .eval ()
720
727
728
+ def iterate_over_files (self , paths ) -> Generator [Tuple [str , torch .Tensor ], None , None ]:
729
+ if self .runai_model_streamer :
730
+ yield from runai_safetensors_weights_iterator (paths , True )
731
+ else :
732
+ from safetensors .torch import safe_open
733
+ for path in paths :
734
+ with safe_open (path , framework = "pt" ) as f :
735
+ for key in f .keys (): # noqa: SIM118
736
+ tensor = f .get_tensor (key )
737
+ yield key , tensor
738
+
721
739
@staticmethod
722
740
def save_model (
723
741
model : torch .nn .Module ,
@@ -1504,4 +1522,7 @@ def get_model_loader(load_config: LoadConfig) -> BaseModelLoader:
1504
1522
if load_config .load_format == LoadFormat .RUNAI_STREAMER :
1505
1523
return RunaiModelStreamerLoader (load_config )
1506
1524
1525
+ if load_config .load_format == LoadFormat .RUNAI_STREAMER_SHARDED :
1526
+ return ShardedStateLoader (load_config , True )
1527
+
1507
1528
return DefaultModelLoader (load_config )
0 commit comments