@@ -145,6 +145,7 @@ def wrap_inductor(graph: fx.GraphModule,
145
145
example_inputs ,
146
146
additional_inductor_config ,
147
147
compilation_config : CompilationConfig ,
148
+ vllm_backend : "VllmBackend" ,
148
149
graph_index : int = 0 ,
149
150
num_graphs : int = 1 ,
150
151
runtime_shape : Optional [int ] = None ,
@@ -176,7 +177,7 @@ def wrap_inductor(graph: fx.GraphModule,
176
177
# see https://github.com/pytorch/pytorch/issues/138980
177
178
graph = copy .deepcopy (graph )
178
179
179
- cache_data = compilation_config .inductor_hash_cache
180
+ cache_data = vllm_backend .inductor_hash_cache
180
181
if (runtime_shape , graph_index ) in cache_data :
181
182
# we compiled this graph before
182
183
# so we can directly lookup the compiled graph via hash
@@ -196,7 +197,7 @@ def wrap_inductor(graph: fx.GraphModule,
196
197
hash_str , example_inputs , True , False )
197
198
assert inductor_compiled_graph is not None , (
198
199
"Inductor cache lookup failed. Please remove"
199
- f"the cache file { compilation_config . inductor_hash_cache .cache_file_path } and try again." # noqa
200
+ f"the cache file { cache_data .cache_file_path } and try again." # noqa
200
201
)
201
202
202
203
# Inductor calling convention (function signature):
@@ -354,14 +355,15 @@ class PiecewiseCompileInterpreter(torch.fx.Interpreter):
354
355
355
356
def __init__ (self , module : torch .fx .GraphModule ,
356
357
compile_submod_names : List [str ], vllm_config : VllmConfig ,
357
- graph_pool ):
358
+ graph_pool , vllm_backend : "VllmBackend" ):
358
359
super ().__init__ (module )
359
360
from torch ._guards import detect_fake_mode
360
361
self .fake_mode = detect_fake_mode ()
361
362
self .compile_submod_names = compile_submod_names
362
363
self .compilation_config = vllm_config .compilation_config
363
364
self .graph_pool = graph_pool
364
365
self .vllm_config = vllm_config
366
+ self .vllm_backend = vllm_backend
365
367
366
368
def run (self , * args ):
367
369
fake_args = [
@@ -389,6 +391,7 @@ def call_module(self, target: torch.fx.node.Target,
389
391
args ,
390
392
self .compilation_config .inductor_compile_config ,
391
393
self .compilation_config ,
394
+ self .vllm_backend ,
392
395
graph_index = index ,
393
396
num_graphs = len (self .compile_submod_names ),
394
397
runtime_shape = None ,
@@ -397,7 +400,7 @@ def call_module(self, target: torch.fx.node.Target,
397
400
self .module .__dict__ [target ] = PiecewiseBackend (
398
401
submod , self .vllm_config , self .graph_pool , index ,
399
402
len (self .compile_submod_names ), sym_shape_indices ,
400
- compiled_graph_for_general_shape )
403
+ compiled_graph_for_general_shape , self . vllm_backend )
401
404
402
405
compilation_counter .num_piecewise_capturable_graphs_seen += 1
403
406
@@ -430,6 +433,7 @@ class VllmBackend:
430
433
post_grad_passes : Sequence [Callable ]
431
434
sym_tensor_indices : List [int ]
432
435
input_buffers : List [torch .Tensor ]
436
+ inductor_hash_cache : InductorHashCache
433
437
434
438
def __init__ (
435
439
self ,
@@ -472,6 +476,53 @@ def configure_post_pass(self):
472
476
473
477
def __call__ (self , graph : fx .GraphModule , example_inputs ) -> Callable :
474
478
479
+ if not self .compilation_config .cache_dir :
480
+ # no provided cache dir, generate one based on the known factors
481
+ # that affects the compilation. if none of the factors change,
482
+ # the cache dir will be the same so that we can reuse the compiled
483
+ # graph.
484
+
485
+ # 1. factors come from the vllm_config (it mainly summarizes how the
486
+ # model is created)
487
+ vllm_config = self .vllm_config
488
+ config_hash = vllm_config .compute_hash ()
489
+
490
+ # 2. factors come from the code files that are traced by Dynamo (
491
+ # it mainly summarizes how the model is used in forward pass)
492
+ forward_code_files = list (
493
+ sorted (self .compilation_config .traced_files ))
494
+ self .compilation_config .traced_files .clear ()
495
+ logger .debug (
496
+ "Traced files (to be considered for compilation cache):\n %s" ,
497
+ "\n " .join (forward_code_files ))
498
+ hash_content = []
499
+ for filepath in forward_code_files :
500
+ hash_content .append (filepath )
501
+ with open (filepath ) as f :
502
+ hash_content .append (f .read ())
503
+ import hashlib
504
+ code_hash = hashlib .md5 (
505
+ "\n " .join (hash_content ).encode ()).hexdigest ()
506
+
507
+ # combine the two hashes to generate the cache dir
508
+ hash_key = hashlib .md5 (
509
+ f"{ config_hash } _{ code_hash } " .encode ()).hexdigest ()[:10 ]
510
+ cache_dir = os .path .join (
511
+ envs .VLLM_CACHE_ROOT , "torch_compile_cache" , hash_key ,
512
+ f"rank_{ vllm_config .parallel_config .rank } " )
513
+ else :
514
+ cache_dir = self .compilation_config .cache_dir
515
+ os .makedirs (cache_dir , exist_ok = True )
516
+
517
+ disabled = envs .VLLM_DISABLE_COMPILE_CACHE
518
+ self .inductor_hash_cache : InductorHashCache = InductorHashCache (
519
+ cache_dir , disabled = disabled )
520
+ if disabled :
521
+ logger .info ("vLLM's torch.compile cache is disabled." )
522
+ else :
523
+ logger .info ("Using cache directory: %s for vLLM's torch.compile" ,
524
+ cache_dir )
525
+
475
526
# when dynamo calls the backend, it means the bytecode
476
527
# transform and analysis are done
477
528
compilation_counter .num_graphs_seen += 1
@@ -507,8 +558,8 @@ def __call__(self, graph: fx.GraphModule, example_inputs) -> Callable:
507
558
# propagate the split graph to the piecewise backend,
508
559
# compile submodules with symbolic shapes
509
560
PiecewiseCompileInterpreter (self .split_gm , submod_names_to_compile ,
510
- self .vllm_config ,
511
- self . graph_pool ).run (* example_inputs )
561
+ self .vllm_config , self . graph_pool ,
562
+ self ).run (* example_inputs )
512
563
513
564
self ._called = True
514
565
@@ -577,7 +628,8 @@ class PiecewiseBackend:
577
628
def __init__ (self , graph : fx .GraphModule , vllm_config : VllmConfig ,
578
629
graph_pool : Any , piecewise_compile_index : int ,
579
630
total_piecewise_compiles : int , sym_shape_indices : List [int ],
580
- compiled_graph_for_general_shape : Callable ):
631
+ compiled_graph_for_general_shape : Callable ,
632
+ vllm_backend : VllmBackend ):
581
633
"""
582
634
The backend for piecewise compilation.
583
635
It mainly handles the compilation and cudagraph capturing.
@@ -597,6 +649,7 @@ def __init__(self, graph: fx.GraphModule, vllm_config: VllmConfig,
597
649
self .graph_pool = graph_pool
598
650
self .piecewise_compile_index = piecewise_compile_index
599
651
self .total_piecewise_compiles = total_piecewise_compiles
652
+ self .vllm_backend = vllm_backend
600
653
601
654
self .is_first_graph = piecewise_compile_index == 0
602
655
self .is_last_graph = (
@@ -634,7 +687,7 @@ def check_for_ending_compilation(self):
634
687
if self .is_last_graph and not self .to_be_compiled_sizes :
635
688
# no specific sizes to compile
636
689
# save the hash of the inductor graph for the next run
637
- self .compilation_config .inductor_hash_cache .save_to_file ()
690
+ self .vllm_backend .inductor_hash_cache .save_to_file ()
638
691
end_monitoring_torch_compile (self .vllm_config )
639
692
640
693
def __call__ (self , * args ) -> Any :
@@ -662,6 +715,7 @@ def __call__(self, *args) -> Any:
662
715
args ,
663
716
self .compilation_config .inductor_compile_config ,
664
717
self .compilation_config ,
718
+ self .vllm_backend ,
665
719
graph_index = self .piecewise_compile_index ,
666
720
num_graphs = self .total_piecewise_compiles ,
667
721
runtime_shape = runtime_shape ,
0 commit comments