@@ -611,8 +611,12 @@ class ShardedStateLoader(BaseModelLoader):
611
611
612
612
DEFAULT_PATTERN = "model-rank-{rank}-part-{part}.safetensors"
613
613
614
- def __init__ (self , load_config : LoadConfig ):
614
+ def __init__ (self ,
615
+ load_config : LoadConfig ,
616
+ runai_model_streamer : bool = False ):
615
617
super ().__init__ (load_config )
618
+
619
+ self .runai_model_streamer = runai_model_streamer
616
620
extra_config = ({} if load_config .model_loader_extra_config is None
617
621
else load_config .model_loader_extra_config .copy ())
618
622
self .pattern = extra_config .pop ("pattern" , self .DEFAULT_PATTERN )
@@ -659,7 +663,7 @@ def get_end_ptr(tensor: torch.Tensor) -> int:
659
663
660
664
def _prepare_weights (self , model_name_or_path : str ,
661
665
revision : Optional [str ]):
662
- if os .path .isdir (model_name_or_path ):
666
+ if is_s3 ( model_name_or_path ) or os .path .isdir (model_name_or_path ):
663
667
return model_name_or_path
664
668
else :
665
669
allow_patterns = ["*.safetensors" ]
@@ -678,12 +682,13 @@ def load_model(self, vllm_config: VllmConfig) -> nn.Module:
678
682
device_config = vllm_config .device_config
679
683
model_config = vllm_config .model_config
680
684
target_device = torch .device (device_config .device )
681
- from safetensors .torch import safe_open
682
685
683
686
from vllm .distributed import get_tensor_model_parallel_rank
684
687
685
- local_model_path = self ._prepare_weights (model_config .model ,
686
- model_config .revision )
688
+ model_weights = model_config .model
689
+ if hasattr (model_config , "model_weights" ):
690
+ model_weights = model_config .model_weights
691
+ local_model_path = model_weights
687
692
688
693
with set_default_torch_dtype (model_config .dtype ):
689
694
with target_device :
@@ -695,40 +700,56 @@ def load_model(self, vllm_config: VllmConfig) -> nn.Module:
695
700
local_model_path ,
696
701
self .pattern .format (rank = rank , part = "*" ),
697
702
)
698
- filepaths = glob .glob (pattern )
703
+
704
+ filepaths = []
705
+ if is_s3 (local_model_path ):
706
+ file_pattern = f"*{ self .pattern .format (rank = rank , part = " * " )} "
707
+ filepaths = s3_glob (path = local_model_path ,
708
+ allow_pattern = [file_pattern ])
709
+ else :
710
+ filepaths = glob .glob (pattern )
699
711
if not filepaths :
700
712
# TODO: support un-sharded checkpoints too
701
713
raise ValueError (
702
714
f"Could not find checkpoint files '{ pattern } ', only "
703
715
f"pre-sharded checkpoints are currently supported!" )
704
716
state_dict = self ._filter_subtensors (model .state_dict ())
705
- for path in filepaths :
706
- with safe_open (path , framework = "pt" ) as f :
707
- for key in f .keys (): # noqa: SIM118
708
- tensor = f .get_tensor (key )
709
- # If loading with LoRA enabled, additional padding may
710
- # be added to certain parameters. We only load into a
711
- # narrowed view of the parameter data.
712
- param_data = state_dict [key ].data
713
- param_shape = state_dict [key ].shape
714
- for dim , size in enumerate (tensor .shape ):
715
- if size < param_shape [dim ]:
716
- param_data = param_data .narrow (dim , 0 , size )
717
- if tensor .shape != param_shape :
718
- logger .warning (
719
- "loading tensor of shape %s into "
720
- "parameter '%s' of shape %s" ,
721
- tensor .shape ,
722
- key ,
723
- param_shape ,
724
- )
725
- param_data .copy_ (tensor )
726
- state_dict .pop (key )
717
+ for key , tensor in self .iterate_over_files (filepaths ):
718
+ # If loading with LoRA enabled, additional padding may
719
+ # be added to certain parameters. We only load into a
720
+ # narrowed view of the parameter data.
721
+ param_data = state_dict [key ].data
722
+ param_shape = state_dict [key ].shape
723
+ for dim , size in enumerate (tensor .shape ):
724
+ if size < param_shape [dim ]:
725
+ param_data = param_data .narrow (dim , 0 , size )
726
+ if tensor .shape != param_shape :
727
+ logger .warning (
728
+ "loading tensor of shape %s into "
729
+ "parameter '%s' of shape %s" ,
730
+ tensor .shape ,
731
+ key ,
732
+ param_shape ,
733
+ )
734
+ param_data .copy_ (tensor )
735
+ state_dict .pop (key )
727
736
if state_dict :
728
737
raise ValueError (
729
738
f"Missing keys { tuple (state_dict )} in loaded state!" )
730
739
return model .eval ()
731
740
741
+ def iterate_over_files (
742
+ self , paths ) -> Generator [Tuple [str , torch .Tensor ], None , None ]:
743
+ if self .runai_model_streamer :
744
+ yield from runai_safetensors_weights_iterator (paths , True )
745
+ else :
746
+ from safetensors .torch import safe_open
747
+ for path in paths :
748
+ with safe_open (path , framework = "pt" ) as f :
749
+ for key in f .keys (): # noqa: SIM118
750
+ tensor = f .get_tensor (key )
751
+ yield key , tensor
752
+
732
753
@staticmethod
733
754
def save_model (
734
755
model : torch .nn .Module ,
@@ -1515,4 +1536,7 @@ def get_model_loader(load_config: LoadConfig) -> BaseModelLoader:
1515
1536
if load_config .load_format == LoadFormat .RUNAI_STREAMER :
1516
1537
return RunaiModelStreamerLoader (load_config )
1517
1538
1539
+ if load_config .load_format == LoadFormat .RUNAI_STREAMER_SHARDED :
1540
+ return ShardedStateLoader (load_config , runai_model_streamer = True )
1541
+
1518
1542
return DefaultModelLoader (load_config )
0 commit comments