Skip to content

Commit e1ad38d

Browse files
youkaichaoIsotr0py
authored andcommitted
[core] separate builder init and builder prepare for each batch (vllm-project#12253)
Signed-off-by: youkaichao <[email protected]> Signed-off-by: Isotr0py <[email protected]>
1 parent 7f4a37d commit e1ad38d

10 files changed

+90
-47
lines changed

vllm/attention/backends/abstract.py

Lines changed: 6 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -65,11 +65,6 @@ def make_metadata(cls, *args, **kwargs) -> "AttentionMetadata":
6565
def get_builder_cls() -> Type["AttentionMetadataBuilder"]:
6666
raise NotImplementedError
6767

68-
@classmethod
69-
def make_metadata_builder(cls, *args,
70-
**kwargs) -> "AttentionMetadataBuilder":
71-
return cls.get_builder_cls()(*args, **kwargs)
72-
7368
@staticmethod
7469
@abstractmethod
7570
def get_kv_cache_shape(
@@ -214,6 +209,12 @@ class AttentionMetadataBuilder(ABC, Generic[T]):
214209

215210
@abstractmethod
216211
def __init__(self, input_builder: "ModelRunnerInputBuilderBase") -> None:
212+
"""Create the builder, remember some configuration and parameters."""
213+
raise NotImplementedError
214+
215+
@abstractmethod
216+
def prepare(self) -> None:
217+
"""Prepare for one batch."""
217218
raise NotImplementedError
218219

219220
@abstractmethod

vllm/attention/backends/flash_attn.py

Lines changed: 6 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -375,6 +375,12 @@ class FlashAttentionMetadataBuilder(
375375
AttentionMetadataBuilder[FlashAttentionMetadata]):
376376

377377
def __init__(self, input_builder: "ModelInputForGPUBuilder"):
378+
self.input_builder = input_builder
379+
self.runner = input_builder.runner
380+
self.sliding_window = input_builder.sliding_window
381+
self.block_size = input_builder.block_size
382+
383+
def prepare(self):
378384
self.slot_mapping: List[int] = []
379385
self.prefill_seq_lens: List[int] = []
380386
self.context_lens: List[int] = []
@@ -388,11 +394,6 @@ def __init__(self, input_builder: "ModelInputForGPUBuilder"):
388394
self.num_decode_tokens = 0
389395
self.has_prefix_cache_hit = False
390396

391-
self.input_builder = input_builder
392-
self.runner = input_builder.runner
393-
self.sliding_window = input_builder.sliding_window
394-
self.block_size = input_builder.block_size
395-
396397
def _add_seq_group(
397398
self, inter_data: "ModelInputForGPUBuilder.InterDataForSeqGroup",
398399
chunked_prefill_enabled: bool, prefix_cache_hit: bool):

vllm/attention/backends/flashinfer.py

Lines changed: 8 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -488,6 +488,14 @@ def advance_step(self,
488488
class FlashInferMetadataBuilder(AttentionMetadataBuilder[FlashInferMetadata]):
489489

490490
def __init__(self, input_builder: "ModelInputForGPUBuilder"):
491+
492+
self.input_builder = input_builder
493+
self.runner = input_builder.runner
494+
495+
self.sliding_window = input_builder.sliding_window
496+
self.block_size = input_builder.block_size
497+
498+
def prepare(self):
491499
self.slot_mapping: List[int] = []
492500
self.prefill_seq_lens: List[int] = []
493501
self.context_lens: List[int] = []
@@ -500,12 +508,6 @@ def __init__(self, input_builder: "ModelInputForGPUBuilder"):
500508
self.num_prefill_tokens = 0
501509
self.num_decode_tokens = 0
502510

503-
self.input_builder = input_builder
504-
self.runner = input_builder.runner
505-
506-
self.sliding_window = input_builder.sliding_window
507-
self.block_size = input_builder.block_size
508-
509511
# Please follow https://docs.flashinfer.ai/tutorials/kv_layout.html#page-layout
510512
# for the precise definition of the following fields.
511513
# An example:

vllm/attention/backends/placeholder_attn.py

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -253,6 +253,11 @@ class PlaceholderAttentionMetadataBuilder(
253253
AttentionMetadataBuilder[PlaceholderAttentionMetadata]):
254254

255255
def __init__(self, input_builder: "ModelInputForGPUBuilder"):
256+
257+
self.input_builder = input_builder
258+
self.runner = input_builder.runner
259+
260+
def prepare(self):
256261
self.prefill_seq_lens: List[int] = []
257262
self.context_lens: List[int] = []
258263
self.curr_seq_lens: List[int] = []
@@ -263,9 +268,6 @@ def __init__(self, input_builder: "ModelInputForGPUBuilder"):
263268
self.num_prefill_tokens = 0
264269
self.num_decode_tokens = 0
265270

266-
self.input_builder = input_builder
267-
self.runner = input_builder.runner
268-
269271
def _add_seq_group(
270272
self, inter_data: "ModelInputForGPUBuilder.InterDataForSeqGroup",
271273
chunked_prefill_enabled: bool):

vllm/attention/backends/torch_sdpa.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -282,7 +282,10 @@ class TorchSDPAMetadataBuilder(AttentionMetadataBuilder[TorchSDPAMetadata]):
282282

283283
def __init__(self, input_builder: ModelInputForCPUBuilder) -> None:
284284
self.chunked_prefill = input_builder.chunked_prefill
285-
self.input_data = input_builder.input_data
285+
self.input_builder = input_builder
286+
287+
def prepare(self):
288+
self.input_data = self.input_builder.input_data
286289

287290
def build(self, seq_lens: List[int], query_lens: List[int],
288291
cuda_graph_pad_size: int, batch_size: int) -> TorchSDPAMetadata:

vllm/attention/backends/utils.py

Lines changed: 7 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -122,6 +122,13 @@ class CommonMetadataBuilder(AttentionMetadataBuilder[TAttentionMetadata]):
122122
_metadata_cls: Type[TAttentionMetadata]
123123

124124
def __init__(self, input_builder: "ModelInputForGPUBuilder"):
125+
self.input_builder = input_builder
126+
self.runner = input_builder.runner
127+
128+
self.sliding_window = input_builder.sliding_window
129+
self.block_size = input_builder.block_size
130+
131+
def prepare(self):
125132
self.slot_mapping: List[int] = []
126133
self.prefill_seq_lens: List[int] = []
127134
self.context_lens: List[int] = []
@@ -134,12 +141,6 @@ def __init__(self, input_builder: "ModelInputForGPUBuilder"):
134141
self.num_prefill_tokens = 0
135142
self.num_decode_tokens = 0
136143

137-
self.input_builder = input_builder
138-
self.runner = input_builder.runner
139-
140-
self.sliding_window = input_builder.sliding_window
141-
self.block_size = input_builder.block_size
142-
143144
def _add_seq_group(
144145
self, inter_data: "ModelInputForGPUBuilder.InterDataForSeqGroup",
145146
chunked_prefill_enabled: bool):

vllm/worker/cpu_model_runner.py

Lines changed: 17 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -144,9 +144,7 @@ def __init__(self,
144144
runner: "CPUModelRunner",
145145
finished_requests_ids: Optional[List[str]] = None) -> None:
146146
super().__init__()
147-
self.seq_group_metadata_list: List[SequenceGroupMetadata] = []
148147
self.runner = runner
149-
150148
self.chunked_prefill = (runner.scheduler_config.chunked_prefill_enabled
151149
or runner.cache_config.enable_prefix_caching)
152150
self.model_input_cls = self.runner._model_input_cls
@@ -156,10 +154,17 @@ def __init__(self,
156154
self.device = self.runner.device
157155
self.multi_modal_input_mapper = self.runner.multi_modal_input_mapper
158156
self.enable_lora = self.runner.lora_config is not None
157+
if self.runner.attn_backend is not None:
158+
# spec decode (e.g. Medusa) does not have atten backend
159+
attn_backend = self.runner.attn_backend
160+
self.att_metadata_builder = attn_backend.get_builder_cls()(self)
161+
162+
def prepare(self,
163+
finished_requests_ids: Optional[List[str]] = None) -> None:
164+
self.seq_group_metadata_list: List[SequenceGroupMetadata] = []
159165
self.input_data = ModelInputForCPUBuilder.ModelInputData(
160166
self.runner.model_config.uses_mrope)
161-
self.att_metadata_builder = self.runner.attn_backend.get_builder_cls()(
162-
self)
167+
self.att_metadata_builder.prepare()
163168

164169
def add_seq_group(self, seq_group_metadata: SequenceGroupMetadata):
165170
self.seq_group_metadata_list.append(seq_group_metadata)
@@ -431,6 +436,7 @@ class CPUModelRunnerBase(ModelRunnerBase[TModelInputForCPU]):
431436
"""
432437
_model_input_cls: Type[TModelInputForCPU]
433438
_builder_cls: Type[ModelInputForCPUBuilder]
439+
builder: ModelInputForCPUBuilder
434440

435441
def __init__(
436442
self,
@@ -477,6 +483,10 @@ def __init__(
477483
# Set after load_model.
478484
self.lora_manager: Optional[LRUCacheWorkerLoRAManager] = None
479485

486+
if hasattr(self, "_builder_cls"):
487+
# multi-step model runner does not have `_builder_cls`
488+
self.builder = self._builder_cls(weakref.proxy(self))
489+
480490
def load_model(self) -> None:
481491
self.model = get_model(vllm_config=self.vllm_config)
482492

@@ -522,10 +532,10 @@ def _prepare_model_input_tensors(
522532
metadata for possible additional steps, e.g., sampling.
523533
524534
"""
525-
builder = self._builder_cls(weakref.proxy(self), finished_requests_ids)
526-
builder.set_seq_group_list(seq_group_metadata_list)
535+
self.builder.prepare(finished_requests_ids)
536+
self.builder.set_seq_group_list(seq_group_metadata_list)
527537

528-
return builder.build() # type: ignore
538+
return self.builder.build() # type: ignore
529539

530540
# sampler property will be used by spec_decode_worker
531541
@property

vllm/worker/model_runner.py

Lines changed: 24 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -457,17 +457,13 @@ def __init__(self,
457457
self.enable_prompt_adapter = (self.runner.prompt_adapter_config
458458
is not None)
459459
self.multi_modal_input_mapper = self.runner.multi_modal_input_mapper
460-
self.finished_requests_ids = finished_requests_ids
461460
self.decode_only = True
462461

463-
# Intermediate data (data in CPU before going to GPU) for
464-
# the current sequence group.
465-
self.inter_data_list: List[
466-
ModelInputForGPUBuilder.InterDataForSeqGroup] = []
467-
468462
# Attention metadata inputs.
469-
self.attn_metadata_builder = self.attn_backend.make_metadata_builder(
470-
weakref.proxy(self))
463+
if self.attn_backend is not None:
464+
# spec decode (e.g. Medusa) does not have atten backend
465+
self.attn_metadata_builder = self.attn_backend.get_builder_cls()(
466+
weakref.proxy(self))
471467

472468
# Engine/Model configurations.
473469
self.chunked_prefill_enabled = (
@@ -479,6 +475,17 @@ def __init__(self,
479475
self.block_aligned_sliding_window = \
480476
self.sliding_window_blocks * self.block_size
481477

478+
def prepare(self,
479+
finished_requests_ids: Optional[List[str]] = None) -> None:
480+
self.finished_requests_ids = finished_requests_ids
481+
482+
# Intermediate data (data in CPU before going to GPU) for
483+
# the current sequence group.
484+
self.inter_data_list: List[
485+
ModelInputForGPUBuilder.InterDataForSeqGroup] = []
486+
487+
self.attn_metadata_builder.prepare()
488+
482489
def _compute_lens(self, inter_data: InterDataForSeqGroup, seq_idx: int,
483490
seq_group_metadata: SequenceGroupMetadata):
484491
"""Compute context length, sequence length and tokens
@@ -993,6 +1000,7 @@ class GPUModelRunnerBase(ModelRunnerBase[TModelInputForGPU]):
9931000
"""
9941001
_model_input_cls: Type[TModelInputForGPU]
9951002
_builder_cls: Type[ModelInputForGPUBuilder]
1003+
builder: ModelInputForGPUBuilder
9961004

9971005
def __init__(
9981006
self,
@@ -1093,6 +1101,10 @@ def __init__(
10931101
SamplingMetadataCache() \
10941102
if self.parallel_config.pipeline_parallel_size == 1 else None
10951103

1104+
if hasattr(self, "_builder_cls"):
1105+
# multi-step model runner does not have `_builder_cls`
1106+
self.builder = self._builder_cls(weakref.proxy(self))
1107+
10961108
def load_model(self) -> None:
10971109
logger.info("Starting to load model %s...", self.model_config.model)
10981110
with DeviceMemoryProfiler() as m:
@@ -1226,13 +1238,13 @@ def _prepare_model_input_tensors(
12261238
12271239
If cuda graph is required, this API automatically pads inputs.
12281240
"""
1229-
builder = self._builder_cls(weakref.proxy(self), finished_requests_ids)
1241+
self.builder.prepare(finished_requests_ids)
12301242
for seq_group_metadata in seq_group_metadata_list:
1231-
builder.add_seq_group(seq_group_metadata)
1243+
self.builder.add_seq_group(seq_group_metadata)
12321244

1233-
builder.reset_cached_inter_data()
1245+
self.builder.reset_cached_inter_data()
12341246

1235-
return builder.build() # type: ignore
1247+
return self.builder.build() # type: ignore
12361248

12371249
@contextmanager
12381250
def set_in_profile_run(self):

vllm/worker/model_runner_base.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -200,6 +200,11 @@ class ModelRunnerInputBuilderBase(ABC, Generic[T]):
200200
"""A builder to create ModelRunnerInputBase objects.
201201
"""
202202

203+
@abstractmethod
204+
def prepare(self,
205+
finished_requests_ids: Optional[List[str]] = None) -> None:
206+
raise NotImplementedError
207+
203208
@abstractmethod
204209
def add_seq_group(self, seq_group_metadata):
205210
"""TBA"""

vllm/worker/xpu_model_runner.py

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -113,14 +113,17 @@ def __init__(self,
113113
runner: "XPUModelRunner",
114114
finished_requests_ids: Optional[List[str]] = None) -> None:
115115
super().__init__()
116-
self.seq_group_metadata_list: List[SequenceGroupMetadata] = []
117116
self.runner = runner
118117
self.model_input_cls = self.runner._model_input_cls
119118
self.attn_backend = self.runner.attn_backend
120119
self.sliding_window = self.runner.sliding_window
121120
self.block_size = self.runner.block_size
122121
self.device = self.runner.device
123122

123+
def prepare(self,
124+
finished_requests_ids: Optional[List[str]] = None) -> None:
125+
self.seq_group_metadata_list: List[SequenceGroupMetadata] = []
126+
124127
def add_seq_group(self, seq_group_metadata: SequenceGroupMetadata):
125128
self.seq_group_metadata_list.append(seq_group_metadata)
126129

@@ -408,6 +411,8 @@ def __init__(
408411
SamplingMetadataCache() \
409412
if self.parallel_config.pipeline_parallel_size == 1 else None
410413

414+
self.builder = self._builder_cls(weakref.proxy(self))
415+
411416
def load_model(self) -> None:
412417
with DeviceMemoryProfiler() as m:
413418
self.model = get_model(vllm_config=self.vllm_config)
@@ -517,7 +522,8 @@ def _prepare_model_input_tensors(
517522
metadata for possible additional steps, e.g., sampling.
518523
519524
"""
520-
builder = self._builder_cls(weakref.proxy(self), finished_requests_ids)
525+
builder = self.builder
526+
builder.prepare(finished_requests_ids)
521527
for seq_group_metadata in seq_group_metadata_list:
522528
builder.add_seq_group(seq_group_metadata)
523529

0 commit comments

Comments
 (0)