Skip to content

Commit bd7eb9c

Browse files
Chen-0210davidxia
authored andcommitted
[Frontend] Reduce vLLM's import time
This change optimizes the import time of `import vllm` and contributes to #14924. Most of the changes are to lazily instead of eagerly import expensive modules. This change shouldn't affect core functionality. Co-authored-by: Chen-0210 <[email protected]> Co-authored-by: David Xia <[email protected]> Signed-off-by: Chen-0210 <[email protected]> Signed-off-by: David Xia <[email protected]>
1 parent 6d0df0e commit bd7eb9c

File tree

21 files changed

+306
-220
lines changed

21 files changed

+306
-220
lines changed

vllm/config.py

Lines changed: 48 additions & 45 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,7 @@
11
# SPDX-License-Identifier: Apache-2.0
22

3+
from __future__ import annotations
4+
35
import ast
46
import copy
57
import enum
@@ -22,42 +24,45 @@
2224
import torch
2325
from pydantic import BaseModel, Field, PrivateAttr
2426
from torch.distributed import ProcessGroup, ReduceOp
25-
from transformers import PretrainedConfig
2627

2728
import vllm.envs as envs
2829
from vllm.compilation.inductor_pass import CallableInductorPass, InductorPass
2930
from vllm.logger import init_logger
30-
from vllm.model_executor.layers.quantization import (QUANTIZATION_METHODS,
31-
get_quantization_config)
32-
from vllm.model_executor.models import ModelRegistry
33-
from vllm.platforms import CpuArchEnum, current_platform
31+
from vllm.platforms import CpuArchEnum
3432
from vllm.sampling_params import GuidedDecodingParams
35-
from vllm.tracing import is_otel_available, otel_import_error_traceback
3633
from vllm.transformers_utils.config import (
3734
ConfigFormat, get_config, get_hf_image_processor_config,
3835
get_hf_text_config, get_pooling_config,
3936
get_sentence_transformer_tokenizer_config, is_encoder_decoder,
4037
try_get_generation_config, uses_mrope)
4138
from vllm.transformers_utils.s3_utils import S3Model
4239
from vllm.transformers_utils.utils import is_s3, maybe_model_redirect
43-
from vllm.utils import (GiB_bytes, LayerBlockType, cuda_device_count_stateless,
44-
get_cpu_memory, get_open_port, is_torch_equal_or_newer,
45-
random_uuid, resolve_obj_by_qualname)
40+
from vllm.utils import (GiB_bytes, LayerBlockType, LazyLoader,
41+
cuda_device_count_stateless, get_cpu_memory,
42+
get_open_port, is_torch_equal_or_newer, random_uuid,
43+
resolve_obj_by_qualname)
4644

4745
if TYPE_CHECKING:
4846
from _typeshed import DataclassInstance
4947
from ray.util.placement_group import PlacementGroup
48+
from transformers import PretrainedConfig
5049

5150
from vllm.executor.executor_base import ExecutorBase
5251
from vllm.model_executor.layers.quantization.base_config import (
5352
QuantizationConfig)
5453
from vllm.model_executor.model_loader.loader import BaseModelLoader
5554

5655
ConfigType = type[DataclassInstance]
56+
HfOverrides = Union[dict[str, Any], Callable[[PretrainedConfig],
57+
PretrainedConfig]]
5758
else:
58-
QuantizationConfig = None
59+
HfOverrides = None
5960
ConfigType = type
6061

62+
me_quant = LazyLoader("model_executor", globals(),
63+
"vllm.model_executor.layers.quantization")
64+
me_models = LazyLoader("model_executor", globals(),
65+
"vllm.model_executor.models")
6166
logger = init_logger(__name__)
6267

6368
ConfigT = TypeVar("ConfigT", bound=ConfigType)
@@ -89,9 +94,6 @@
8994
for task in tasks
9095
}
9196

92-
HfOverrides = Union[dict[str, Any], Callable[[PretrainedConfig],
93-
PretrainedConfig]]
94-
9597

9698
class SupportsHash(Protocol):
9799

@@ -365,7 +367,7 @@ def __init__(
365367
mm_processor_kwargs: Optional[dict[str, Any]] = None,
366368
disable_mm_preprocessor_cache: bool = False,
367369
override_neuron_config: Optional[dict[str, Any]] = None,
368-
override_pooler_config: Optional["PoolerConfig"] = None,
370+
override_pooler_config: Optional[PoolerConfig] = None,
369371
logits_processor_pattern: Optional[str] = None,
370372
generation_config: str = "auto",
371373
enable_sleep_mode: bool = False,
@@ -548,7 +550,7 @@ def __init__(
548550

549551
@property
550552
def registry(self):
551-
return ModelRegistry
553+
return me_models.ModelRegistry
552554

553555
@property
554556
def architectures(self) -> list[str]:
@@ -581,7 +583,7 @@ def maybe_pull_model_tokenizer_for_s3(self, model: str,
581583

582584
def _init_multimodal_config(
583585
self, limit_mm_per_prompt: Optional[dict[str, int]]
584-
) -> Optional["MultiModalConfig"]:
586+
) -> Optional[MultiModalConfig]:
585587
if self.registry.is_multimodal_model(self.architectures):
586588
return MultiModalConfig(limit_per_prompt=limit_mm_per_prompt or {})
587589

@@ -597,8 +599,8 @@ def _get_encoder_config(self):
597599

598600
def _init_pooler_config(
599601
self,
600-
override_pooler_config: Optional["PoolerConfig"],
601-
) -> Optional["PoolerConfig"]:
602+
override_pooler_config: Optional[PoolerConfig],
603+
) -> Optional[PoolerConfig]:
602604

603605
if self.runner_type == "pooling":
604606
user_config = override_pooler_config or PoolerConfig()
@@ -749,7 +751,8 @@ def _parse_quant_hf_config(self):
749751
return quant_cfg
750752

751753
def _verify_quantization(self) -> None:
752-
supported_quantization = QUANTIZATION_METHODS
754+
supported_quantization = me_quant.QUANTIZATION_METHODS
755+
753756
optimized_quantization_methods = [
754757
"fp8", "marlin", "modelopt", "gptq_marlin_24", "gptq_marlin",
755758
"awq_marlin", "fbgemm_fp8", "compressed_tensors",
@@ -766,8 +769,8 @@ def _verify_quantization(self) -> None:
766769
quant_method = quant_cfg.get("quant_method", "").lower()
767770

768771
# Detect which checkpoint is it
769-
for name in QUANTIZATION_METHODS:
770-
method = get_quantization_config(name)
772+
for name in me_quant.QUANTIZATION_METHODS:
773+
method = me_quant.get_quantization_config(name)
771774
quantization_override = method.override_quantization_method(
772775
quant_cfg, self.quantization)
773776
if quantization_override:
@@ -799,6 +802,8 @@ def _verify_quantization(self) -> None:
799802
"non-quantized models.", self.quantization)
800803

801804
def _verify_cuda_graph(self) -> None:
805+
from vllm.platforms import current_platform
806+
802807
if self.max_seq_len_to_capture is None:
803808
self.max_seq_len_to_capture = self.max_model_len
804809
self.max_seq_len_to_capture = min(self.max_seq_len_to_capture,
@@ -885,7 +890,7 @@ def verify_async_output_proc(self, parallel_config, speculative_config,
885890

886891
def verify_with_parallel_config(
887892
self,
888-
parallel_config: "ParallelConfig",
893+
parallel_config: ParallelConfig,
889894
) -> None:
890895

891896
if parallel_config.distributed_executor_backend == "external_launcher":
@@ -1038,7 +1043,7 @@ def get_total_num_kv_heads(self) -> int:
10381043
# equal to the number of attention heads.
10391044
return self.hf_text_config.num_attention_heads
10401045

1041-
def get_num_kv_heads(self, parallel_config: "ParallelConfig") -> int:
1046+
def get_num_kv_heads(self, parallel_config: ParallelConfig) -> int:
10421047
"""Returns the number of KV heads per GPU."""
10431048
if self.use_mla:
10441049
# When using MLA during decode it becomes MQA
@@ -1052,13 +1057,12 @@ def get_num_kv_heads(self, parallel_config: "ParallelConfig") -> int:
10521057
return max(1,
10531058
total_num_kv_heads // parallel_config.tensor_parallel_size)
10541059

1055-
def get_num_attention_heads(self,
1056-
parallel_config: "ParallelConfig") -> int:
1060+
def get_num_attention_heads(self, parallel_config: ParallelConfig) -> int:
10571061
num_heads = getattr(self.hf_text_config, "num_attention_heads", 0)
10581062
return num_heads // parallel_config.tensor_parallel_size
10591063

10601064
def get_layers_start_end_indices(
1061-
self, parallel_config: "ParallelConfig") -> tuple[int, int]:
1065+
self, parallel_config: ParallelConfig) -> tuple[int, int]:
10621066
from vllm.distributed.utils import get_pp_indices
10631067
if self.hf_text_config.model_type == "deepseek_mtp":
10641068
total_num_hidden_layers = getattr(self.hf_text_config,
@@ -1073,13 +1077,13 @@ def get_layers_start_end_indices(
10731077
start, end = get_pp_indices(total_num_hidden_layers, pp_rank, pp_size)
10741078
return start, end
10751079

1076-
def get_num_layers(self, parallel_config: "ParallelConfig") -> int:
1080+
def get_num_layers(self, parallel_config: ParallelConfig) -> int:
10771081
start, end = self.get_layers_start_end_indices(parallel_config)
10781082
return end - start
10791083

10801084
def get_num_layers_by_block_type(
10811085
self,
1082-
parallel_config: "ParallelConfig",
1086+
parallel_config: ParallelConfig,
10831087
block_type: LayerBlockType = LayerBlockType.attention,
10841088
) -> int:
10851089
# This function relies on 'layers_block_type' in hf_config,
@@ -1132,7 +1136,7 @@ def get_num_layers_by_block_type(
11321136

11331137
return sum(t == 1 for t in attn_type_list[start:end])
11341138

1135-
def get_multimodal_config(self) -> "MultiModalConfig":
1139+
def get_multimodal_config(self) -> MultiModalConfig:
11361140
"""
11371141
Get the multimodal configuration of the model.
11381142
@@ -1241,7 +1245,7 @@ def runner_type(self) -> RunnerType:
12411245
@property
12421246
def is_v1_compatible(self) -> bool:
12431247
architectures = getattr(self.hf_config, "architectures", [])
1244-
return ModelRegistry.is_v1_compatible(architectures)
1248+
return me_models.ModelRegistry.is_v1_compatible(architectures)
12451249

12461250
@property
12471251
def is_matryoshka(self) -> bool:
@@ -1392,7 +1396,7 @@ def _verify_prefix_caching(self) -> None:
13921396

13931397
def verify_with_parallel_config(
13941398
self,
1395-
parallel_config: "ParallelConfig",
1399+
parallel_config: ParallelConfig,
13961400
) -> None:
13971401
total_cpu_memory = get_cpu_memory()
13981402
# FIXME(woosuk): Here, it is assumed that the GPUs in a tensor parallel
@@ -1460,7 +1464,7 @@ class LoadConfig:
14601464
"""Configuration for loading the model weights."""
14611465

14621466
load_format: Union[str, LoadFormat,
1463-
"BaseModelLoader"] = LoadFormat.AUTO.value
1467+
BaseModelLoader] = LoadFormat.AUTO.value
14641468
"""The format of the model weights to load:\n
14651469
- "auto" will try to load the weights in the safetensors format and fall
14661470
back to the pytorch bin format if safetensors format is not available.\n
@@ -1582,11 +1586,11 @@ def data_parallel_rank_local(self, value: int) -> None:
15821586
ray_workers_use_nsight: bool = False
15831587
"""Whether to profile Ray workers with nsight, see https://docs.ray.io/en/latest/ray-observability/user-guides/profiling.html#profiling-nsight-profiler."""
15841588

1585-
placement_group: Optional["PlacementGroup"] = None
1589+
placement_group: Optional[PlacementGroup] = None
15861590
"""ray distributed model workers placement group."""
15871591

15881592
distributed_executor_backend: Optional[Union[DistributedExecutorBackend,
1589-
type["ExecutorBase"]]] = None
1593+
type[ExecutorBase]]] = None
15901594
"""Backend to use for distributed model
15911595
workers, either "ray" or "mp" (multiprocessing). If the product
15921596
of pipeline_parallel_size and tensor_parallel_size is less than
@@ -1629,7 +1633,7 @@ def get_next_dp_init_port(self) -> int:
16291633
self.data_parallel_master_port += 1
16301634
return answer
16311635

1632-
def stateless_init_dp_group(self) -> "ProcessGroup":
1636+
def stateless_init_dp_group(self) -> ProcessGroup:
16331637
from vllm.distributed.utils import (
16341638
stateless_init_torch_distributed_process_group)
16351639

@@ -1644,7 +1648,7 @@ def stateless_init_dp_group(self) -> "ProcessGroup":
16441648
return dp_group
16451649

16461650
@staticmethod
1647-
def has_unfinished_dp(dp_group: "ProcessGroup",
1651+
def has_unfinished_dp(dp_group: ProcessGroup,
16481652
has_unfinished: bool) -> bool:
16491653
tensor = torch.tensor([has_unfinished],
16501654
dtype=torch.int32,
@@ -2227,7 +2231,7 @@ def compute_hash(self) -> str:
22272231
return hash_str
22282232

22292233
@classmethod
2230-
def from_dict(cls, dict_value: dict) -> "SpeculativeConfig":
2234+
def from_dict(cls, dict_value: dict) -> SpeculativeConfig:
22312235
"""Parse the CLI value for the speculative config."""
22322236
return cls(**dict_value)
22332237

@@ -2819,7 +2823,7 @@ def compute_hash(self) -> str:
28192823
return hash_str
28202824

28212825
@staticmethod
2822-
def from_json(json_str: str) -> "PoolerConfig":
2826+
def from_json(json_str: str) -> PoolerConfig:
28232827
return PoolerConfig(**json.loads(json_str))
28242828

28252829

@@ -3176,6 +3180,7 @@ def compute_hash(self) -> str:
31763180
return hash_str
31773181

31783182
def __post_init__(self):
3183+
from vllm.tracing import is_otel_available, otel_import_error_traceback
31793184
if not is_otel_available() and self.otlp_traces_endpoint is not None:
31803185
raise ValueError(
31813186
"OpenTelemetry is not available. Unable to configure "
@@ -3239,7 +3244,7 @@ def compute_hash(self) -> str:
32393244
return hash_str
32403245

32413246
@classmethod
3242-
def from_cli(cls, cli_value: str) -> "KVTransferConfig":
3247+
def from_cli(cls, cli_value: str) -> KVTransferConfig:
32433248
"""Parse the CLI value for the kv cache transfer config."""
32443249
return KVTransferConfig.model_validate_json(cli_value)
32453250

@@ -3476,7 +3481,7 @@ def __repr__(self) -> str:
34763481
__str__ = __repr__
34773482

34783483
@classmethod
3479-
def from_cli(cls, cli_value: str) -> "CompilationConfig":
3484+
def from_cli(cls, cli_value: str) -> CompilationConfig:
34803485
"""Parse the CLI value for the compilation config."""
34813486
if cli_value in ["0", "1", "2", "3"]:
34823487
return cls(level=int(cli_value))
@@ -3528,7 +3533,7 @@ def model_post_init(self, __context: Any) -> None:
35283533
self.static_forward_context = {}
35293534
self.compilation_time = 0.0
35303535

3531-
def init_backend(self, vllm_config: "VllmConfig") -> Union[str, Callable]:
3536+
def init_backend(self, vllm_config: VllmConfig) -> Union[str, Callable]:
35323537
if self.level == CompilationLevel.NO_COMPILATION:
35333538
raise ValueError("No compilation level is set.")
35343539

@@ -3744,9 +3749,7 @@ def _get_quantization_config(
37443749
"""Get the quantization config."""
37453750
from vllm.platforms import current_platform
37463751
if model_config.quantization is not None:
3747-
from vllm.model_executor.model_loader.weight_utils import (
3748-
get_quant_config)
3749-
quant_config = get_quant_config(model_config, load_config)
3752+
quant_config = me_quant.get_quant_config(model_config, load_config)
37503753
capability_tuple = current_platform.get_device_capability()
37513754

37523755
if capability_tuple is not None:
@@ -3770,7 +3773,7 @@ def with_hf_config(
37703773
self,
37713774
hf_config: PretrainedConfig,
37723775
architectures: Optional[list[str]] = None,
3773-
) -> "VllmConfig":
3776+
) -> VllmConfig:
37743777
if architectures is not None:
37753778
hf_config = copy.deepcopy(hf_config)
37763779
hf_config.architectures = architectures

0 commit comments

Comments
 (0)