25
25
logger = init_logger (__name__ )
26
26
27
27
28
+ @dataclasses .dataclass
29
+ class InductorArtifact :
30
+ hash_str : str = ""
31
+ file_path : str = ""
32
+
33
+
28
34
class InductorHashCache :
29
35
"""
30
36
Disk format: a Python list of tuples, each tuple is
31
- (runtime_shape, graph_index, hash_str)
37
+ (runtime_shape, graph_index, hash_str, file_path )
32
38
We use list of tuple for readability.
33
39
34
40
In-memory format: a defaultdict of dict, where the key is
35
41
runtime_shape, and the value is a dict of graph_index to hash_str.
36
42
37
- The data is essentially `Dict[Optional[int], Dict[int, str ]]`,
43
+ The data is essentially `Dict[Optional[int], Dict[int, InductorArtifact ]]`,
38
44
we don't use json here because json doesn't support int as key.
39
45
40
46
TODO: better off-the-shelf solution to serialize the data?
41
47
"""
42
48
43
49
def __init__ (self , cache_dir : str , disabled : bool = False ):
44
- self .cache : defaultdict = defaultdict (dict )
50
+ self .cache : Dict [Optional [int ],
51
+ Dict [int , InductorArtifact ]] = defaultdict (dict )
45
52
self .disabled = disabled
46
53
self .cache_dir = cache_dir
47
54
self .cache_file_path = os .path .join (cache_dir ,
@@ -66,14 +73,25 @@ def deserialize(self, data: str):
66
73
# because it is a safe way to parse Python literals.
67
74
# do not use eval(), it is unsafe.
68
75
list_data = ast .literal_eval (data )
69
- for runtime_shape , graph_index , hash_str in list_data :
70
- self .cache [runtime_shape ][graph_index ] = hash_str
76
+ for item in list_data :
77
+ runtime_shape = item [0 ]
78
+ graph_index = item [1 ]
79
+ hash_str = item [2 ]
80
+ # for compatibility of old version,
81
+ # where we don't have file_path.
82
+ # NOTE: after running the new code, the file_path
83
+ # will be updated.
84
+ file_path = "" if len (item ) == 3 else item [3 ]
85
+ self .cache [runtime_shape ][graph_index ] = InductorArtifact (
86
+ hash_str = hash_str , file_path = file_path )
71
87
72
88
def serialize (self ) -> str :
73
89
data = []
74
- for runtime_shape , graph_index_to_hash_str in self .cache .items ():
75
- for graph_index , hash_str in graph_index_to_hash_str .items ():
76
- data .append ((runtime_shape , graph_index , hash_str ))
90
+ for runtime_shape , value in self .cache .items ():
91
+ for graph_index , inductor_artifact in value .items ():
92
+ data .append (
93
+ (runtime_shape , graph_index , inductor_artifact .hash_str ,
94
+ inductor_artifact .file_path ))
77
95
printer = pprint .PrettyPrinter (indent = 4 )
78
96
return printer .pformat (data )
79
97
@@ -90,13 +108,14 @@ def __contains__(self, key: Tuple[Optional[int], int]) -> bool:
90
108
return runtime_shape in self .cache and graph_index in self .cache [
91
109
runtime_shape ]
92
110
93
- def __getitem__ (self , key : Tuple [Optional [int ], int ]) -> str :
111
+ def __getitem__ (self , key : Tuple [Optional [int ], int ]) -> InductorArtifact :
94
112
if self .disabled :
95
113
raise KeyError ("cannot read from disabled cache" )
96
114
runtime_shape , graph_index = key
97
115
return self .cache [runtime_shape ][graph_index ]
98
116
99
- def __setitem__ (self , key : Tuple [Optional [int ], int ], value : str ):
117
+ def __setitem__ (self , key : Tuple [Optional [int ], int ],
118
+ value : InductorArtifact ):
100
119
# setitem for disabled cache is fine, because we
101
120
# don't actually write to the disk
102
121
runtime_shape , graph_index = key
@@ -181,7 +200,8 @@ def wrap_inductor(graph: fx.GraphModule,
181
200
if (runtime_shape , graph_index ) in cache_data :
182
201
# we compiled this graph before
183
202
# so we can directly lookup the compiled graph via hash
184
- hash_str = cache_data [(runtime_shape , graph_index )]
203
+ inductor_artifact = cache_data [(runtime_shape , graph_index )]
204
+ hash_str = inductor_artifact .hash_str
185
205
if graph_index == 0 :
186
206
# adds some info logging for the first graph
187
207
logger .info (
@@ -199,6 +219,7 @@ def wrap_inductor(graph: fx.GraphModule,
199
219
"Inductor cache lookup failed. Please remove"
200
220
f"the cache file { cache_data .cache_file_path } and try again." # noqa
201
221
)
222
+ inductor_artifact .file_path = inductor_compiled_graph .current_callable .__code__ .co_filename # noqa
202
223
203
224
# Inductor calling convention (function signature):
204
225
# f(list) -> tuple
@@ -224,19 +245,20 @@ def compiled_graph(*args):
224
245
# the assumption is that we don't have nested Inductor compilation.
225
246
# compiled_fx_graph_hash will only be called once, and we can hook
226
247
# it to get the hash of the compiled graph directly.
227
- from torch ._inductor .codecache import compiled_fx_graph_hash
248
+
249
+ inductor_artifact = InductorArtifact ()
250
+ from torch ._inductor .codecache import (FxGraphCache ,
251
+ compiled_fx_graph_hash )
252
+ original_load = FxGraphCache .load
253
+
254
+ def hijack_load (* args , ** kwargs ):
255
+ inductor_compiled_graph = original_load (* args , ** kwargs )
256
+ inductor_artifact .file_path = inductor_compiled_graph .current_callable .__code__ .co_filename # noqa
257
+ return inductor_compiled_graph
228
258
229
259
def hijack_compiled_fx_graph_hash (* args , ** kwargs ):
230
260
out = compiled_fx_graph_hash (* args , ** kwargs )
231
- # store the hash in the cache
232
- nonlocal cache_data
233
- cache_data [(runtime_shape , graph_index )] = out [0 ]
234
- if graph_index == 0 :
235
- # adds some info logging for the first graph
236
- logger .info ("Cache the graph of shape %s for later use" ,
237
- str (runtime_shape ))
238
- logger .debug ("store the %s-th graph for shape %s via hash %s" ,
239
- graph_index , str (runtime_shape ), out [0 ])
261
+ inductor_artifact .hash_str = out [0 ]
240
262
return out
241
263
242
264
def _check_can_cache (* args , ** kwargs ):
@@ -255,6 +277,11 @@ def _get_shape_env() -> AlwaysHitShapeEnv:
255
277
if not cache_data .disabled :
256
278
# compilation cache is enabled, patch several functions
257
279
280
+ # hijack to get the compiled graph itself
281
+ stack .enter_context (
282
+ patch ("torch._inductor.codecache.FxGraphCache.load" ,
283
+ hijack_load ))
284
+
258
285
# for hijacking the hash of the compiled graph
259
286
stack .enter_context (
260
287
patch ("torch._inductor.codecache.compiled_fx_graph_hash" ,
@@ -275,7 +302,16 @@ def _get_shape_env() -> AlwaysHitShapeEnv:
275
302
compiled_graph = compile_fx (graph ,
276
303
example_inputs ,
277
304
config_patches = current_config )
278
-
305
+ # store the inductor_artifact in the cache
306
+ cache_data [(runtime_shape , graph_index )] = inductor_artifact
307
+ if graph_index == 0 :
308
+ # adds some info logging for the first graph
309
+ logger .info ("Cache the graph of shape %s for later use" ,
310
+ str (runtime_shape ))
311
+ logger .debug (
312
+ "store the %s-th graph for shape %s via hash %s from file %s" ,
313
+ graph_index , str (runtime_shape ), inductor_artifact .hash_str ,
314
+ inductor_artifact .file_path )
279
315
# after compiling the last graph, record the end time
280
316
if graph_index == num_graphs - 1 :
281
317
now = time .time ()
0 commit comments