Skip to content

Commit a369a47

Browse files
authored
Revert "[core] separate builder init and builder prepare for each batch (#12253)"
This reverts commit 66818e5.
1 parent 24b0205 commit a369a47

10 files changed

+47
-90
lines changed

vllm/attention/backends/abstract.py

Lines changed: 5 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -65,6 +65,11 @@ def make_metadata(cls, *args, **kwargs) -> "AttentionMetadata":
6565
def get_builder_cls() -> Type["AttentionMetadataBuilder"]:
6666
raise NotImplementedError
6767

68+
@classmethod
69+
def make_metadata_builder(cls, *args,
70+
**kwargs) -> "AttentionMetadataBuilder":
71+
return cls.get_builder_cls()(*args, **kwargs)
72+
6873
@staticmethod
6974
@abstractmethod
7075
def get_kv_cache_shape(
@@ -213,12 +218,6 @@ class AttentionMetadataBuilder(ABC, Generic[T]):
213218

214219
@abstractmethod
215220
def __init__(self, input_builder: "ModelRunnerInputBuilderBase") -> None:
216-
"""Create the builder, remember some configuration and parameters."""
217-
raise NotImplementedError
218-
219-
@abstractmethod
220-
def prepare(self) -> None:
221-
"""Prepare for one batch."""
222221
raise NotImplementedError
223222

224223
@abstractmethod

vllm/attention/backends/flash_attn.py

Lines changed: 5 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -380,12 +380,6 @@ class FlashAttentionMetadataBuilder(
380380
AttentionMetadataBuilder[FlashAttentionMetadata]):
381381

382382
def __init__(self, input_builder: "ModelInputForGPUBuilder"):
383-
self.input_builder = input_builder
384-
self.runner = input_builder.runner
385-
self.sliding_window = input_builder.sliding_window
386-
self.block_size = input_builder.block_size
387-
388-
def prepare(self):
389383
self.slot_mapping: List[int] = []
390384
self.prefill_seq_lens: List[int] = []
391385
self.context_lens: List[int] = []
@@ -399,6 +393,11 @@ def prepare(self):
399393
self.num_decode_tokens = 0
400394
self.has_prefix_cache_hit = False
401395

396+
self.input_builder = input_builder
397+
self.runner = input_builder.runner
398+
self.sliding_window = input_builder.sliding_window
399+
self.block_size = input_builder.block_size
400+
402401
def _add_seq_group(
403402
self, inter_data: "ModelInputForGPUBuilder.InterDataForSeqGroup",
404403
chunked_prefill_enabled: bool, prefix_cache_hit: bool):

vllm/attention/backends/flashinfer.py

Lines changed: 6 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -489,14 +489,6 @@ def advance_step(self,
489489
class FlashInferMetadataBuilder(AttentionMetadataBuilder[FlashInferMetadata]):
490490

491491
def __init__(self, input_builder: "ModelInputForGPUBuilder"):
492-
493-
self.input_builder = input_builder
494-
self.runner = input_builder.runner
495-
496-
self.sliding_window = input_builder.sliding_window
497-
self.block_size = input_builder.block_size
498-
499-
def prepare(self):
500492
self.slot_mapping: List[int] = []
501493
self.prefill_seq_lens: List[int] = []
502494
self.context_lens: List[int] = []
@@ -509,6 +501,12 @@ def prepare(self):
509501
self.num_prefill_tokens = 0
510502
self.num_decode_tokens = 0
511503

504+
self.input_builder = input_builder
505+
self.runner = input_builder.runner
506+
507+
self.sliding_window = input_builder.sliding_window
508+
self.block_size = input_builder.block_size
509+
512510
# Please follow https://docs.flashinfer.ai/tutorials/kv_layout.html#page-layout
513511
# for the precise definition of the following fields.
514512
# An example:

vllm/attention/backends/placeholder_attn.py

Lines changed: 3 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -255,11 +255,6 @@ class PlaceholderAttentionMetadataBuilder(
255255
AttentionMetadataBuilder[PlaceholderAttentionMetadata]):
256256

257257
def __init__(self, input_builder: "ModelInputForGPUBuilder"):
258-
259-
self.input_builder = input_builder
260-
self.runner = input_builder.runner
261-
262-
def prepare(self):
263258
self.prefill_seq_lens: List[int] = []
264259
self.context_lens: List[int] = []
265260
self.curr_seq_lens: List[int] = []
@@ -270,6 +265,9 @@ def prepare(self):
270265
self.num_prefill_tokens = 0
271266
self.num_decode_tokens = 0
272267

268+
self.input_builder = input_builder
269+
self.runner = input_builder.runner
270+
273271
def _add_seq_group(
274272
self, inter_data: "ModelInputForGPUBuilder.InterDataForSeqGroup",
275273
chunked_prefill_enabled: bool):

vllm/attention/backends/torch_sdpa.py

Lines changed: 1 addition & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -282,10 +282,7 @@ class TorchSDPAMetadataBuilder(AttentionMetadataBuilder[TorchSDPAMetadata]):
282282

283283
def __init__(self, input_builder: ModelInputForCPUBuilder) -> None:
284284
self.chunked_prefill = input_builder.chunked_prefill
285-
self.input_builder = input_builder
286-
287-
def prepare(self):
288-
self.input_data = self.input_builder.input_data
285+
self.input_data = input_builder.input_data
289286

290287
def build(self, seq_lens: List[int], query_lens: List[int],
291288
cuda_graph_pad_size: int, batch_size: int) -> TorchSDPAMetadata:

vllm/attention/backends/utils.py

Lines changed: 6 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -122,13 +122,6 @@ class CommonMetadataBuilder(AttentionMetadataBuilder[TAttentionMetadata]):
122122
_metadata_cls: Type[TAttentionMetadata]
123123

124124
def __init__(self, input_builder: "ModelInputForGPUBuilder"):
125-
self.input_builder = input_builder
126-
self.runner = input_builder.runner
127-
128-
self.sliding_window = input_builder.sliding_window
129-
self.block_size = input_builder.block_size
130-
131-
def prepare(self):
132125
self.slot_mapping: List[int] = []
133126
self.prefill_seq_lens: List[int] = []
134127
self.context_lens: List[int] = []
@@ -141,6 +134,12 @@ def prepare(self):
141134
self.num_prefill_tokens = 0
142135
self.num_decode_tokens = 0
143136

137+
self.input_builder = input_builder
138+
self.runner = input_builder.runner
139+
140+
self.sliding_window = input_builder.sliding_window
141+
self.block_size = input_builder.block_size
142+
144143
def _add_seq_group(
145144
self, inter_data: "ModelInputForGPUBuilder.InterDataForSeqGroup",
146145
chunked_prefill_enabled: bool):

vllm/worker/cpu_model_runner.py

Lines changed: 7 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -144,7 +144,9 @@ def __init__(self,
144144
runner: "CPUModelRunner",
145145
finished_requests_ids: Optional[List[str]] = None) -> None:
146146
super().__init__()
147+
self.seq_group_metadata_list: List[SequenceGroupMetadata] = []
147148
self.runner = runner
149+
148150
self.chunked_prefill = (runner.scheduler_config.chunked_prefill_enabled
149151
or runner.cache_config.enable_prefix_caching)
150152
self.model_input_cls = self.runner._model_input_cls
@@ -154,17 +156,10 @@ def __init__(self,
154156
self.device = self.runner.device
155157
self.multi_modal_input_mapper = self.runner.multi_modal_input_mapper
156158
self.enable_lora = self.runner.lora_config is not None
157-
if self.runner.attn_backend is not None:
158-
# spec decode (e.g. Medusa) does not have atten backend
159-
attn_backend = self.runner.attn_backend
160-
self.att_metadata_builder = attn_backend.get_builder_cls()(self)
161-
162-
def prepare(self,
163-
finished_requests_ids: Optional[List[str]] = None) -> None:
164-
self.seq_group_metadata_list: List[SequenceGroupMetadata] = []
165159
self.input_data = ModelInputForCPUBuilder.ModelInputData(
166160
self.runner.model_config.uses_mrope)
167-
self.att_metadata_builder.prepare()
161+
self.att_metadata_builder = self.runner.attn_backend.get_builder_cls()(
162+
self)
168163

169164
def add_seq_group(self, seq_group_metadata: SequenceGroupMetadata):
170165
self.seq_group_metadata_list.append(seq_group_metadata)
@@ -436,7 +431,6 @@ class CPUModelRunnerBase(ModelRunnerBase[TModelInputForCPU]):
436431
"""
437432
_model_input_cls: Type[TModelInputForCPU]
438433
_builder_cls: Type[ModelInputForCPUBuilder]
439-
builder: ModelInputForCPUBuilder
440434

441435
def __init__(
442436
self,
@@ -483,10 +477,6 @@ def __init__(
483477
# Set after load_model.
484478
self.lora_manager: Optional[LRUCacheWorkerLoRAManager] = None
485479

486-
if hasattr(self, "_builder_cls"):
487-
# multi-step model runner does not have `_builder_cls`
488-
self.builder = self._builder_cls(weakref.proxy(self))
489-
490480
def load_model(self) -> None:
491481
self.model = get_model(vllm_config=self.vllm_config)
492482

@@ -532,10 +522,10 @@ def _prepare_model_input_tensors(
532522
metadata for possible additional steps, e.g., sampling.
533523
534524
"""
535-
self.builder.prepare(finished_requests_ids)
536-
self.builder.set_seq_group_list(seq_group_metadata_list)
525+
builder = self._builder_cls(weakref.proxy(self), finished_requests_ids)
526+
builder.set_seq_group_list(seq_group_metadata_list)
537527

538-
return self.builder.build() # type: ignore
528+
return builder.build() # type: ignore
539529

540530
# sampler property will be used by spec_decode_worker
541531
@property

vllm/worker/model_runner.py

Lines changed: 12 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -455,13 +455,17 @@ def __init__(self,
455455
self.enable_prompt_adapter = (self.runner.prompt_adapter_config
456456
is not None)
457457
self.multi_modal_input_mapper = self.runner.multi_modal_input_mapper
458+
self.finished_requests_ids = finished_requests_ids
458459
self.decode_only = True
459460

461+
# Intermediate data (data in CPU before going to GPU) for
462+
# the current sequence group.
463+
self.inter_data_list: List[
464+
ModelInputForGPUBuilder.InterDataForSeqGroup] = []
465+
460466
# Attention metadata inputs.
461-
if self.attn_backend is not None:
462-
# spec decode (e.g. Medusa) does not have atten backend
463-
self.attn_metadata_builder = self.attn_backend.get_builder_cls()(
464-
weakref.proxy(self))
467+
self.attn_metadata_builder = self.attn_backend.make_metadata_builder(
468+
weakref.proxy(self))
465469

466470
# Engine/Model configurations.
467471
self.chunked_prefill_enabled = (
@@ -473,17 +477,6 @@ def __init__(self,
473477
self.block_aligned_sliding_window = \
474478
self.sliding_window_blocks * self.block_size
475479

476-
def prepare(self,
477-
finished_requests_ids: Optional[List[str]] = None) -> None:
478-
self.finished_requests_ids = finished_requests_ids
479-
480-
# Intermediate data (data in CPU before going to GPU) for
481-
# the current sequence group.
482-
self.inter_data_list: List[
483-
ModelInputForGPUBuilder.InterDataForSeqGroup] = []
484-
485-
self.attn_metadata_builder.prepare()
486-
487480
def _compute_lens(self, inter_data: InterDataForSeqGroup, seq_idx: int,
488481
seq_group_metadata: SequenceGroupMetadata):
489482
"""Compute context length, sequence length and tokens
@@ -998,7 +991,6 @@ class GPUModelRunnerBase(ModelRunnerBase[TModelInputForGPU]):
998991
"""
999992
_model_input_cls: Type[TModelInputForGPU]
1000993
_builder_cls: Type[ModelInputForGPUBuilder]
1001-
builder: ModelInputForGPUBuilder
1002994

1003995
def __init__(
1004996
self,
@@ -1099,10 +1091,6 @@ def __init__(
10991091
SamplingMetadataCache() \
11001092
if self.parallel_config.pipeline_parallel_size == 1 else None
11011093

1102-
if hasattr(self, "_builder_cls"):
1103-
# multi-step model runner does not have `_builder_cls`
1104-
self.builder = self._builder_cls(weakref.proxy(self))
1105-
11061094
def load_model(self) -> None:
11071095
logger.info("Starting to load model %s...", self.model_config.model)
11081096
with DeviceMemoryProfiler() as m:
@@ -1208,13 +1196,13 @@ def _prepare_model_input_tensors(
12081196
12091197
If cuda graph is required, this API automatically pads inputs.
12101198
"""
1211-
self.builder.prepare(finished_requests_ids)
1199+
builder = self._builder_cls(weakref.proxy(self), finished_requests_ids)
12121200
for seq_group_metadata in seq_group_metadata_list:
1213-
self.builder.add_seq_group(seq_group_metadata)
1201+
builder.add_seq_group(seq_group_metadata)
12141202

1215-
self.builder.reset_cached_inter_data()
1203+
builder.reset_cached_inter_data()
12161204

1217-
return self.builder.build() # type: ignore
1205+
return builder.build() # type: ignore
12181206

12191207
@contextmanager
12201208
def set_in_profile_run(self):

vllm/worker/model_runner_base.py

Lines changed: 0 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -200,11 +200,6 @@ class ModelRunnerInputBuilderBase(ABC, Generic[T]):
200200
"""A builder to create ModelRunnerInputBase objects.
201201
"""
202202

203-
@abstractmethod
204-
def prepare(self,
205-
finished_requests_ids: Optional[List[str]] = None) -> None:
206-
raise NotImplementedError
207-
208203
@abstractmethod
209204
def add_seq_group(self, seq_group_metadata):
210205
"""TBA"""

vllm/worker/xpu_model_runner.py

Lines changed: 2 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -113,17 +113,14 @@ def __init__(self,
113113
runner: "XPUModelRunner",
114114
finished_requests_ids: Optional[List[str]] = None) -> None:
115115
super().__init__()
116+
self.seq_group_metadata_list: List[SequenceGroupMetadata] = []
116117
self.runner = runner
117118
self.model_input_cls = self.runner._model_input_cls
118119
self.attn_backend = self.runner.attn_backend
119120
self.sliding_window = self.runner.sliding_window
120121
self.block_size = self.runner.block_size
121122
self.device = self.runner.device
122123

123-
def prepare(self,
124-
finished_requests_ids: Optional[List[str]] = None) -> None:
125-
self.seq_group_metadata_list: List[SequenceGroupMetadata] = []
126-
127124
def add_seq_group(self, seq_group_metadata: SequenceGroupMetadata):
128125
self.seq_group_metadata_list.append(seq_group_metadata)
129126

@@ -413,8 +410,6 @@ def __init__(
413410
SamplingMetadataCache() \
414411
if self.parallel_config.pipeline_parallel_size == 1 else None
415412

416-
self.builder = self._builder_cls(weakref.proxy(self))
417-
418413
def load_model(self) -> None:
419414
with DeviceMemoryProfiler() as m:
420415
self.model = get_model(vllm_config=self.vllm_config)
@@ -524,8 +519,7 @@ def _prepare_model_input_tensors(
524519
metadata for possible additional steps, e.g., sampling.
525520
526521
"""
527-
builder = self.builder
528-
builder.prepare(finished_requests_ids)
522+
builder = self._builder_cls(weakref.proxy(self), finished_requests_ids)
529523
for seq_group_metadata in seq_group_metadata_list:
530524
builder.add_seq_group(seq_group_metadata)
531525

0 commit comments

Comments
 (0)