Skip to content

Commit f121411

Browse files
authored
[torch.compile] consider relevant code in compilation cache (#11614)
Signed-off-by: youkaichao <[email protected]>
1 parent cfd3219 commit f121411

File tree

4 files changed

+99
-35
lines changed

4 files changed

+99
-35
lines changed

vllm/compilation/backends.py

Lines changed: 62 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -145,6 +145,7 @@ def wrap_inductor(graph: fx.GraphModule,
145145
example_inputs,
146146
additional_inductor_config,
147147
compilation_config: CompilationConfig,
148+
vllm_backend: "VllmBackend",
148149
graph_index: int = 0,
149150
num_graphs: int = 1,
150151
runtime_shape: Optional[int] = None,
@@ -176,7 +177,7 @@ def wrap_inductor(graph: fx.GraphModule,
176177
# see https://github.com/pytorch/pytorch/issues/138980
177178
graph = copy.deepcopy(graph)
178179

179-
cache_data = compilation_config.inductor_hash_cache
180+
cache_data = vllm_backend.inductor_hash_cache
180181
if (runtime_shape, graph_index) in cache_data:
181182
# we compiled this graph before
182183
# so we can directly lookup the compiled graph via hash
@@ -196,7 +197,7 @@ def wrap_inductor(graph: fx.GraphModule,
196197
hash_str, example_inputs, True, False)
197198
assert inductor_compiled_graph is not None, (
198199
"Inductor cache lookup failed. Please remove"
199-
f"the cache file {compilation_config.inductor_hash_cache.cache_file_path} and try again." # noqa
200+
f"the cache file {cache_data.cache_file_path} and try again." # noqa
200201
)
201202

202203
# Inductor calling convention (function signature):
@@ -354,14 +355,15 @@ class PiecewiseCompileInterpreter(torch.fx.Interpreter):
354355

355356
def __init__(self, module: torch.fx.GraphModule,
356357
compile_submod_names: List[str], vllm_config: VllmConfig,
357-
graph_pool):
358+
graph_pool, vllm_backend: "VllmBackend"):
358359
super().__init__(module)
359360
from torch._guards import detect_fake_mode
360361
self.fake_mode = detect_fake_mode()
361362
self.compile_submod_names = compile_submod_names
362363
self.compilation_config = vllm_config.compilation_config
363364
self.graph_pool = graph_pool
364365
self.vllm_config = vllm_config
366+
self.vllm_backend = vllm_backend
365367

366368
def run(self, *args):
367369
fake_args = [
@@ -389,6 +391,7 @@ def call_module(self, target: torch.fx.node.Target,
389391
args,
390392
self.compilation_config.inductor_compile_config,
391393
self.compilation_config,
394+
self.vllm_backend,
392395
graph_index=index,
393396
num_graphs=len(self.compile_submod_names),
394397
runtime_shape=None,
@@ -397,7 +400,7 @@ def call_module(self, target: torch.fx.node.Target,
397400
self.module.__dict__[target] = PiecewiseBackend(
398401
submod, self.vllm_config, self.graph_pool, index,
399402
len(self.compile_submod_names), sym_shape_indices,
400-
compiled_graph_for_general_shape)
403+
compiled_graph_for_general_shape, self.vllm_backend)
401404

402405
compilation_counter.num_piecewise_capturable_graphs_seen += 1
403406

@@ -430,6 +433,7 @@ class VllmBackend:
430433
post_grad_passes: Sequence[Callable]
431434
sym_tensor_indices: List[int]
432435
input_buffers: List[torch.Tensor]
436+
inductor_hash_cache: InductorHashCache
433437

434438
def __init__(
435439
self,
@@ -472,6 +476,53 @@ def configure_post_pass(self):
472476

473477
def __call__(self, graph: fx.GraphModule, example_inputs) -> Callable:
474478

479+
if not self.compilation_config.cache_dir:
480+
# no provided cache dir, generate one based on the known factors
481+
# that affects the compilation. if none of the factors change,
482+
# the cache dir will be the same so that we can reuse the compiled
483+
# graph.
484+
485+
# 1. factors come from the vllm_config (it mainly summarizes how the
486+
# model is created)
487+
vllm_config = self.vllm_config
488+
config_hash = vllm_config.compute_hash()
489+
490+
# 2. factors come from the code files that are traced by Dynamo (
491+
# it mainly summarizes how the model is used in forward pass)
492+
forward_code_files = list(
493+
sorted(self.compilation_config.traced_files))
494+
self.compilation_config.traced_files.clear()
495+
logger.debug(
496+
"Traced files (to be considered for compilation cache):\n%s",
497+
"\n".join(forward_code_files))
498+
hash_content = []
499+
for filepath in forward_code_files:
500+
hash_content.append(filepath)
501+
with open(filepath) as f:
502+
hash_content.append(f.read())
503+
import hashlib
504+
code_hash = hashlib.md5(
505+
"\n".join(hash_content).encode()).hexdigest()
506+
507+
# combine the two hashes to generate the cache dir
508+
hash_key = hashlib.md5(
509+
f"{config_hash}_{code_hash}".encode()).hexdigest()[:10]
510+
cache_dir = os.path.join(
511+
envs.VLLM_CACHE_ROOT, "torch_compile_cache", hash_key,
512+
f"rank_{vllm_config.parallel_config.rank}")
513+
else:
514+
cache_dir = self.compilation_config.cache_dir
515+
os.makedirs(cache_dir, exist_ok=True)
516+
517+
disabled = envs.VLLM_DISABLE_COMPILE_CACHE
518+
self.inductor_hash_cache: InductorHashCache = InductorHashCache(
519+
cache_dir, disabled=disabled)
520+
if disabled:
521+
logger.info("vLLM's torch.compile cache is disabled.")
522+
else:
523+
logger.info("Using cache directory: %s for vLLM's torch.compile",
524+
cache_dir)
525+
475526
# when dynamo calls the backend, it means the bytecode
476527
# transform and analysis are done
477528
compilation_counter.num_graphs_seen += 1
@@ -507,8 +558,8 @@ def __call__(self, graph: fx.GraphModule, example_inputs) -> Callable:
507558
# propagate the split graph to the piecewise backend,
508559
# compile submodules with symbolic shapes
509560
PiecewiseCompileInterpreter(self.split_gm, submod_names_to_compile,
510-
self.vllm_config,
511-
self.graph_pool).run(*example_inputs)
561+
self.vllm_config, self.graph_pool,
562+
self).run(*example_inputs)
512563

513564
self._called = True
514565

@@ -577,7 +628,8 @@ class PiecewiseBackend:
577628
def __init__(self, graph: fx.GraphModule, vllm_config: VllmConfig,
578629
graph_pool: Any, piecewise_compile_index: int,
579630
total_piecewise_compiles: int, sym_shape_indices: List[int],
580-
compiled_graph_for_general_shape: Callable):
631+
compiled_graph_for_general_shape: Callable,
632+
vllm_backend: VllmBackend):
581633
"""
582634
The backend for piecewise compilation.
583635
It mainly handles the compilation and cudagraph capturing.
@@ -597,6 +649,7 @@ def __init__(self, graph: fx.GraphModule, vllm_config: VllmConfig,
597649
self.graph_pool = graph_pool
598650
self.piecewise_compile_index = piecewise_compile_index
599651
self.total_piecewise_compiles = total_piecewise_compiles
652+
self.vllm_backend = vllm_backend
600653

601654
self.is_first_graph = piecewise_compile_index == 0
602655
self.is_last_graph = (
@@ -634,7 +687,7 @@ def check_for_ending_compilation(self):
634687
if self.is_last_graph and not self.to_be_compiled_sizes:
635688
# no specific sizes to compile
636689
# save the hash of the inductor graph for the next run
637-
self.compilation_config.inductor_hash_cache.save_to_file()
690+
self.vllm_backend.inductor_hash_cache.save_to_file()
638691
end_monitoring_torch_compile(self.vllm_config)
639692

640693
def __call__(self, *args) -> Any:
@@ -662,6 +715,7 @@ def __call__(self, *args) -> Any:
662715
args,
663716
self.compilation_config.inductor_compile_config,
664717
self.compilation_config,
718+
self.vllm_backend,
665719
graph_index=self.piecewise_compile_index,
666720
num_graphs=self.total_piecewise_compiles,
667721
runtime_shape=runtime_shape,

vllm/compilation/decorators.py

Lines changed: 27 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,10 @@
11
import inspect
22
from typing import Callable, Dict, List, Optional, TypeVar, Union, overload
3+
from unittest.mock import patch
34

45
import torch
56
import torch.nn as nn
7+
from torch._dynamo.symbolic_convert import InliningInstructionTranslator
68

79
from vllm.compilation.counter import compilation_counter
810
from vllm.compilation.wrapper import TorchCompileWrapperWithCustomDispatcher
@@ -196,7 +198,31 @@ def __call__(self, *args, **kwargs):
196198
# we need to control all the compilation of the model.
197199
torch._dynamo.eval_frame.remove_from_cache(
198200
self.original_code_object)
199-
return self.compiled_callable(*args, **kwargs)
201+
202+
# collect all relevant files traced by Dynamo,
203+
# so that the compilation cache can trigger re-compilation
204+
# properly when any of these files change.
205+
206+
# 1. the file containing the top-level forward function
207+
self.vllm_config.compilation_config.traced_files.add(
208+
self.original_code_object.co_filename)
209+
210+
# 2. every time Dynamo sees a function call, it will inline
211+
# the function by calling InliningInstructionTranslator.inline_call
212+
# we hijack this function to know all the functions called
213+
# during Dynamo tracing, and their corresponding files
214+
inline_call = InliningInstructionTranslator.inline_call
215+
216+
def patched_inline_call(parent, func, args, kwargs):
217+
code = func.get_code()
218+
self.vllm_config.compilation_config.traced_files.add(
219+
code.co_filename)
220+
return inline_call(parent, func, args, kwargs)
221+
222+
with patch.object(InliningInstructionTranslator, 'inline_call',
223+
patched_inline_call):
224+
output = self.compiled_callable(*args, **kwargs)
225+
return output
200226

201227
# usually, capturing the model once is enough, and then we can
202228
# dispatch to the compiled code directly, without going through

vllm/config.py

Lines changed: 3 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,6 @@
33
import enum
44
import hashlib
55
import json
6-
import os
76
import sys
87
import warnings
98
from contextlib import contextmanager
@@ -2778,9 +2777,8 @@ def model_post_init(self, __context: Any) -> None:
27782777
# keep track of enabled and disabled custom ops
27792778
enabled_custom_ops: Counter[str] = PrivateAttr
27802779
disabled_custom_ops: Counter[str] = PrivateAttr
2780+
traced_files: Set[str] = PrivateAttr
27812781
compilation_time: float = PrivateAttr
2782-
# should be InductorHashCache, but Pydantic does not support it
2783-
inductor_hash_cache: Any = PrivateAttr
27842782

27852783
# Per-model forward context
27862784
# Mainly used to store attention cls
@@ -2818,6 +2816,7 @@ def __repr__(self) -> str:
28182816
"compilation_time",
28192817
"bs_to_padded_graph_size",
28202818
"pass_config",
2819+
"traced_files",
28212820
}
28222821
return self.model_dump_json(exclude=exclude, exclude_unset=True)
28232822

@@ -2877,6 +2876,7 @@ def model_post_init(self, __context: Any) -> None:
28772876

28782877
self.enabled_custom_ops = Counter()
28792878
self.disabled_custom_ops = Counter()
2879+
self.traced_files = set()
28802880
self.static_forward_context = {}
28812881
self.compilation_time = 0.0
28822882

@@ -2899,29 +2899,6 @@ def init_backend(self, vllm_config: "VllmConfig") -> Union[str, Callable]:
28992899
# merge with the config use_inductor
29002900
assert self.level == CompilationLevel.PIECEWISE
29012901

2902-
if not self.cache_dir:
2903-
# no provided cache dir, generate one based on the known factors
2904-
# that affects the compilation. if none of the factors change,
2905-
# the cache dir will be the same so that we can reuse the compiled
2906-
# graph.
2907-
hash_key = vllm_config.compute_hash()
2908-
cache_dir = os.path.join(
2909-
envs.VLLM_CACHE_ROOT, "torch_compile_cache", hash_key,
2910-
f"rank_{vllm_config.parallel_config.rank}")
2911-
os.makedirs(cache_dir, exist_ok=True)
2912-
self.cache_dir = cache_dir
2913-
2914-
disabled = envs.VLLM_DISABLE_COMPILE_CACHE
2915-
from vllm.compilation.backends import InductorHashCache
2916-
self.inductor_hash_cache: InductorHashCache = InductorHashCache(
2917-
self.cache_dir, disabled=disabled)
2918-
if disabled:
2919-
logger.info("vLLM's torch.compile cache is disabled.")
2920-
else:
2921-
logger.info(
2922-
"Using cache directory: %s for vLLM's torch.compile",
2923-
self.cache_dir)
2924-
29252902
from vllm.compilation.backends import VllmBackend
29262903
return VllmBackend(vllm_config)
29272904

vllm/sequence.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1108,6 +1108,13 @@ class IntermediateTensors:
11081108

11091109
tensors: Dict[str, torch.Tensor]
11101110

1111+
def __init__(self, tensors):
1112+
# manually define this function, so that
1113+
# Dynamo knows `IntermediateTensors()` comes from this file.
1114+
# Otherwise, dataclass will generate this function by evaluating
1115+
# a string, and we will lose the information about the source file.
1116+
self.tensors = tensors
1117+
11111118
def __getitem__(self, key: Union[str, slice]):
11121119
if isinstance(key, str):
11131120
return self.tensors[key]

0 commit comments

Comments
 (0)