Skip to content

Commit 31b91c7

Browse files
committed
Support Sharded loading with RunAI Model Streamer
Signed-off-by: Omer Dayan (SW-GPU) <[email protected]>
1 parent b932c04 commit 31b91c7

File tree

2 files changed

+50
-28
lines changed

2 files changed

+50
-28
lines changed

vllm/config.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1306,6 +1306,7 @@ class LoadFormat(str, enum.Enum):
13061306
BITSANDBYTES = "bitsandbytes"
13071307
MISTRAL = "mistral"
13081308
RUNAI_STREAMER = "runai_streamer"
1309+
RUNAI_STREAMER_SHARDED = "runai_streamer_sharded"
13091310
FASTSAFETENSORS = "fastsafetensors"
13101311

13111312

vllm/model_executor/model_loader/loader.py

Lines changed: 49 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -600,8 +600,10 @@ class ShardedStateLoader(BaseModelLoader):
600600

601601
DEFAULT_PATTERN = "model-rank-{rank}-part-{part}.safetensors"
602602

603-
def __init__(self, load_config: LoadConfig):
603+
def __init__(self, load_config: LoadConfig, runai_model_streamer: bool = False):
604604
super().__init__(load_config)
605+
606+
self.runai_model_streamer = runai_model_streamer
605607
extra_config = ({} if load_config.model_loader_extra_config is None
606608
else load_config.model_loader_extra_config.copy())
607609
self.pattern = extra_config.pop("pattern", self.DEFAULT_PATTERN)
@@ -648,7 +650,7 @@ def get_end_ptr(tensor: torch.Tensor) -> int:
648650

649651
def _prepare_weights(self, model_name_or_path: str,
650652
revision: Optional[str]):
651-
if os.path.isdir(model_name_or_path):
653+
if is_s3(model_name_or_path) or os.path.isdir(model_name_or_path):
652654
return model_name_or_path
653655
else:
654656
allow_patterns = ["*.safetensors"]
@@ -667,12 +669,14 @@ def load_model(self, vllm_config: VllmConfig) -> nn.Module:
667669
device_config = vllm_config.device_config
668670
model_config = vllm_config.model_config
669671
target_device = torch.device(device_config.device)
670-
from safetensors.torch import safe_open
671672

672673
from vllm.distributed import get_tensor_model_parallel_rank
673674

674-
local_model_path = self._prepare_weights(model_config.model,
675-
model_config.revision)
675+
676+
model_weights = model_config.model
677+
if hasattr(model_config, "model_weights"):
678+
model_weights = model_config.model_weights
679+
local_model_path = model_weights
676680

677681
with set_default_torch_dtype(model_config.dtype):
678682
with target_device:
@@ -684,40 +688,54 @@ def load_model(self, vllm_config: VllmConfig) -> nn.Module:
684688
local_model_path,
685689
self.pattern.format(rank=rank, part="*"),
686690
)
687-
filepaths = glob.glob(pattern)
691+
692+
filepaths = []
693+
if is_s3(local_model_path):
694+
file_pattern = f"*{self.pattern.format(rank=rank, part="*")}"
695+
filepaths = s3_glob(path=local_model_path, allow_pattern=[file_pattern])
696+
else:
697+
filepaths = glob.glob(pattern)
688698
if not filepaths:
689699
# TODO: support un-sharded checkpoints too
690700
raise ValueError(
691701
f"Could not find checkpoint files '{pattern}', only "
692702
f"pre-sharded checkpoints are currently supported!")
693703
state_dict = self._filter_subtensors(model.state_dict())
694-
for path in filepaths:
695-
with safe_open(path, framework="pt") as f:
696-
for key in f.keys(): # noqa: SIM118
697-
tensor = f.get_tensor(key)
698-
# If loading with LoRA enabled, additional padding may
699-
# be added to certain parameters. We only load into a
700-
# narrowed view of the parameter data.
701-
param_data = state_dict[key].data
702-
param_shape = state_dict[key].shape
703-
for dim, size in enumerate(tensor.shape):
704-
if size < param_shape[dim]:
705-
param_data = param_data.narrow(dim, 0, size)
706-
if tensor.shape != param_shape:
707-
logger.warning(
708-
"loading tensor of shape %s into "
709-
"parameter '%s' of shape %s",
710-
tensor.shape,
711-
key,
712-
param_shape,
713-
)
714-
param_data.copy_(tensor)
715-
state_dict.pop(key)
704+
for key, tensor in self.iterate_over_files(filepaths):
705+
# If loading with LoRA enabled, additional padding may
706+
# be added to certain parameters. We only load into a
707+
# narrowed view of the parameter data.
708+
param_data = state_dict[key].data
709+
param_shape = state_dict[key].shape
710+
for dim, size in enumerate(tensor.shape):
711+
if size < param_shape[dim]:
712+
param_data = param_data.narrow(dim, 0, size)
713+
if tensor.shape != param_shape:
714+
logger.warning(
715+
"loading tensor of shape %s into "
716+
"parameter '%s' of shape %s",
717+
tensor.shape,
718+
key,
719+
param_shape,
720+
)
721+
param_data.copy_(tensor)
722+
state_dict.pop(key)
716723
if state_dict:
717724
raise ValueError(
718725
f"Missing keys {tuple(state_dict)} in loaded state!")
719726
return model.eval()
720727

728+
def iterate_over_files(self, paths) -> Generator[Tuple[str, torch.Tensor], None, None]:
729+
if self.runai_model_streamer:
730+
yield from runai_safetensors_weights_iterator(paths, True)
731+
else:
732+
from safetensors.torch import safe_open
733+
for path in paths:
734+
with safe_open(path, framework="pt") as f:
735+
for key in f.keys(): # noqa: SIM118
736+
tensor = f.get_tensor(key)
737+
yield key, tensor
738+
721739
@staticmethod
722740
def save_model(
723741
model: torch.nn.Module,
@@ -1504,4 +1522,7 @@ def get_model_loader(load_config: LoadConfig) -> BaseModelLoader:
15041522
if load_config.load_format == LoadFormat.RUNAI_STREAMER:
15051523
return RunaiModelStreamerLoader(load_config)
15061524

1525+
if load_config.load_format == LoadFormat.RUNAI_STREAMER_SHARDED:
1526+
return ShardedStateLoader(load_config, True)
1527+
15071528
return DefaultModelLoader(load_config)

0 commit comments

Comments
 (0)