Skip to content

Commit c222f47

Browse files
authored
[core][bugfix] configure env var during import vllm (#12209)
Signed-off-by: youkaichao <[email protected]>
1 parent 170eb35 commit c222f47

File tree

4 files changed

+37
-45
lines changed

4 files changed

+37
-45
lines changed

examples/offline_inference/rlhf.py

Lines changed: 1 addition & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,7 @@
1919
from ray.util.scheduling_strategies import PlacementGroupSchedulingStrategy
2020
from transformers import AutoModelForCausalLM
2121

22-
from vllm import LLM, SamplingParams, configure_as_vllm_process
22+
from vllm import LLM, SamplingParams
2323
from vllm.utils import get_ip, get_open_port
2424
from vllm.worker.worker import Worker
2525

@@ -98,12 +98,7 @@ def __init__(self, *args, **kwargs):
9898
"""
9999
Start the training process, here we use huggingface transformers
100100
as an example to hold a model on GPU 0.
101-
102-
It is important for all the processes outside of vLLM to call
103-
`configure_as_vllm_process` to set some common environment variables
104-
the same as vLLM workers.
105101
"""
106-
configure_as_vllm_process()
107102

108103
train_model = AutoModelForCausalLM.from_pretrained("facebook/opt-125m")
109104
train_model.to("cuda:0")

vllm/__init__.py

Lines changed: 13 additions & 36 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,7 @@
11
"""vLLM: a high-throughput and memory-efficient inference engine for LLMs"""
2+
import os
3+
4+
import torch
25

36
from vllm.engine.arg_utils import AsyncEngineArgs, EngineArgs
47
from vllm.engine.async_llm_engine import AsyncLLMEngine
@@ -17,43 +20,18 @@
1720

1821
from .version import __version__, __version_tuple__
1922

23+
# set some common config/environment variables that should be set
24+
# for all processes created by vllm and all processes
25+
# that interact with vllm workers.
26+
# they are executed whenever `import vllm` is called.
2027

21-
def configure_as_vllm_process():
22-
"""
23-
set some common config/environment variables that should be set
24-
for all processes created by vllm and all processes
25-
that interact with vllm workers.
26-
"""
27-
import os
28-
29-
import torch
30-
31-
# see https://github.com/NVIDIA/nccl/issues/1234
32-
os.environ['NCCL_CUMEM_ENABLE'] = '0'
33-
34-
# see https://github.com/vllm-project/vllm/issues/10480
35-
os.environ['TORCHINDUCTOR_COMPILE_THREADS'] = '1'
36-
# see https://github.com/vllm-project/vllm/issues/10619
37-
torch._inductor.config.compile_threads = 1
38-
39-
from vllm.platforms import current_platform
40-
41-
if current_platform.is_xpu():
42-
# see https://github.com/pytorch/pytorch/blob/43c5f59/torch/_dynamo/config.py#L158
43-
torch._dynamo.config.disable = True
44-
elif current_platform.is_hpu():
45-
# NOTE(kzawora): PT HPU lazy backend (PT_HPU_LAZY_MODE = 1)
46-
# does not support torch.compile
47-
# Eager backend (PT_HPU_LAZY_MODE = 0) must be selected for
48-
# torch.compile support
49-
is_lazy = os.environ.get('PT_HPU_LAZY_MODE', '1') == '1'
50-
if is_lazy:
51-
torch._dynamo.config.disable = True
52-
# NOTE(kzawora) multi-HPU inference with HPUGraphs (lazy-only)
53-
# requires enabling lazy collectives
54-
# see https://docs.habana.ai/en/latest/PyTorch/Inference_on_PyTorch/Inference_Using_HPU_Graphs.html # noqa: E501
55-
os.environ['PT_HPU_ENABLE_LAZY_COLLECTIVES'] = 'true'
28+
# see https://github.com/NVIDIA/nccl/issues/1234
29+
os.environ['NCCL_CUMEM_ENABLE'] = '0'
5630

31+
# see https://github.com/vllm-project/vllm/issues/10480
32+
os.environ['TORCHINDUCTOR_COMPILE_THREADS'] = '1'
33+
# see https://github.com/vllm-project/vllm/issues/10619
34+
torch._inductor.config.compile_threads = 1
5735

5836
__all__ = [
5937
"__version__",
@@ -80,5 +58,4 @@ def configure_as_vllm_process():
8058
"AsyncEngineArgs",
8159
"initialize_ray_cluster",
8260
"PoolingParams",
83-
"configure_as_vllm_process",
8461
]

vllm/plugins/__init__.py

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,9 @@
11
import logging
2+
import os
23
from typing import Callable, Dict
34

5+
import torch
6+
47
import vllm.envs as envs
58

69
logger = logging.getLogger(__name__)
@@ -51,6 +54,26 @@ def load_general_plugins():
5154
if plugins_loaded:
5255
return
5356
plugins_loaded = True
57+
58+
# some platform-specific configurations
59+
from vllm.platforms import current_platform
60+
61+
if current_platform.is_xpu():
62+
# see https://github.com/pytorch/pytorch/blob/43c5f59/torch/_dynamo/config.py#L158
63+
torch._dynamo.config.disable = True
64+
elif current_platform.is_hpu():
65+
# NOTE(kzawora): PT HPU lazy backend (PT_HPU_LAZY_MODE = 1)
66+
# does not support torch.compile
67+
# Eager backend (PT_HPU_LAZY_MODE = 0) must be selected for
68+
# torch.compile support
69+
is_lazy = os.environ.get('PT_HPU_LAZY_MODE', '1') == '1'
70+
if is_lazy:
71+
torch._dynamo.config.disable = True
72+
# NOTE(kzawora) multi-HPU inference with HPUGraphs (lazy-only)
73+
# requires enabling lazy collectives
74+
# see https://docs.habana.ai/en/latest/PyTorch/Inference_on_PyTorch/Inference_Using_HPU_Graphs.html # noqa: E501
75+
os.environ['PT_HPU_ENABLE_LAZY_COLLECTIVES'] = 'true'
76+
5477
plugins = load_plugins_by_group(group='vllm.general_plugins')
5578
# general plugins, we only need to execute the loaded functions
5679
for func in plugins.values():

vllm/worker/worker_base.py

Lines changed: 0 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -535,9 +535,6 @@ def init_worker(self, all_kwargs: List[Dict[str, Any]]) -> None:
535535
kwargs = all_kwargs[self.rpc_rank]
536536
enable_trace_function_call_for_thread(self.vllm_config)
537537

538-
from vllm import configure_as_vllm_process
539-
configure_as_vllm_process()
540-
541538
from vllm.plugins import load_general_plugins
542539
load_general_plugins()
543540

0 commit comments

Comments
 (0)