Skip to content

Commit 1f664ef

Browse files
youkaichaotjtanaa
authored andcommitted
[torch.compile] decouple compile sizes and cudagraph sizes (vllm-project#12243)
Signed-off-by: youkaichao <[email protected]>
1 parent f10e75d commit 1f664ef

File tree

7 files changed

+95
-58
lines changed

7 files changed

+95
-58
lines changed

vllm/compilation/backends.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -680,7 +680,7 @@ def copy_and_call(*args):
680680
class ConcreteSizeEntry:
681681
runtime_shape: int
682682
need_to_compile: bool # the size is in compile_sizes
683-
use_cudagraph: bool # the size is in capture_sizes
683+
use_cudagraph: bool # the size is in cudagraph_capture_sizes
684684

685685
compiled: bool = False
686686
runnable: Callable = None # type: ignore
@@ -727,8 +727,8 @@ def __init__(self, graph: fx.GraphModule, vllm_config: VllmConfig,
727727

728728
self.compile_sizes: Set[int] = set(
729729
self.compilation_config.compile_sizes)
730-
self.capture_sizes: Set[int] = set(
731-
self.compilation_config.capture_sizes
730+
self.cudagraph_capture_sizes: Set[int] = set(
731+
self.compilation_config.cudagraph_capture_sizes
732732
) if self.compilation_config.use_cudagraph else set()
733733

734734
self.first_run_finished = False
@@ -746,11 +746,11 @@ def __init__(self, graph: fx.GraphModule, vllm_config: VllmConfig,
746746
# to_be_compiled_sizes tracks the remaining sizes to compile,
747747
# and updates during the compilation process, so we need to copy it
748748
self.to_be_compiled_sizes: Set[int] = self.compile_sizes.copy()
749-
for shape in self.compile_sizes.union(self.capture_sizes):
749+
for shape in self.compile_sizes.union(self.cudagraph_capture_sizes):
750750
self.concrete_size_entries[shape] = ConcreteSizeEntry(
751751
runtime_shape=shape,
752752
need_to_compile=shape in self.compile_sizes,
753-
use_cudagraph=shape in self.capture_sizes,
753+
use_cudagraph=shape in self.cudagraph_capture_sizes,
754754
)
755755

756756
def check_for_ending_compilation(self):

vllm/config.py

Lines changed: 37 additions & 34 deletions
Original file line numberDiff line numberDiff line change
@@ -2711,10 +2711,11 @@ class CompilationConfig(BaseModel):
27112711
- use_inductor: whether to use inductor compilation.
27122712
- False: inductor compilation is not used. graph runs in eager.
27132713
- True: inductor compilation is used. one graph for symbolic shape
2714-
is compiled. In addition, compile for cudagraph sizes that are
2715-
in candidate_compile_sizes, using configurations
2716-
in inductor_compile_config.
2717-
- candidate_compile_sizes: sizes to compile for inductor.
2714+
is compiled. In addition, compile for compile_sizes,
2715+
using configurations in inductor_compile_config.
2716+
- compile_sizes: sizes to compile for inductor. In addition
2717+
to integers, it also supports "cudagraph_capture_sizes" to
2718+
specify the sizes for cudagraph capture.
27182719
- inductor_compile_config: additional configurations for inductor.
27192720
- None: use default configurations.
27202721
- inductor_passes: additional passes for inductor. It is a dictionary
@@ -2742,7 +2743,7 @@ class CompilationConfig(BaseModel):
27422743
splitting_ops: List[str] = Field(default=None) # type: ignore
27432744

27442745
use_inductor: bool = True
2745-
candidate_compile_sizes: Optional[List[int]] = Field(default=None)
2746+
compile_sizes: Optional[List[Union[int, str]]] = Field(default=None)
27462747
inductor_compile_config: Dict = Field(default_factory=dict)
27472748
inductor_passes: Dict[str, str] = Field(default_factory=dict)
27482749

@@ -2790,8 +2791,6 @@ def model_post_init(self, __context: Any) -> None:
27902791
pass_config: PassConfig = Field(default_factory=PassConfig)
27912792

27922793
# not configurable, computed after init
2793-
compile_sizes: List[int] = PrivateAttr
2794-
capture_sizes: List[int] = PrivateAttr
27952794
max_capture_size: int = PrivateAttr
27962795
local_cache_dir: str = PrivateAttr # local cache dir for each rank
27972796
# optimization:
@@ -2918,43 +2917,47 @@ def init_backend(self, vllm_config: "VllmConfig") -> Union[str, Callable]:
29182917
from vllm.compilation.backends import VllmBackend
29192918
return VllmBackend(vllm_config)
29202919

2921-
def init_with_cudagraph_sizes(self, sizes_to_specialize: List[int]):
2920+
def init_with_cudagraph_sizes(self,
2921+
cudagraph_capture_sizes: List[int]) -> None:
29222922
"""To complete the initialization of config,
29232923
we need to know the cudagraph sizes."""
29242924

29252925
if self.cudagraph_capture_sizes is None:
2926-
self.capture_sizes = sizes_to_specialize
2926+
self.cudagraph_capture_sizes = cudagraph_capture_sizes
29272927
else:
2928-
self.capture_sizes = self.cudagraph_capture_sizes
2928+
# de-duplicate the sizes provided by the config
2929+
self.cudagraph_capture_sizes = list(
2930+
set(self.cudagraph_capture_sizes))
29292931
logger.info(("cudagraph sizes specified by model runner"
29302932
" %s is overridden by config %s"),
2931-
sizes_to_specialize, self.cudagraph_capture_sizes)
2932-
2933-
if self.candidate_compile_sizes is None:
2934-
self.candidate_compile_sizes = []
2935-
self.compile_sizes = [
2936-
x for x in self.candidate_compile_sizes if x in self.capture_sizes
2937-
]
2938-
ignored_sizes = [
2939-
x for x in self.candidate_compile_sizes
2940-
if x not in self.capture_sizes
2941-
]
2942-
if ignored_sizes:
2943-
logger.warning(("candidate_compile_sizes %s are ignored "
2944-
"because they are not cudagraph capture sizes."),
2945-
ignored_sizes)
2933+
cudagraph_capture_sizes, self.cudagraph_capture_sizes)
2934+
2935+
computed_compile_sizes = []
2936+
if self.compile_sizes is not None:
2937+
# de-duplicate the sizes provided by the config
2938+
self.compile_sizes = list(set(self.compile_sizes))
2939+
for x in self.compile_sizes:
2940+
if isinstance(x, str):
2941+
assert x == "cudagraph_capture_sizes", \
2942+
"Unrecognized size type in compile_sizes, " \
2943+
f"expect 'cudagraph_capture_sizes', got {x}"
2944+
computed_compile_sizes.extend(self.cudagraph_capture_sizes)
2945+
else:
2946+
assert isinstance(x, int)
2947+
computed_compile_sizes.append(x)
2948+
self.compile_sizes = computed_compile_sizes # type: ignore
29462949

29472950
# sort to make sure cudagraph capture sizes are in descending order
2948-
self.capture_sizes.sort(reverse=True)
2949-
self.max_capture_size = self.capture_sizes[
2950-
0] if self.capture_sizes else 0
2951+
self.cudagraph_capture_sizes.sort(reverse=True)
2952+
self.max_capture_size = self.cudagraph_capture_sizes[
2953+
0] if self.cudagraph_capture_sizes else 0
29512954

29522955
# pre-compute the mapping from batch size to padded graph size
29532956
self.bs_to_padded_graph_size = [
29542957
0 for i in range(self.max_capture_size + 1)
29552958
]
2956-
for end, start in zip(self.capture_sizes,
2957-
self.capture_sizes[1:] + [0]):
2959+
for end, start in zip(self.cudagraph_capture_sizes,
2960+
self.cudagraph_capture_sizes[1:] + [0]):
29582961
for bs in range(start, end):
29592962
if bs == start:
29602963
self.bs_to_padded_graph_size[bs] = start
@@ -3225,14 +3228,14 @@ def _set_cudagraph_sizes(self):
32253228
However, if users specify the cudagraph capture sizes through
32263229
compilation config, we will use the specified sizes instead.
32273230
3228-
In the end, `vllm_config.compilation_config.capture_sizes` will be the
3229-
final sizes to capture cudagraph (in descending order).
3231+
In the end, `vllm_config.compilation_config.cudagraph_capture_sizes`
3232+
will be the final sizes to capture cudagraph (in descending order).
32303233
32313234
During runtime, if batchsize is larger than
3232-
`vllm_config.compilation_config.capture_sizes`,
3235+
`vllm_config.compilation_config.cudagraph_capture_sizes`,
32333236
no cudagraph will be used.
32343237
If the batch size is no larger than
3235-
`vllm_config.compilation_config.capture_sizes`,
3238+
`vllm_config.compilation_config.cudagraph_capture_sizes`,
32363239
we can quickly find the padded graph size for a given batch size by
32373240
looking up `vllm_config.compilation_config.bs_to_padded_graph_size`.
32383241
"""

vllm/engine/metrics.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -120,7 +120,8 @@ def __init__(self, labelnames: List[str], vllm_config: VllmConfig):
120120
labelnames=labelnames)
121121
buckets = [1, 8, 16, 32, 64, 128, 256, 512, 1024, 2048, 4096, 8096]
122122
if not vllm_config.model_config.enforce_eager:
123-
buckets = vllm_config.compilation_config.capture_sizes.copy()
123+
buckets = vllm_config.compilation_config.\
124+
cudagraph_capture_sizes.copy()
124125
buckets.sort()
125126
self.histogram_iteration_tokens = self._histogram_cls(
126127
name="vllm:iteration_tokens_total",

vllm/v1/worker/gpu_model_runner.py

Lines changed: 10 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
import gc
22
import time
3-
from typing import TYPE_CHECKING, Dict, List, Tuple, cast
3+
from typing import TYPE_CHECKING, Dict, List, Optional, Tuple, cast
44

55
import numpy as np
66
import torch
@@ -128,7 +128,8 @@ def __init__(
128128
# self.cudagraph_batch_sizes sorts in ascending order.
129129
# The batch sizes in the config are in descending order.
130130
self.cudagraph_batch_sizes = list(
131-
reversed(self.vllm_config.compilation_config.capture_sizes))
131+
reversed(
132+
self.vllm_config.compilation_config.cudagraph_capture_sizes))
132133

133134
# Cache the device properties.
134135
self.device_properties = torch.cuda.get_device_properties(self.device)
@@ -834,10 +835,12 @@ def load_model(self) -> None:
834835
@torch.inference_mode()
835836
def _dummy_run(
836837
self,
837-
model: nn.Module,
838838
num_tokens: int,
839-
kv_caches: List[torch.Tensor],
839+
kv_caches: Optional[List[torch.Tensor]] = None,
840840
) -> torch.Tensor:
841+
model = self.model
842+
if kv_caches is None:
843+
kv_caches = self.kv_caches
841844
if self.is_multimodal_model:
842845
input_ids = None
843846
inputs_embeds = self.inputs_embeds[:num_tokens]
@@ -963,8 +966,7 @@ def profile_run(self) -> None:
963966
self.encoder_cache["tmp"] = dict(enumerate(dummy_encoder_outputs))
964967

965968
# Trigger compilation for general shape.
966-
hidden_states = self._dummy_run(self.model, self.max_num_tokens,
967-
dummy_kv_caches)
969+
hidden_states = self._dummy_run(self.max_num_tokens, dummy_kv_caches)
968970
logits = self.model.compute_logits(hidden_states, None)
969971
logits = logits[:self.max_num_tokens]
970972
# TODO(woosuk): Consider the memory usage of the sampler.
@@ -990,8 +992,8 @@ def capture_model(self) -> None:
990992
for num_tokens in reversed(self.cudagraph_batch_sizes):
991993
for _ in range(self.vllm_config.compilation_config.
992994
cudagraph_num_of_warmups):
993-
self._dummy_run(self.model, num_tokens, self.kv_caches)
994-
self._dummy_run(self.model, num_tokens, self.kv_caches)
995+
self._dummy_run(num_tokens)
996+
self._dummy_run(num_tokens)
995997

996998
end_time = time.perf_counter()
997999
end_free_gpu_memory = torch.cuda.mem_get_info()[0]

vllm/v1/worker/gpu_worker.py

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -206,6 +206,18 @@ def initialize_cache(self, kv_cache_config: KVCacheConfig) -> None:
206206
self.model_runner.initialize_kv_cache(kv_cache_config)
207207

208208
def compile_or_warm_up_model(self) -> None:
209+
# warm up sizes that are not in cudagraph capture sizes,
210+
# but users still want to compile for better performance,
211+
# e.g. for the max-num-batched token size in chunked prefill.
212+
warmup_sizes = self.vllm_config.compilation_config.compile_sizes.copy()
213+
if not self.model_config.enforce_eager:
214+
warmup_sizes = [
215+
x for x in warmup_sizes if x not in
216+
self.vllm_config.compilation_config.cudagraph_capture_sizes
217+
]
218+
for size in sorted(warmup_sizes, reverse=True):
219+
logger.info("Compile and warming up model for size %d", size)
220+
self.model_runner._dummy_run(size)
209221
if not self.model_config.enforce_eager:
210222
self.model_runner.capture_model()
211223
# Reset the seed to ensure that the random state is not affected by

vllm/worker/model_runner.py

Lines changed: 17 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -1256,13 +1256,19 @@ def set_in_profile_run(self):
12561256

12571257
@torch.inference_mode()
12581258
def profile_run(self) -> None:
1259+
max_num_batched_tokens = \
1260+
self.scheduler_config.max_num_batched_tokens
1261+
max_num_seqs = self.scheduler_config.max_num_seqs
1262+
self._dummy_run(max_num_batched_tokens, max_num_seqs)
1263+
1264+
def _dummy_run(self,
1265+
max_num_batched_tokens: int,
1266+
max_num_seqs: int = 1) -> None:
12591267
with self.set_in_profile_run():
12601268
# Enable top-k sampling to reflect the accurate memory usage.
12611269
sampling_params = \
12621270
SamplingParams(top_p=0.99, top_k=self.vocab_size - 1)
1263-
max_num_batched_tokens = \
1264-
self.scheduler_config.max_num_batched_tokens
1265-
max_num_seqs = self.scheduler_config.max_num_seqs
1271+
12661272
# This represents the maximum number of different requests
12671273
# that will have unique loras, an therefore the max amount of memory
12681274
# consumption create dummy lora request copies from the lora request
@@ -1491,13 +1497,14 @@ def capture_model(self, kv_caches: List[List[torch.Tensor]]) -> None:
14911497
for virtual_engine in range(
14921498
self.parallel_config.pipeline_parallel_size):
14931499
# Only rank 0 should print progress bar during capture
1494-
capture_sizes = (
1495-
tqdm(
1496-
self.vllm_config.compilation_config.capture_sizes,
1497-
desc="Capturing CUDA graph shapes",
1498-
) if get_tensor_model_parallel_rank() == 0 else
1499-
self.vllm_config.compilation_config.capture_sizes)
1500-
for batch_size in capture_sizes:
1500+
cudagraph_capture_sizes = (tqdm(
1501+
self.vllm_config.compilation_config.
1502+
cudagraph_capture_sizes,
1503+
desc="Capturing CUDA graph shapes",
1504+
) if get_tensor_model_parallel_rank() == 0 else
1505+
self.vllm_config.compilation_config.
1506+
cudagraph_capture_sizes)
1507+
for batch_size in cudagraph_capture_sizes:
15011508
attn_metadata = (
15021509
self.attn_state.graph_capture_get_metadata_for_batch(
15031510
batch_size,

vllm/worker/worker.py

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -323,6 +323,18 @@ def _init_cache_engine(self):
323323
self.gpu_cache)
324324

325325
def _warm_up_model(self) -> None:
326+
# warm up sizes that are not in cudagraph capture sizes,
327+
# but users still want to compile for better performance,
328+
# e.g. for the max-num-batched token size in chunked prefill.
329+
warmup_sizes = self.vllm_config.compilation_config.compile_sizes.copy()
330+
if not self.model_config.enforce_eager:
331+
warmup_sizes = [
332+
x for x in warmup_sizes if x not in
333+
self.vllm_config.compilation_config.cudagraph_capture_sizes
334+
]
335+
for size in sorted(warmup_sizes, reverse=True):
336+
logger.info("Compile and warming up model for size %d", size)
337+
self.model_runner._dummy_run(size)
326338
if not self.model_config.enforce_eager:
327339
self.model_runner.capture_model(self.gpu_cache)
328340
# Reset the seed to ensure that the random state is not affected by

0 commit comments

Comments
 (0)