Skip to content

Commit bb452c7

Browse files
omer-dayanDarkLight1337
authored andcommitted
Support S3 Sharded loading with RunAI Model Streamer (vllm-project#16317)
Signed-off-by: Omer Dayan (SW-GPU) <[email protected]> Co-authored-by: Cyrus Leung <[email protected]> Signed-off-by: Zijing Liu <[email protected]>
1 parent e0037b8 commit bb452c7

File tree

2 files changed

+53
-28
lines changed

2 files changed

+53
-28
lines changed

vllm/config.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1489,6 +1489,7 @@ class LoadFormat(str, enum.Enum):
14891489
BITSANDBYTES = "bitsandbytes"
14901490
MISTRAL = "mistral"
14911491
RUNAI_STREAMER = "runai_streamer"
1492+
RUNAI_STREAMER_SHARDED = "runai_streamer_sharded"
14921493
FASTSAFETENSORS = "fastsafetensors"
14931494

14941495

vllm/model_executor/model_loader/loader.py

Lines changed: 52 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -611,8 +611,12 @@ class ShardedStateLoader(BaseModelLoader):
611611

612612
DEFAULT_PATTERN = "model-rank-{rank}-part-{part}.safetensors"
613613

614-
def __init__(self, load_config: LoadConfig):
614+
def __init__(self,
615+
load_config: LoadConfig,
616+
runai_model_streamer: bool = False):
615617
super().__init__(load_config)
618+
619+
self.runai_model_streamer = runai_model_streamer
616620
extra_config = ({} if load_config.model_loader_extra_config is None
617621
else load_config.model_loader_extra_config.copy())
618622
self.pattern = extra_config.pop("pattern", self.DEFAULT_PATTERN)
@@ -659,7 +663,7 @@ def get_end_ptr(tensor: torch.Tensor) -> int:
659663

660664
def _prepare_weights(self, model_name_or_path: str,
661665
revision: Optional[str]):
662-
if os.path.isdir(model_name_or_path):
666+
if is_s3(model_name_or_path) or os.path.isdir(model_name_or_path):
663667
return model_name_or_path
664668
else:
665669
allow_patterns = ["*.safetensors"]
@@ -678,12 +682,13 @@ def load_model(self, vllm_config: VllmConfig) -> nn.Module:
678682
device_config = vllm_config.device_config
679683
model_config = vllm_config.model_config
680684
target_device = torch.device(device_config.device)
681-
from safetensors.torch import safe_open
682685

683686
from vllm.distributed import get_tensor_model_parallel_rank
684687

685-
local_model_path = self._prepare_weights(model_config.model,
686-
model_config.revision)
688+
model_weights = model_config.model
689+
if hasattr(model_config, "model_weights"):
690+
model_weights = model_config.model_weights
691+
local_model_path = model_weights
687692

688693
with set_default_torch_dtype(model_config.dtype):
689694
with target_device:
@@ -695,40 +700,56 @@ def load_model(self, vllm_config: VllmConfig) -> nn.Module:
695700
local_model_path,
696701
self.pattern.format(rank=rank, part="*"),
697702
)
698-
filepaths = glob.glob(pattern)
703+
704+
filepaths = []
705+
if is_s3(local_model_path):
706+
file_pattern = f"*{self.pattern.format(rank=rank, part=" * ")}"
707+
filepaths = s3_glob(path=local_model_path,
708+
allow_pattern=[file_pattern])
709+
else:
710+
filepaths = glob.glob(pattern)
699711
if not filepaths:
700712
# TODO: support un-sharded checkpoints too
701713
raise ValueError(
702714
f"Could not find checkpoint files '{pattern}', only "
703715
f"pre-sharded checkpoints are currently supported!")
704716
state_dict = self._filter_subtensors(model.state_dict())
705-
for path in filepaths:
706-
with safe_open(path, framework="pt") as f:
707-
for key in f.keys(): # noqa: SIM118
708-
tensor = f.get_tensor(key)
709-
# If loading with LoRA enabled, additional padding may
710-
# be added to certain parameters. We only load into a
711-
# narrowed view of the parameter data.
712-
param_data = state_dict[key].data
713-
param_shape = state_dict[key].shape
714-
for dim, size in enumerate(tensor.shape):
715-
if size < param_shape[dim]:
716-
param_data = param_data.narrow(dim, 0, size)
717-
if tensor.shape != param_shape:
718-
logger.warning(
719-
"loading tensor of shape %s into "
720-
"parameter '%s' of shape %s",
721-
tensor.shape,
722-
key,
723-
param_shape,
724-
)
725-
param_data.copy_(tensor)
726-
state_dict.pop(key)
717+
for key, tensor in self.iterate_over_files(filepaths):
718+
# If loading with LoRA enabled, additional padding may
719+
# be added to certain parameters. We only load into a
720+
# narrowed view of the parameter data.
721+
param_data = state_dict[key].data
722+
param_shape = state_dict[key].shape
723+
for dim, size in enumerate(tensor.shape):
724+
if size < param_shape[dim]:
725+
param_data = param_data.narrow(dim, 0, size)
726+
if tensor.shape != param_shape:
727+
logger.warning(
728+
"loading tensor of shape %s into "
729+
"parameter '%s' of shape %s",
730+
tensor.shape,
731+
key,
732+
param_shape,
733+
)
734+
param_data.copy_(tensor)
735+
state_dict.pop(key)
727736
if state_dict:
728737
raise ValueError(
729738
f"Missing keys {tuple(state_dict)} in loaded state!")
730739
return model.eval()
731740

741+
def iterate_over_files(
742+
self, paths) -> Generator[Tuple[str, torch.Tensor], None, None]:
743+
if self.runai_model_streamer:
744+
yield from runai_safetensors_weights_iterator(paths, True)
745+
else:
746+
from safetensors.torch import safe_open
747+
for path in paths:
748+
with safe_open(path, framework="pt") as f:
749+
for key in f.keys(): # noqa: SIM118
750+
tensor = f.get_tensor(key)
751+
yield key, tensor
752+
732753
@staticmethod
733754
def save_model(
734755
model: torch.nn.Module,
@@ -1515,4 +1536,7 @@ def get_model_loader(load_config: LoadConfig) -> BaseModelLoader:
15151536
if load_config.load_format == LoadFormat.RUNAI_STREAMER:
15161537
return RunaiModelStreamerLoader(load_config)
15171538

1539+
if load_config.load_format == LoadFormat.RUNAI_STREAMER_SHARDED:
1540+
return ShardedStateLoader(load_config, runai_model_streamer=True)
1541+
15181542
return DefaultModelLoader(load_config)

0 commit comments

Comments
 (0)