Skip to content

Commit c14abd1

Browse files
youkaichaoIsotr0py
authored andcommitted
[torch.compile] store inductor compiled Python file (vllm-project#12182)
Signed-off-by: youkaichao <[email protected]> Signed-off-by: Isotr0py <[email protected]>
1 parent 682f280 commit c14abd1

File tree

2 files changed

+60
-33
lines changed

2 files changed

+60
-33
lines changed

vllm/compilation/backends.py

Lines changed: 58 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -25,23 +25,30 @@
2525
logger = init_logger(__name__)
2626

2727

28+
@dataclasses.dataclass
29+
class InductorArtifact:
30+
hash_str: str = ""
31+
file_path: str = ""
32+
33+
2834
class InductorHashCache:
2935
"""
3036
Disk format: a Python list of tuples, each tuple is
31-
(runtime_shape, graph_index, hash_str)
37+
(runtime_shape, graph_index, hash_str, file_path)
3238
We use list of tuple for readability.
3339
3440
In-memory format: a defaultdict of dict, where the key is
3541
runtime_shape, and the value is a dict of graph_index to hash_str.
3642
37-
The data is essentially `Dict[Optional[int], Dict[int, str]]`,
43+
The data is essentially `Dict[Optional[int], Dict[int, InductorArtifact]]`,
3844
we don't use json here because json doesn't support int as key.
3945
4046
TODO: better off-the-shelf solution to serialize the data?
4147
"""
4248

4349
def __init__(self, cache_dir: str, disabled: bool = False):
44-
self.cache: defaultdict = defaultdict(dict)
50+
self.cache: Dict[Optional[int],
51+
Dict[int, InductorArtifact]] = defaultdict(dict)
4552
self.disabled = disabled
4653
self.cache_dir = cache_dir
4754
self.cache_file_path = os.path.join(cache_dir,
@@ -66,14 +73,25 @@ def deserialize(self, data: str):
6673
# because it is a safe way to parse Python literals.
6774
# do not use eval(), it is unsafe.
6875
list_data = ast.literal_eval(data)
69-
for runtime_shape, graph_index, hash_str in list_data:
70-
self.cache[runtime_shape][graph_index] = hash_str
76+
for item in list_data:
77+
runtime_shape = item[0]
78+
graph_index = item[1]
79+
hash_str = item[2]
80+
# for compatibility of old version,
81+
# where we don't have file_path.
82+
# NOTE: after running the new code, the file_path
83+
# will be updated.
84+
file_path = "" if len(item) == 3 else item[3]
85+
self.cache[runtime_shape][graph_index] = InductorArtifact(
86+
hash_str=hash_str, file_path=file_path)
7187

7288
def serialize(self) -> str:
7389
data = []
74-
for runtime_shape, graph_index_to_hash_str in self.cache.items():
75-
for graph_index, hash_str in graph_index_to_hash_str.items():
76-
data.append((runtime_shape, graph_index, hash_str))
90+
for runtime_shape, value in self.cache.items():
91+
for graph_index, inductor_artifact in value.items():
92+
data.append(
93+
(runtime_shape, graph_index, inductor_artifact.hash_str,
94+
inductor_artifact.file_path))
7795
printer = pprint.PrettyPrinter(indent=4)
7896
return printer.pformat(data)
7997

@@ -90,13 +108,14 @@ def __contains__(self, key: Tuple[Optional[int], int]) -> bool:
90108
return runtime_shape in self.cache and graph_index in self.cache[
91109
runtime_shape]
92110

93-
def __getitem__(self, key: Tuple[Optional[int], int]) -> str:
111+
def __getitem__(self, key: Tuple[Optional[int], int]) -> InductorArtifact:
94112
if self.disabled:
95113
raise KeyError("cannot read from disabled cache")
96114
runtime_shape, graph_index = key
97115
return self.cache[runtime_shape][graph_index]
98116

99-
def __setitem__(self, key: Tuple[Optional[int], int], value: str):
117+
def __setitem__(self, key: Tuple[Optional[int], int],
118+
value: InductorArtifact):
100119
# setitem for disabled cache is fine, because we
101120
# don't actually write to the disk
102121
runtime_shape, graph_index = key
@@ -181,7 +200,8 @@ def wrap_inductor(graph: fx.GraphModule,
181200
if (runtime_shape, graph_index) in cache_data:
182201
# we compiled this graph before
183202
# so we can directly lookup the compiled graph via hash
184-
hash_str = cache_data[(runtime_shape, graph_index)]
203+
inductor_artifact = cache_data[(runtime_shape, graph_index)]
204+
hash_str = inductor_artifact.hash_str
185205
if graph_index == 0:
186206
# adds some info logging for the first graph
187207
logger.info(
@@ -199,6 +219,7 @@ def wrap_inductor(graph: fx.GraphModule,
199219
"Inductor cache lookup failed. Please remove"
200220
f"the cache file {cache_data.cache_file_path} and try again." # noqa
201221
)
222+
inductor_artifact.file_path = inductor_compiled_graph.current_callable.__code__.co_filename # noqa
202223

203224
# Inductor calling convention (function signature):
204225
# f(list) -> tuple
@@ -224,19 +245,20 @@ def compiled_graph(*args):
224245
# the assumption is that we don't have nested Inductor compilation.
225246
# compiled_fx_graph_hash will only be called once, and we can hook
226247
# it to get the hash of the compiled graph directly.
227-
from torch._inductor.codecache import compiled_fx_graph_hash
248+
249+
inductor_artifact = InductorArtifact()
250+
from torch._inductor.codecache import (FxGraphCache,
251+
compiled_fx_graph_hash)
252+
original_load = FxGraphCache.load
253+
254+
def hijack_load(*args, **kwargs):
255+
inductor_compiled_graph = original_load(*args, **kwargs)
256+
inductor_artifact.file_path = inductor_compiled_graph.current_callable.__code__.co_filename # noqa
257+
return inductor_compiled_graph
228258

229259
def hijack_compiled_fx_graph_hash(*args, **kwargs):
230260
out = compiled_fx_graph_hash(*args, **kwargs)
231-
# store the hash in the cache
232-
nonlocal cache_data
233-
cache_data[(runtime_shape, graph_index)] = out[0]
234-
if graph_index == 0:
235-
# adds some info logging for the first graph
236-
logger.info("Cache the graph of shape %s for later use",
237-
str(runtime_shape))
238-
logger.debug("store the %s-th graph for shape %s via hash %s",
239-
graph_index, str(runtime_shape), out[0])
261+
inductor_artifact.hash_str = out[0]
240262
return out
241263

242264
def _check_can_cache(*args, **kwargs):
@@ -255,6 +277,11 @@ def _get_shape_env() -> AlwaysHitShapeEnv:
255277
if not cache_data.disabled:
256278
# compilation cache is enabled, patch several functions
257279

280+
# hijack to get the compiled graph itself
281+
stack.enter_context(
282+
patch("torch._inductor.codecache.FxGraphCache.load",
283+
hijack_load))
284+
258285
# for hijacking the hash of the compiled graph
259286
stack.enter_context(
260287
patch("torch._inductor.codecache.compiled_fx_graph_hash",
@@ -275,7 +302,16 @@ def _get_shape_env() -> AlwaysHitShapeEnv:
275302
compiled_graph = compile_fx(graph,
276303
example_inputs,
277304
config_patches=current_config)
278-
305+
# store the inductor_artifact in the cache
306+
cache_data[(runtime_shape, graph_index)] = inductor_artifact
307+
if graph_index == 0:
308+
# adds some info logging for the first graph
309+
logger.info("Cache the graph of shape %s for later use",
310+
str(runtime_shape))
311+
logger.debug(
312+
"store the %s-th graph for shape %s via hash %s from file %s",
313+
graph_index, str(runtime_shape), inductor_artifact.hash_str,
314+
inductor_artifact.file_path)
279315
# after compiling the last graph, record the end time
280316
if graph_index == num_graphs - 1:
281317
now = time.time()

vllm/config.py

Lines changed: 2 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -2862,17 +2862,8 @@ def model_post_init(self, __context: Any) -> None:
28622862
"vllm.unified_attention_with_output",
28632863
]
28642864
else:
2865-
# v0 can use full graph compilation without splitting,
2866-
# splitting is optional.
2867-
# right now we still need it. kv cache shape
2868-
# will be included in the graph if we don't split
2869-
# the graph.
2870-
# TODO: hide kv cache in static forward context
2871-
# so that inductor does not see it.
2872-
self.splitting_ops = [
2873-
"vllm.unified_attention",
2874-
"vllm.unified_attention_with_output",
2875-
]
2865+
# v0 uses full graph compilation
2866+
self.splitting_ops = []
28762867

28772868
for k, v in self.inductor_passes.items():
28782869
if not isinstance(v, str):

0 commit comments

Comments
 (0)