Skip to content

Support S3 Sharded loading with RunAI Model Streamer #16317

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions vllm/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -1306,6 +1306,7 @@ class LoadFormat(str, enum.Enum):
BITSANDBYTES = "bitsandbytes"
MISTRAL = "mistral"
RUNAI_STREAMER = "runai_streamer"
RUNAI_STREAMER_SHARDED = "runai_streamer_sharded"
FASTSAFETENSORS = "fastsafetensors"


Expand Down
80 changes: 52 additions & 28 deletions vllm/model_executor/model_loader/loader.py
Original file line number Diff line number Diff line change
Expand Up @@ -600,8 +600,12 @@ class ShardedStateLoader(BaseModelLoader):

DEFAULT_PATTERN = "model-rank-{rank}-part-{part}.safetensors"

def __init__(self, load_config: LoadConfig):
def __init__(self,
load_config: LoadConfig,
runai_model_streamer: bool = False):
super().__init__(load_config)

self.runai_model_streamer = runai_model_streamer
extra_config = ({} if load_config.model_loader_extra_config is None
else load_config.model_loader_extra_config.copy())
self.pattern = extra_config.pop("pattern", self.DEFAULT_PATTERN)
Expand Down Expand Up @@ -648,7 +652,7 @@ def get_end_ptr(tensor: torch.Tensor) -> int:

def _prepare_weights(self, model_name_or_path: str,
revision: Optional[str]):
if os.path.isdir(model_name_or_path):
if is_s3(model_name_or_path) or os.path.isdir(model_name_or_path):
return model_name_or_path
else:
allow_patterns = ["*.safetensors"]
Expand All @@ -667,12 +671,13 @@ def load_model(self, vllm_config: VllmConfig) -> nn.Module:
device_config = vllm_config.device_config
model_config = vllm_config.model_config
target_device = torch.device(device_config.device)
from safetensors.torch import safe_open

from vllm.distributed import get_tensor_model_parallel_rank

local_model_path = self._prepare_weights(model_config.model,
model_config.revision)
model_weights = model_config.model
if hasattr(model_config, "model_weights"):
model_weights = model_config.model_weights
local_model_path = model_weights

with set_default_torch_dtype(model_config.dtype):
with target_device:
Expand All @@ -684,40 +689,56 @@ def load_model(self, vllm_config: VllmConfig) -> nn.Module:
local_model_path,
self.pattern.format(rank=rank, part="*"),
)
filepaths = glob.glob(pattern)

filepaths = []
if is_s3(local_model_path):
file_pattern = f"*{self.pattern.format(rank=rank, part=" * ")}"
filepaths = s3_glob(path=local_model_path,
allow_pattern=[file_pattern])
else:
filepaths = glob.glob(pattern)
if not filepaths:
# TODO: support un-sharded checkpoints too
raise ValueError(
f"Could not find checkpoint files '{pattern}', only "
f"pre-sharded checkpoints are currently supported!")
state_dict = self._filter_subtensors(model.state_dict())
for path in filepaths:
with safe_open(path, framework="pt") as f:
for key in f.keys(): # noqa: SIM118
tensor = f.get_tensor(key)
# If loading with LoRA enabled, additional padding may
# be added to certain parameters. We only load into a
# narrowed view of the parameter data.
param_data = state_dict[key].data
param_shape = state_dict[key].shape
for dim, size in enumerate(tensor.shape):
if size < param_shape[dim]:
param_data = param_data.narrow(dim, 0, size)
if tensor.shape != param_shape:
logger.warning(
"loading tensor of shape %s into "
"parameter '%s' of shape %s",
tensor.shape,
key,
param_shape,
)
param_data.copy_(tensor)
state_dict.pop(key)
for key, tensor in self.iterate_over_files(filepaths):
# If loading with LoRA enabled, additional padding may
# be added to certain parameters. We only load into a
# narrowed view of the parameter data.
param_data = state_dict[key].data
param_shape = state_dict[key].shape
for dim, size in enumerate(tensor.shape):
if size < param_shape[dim]:
param_data = param_data.narrow(dim, 0, size)
if tensor.shape != param_shape:
logger.warning(
"loading tensor of shape %s into "
"parameter '%s' of shape %s",
tensor.shape,
key,
param_shape,
)
param_data.copy_(tensor)
state_dict.pop(key)
if state_dict:
raise ValueError(
f"Missing keys {tuple(state_dict)} in loaded state!")
return model.eval()

def iterate_over_files(
self, paths) -> Generator[Tuple[str, torch.Tensor], None, None]:
if self.runai_model_streamer:
yield from runai_safetensors_weights_iterator(paths, True)
else:
from safetensors.torch import safe_open
for path in paths:
with safe_open(path, framework="pt") as f:
for key in f.keys(): # noqa: SIM118
tensor = f.get_tensor(key)
yield key, tensor

@staticmethod
def save_model(
model: torch.nn.Module,
Expand Down Expand Up @@ -1504,4 +1525,7 @@ def get_model_loader(load_config: LoadConfig) -> BaseModelLoader:
if load_config.load_format == LoadFormat.RUNAI_STREAMER:
return RunaiModelStreamerLoader(load_config)

if load_config.load_format == LoadFormat.RUNAI_STREAMER_SHARDED:
return ShardedStateLoader(load_config, True)

return DefaultModelLoader(load_config)