Skip to content

Commit f70d007

Browse files
WoosukKwonrasmith
authored andcommitted
[V1] Add BlockTable class (vllm-project#11693)
Signed-off-by: Woosuk Kwon <[email protected]>
1 parent 44f56fd commit f70d007

File tree

3 files changed

+94
-25
lines changed

3 files changed

+94
-25
lines changed

vllm/v1/worker/block_table.py

Lines changed: 78 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,78 @@
1+
from typing import List
2+
3+
import numpy as np
4+
import torch
5+
6+
from vllm.logger import init_logger
7+
8+
logger = init_logger(__name__)
9+
10+
11+
class BlockTable:
12+
13+
def __init__(
14+
self,
15+
max_num_reqs: int,
16+
max_model_len: int,
17+
max_num_blocks_per_req: int,
18+
pin_memory: bool,
19+
device: torch.device,
20+
):
21+
self.max_num_reqs = max_num_reqs
22+
self.max_model_len = max_model_len
23+
self.max_num_blocks_per_req = max_num_blocks_per_req
24+
self.pin_memory = pin_memory
25+
self.device = device
26+
27+
self.block_table = torch.zeros(
28+
(max_num_reqs, max_num_blocks_per_req),
29+
device=self.device,
30+
dtype=torch.int32,
31+
)
32+
self.block_table_cpu = torch.zeros(
33+
(max_num_reqs, max_num_blocks_per_req),
34+
device="cpu",
35+
dtype=torch.int32,
36+
pin_memory=pin_memory,
37+
)
38+
self.block_table_np = self.block_table_cpu.numpy()
39+
self.num_blocks_per_row = np.zeros(max_num_reqs, dtype=np.int32)
40+
41+
def append_row(
42+
self,
43+
row_idx: int,
44+
start: int,
45+
block_ids: List[int],
46+
) -> None:
47+
num_blocks = len(block_ids)
48+
self.block_table_np[row_idx, start:start + num_blocks] = block_ids
49+
self.num_blocks_per_row[row_idx] = start + num_blocks
50+
51+
def add_row(self, row_idx: int, block_ids: List[int]) -> None:
52+
self.append_row(row_idx, 0, block_ids)
53+
54+
def move_row(self, src: int, tgt: int) -> None:
55+
num_blocks = self.num_blocks_per_row[src]
56+
self.block_table_np[tgt, :num_blocks] = self.block_table_np[
57+
src, :num_blocks]
58+
self.num_blocks_per_row[tgt] = num_blocks
59+
60+
def commit(self, num_reqs: int) -> None:
61+
self.block_table[:num_reqs].copy_(self.block_table_cpu[:num_reqs],
62+
non_blocking=True)
63+
64+
def clear(self) -> None:
65+
self.block_table.fill_(0)
66+
self.block_table_cpu.fill_(0)
67+
68+
def get_device_tensor(self) -> torch.Tensor:
69+
"""Ruturns the device tensor of the block table."""
70+
return self.block_table
71+
72+
def get_cpu_tensor(self) -> torch.Tensor:
73+
"""Returns the CPU tensor of the block table."""
74+
return self.block_table_cpu
75+
76+
def get_numpy_array(self) -> np.ndarray:
77+
"""Returns the numpy array of the block table."""
78+
return self.block_table_np

vllm/v1/worker/gpu_input_batch.py

Lines changed: 9 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@
99
from vllm.multimodal import MultiModalKwargs
1010
from vllm.sampling_params import SamplingParams, SamplingType
1111
from vllm.v1.sample.metadata import SamplingMetadata
12+
from vllm.v1.worker.block_table import BlockTable
1213

1314
if TYPE_CHECKING:
1415
from vllm.multimodal.inputs import PlaceholderRange
@@ -70,19 +71,14 @@ def __init__(
7071
self.num_prompt_tokens = np.zeros(max_num_reqs, dtype=np.int32)
7172
self.num_computed_tokens_cpu = np.empty(max_num_reqs, dtype=np.int32)
7273

73-
# Attention-related.
74-
self.block_table = torch.zeros(
75-
(max_num_reqs, max_num_blocks_per_req),
76-
device=self.device,
77-
dtype=torch.int32,
78-
)
79-
self.block_table_cpu_tensor = torch.zeros(
80-
(max_num_reqs, max_num_blocks_per_req),
81-
device="cpu",
82-
dtype=torch.int32,
74+
# Block table.
75+
self.block_table = BlockTable(
76+
max_num_reqs=max_num_reqs,
77+
max_model_len=max_model_len,
78+
max_num_blocks_per_req=max_num_blocks_per_req,
8379
pin_memory=pin_memory,
80+
device=device,
8481
)
85-
self.block_table_cpu = self.block_table_cpu_tensor.numpy()
8682

8783
# Sampling-related.
8884
self.temperature = torch.empty((max_num_reqs, ),
@@ -193,8 +189,7 @@ def add_request(
193189
self.num_tokens[req_index] = request.num_tokens
194190

195191
self.num_computed_tokens_cpu[req_index] = request.num_computed_tokens
196-
num_blocks = len(request.block_ids)
197-
self.block_table_cpu[req_index, :num_blocks] = request.block_ids
192+
self.block_table.add_row(req_index, request.block_ids)
198193

199194
sampling_params = request.sampling_params
200195
self.temperature_cpu[req_index] = sampling_params.temperature
@@ -300,9 +295,7 @@ def condense(self, empty_req_indices: List[int]) -> None:
300295
self.num_prompt_tokens[last_req_index]
301296
self.num_computed_tokens_cpu[
302297
empty_index] = self.num_computed_tokens_cpu[last_req_index]
303-
# TODO(woosuk): Optimize the copy of block_table_cpu.
304-
self.block_table_cpu[empty_index] = self.block_table_cpu[
305-
last_req_index]
298+
self.block_table.move_row(last_req_index, empty_index)
306299
self.temperature_cpu[empty_index] = self.temperature_cpu[
307300
last_req_index]
308301
self.top_p_cpu[empty_index] = self.top_p_cpu[last_req_index]

vllm/v1/worker/gpu_model_runner.py

Lines changed: 7 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -211,10 +211,9 @@ def _update_states(self, scheduler_output: "SchedulerOutput") -> None:
211211
if num_new_blocks == 0:
212212
continue
213213
start_index = len(req_state.block_ids)
214-
end_index = start_index + num_new_blocks
215214
req_state.block_ids.extend(req_data.new_block_ids)
216-
self.input_batch.block_table_cpu[
217-
req_index, start_index:end_index] = req_data.new_block_ids
215+
self.input_batch.block_table.append_row(req_index, start_index,
216+
req_data.new_block_ids)
218217

219218
req_ids_to_add: List[str] = []
220219
# Add new requests to the cached states.
@@ -275,9 +274,7 @@ def _prepare_inputs(self, scheduler_output: "SchedulerOutput"):
275274

276275
# OPTIMIZATION: Start copying the block table first.
277276
# This way, we can overlap the copy with the following CPU operations.
278-
self.input_batch.block_table[:num_reqs].copy_(
279-
self.input_batch.block_table_cpu_tensor[:num_reqs],
280-
non_blocking=True)
277+
self.input_batch.block_table.commit(num_reqs)
281278

282279
# Get the number of scheduled tokens for each request.
283280
# TODO: The Python loop can be slow. Optimize.
@@ -333,8 +330,8 @@ def _prepare_inputs(self, scheduler_output: "SchedulerOutput"):
333330
# NOTE(woosuk): We use torch.index_select instead of np.take here
334331
# because torch.index_select is much faster than np.take for large
335332
# tensors.
336-
block_numbers = (self.input_batch.block_table_cpu_tensor.flatten()
337-
[block_table_indices].numpy())
333+
block_table_cpu = self.input_batch.block_table.get_cpu_tensor()
334+
block_numbers = block_table_cpu.flatten()[block_table_indices].numpy()
338335
block_offsets = positions_np % self.block_size
339336
np.add(block_numbers * self.block_size,
340337
block_offsets,
@@ -450,7 +447,8 @@ def _prepare_inputs(self, scheduler_output: "SchedulerOutput"):
450447
query_start_loc=query_start_loc,
451448
max_seq_len=max_seq_len,
452449
seq_start_loc=seq_start_loc,
453-
block_table=self.input_batch.block_table[:num_reqs],
450+
block_table=(
451+
self.input_batch.block_table.get_device_tensor()[:num_reqs]),
454452
slot_mapping=slot_mapping,
455453
use_cascade=use_cascade,
456454
common_prefix_len=common_prefix_len,

0 commit comments

Comments
 (0)