Skip to content

Commit a454748

Browse files
authored
[TPU][V1] Refine tpu_model_runner to mitigate future recompilation issues (#16275)
Signed-off-by: Chengji Yao <[email protected]>
1 parent 1bff42c commit a454748

File tree

4 files changed

+165
-124
lines changed

4 files changed

+165
-124
lines changed

Diff for: tests/tpu/test_compilation.py

+10-4
Original file line numberDiff line numberDiff line change
@@ -44,23 +44,29 @@ def test_tpu_compilation():
4444
assert generated_text.startswith(answer)
4545

4646
compiled_codes = sorted(
47-
glob.glob(os.path.join(temp_dir, "__transformed_code*.py")))
47+
glob.glob(os.path.join(temp_dir, "__transformed_code*for_forward.py")))
4848

4949
for i, compiled_code in enumerate(compiled_codes):
5050
print("{} file: {}".format(i + 1, compiled_code))
5151

5252
# We should only trigger Dynamo compilation 2 times:
5353
# 1. Forward pass without kv_caches
5454
# 2. Forward pass with kv_caches
55-
# Check we have 4 compiled codes
55+
# Check we have 2 compiled codes
5656
assert len(compiled_codes) == 2
5757

5858
kv_cache_prefix = "kv_cache"
5959
attn_prefix = "ragged_paged_attention"
6060

61+
def extract_compiled_index(s):
62+
parts = s.replace(".", "_").split("_")
63+
numbers = [int(part) for part in parts if part.isdigit()]
64+
return numbers[0]
65+
6166
# Check all the compilations are as expected
62-
compiled_fns = sorted(
63-
glob.glob(os.path.join(temp_dir, "__compiled_fn*Captured*.py")))
67+
compiled_fns = sorted(glob.glob(
68+
os.path.join(temp_dir, "__compiled_fn*Captured*.py")),
69+
key=lambda s: extract_compiled_index(s))
6470

6571
for i, compiled_fn in enumerate(compiled_fns):
6672
print("{} file: {}".format(i + 1, compiled_fn))

Diff for: tests/v1/tpu/worker/test_tpu_model_runner.py

+19-6
Original file line numberDiff line numberDiff line change
@@ -7,9 +7,9 @@
77
from vllm.sampling_params import SamplingParams
88
from vllm.v1.core.sched.output import (CachedRequestData, NewRequestData,
99
SchedulerOutput)
10-
from vllm.v1.worker.tpu_model_runner import (TPUModelRunner,
11-
_get_padded_token_len,
12-
_get_paddings)
10+
from vllm.v1.worker.tpu_model_runner import (
11+
TPUModelRunner, _get_padded_num_reqs_with_upper_limit,
12+
_get_padded_token_len, _get_req_paddings, _get_token_paddings)
1313

1414
# Mock torch_xla module since it may not be available in the test environments
1515
torch_xla_patcher = mock.patch.dict(
@@ -296,16 +296,29 @@ def test_update_states_request_unscheduled(model_runner):
296296
def test_get_paddings():
297297
min_token_size, max_token_size, padding_gap = 16, 512, 64
298298
expected_paddings = [16, 32, 64, 128, 192, 256, 320, 384, 448, 512]
299-
actual_paddings = _get_paddings(min_token_size, max_token_size,
300-
padding_gap)
299+
actual_paddings = _get_token_paddings(min_token_size, max_token_size,
300+
padding_gap)
301301
assert actual_paddings == expected_paddings
302302

303303

304304
def test_get_padded_token_len():
305305
min_token_size, max_token_size, padding_gap = 16, 512, 64
306-
paddings = _get_paddings(min_token_size, max_token_size, padding_gap)
306+
paddings = _get_token_paddings(min_token_size, max_token_size, padding_gap)
307307
assert _get_padded_token_len(paddings, 1) == 16
308308
assert _get_padded_token_len(paddings, 16) == 16
309309
assert _get_padded_token_len(paddings, 20) == 32
310310
assert _get_padded_token_len(paddings, 300) == 320
311311
assert _get_padded_token_len(paddings, 512) == 512
312+
313+
314+
def test_get_padded_num_reqs_with_upper_limit():
315+
assert _get_padded_num_reqs_with_upper_limit(3, 32) == 8
316+
assert _get_padded_num_reqs_with_upper_limit(9, 32) == 16
317+
assert _get_padded_num_reqs_with_upper_limit(19, 32) == 32
318+
assert _get_padded_num_reqs_with_upper_limit(17, 28) == 28
319+
320+
321+
def test_get_req_paddings():
322+
assert _get_req_paddings(1, 32) == [8, 16, 32]
323+
assert _get_req_paddings(8, 32) == [8, 16, 32]
324+
assert _get_req_paddings(8, 36) == [8, 16, 32, 36]

Diff for: vllm/v1/sample/tpu/metadata.py

+35-42
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,6 @@
33
from typing import Optional
44

55
import torch
6-
import torch_xla.core.xla_model as xm
76

87
from vllm.v1.worker.gpu_input_batch import InputBatch
98

@@ -24,15 +23,15 @@ class TPUSupportedSamplingMetadata:
2423
# This class exposes a more xla-friendly interface than SamplingMetadata
2524
# on TPU, in particular all arguments should be traceable and no optionals
2625
# are allowed, to avoid graph recompilation on Nones.
27-
temperature: torch.Tensor
26+
temperature: torch.Tensor = None
2827

29-
min_p: torch.Tensor
28+
min_p: torch.Tensor = None
3029
# Still too slow on forward_native!
3130
top_k: torch.Tensor = None
3231
top_p: torch.Tensor = None
3332

3433
# Greedy sampling flag for compiling single xla graph.
35-
all_greedy: torch.Tensor = None
34+
all_greedy: bool = True
3635

3736
# Generator not supported by xla
3837
generators: dict[int,
@@ -57,64 +56,58 @@ class TPUSupportedSamplingMetadata:
5756

5857
allowed_token_ids_mask = None
5958
bad_words_token_ids = None
60-
indices_do_sample: torch.Tensor = None
6159

6260
@classmethod
6361
def from_input_batch(
64-
cls, input_batch: InputBatch,
65-
indices_do_sample: torch.Tensor) -> "TPUSupportedSamplingMetadata":
62+
cls,
63+
input_batch: InputBatch,
64+
padded_num_reqs: int,
65+
xla_device: torch.device,
66+
generate_params_if_all_greedy: bool = False
67+
) -> "TPUSupportedSamplingMetadata":
6668
"""
6769
Copy sampling tensors slices from `input_batch` to on device tensors.
6870
6971
`InputBatch._make_sampling_metadata` causes recompilation on XLA as it
7072
slices dynamic shapes on device tensors. This impl moves the dynamic
71-
ops to CPU and produces tensors of fixed `padded_num_reqs` size. It
72-
also reuses the on-device persistent tensors managed in `input_batch`
73-
to reduce waste.
74-
75-
`indices_do_sample` contains the indices to be fed to the Sampler,
76-
normally one per request, here padded to the closest pre-compiled shape
77-
We expect sampling params tensors to be padded to the same fixed shape.
78-
79-
Eg. 3 requests, tensors padded to 4
80-
temperature: [0.7, 0.2, 0.9]=>[0.7, 0.2, 0.9, 0.0]
81-
sample indices: [4, 10, 11]=>indices_do_sample: [4, 10, 11, 0]
73+
ops to CPU and produces tensors of fixed `padded_num_reqs` size.
74+
75+
Args:
76+
input_batch: The input batch containing sampling parameters.
77+
padded_num_reqs: The padded number of requests.
78+
xla_device: The XLA device.
79+
generate_params_if_all_greedy: If True, generate sampling parameters
80+
even if all requests are greedy. this is useful for cases where
81+
we want to pre-compile a graph with sampling parameters, even if
82+
they are not strictly needed for greedy decoding.
8283
"""
84+
# Early return to avoid unnecessary cpu to tpu copy
85+
if (input_batch.all_greedy is True
86+
and generate_params_if_all_greedy is False):
87+
return cls(all_greedy=True)
88+
8389
num_reqs = input_batch.num_reqs
84-
padded_num_reqs = len(indices_do_sample)
8590

86-
def copy_slice(cpu_tensor: torch.Tensor, tpu_tensor: torch.Tensor,
87-
fill_val) -> torch.Tensor:
88-
# Copy slice from CPU to corresponding TPU pre-allocated tensor.
91+
def fill_slice(cpu_tensor: torch.Tensor, fill_val) -> torch.Tensor:
8992
# Pad value is the default one.
9093
cpu_tensor[num_reqs:padded_num_reqs] = fill_val
91-
# Subtle compilation: len(tpu_tensor) must be >= `padded_num_reqs`
92-
tpu_tensor[:padded_num_reqs] = cpu_tensor[:padded_num_reqs]
9394

94-
# NOTE NickLucche The sync CPU-TPU graph we produce here must be
95-
# consistent. We can't have flags to skip copies or we'll end up
96-
# recompiling.
97-
copy_slice(input_batch.temperature_cpu_tensor, input_batch.temperature,
95+
fill_slice(input_batch.temperature_cpu_tensor,
9896
DEFAULT_SAMPLING_PARAMS["temperature"])
9997
# TODO Temporarily disabled until sampling options are enabled
100-
# copy_slice(input_batch.top_p_cpu_tensor, input_batch.top_p)
101-
# copy_slice(input_batch.top_k_cpu_tensor, input_batch.top_k)
102-
copy_slice(input_batch.min_p_cpu_tensor, input_batch.min_p,
98+
# fill_slice(input_batch.top_p_cpu_tensor)
99+
# fill_slice(input_batch.top_k_cpu_tensor)
100+
fill_slice(input_batch.min_p_cpu_tensor,
103101
DEFAULT_SAMPLING_PARAMS["min_p"])
104102

105-
xm.mark_step()
106-
xm.wait_device_ops()
107-
108103
# Slice persistent device tensors to a fixed pre-compiled padded shape.
109104
return cls(
110-
temperature=input_batch.temperature[:padded_num_reqs],
111-
# Scalar tensor for xla-friendly tracing.
112-
all_greedy=torch.tensor(input_batch.all_greedy,
113-
dtype=torch.bool,
114-
device=input_batch.device),
105+
temperature=input_batch.temperature_cpu_tensor[:padded_num_reqs].
106+
to(xla_device),
107+
all_greedy=input_batch.all_greedy,
115108
# TODO enable more and avoid returning None values
116109
top_p=None, # input_batch.top_p[:padded_num_reqs],
117110
top_k=None, # input_batch.top_k[:padded_num_reqs],
118-
min_p=input_batch.min_p[:padded_num_reqs],
119-
generators=input_batch.generators,
120-
indices_do_sample=indices_do_sample)
111+
min_p=input_batch.min_p_cpu_tensor[:padded_num_reqs].to(
112+
xla_device),
113+
generators=input_batch.generators)

0 commit comments

Comments
 (0)