@@ -251,15 +251,27 @@ def _check_can_cache(*args, **kwargs):
251
251
def _get_shape_env () -> AlwaysHitShapeEnv :
252
252
return AlwaysHitShapeEnv ()
253
253
254
- with patch (# for hijacking the hash of the compiled graph
255
- "torch._inductor.codecache.compiled_fx_graph_hash" ,
256
- hijack_compiled_fx_graph_hash ), \
257
- patch (# for providing a dummy shape environment
258
- "torch._inductor.codecache.FxGraphCache._get_shape_env" ,
259
- _get_shape_env ), \
260
- patch (# for forcing the graph to be cached
261
- "torch._inductor.codecache.FxGraphCache._check_can_cache" ,
262
- _check_can_cache ):
254
+ with ExitStack () as stack :
255
+ if not cache_data .disabled :
256
+ # compilation cache is enabled, patch several functions
257
+
258
+ # for hijacking the hash of the compiled graph
259
+ stack .enter_context (
260
+ patch ("torch._inductor.codecache.compiled_fx_graph_hash" ,
261
+ hijack_compiled_fx_graph_hash ))
262
+
263
+ # for providing a dummy shape environment
264
+ stack .enter_context (
265
+ patch (
266
+ "torch._inductor.codecache.FxGraphCache._get_shape_env" ,
267
+ _get_shape_env ))
268
+
269
+ # for forcing the graph to be cached
270
+ stack .enter_context (
271
+ patch (
272
+ "torch._inductor.codecache.FxGraphCache._check_can_cache" ,
273
+ _check_can_cache ))
274
+
263
275
compiled_graph = compile_fx (graph ,
264
276
example_inputs ,
265
277
config_patches = current_config )
0 commit comments