Skip to content

Commit faef77c

Browse files
authored
[Misc] KV cache transfer connector registry (#11481)
Signed-off-by: KuntaiDu <[email protected]>
1 parent dba4d9d commit faef77c

File tree

2 files changed

+38
-18
lines changed

2 files changed

+38
-18
lines changed

vllm/config.py

Lines changed: 0 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -2559,14 +2559,6 @@ def from_cli(cls, cli_value: str) -> "KVTransferConfig":
25592559
return KVTransferConfig.model_validate_json(cli_value)
25602560

25612561
def model_post_init(self, __context: Any) -> None:
2562-
supported_kv_connector = ["PyNcclConnector", "MooncakeConnector"]
2563-
if all([
2564-
self.kv_connector is not None, self.kv_connector
2565-
not in supported_kv_connector
2566-
]):
2567-
raise ValueError(f"Unsupported kv_connector: {self.kv_connector}. "
2568-
f"Supported connectors are "
2569-
f"{supported_kv_connector}.")
25702562

25712563
if self.kv_role is not None and self.kv_role not in [
25722564
"kv_producer", "kv_consumer", "kv_both"
Lines changed: 38 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
1-
from typing import TYPE_CHECKING
1+
import importlib
2+
from typing import TYPE_CHECKING, Callable, Dict, Type
23

34
from .base import KVConnectorBase
45

@@ -7,14 +8,41 @@
78

89

910
class KVConnectorFactory:
11+
_registry: Dict[str, Callable[[], Type[KVConnectorBase]]] = {}
1012

11-
@staticmethod
12-
def create_connector(rank: int, local_rank: int,
13+
@classmethod
14+
def register_connector(cls, name: str, module_path: str,
15+
class_name: str) -> None:
16+
"""Register a connector with a lazy-loading module and class name."""
17+
if name in cls._registry:
18+
raise ValueError(f"Connector '{name}' is already registered.")
19+
20+
def loader() -> Type[KVConnectorBase]:
21+
module = importlib.import_module(module_path)
22+
return getattr(module, class_name)
23+
24+
cls._registry[name] = loader
25+
26+
@classmethod
27+
def create_connector(cls, rank: int, local_rank: int,
1328
config: "VllmConfig") -> KVConnectorBase:
14-
supported_kv_connector = ["PyNcclConnector", "MooncakeConnector"]
15-
if config.kv_transfer_config.kv_connector in supported_kv_connector:
16-
from .simple_connector import SimpleConnector
17-
return SimpleConnector(rank, local_rank, config)
18-
else:
19-
raise ValueError(f"Unsupported connector type: "
20-
f"{config.kv_connector}")
29+
connector_name = config.kv_transfer_config.kv_connector
30+
if connector_name not in cls._registry:
31+
raise ValueError(f"Unsupported connector type: {connector_name}")
32+
33+
connector_cls = cls._registry[connector_name]()
34+
return connector_cls(rank, local_rank, config)
35+
36+
37+
# Register various connectors here.
38+
# The registration should not be done in each individual file, as we want to
39+
# only load the files corresponding to the current connector.
40+
KVConnectorFactory.register_connector(
41+
"PyNcclConnector",
42+
"vllm.distributed.kv_transfer.kv_connector.simple_connector",
43+
"SimpleConnector")
44+
45+
KVConnectorFactory.register_connector(
46+
"MooncakeConnector",
47+
"vllm.distributed.kv_transfer.kv_connector.simple_connector",
48+
"SimpleConnector")

0 commit comments

Comments
 (0)