Skip to content

Commit 29b95c6

Browse files
youkaichaotjtanaa
authored andcommitted
[torch.compile] transparent compilation with more logging (vllm-project#12246)
Signed-off-by: youkaichao <[email protected]>
1 parent 0572080 commit 29b95c6

File tree

4 files changed

+50
-7
lines changed

4 files changed

+50
-7
lines changed

vllm/compilation/backends.py

Lines changed: 25 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -524,6 +524,7 @@ def configure_post_pass(self):
524524

525525
def __call__(self, graph: fx.GraphModule, example_inputs) -> Callable:
526526

527+
vllm_config = self.vllm_config
527528
if not self.compilation_config.cache_dir:
528529
# no provided cache dir, generate one based on the known factors
529530
# that affects the compilation. if none of the factors change,
@@ -532,7 +533,6 @@ def __call__(self, graph: fx.GraphModule, example_inputs) -> Callable:
532533

533534
# 1. factors come from the vllm_config (it mainly summarizes how the
534535
# model is created)
535-
vllm_config = self.vllm_config
536536
config_hash = vllm_config.compute_hash()
537537

538538
# 2. factors come from the code files that are traced by Dynamo (
@@ -556,20 +556,26 @@ def __call__(self, graph: fx.GraphModule, example_inputs) -> Callable:
556556
hash_key = hashlib.md5(
557557
f"{config_hash}_{code_hash}".encode()).hexdigest()[:10]
558558
cache_dir = os.path.join(
559-
envs.VLLM_CACHE_ROOT, "torch_compile_cache", hash_key,
560-
f"rank_{vllm_config.parallel_config.rank}")
561-
else:
562-
cache_dir = self.compilation_config.cache_dir
559+
envs.VLLM_CACHE_ROOT,
560+
"torch_compile_cache",
561+
hash_key,
562+
)
563+
self.compilation_config.cache_dir = cache_dir
564+
565+
cache_dir = self.compilation_config.cache_dir
563566
os.makedirs(cache_dir, exist_ok=True)
567+
local_cache_dir = os.path.join(
568+
cache_dir, f"rank_{vllm_config.parallel_config.rank}")
569+
self.compilation_config.local_cache_dir = local_cache_dir
564570

565571
disabled = envs.VLLM_DISABLE_COMPILE_CACHE
566572
self.inductor_hash_cache: InductorHashCache = InductorHashCache(
567-
cache_dir, disabled=disabled)
573+
local_cache_dir, disabled=disabled)
568574
if disabled:
569575
logger.info("vLLM's torch.compile cache is disabled.")
570576
else:
571577
logger.info("Using cache directory: %s for vLLM's torch.compile",
572-
cache_dir)
578+
local_cache_dir)
573579

574580
# when dynamo calls the backend, it means the bytecode
575581
# transform and analysis are done
@@ -609,6 +615,18 @@ def __call__(self, graph: fx.GraphModule, example_inputs) -> Callable:
609615
self.vllm_config, self.graph_pool,
610616
self).run(*example_inputs)
611617

618+
graph_path = os.path.join(local_cache_dir, "computation_graph.py")
619+
if not os.path.exists(graph_path):
620+
# code adapted from https://github.com/thuml/depyf/blob/dab831108a752d1facc00acdd6d4243891845c37/depyf/explain/patched_lazy_format_graph_code.py#L30 # noqa
621+
# use `print_readable` because it can include submodules
622+
src = "from __future__ import annotations\nimport torch\n" + \
623+
self.split_gm.print_readable(print_output=False)
624+
src = src.replace("<lambda>", "GraphModule")
625+
with open(graph_path, "w") as f:
626+
f.write(src)
627+
628+
logger.debug("Computation graph saved to %s", graph_path)
629+
612630
self._called = True
613631

614632
if not self.compilation_config.use_cudagraph or \

vllm/compilation/decorators.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -198,6 +198,8 @@ def __call__(self, *args, **kwargs):
198198
f" {dims} for argument {k} with type {type(arg)}.")
199199
# here, it is the starting point of the `torch.compile` process
200200
start_monitoring_torch_compile(self.vllm_config)
201+
logger.debug("Start compiling function %s",
202+
self.original_code_object)
201203

202204
# if we don't use custom dispatcher, we can directly call the
203205
# compiled function and let torch.compile handle the dispatching,

vllm/compilation/wrapper.py

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,9 @@
99

1010
import vllm.envs as envs
1111
from vllm.config import CompilationLevel, get_current_vllm_config
12+
from vllm.logger import init_logger
13+
14+
logger = init_logger(__name__)
1215

1316

1417
class TorchCompileWrapperWithCustomDispatcher:
@@ -82,6 +85,25 @@ def bytecode_hook(self, old_code: CodeType, new_code: CodeType):
8285
return
8386

8487
self.compiled_codes.append(new_code)
88+
local_cache_dir = self.vllm_config.compilation_config.local_cache_dir
89+
if isinstance(local_cache_dir, str):
90+
decompiled_file = os.path.join(local_cache_dir,
91+
"transformed_code.py")
92+
if not os.path.exists(decompiled_file):
93+
try:
94+
# usually the decompilation will succeed for most models,
95+
# as we guarantee a full-graph compilation in Dynamo.
96+
# but there's no 100% guarantee, since decompliation is
97+
# not a reversible process.
98+
import depyf
99+
src = depyf.decompile(new_code)
100+
with open(decompiled_file, "w") as f:
101+
f.write(src)
102+
103+
logger.debug("Dynamo transformed code saved to %s",
104+
decompiled_file)
105+
except Exception:
106+
pass
85107

86108
if self.vllm_config.compilation_config.use_cudagraph and \
87109
"update" in new_code.co_names:

vllm/config.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2785,6 +2785,7 @@ def model_post_init(self, __context: Any) -> None:
27852785
compile_sizes: List[int] = PrivateAttr
27862786
capture_sizes: List[int] = PrivateAttr
27872787
max_capture_size: int = PrivateAttr
2788+
local_cache_dir: str = PrivateAttr # local cache dir for each rank
27882789
# optimization:
27892790
# Intuitively, bs_to_padded_graph_size should be Dict[int, int].
27902791
# since we know all keys are in a range [0, max_capture_size],

0 commit comments

Comments
 (0)