Skip to content

Commit 3682e33

Browse files
authored
[v1] fix compilation cache (#11598)
Signed-off-by: youkaichao <[email protected]>
1 parent 0aa38d1 commit 3682e33

File tree

4 files changed

+69
-14
lines changed

4 files changed

+69
-14
lines changed

tests/compile/piecewise/test_toy_llama.py

Lines changed: 13 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@
77
initialized randomly with a fixed seed.
88
"""
99
from dataclasses import dataclass
10-
from typing import Optional, Tuple
10+
from typing import Any, List, Optional, Tuple
1111

1212
import torch
1313
from torch import nn
@@ -54,6 +54,16 @@ class LlamaConfig:
5454
tractable_init: bool = False
5555
random_seed: int = 0
5656

57+
def compute_hash(self) -> str:
58+
factors: List[Any] = []
59+
for k, v in self.__dict__.items():
60+
if k == "random_seed":
61+
continue
62+
factors.append((k, v))
63+
factors.sort()
64+
import hashlib
65+
return hashlib.md5(str(factors).encode()).hexdigest()
66+
5767
def __post_init__(self):
5868
assert self.mlp_size >= self.hidden_size
5969

@@ -263,7 +273,8 @@ def run_model(llama_config,
263273
compilation_config = CompilationConfig(
264274
level=CompilationLevel.NO_COMPILATION, )
265275

266-
vllm_config = VllmConfig(compilation_config=compilation_config)
276+
vllm_config = VllmConfig(compilation_config=compilation_config,
277+
additional_config=llama_config)
267278
with set_current_vllm_config(vllm_config):
268279
model = LlamaModel(config=llama_config,
269280
vllm_config=vllm_config,

vllm/compilation/backends.py

Lines changed: 13 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -619,21 +619,28 @@ def __init__(self, graph: fx.GraphModule, vllm_config: VllmConfig,
619619
# the entries for different shapes that we need to either
620620
# compile or capture cudagraph
621621
self.concrete_size_entries: Dict[int, ConcreteSizeEntry] = {}
622-
self.to_be_compiled_sizes: Set[int] = self.compile_sizes.union(
623-
self.capture_sizes)
622+
623+
# to_be_compiled_sizes tracks the remaining sizes to compile,
624+
# and updates during the compilation process, so we need to copy it
625+
self.to_be_compiled_sizes: Set[int] = self.compile_sizes.copy()
624626
for shape in self.compile_sizes.union(self.capture_sizes):
625627
self.concrete_size_entries[shape] = ConcreteSizeEntry(
626628
runtime_shape=shape,
627629
need_to_compile=shape in self.compile_sizes,
628630
use_cudagraph=shape in self.capture_sizes,
629631
)
630632

633+
def check_for_ending_compilation(self):
634+
if self.is_last_graph and not self.to_be_compiled_sizes:
635+
# no specific sizes to compile
636+
# save the hash of the inductor graph for the next run
637+
self.compilation_config.inductor_hash_cache.save_to_file()
638+
end_monitoring_torch_compile(self.vllm_config)
639+
631640
def __call__(self, *args) -> Any:
632641
if not self.first_run_finished:
633642
self.first_run_finished = True
634-
# no specific sizes to compile
635-
if self.is_last_graph and not self.to_be_compiled_sizes:
636-
end_monitoring_torch_compile(self.vllm_config)
643+
self.check_for_ending_compilation()
637644
return self.compiled_graph_for_general_shape(*args)
638645

639646
runtime_shape = args[self.sym_shape_indices[0]]
@@ -662,10 +669,7 @@ def __call__(self, *args) -> Any:
662669

663670
# finished compilations for all required shapes
664671
if self.is_last_graph and not self.to_be_compiled_sizes:
665-
666-
# save the hash of the inductor graph for the next run
667-
self.compilation_config.inductor_hash_cache.save_to_file()
668-
end_monitoring_torch_compile(self.vllm_config)
672+
self.check_for_ending_compilation()
669673

670674
if not entry.use_cudagraph:
671675
return entry.runnable(*args)

vllm/config.py

Lines changed: 42 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -9,8 +9,8 @@
99
from dataclasses import dataclass, field, replace
1010
from pathlib import Path
1111
from typing import (TYPE_CHECKING, Any, Callable, ClassVar, Counter, Dict,
12-
Final, List, Literal, Mapping, Optional, Set, Tuple, Type,
13-
Union)
12+
Final, List, Literal, Mapping, Optional, Protocol, Set,
13+
Tuple, Type, Union)
1414

1515
import torch
1616
from pydantic import BaseModel, Field, PrivateAttr
@@ -75,6 +75,12 @@
7575
PretrainedConfig]]
7676

7777

78+
class SupportsHash(Protocol):
79+
80+
def compute_hash(self) -> str:
81+
...
82+
83+
7884
class ModelConfig:
7985
"""Configuration for the model.
8086
@@ -2969,6 +2975,10 @@ class VllmConfig:
29692975
init=True) # type: ignore
29702976
kv_transfer_config: KVTransferConfig = field(default=None,
29712977
init=True) # type: ignore
2978+
# some opaque config, only used to provide additional information
2979+
# for the hash computation, mainly used for testing and debugging.
2980+
additional_config: SupportsHash = field(default=None,
2981+
init=True) # type: ignore
29722982
instance_id: str = ""
29732983

29742984
def compute_hash(self) -> str:
@@ -3000,33 +3010,62 @@ def compute_hash(self) -> str:
30003010
vllm_factors.append(__version__)
30013011
if self.model_config:
30023012
vllm_factors.append(self.model_config.compute_hash())
3013+
else:
3014+
vllm_factors.append("None")
30033015
if self.cache_config:
30043016
vllm_factors.append(self.cache_config.compute_hash())
3017+
else:
3018+
vllm_factors.append("None")
30053019
if self.parallel_config:
30063020
vllm_factors.append(self.parallel_config.compute_hash())
3021+
else:
3022+
vllm_factors.append("None")
30073023
if self.scheduler_config:
30083024
vllm_factors.append(self.scheduler_config.compute_hash())
3025+
else:
3026+
vllm_factors.append("None")
30093027
if self.device_config:
30103028
vllm_factors.append(self.device_config.compute_hash())
3029+
else:
3030+
vllm_factors.append("None")
30113031
if self.load_config:
30123032
vllm_factors.append(self.load_config.compute_hash())
3033+
else:
3034+
vllm_factors.append("None")
30133035
if self.lora_config:
30143036
vllm_factors.append(self.lora_config.compute_hash())
3037+
else:
3038+
vllm_factors.append("None")
30153039
if self.speculative_config:
30163040
vllm_factors.append(self.speculative_config.compute_hash())
3041+
else:
3042+
vllm_factors.append("None")
30173043
if self.decoding_config:
30183044
vllm_factors.append(self.decoding_config.compute_hash())
3045+
else:
3046+
vllm_factors.append("None")
30193047
if self.observability_config:
30203048
vllm_factors.append(self.observability_config.compute_hash())
3049+
else:
3050+
vllm_factors.append("None")
30213051
if self.prompt_adapter_config:
30223052
vllm_factors.append(self.prompt_adapter_config.compute_hash())
3053+
else:
3054+
vllm_factors.append("None")
30233055
if self.quant_config:
30243056
pass # should be captured by model_config.quantization
30253057
if self.compilation_config:
30263058
vllm_factors.append(self.compilation_config.compute_hash())
3059+
else:
3060+
vllm_factors.append("None")
30273061
if self.kv_transfer_config:
30283062
vllm_factors.append(self.kv_transfer_config.compute_hash())
3029-
3063+
else:
3064+
vllm_factors.append("None")
3065+
if self.additional_config:
3066+
vllm_factors.append(self.additional_config.compute_hash())
3067+
else:
3068+
vllm_factors.append("None")
30303069
factors.append(vllm_factors)
30313070

30323071
hash_str = hashlib.md5(str(factors).encode()).hexdigest()[:10]

vllm/v1/worker/gpu_worker.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -48,6 +48,7 @@ def __init__(
4848
self.prompt_adapter_config = vllm_config.prompt_adapter_config
4949
self.observability_config = vllm_config.observability_config
5050

51+
self.parallel_config.rank = rank
5152
self.local_rank = local_rank
5253
self.rank = rank
5354
self.distributed_init_method = distributed_init_method

0 commit comments

Comments
 (0)