Skip to content

support loading model with user input params (turbomind) #3204

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 5 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 8 additions & 0 deletions lmdeploy/api.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
# Copyright (c) OpenMMLab. All rights reserved.
import os
from multiprocessing.queues import Queue as MpQueue
from queue import Queue
from typing import List, Literal, Optional, Union

from .archs import autoget_backend_config, get_task
Expand Down Expand Up @@ -92,6 +94,7 @@ def serve(model_path: str,
model_name: Optional[str] = None,
backend: Literal['turbomind', 'pytorch'] = 'turbomind',
backend_config: Optional[Union[TurbomindEngineConfig, PytorchEngineConfig]] = None,
model_params_que: Optional[Union[Queue, MpQueue]] = None,
chat_template_config: Optional[ChatTemplateConfig] = None,
server_name: str = '0.0.0.0',
server_port: int = 23333,
Expand Down Expand Up @@ -122,6 +125,10 @@ def serve(model_path: str,
`turbomind` backend.
backend_config (TurbomindEngineConfig | PytorchEngineConfig): backend
config instance. Default to none.
model_params_que (queue.Queue | multiprocessing.queues.Queue): model parameters.
The first item should be list of all names of a model (state_dict().keys()),
the following item should be part of state_dict(), and the last item should
be None, indicating the end of the queue.
chat_template_config (ChatTemplateConfig): chat template configuration.
Default to None.
server_name (str): host ip for serving
Expand Down Expand Up @@ -156,6 +163,7 @@ def serve(model_path: str,
kwargs=dict(model_name=model_name,
backend=backend,
backend_config=backend_config,
model_params_que=model_params_que,
chat_template_config=chat_template_config,
server_name=server_name,
server_port=server_port,
Expand Down
8 changes: 8 additions & 0 deletions lmdeploy/serve/async_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
from copy import deepcopy
from functools import partial
from itertools import count
from multiprocessing.queues import Queue as MpQueue
from queue import Queue
from threading import Thread
from typing import Any, AsyncIterator, Dict, Iterator, List, Literal, Optional, Tuple, Union
Expand Down Expand Up @@ -244,6 +245,10 @@ class AsyncEngine(LogitsMixin):
`turbomind` backend.
backend_config (TurbomindEngineConfig | PytorchEngineConfig): beckend
config instance. Default to none.
model_params_que (queue.Queue | multiprocessing.queues.Queue): model parameters.
The first item should be list of all names of a model (state_dict().keys()),
the following item should be part of state_dict(), and the last item should
be None, indicating the end of the queue.
chat_template_config (ChatTemplateConfig): chat template configuration.
Default to None.
max_log_len (int): Max number of prompt characters or prompt tokens
Expand All @@ -255,6 +260,7 @@ def __init__(self,
model_name: Optional[str] = None,
backend: Literal['turbomind', 'pytorch'] = 'turbomind',
backend_config: Optional[Union[TurbomindEngineConfig, PytorchEngineConfig]] = None,
model_params_que: Optional[Union[Queue, MpQueue]] = None,
chat_template_config: Optional[ChatTemplateConfig] = None,
max_log_len: int = None,
**kwargs) -> None:
Expand All @@ -275,6 +281,8 @@ def __init__(self,
self.arch, _ = get_model_arch(model_path)

# build backend engine
assert model_params_que is None or backend == 'turbomind', 'only support turbomind backend'
kwargs.update(model_params_que=model_params_que)
if backend == 'turbomind':
self._build_turbomind(model_path=model_path, backend_config=backend_config, **kwargs)
elif backend == 'pytorch':
Expand Down
4 changes: 4 additions & 0 deletions lmdeploy/serve/openai/api_server.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,8 @@
import time
from functools import partial
from http import HTTPStatus
from multiprocessing.queues import Queue as MpQueue
from queue import Queue
from typing import AsyncGenerator, Dict, List, Literal, Optional, Union

import uvicorn
Expand Down Expand Up @@ -925,6 +927,7 @@ def serve(model_path: str,
model_name: Optional[str] = None,
backend: Literal['turbomind', 'pytorch'] = 'turbomind',
backend_config: Optional[Union[PytorchEngineConfig, TurbomindEngineConfig]] = None,
model_params_que: Optional[Union[Queue, MpQueue]] = None,
chat_template_config: Optional[ChatTemplateConfig] = None,
server_name: str = '0.0.0.0',
server_port: int = 23333,
Expand Down Expand Up @@ -1032,6 +1035,7 @@ def serve(model_path: str,
model_name=model_name,
backend=backend,
backend_config=backend_config,
model_params_que=model_params_que,
chat_template_config=chat_template_config,
max_log_len=max_log_len,
**kwargs)
Expand Down
39 changes: 38 additions & 1 deletion lmdeploy/turbomind/deploy/loader.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,9 @@
from collections import defaultdict
from functools import partial
from glob import glob
from typing import Iterator, Tuple
from multiprocessing.queues import Queue as MpQueue
from queue import Queue
from typing import Iterator, Tuple, Union

import torch
from safetensors import safe_open
Expand Down Expand Up @@ -144,9 +146,44 @@ def items(self):
yield (idx, params.pop(idx))


class StateDictLoader:

def __init__(self, queue: Union[Queue, MpQueue], pattern: str):
self.que = queue
self.pattern = pattern
self.item_count = defaultdict(int)

def items(self):
params = defaultdict(dict)
# the first item should be all keys of weight
keys = self.que.get()
for k in keys:
match = re.findall(self.pattern, k)
if match:
self.item_count[int(match[0])] += 1
# load weights from queue
for state_dict in iter(self.que.get, None):
misc = []
for k, v in state_dict.items():
match = re.findall(self.pattern, k)
if not match:
misc.append((k, v))
else:
idx = int(match[0])
param = params[idx]
param[k] = v
if len(param) == self.item_count[idx]:
yield (idx, params.pop(idx))
if misc:
yield (-1, {k: v for k, v in misc})


def create_loader(model_path: str, pattern: str) -> BaseLoader:
args = (model_path, pattern)

if isinstance(model_path, Queue) or isinstance(model_path, MpQueue):
return StateDictLoader(*args)

if osp.exists(osp.join(model_path, SAFE_WEIGHT_INDEX_NAME)):
return SafetensorsLoader(*args, index_name=SAFE_WEIGHT_INDEX_NAME)

Expand Down
21 changes: 17 additions & 4 deletions lmdeploy/turbomind/turbomind.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,8 +9,9 @@
from concurrent.futures import ThreadPoolExecutor
from dataclasses import asdict
from functools import partial
from multiprocessing.queues import Queue as MpQueue
from queue import Queue
from typing import Dict, List
from typing import Dict, List, Optional, Union

import numpy as np
import torch
Expand Down Expand Up @@ -76,6 +77,10 @@ class TurboMind:
engine
model_source (int): the source of the model, which is either
turbomind model, or a transformers model
model_params_que (queue.Queue | multiprocessing.queues.Queue): model parameters.
The first item should be list of all names of a model (state_dict().keys()),
the following item should be part of state_dict(), and the last item should
be None, indicating the end of the queue.
"""

def __init__(self,
Expand All @@ -85,11 +90,12 @@ def __init__(self,
chat_template_name: str = None,
engine_config: TurbomindEngineConfig = None,
model_source: ModelSource = ModelSource.WORKSPACE,
model_params_que: Optional[Union[Queue, MpQueue]] = None,
**kwargs):
self.model_name = model_name
self.chat_template_name = chat_template_name

_engine_config = copy.deepcopy(engine_config)
_engine_config = copy.copy(engine_config)
if _engine_config is None:
_engine_config = TurbomindEngineConfig()
if _engine_config.max_batch_size is None:
Expand All @@ -107,7 +113,8 @@ def __init__(self,
model_path = get_model(model_path, _engine_config.download_dir, _engine_config.revision)
self.model_comm = self._from_hf(model_source=model_source,
model_path=model_path,
engine_config=_engine_config)
engine_config=_engine_config,
model_params_que=model_params_que)

with ThreadPoolExecutor(max_workers=self.gpu_count) as e:
ranks = [self.node_id * self.gpu_count + device_id for device_id in range(self.gpu_count)]
Expand Down Expand Up @@ -187,7 +194,11 @@ def _postprocess_config(self, tm_config, engine_config):
logger.info(f'turbomind model config:\n\n'
f'{json.dumps(self.config_dict, indent=2)}')

def _from_hf(self, model_source: ModelSource, model_path: str, engine_config: TurbomindEngineConfig):
def _from_hf(self,
model_source: ModelSource,
model_path: str,
engine_config: TurbomindEngineConfig,
model_params_que: Optional[Union[Queue, MpQueue]] = None):
"""Load model which is in hf format."""
assert model_source == ModelSource.HF_MODEL, \
f'{model_source} is not supported'
Expand All @@ -212,6 +223,8 @@ def _from_hf(self, model_source: ModelSource, model_path: str, engine_config: Tu
tm_params = tm_model.tm_params
self._get_model_params(model_comm, tm_params)
logger.warning(f'get {len(tm_params)} model params')
if model_params_que is not None:
tm_model.input_model.model_path = model_params_que
tm_model.export()
# there should be no left turbomind params.
if len(tm_params) > 0:
Expand Down
Loading