@@ -524,6 +524,7 @@ def configure_post_pass(self):
524
524
525
525
def __call__ (self , graph : fx .GraphModule , example_inputs ) -> Callable :
526
526
527
+ vllm_config = self .vllm_config
527
528
if not self .compilation_config .cache_dir :
528
529
# no provided cache dir, generate one based on the known factors
529
530
# that affects the compilation. if none of the factors change,
@@ -532,7 +533,6 @@ def __call__(self, graph: fx.GraphModule, example_inputs) -> Callable:
532
533
533
534
# 1. factors come from the vllm_config (it mainly summarizes how the
534
535
# model is created)
535
- vllm_config = self .vllm_config
536
536
config_hash = vllm_config .compute_hash ()
537
537
538
538
# 2. factors come from the code files that are traced by Dynamo (
@@ -556,20 +556,26 @@ def __call__(self, graph: fx.GraphModule, example_inputs) -> Callable:
556
556
hash_key = hashlib .md5 (
557
557
f"{ config_hash } _{ code_hash } " .encode ()).hexdigest ()[:10 ]
558
558
cache_dir = os .path .join (
559
- envs .VLLM_CACHE_ROOT , "torch_compile_cache" , hash_key ,
560
- f"rank_{ vllm_config .parallel_config .rank } " )
561
- else :
562
- cache_dir = self .compilation_config .cache_dir
559
+ envs .VLLM_CACHE_ROOT ,
560
+ "torch_compile_cache" ,
561
+ hash_key ,
562
+ )
563
+ self .compilation_config .cache_dir = cache_dir
564
+
565
+ cache_dir = self .compilation_config .cache_dir
563
566
os .makedirs (cache_dir , exist_ok = True )
567
+ local_cache_dir = os .path .join (
568
+ cache_dir , f"rank_{ vllm_config .parallel_config .rank } " )
569
+ self .compilation_config .local_cache_dir = local_cache_dir
564
570
565
571
disabled = envs .VLLM_DISABLE_COMPILE_CACHE
566
572
self .inductor_hash_cache : InductorHashCache = InductorHashCache (
567
- cache_dir , disabled = disabled )
573
+ local_cache_dir , disabled = disabled )
568
574
if disabled :
569
575
logger .info ("vLLM's torch.compile cache is disabled." )
570
576
else :
571
577
logger .info ("Using cache directory: %s for vLLM's torch.compile" ,
572
- cache_dir )
578
+ local_cache_dir )
573
579
574
580
# when dynamo calls the backend, it means the bytecode
575
581
# transform and analysis are done
@@ -609,6 +615,18 @@ def __call__(self, graph: fx.GraphModule, example_inputs) -> Callable:
609
615
self .vllm_config , self .graph_pool ,
610
616
self ).run (* example_inputs )
611
617
618
+ graph_path = os .path .join (local_cache_dir , "computation_graph.py" )
619
+ if not os .path .exists (graph_path ):
620
+ # code adapted from https://github.com/thuml/depyf/blob/dab831108a752d1facc00acdd6d4243891845c37/depyf/explain/patched_lazy_format_graph_code.py#L30 # noqa
621
+ # use `print_readable` because it can include submodules
622
+ src = "from __future__ import annotations\n import torch\n " + \
623
+ self .split_gm .print_readable (print_output = False )
624
+ src = src .replace ("<lambda>" , "GraphModule" )
625
+ with open (graph_path , "w" ) as f :
626
+ f .write (src )
627
+
628
+ logger .debug ("Computation graph saved to %s" , graph_path )
629
+
612
630
self ._called = True
613
631
614
632
if not self .compilation_config .use_cudagraph or \
0 commit comments