3
3
from typing import Optional
4
4
5
5
import torch
6
- import torch_xla .core .xla_model as xm
7
6
8
7
from vllm .v1 .worker .gpu_input_batch import InputBatch
9
8
@@ -24,15 +23,15 @@ class TPUSupportedSamplingMetadata:
24
23
# This class exposes a more xla-friendly interface than SamplingMetadata
25
24
# on TPU, in particular all arguments should be traceable and no optionals
26
25
# are allowed, to avoid graph recompilation on Nones.
27
- temperature : torch .Tensor
26
+ temperature : torch .Tensor = None
28
27
29
- min_p : torch .Tensor
28
+ min_p : torch .Tensor = None
30
29
# Still too slow on forward_native!
31
30
top_k : torch .Tensor = None
32
31
top_p : torch .Tensor = None
33
32
34
33
# Greedy sampling flag for compiling single xla graph.
35
- all_greedy : torch . Tensor = None
34
+ all_greedy : bool = True
36
35
37
36
# Generator not supported by xla
38
37
generators : dict [int ,
@@ -57,64 +56,58 @@ class TPUSupportedSamplingMetadata:
57
56
58
57
allowed_token_ids_mask = None
59
58
bad_words_token_ids = None
60
- indices_do_sample : torch .Tensor = None
61
59
62
60
@classmethod
63
61
def from_input_batch (
64
- cls , input_batch : InputBatch ,
65
- indices_do_sample : torch .Tensor ) -> "TPUSupportedSamplingMetadata" :
62
+ cls ,
63
+ input_batch : InputBatch ,
64
+ padded_num_reqs : int ,
65
+ xla_device : torch .device ,
66
+ generate_params_if_all_greedy : bool = False
67
+ ) -> "TPUSupportedSamplingMetadata" :
66
68
"""
67
69
Copy sampling tensors slices from `input_batch` to on device tensors.
68
70
69
71
`InputBatch._make_sampling_metadata` causes recompilation on XLA as it
70
72
slices dynamic shapes on device tensors. This impl moves the dynamic
71
- ops to CPU and produces tensors of fixed `padded_num_reqs` size. It
72
- also reuses the on-device persistent tensors managed in `input_batch`
73
- to reduce waste.
74
-
75
- `indices_do_sample` contains the indices to be fed to the Sampler,
76
- normally one per request, here padded to the closest pre-compiled shape
77
- We expect sampling params tensors to be padded to the same fixed shape.
78
-
79
- Eg. 3 requests, tensors padded to 4
80
- temperature: [0.7, 0.2, 0.9]=>[0.7, 0.2, 0.9, 0.0]
81
- sample indices: [4, 10, 11]=>indices_do_sample: [4, 10, 11, 0]
73
+ ops to CPU and produces tensors of fixed `padded_num_reqs` size.
74
+
75
+ Args:
76
+ input_batch: The input batch containing sampling parameters.
77
+ padded_num_reqs: The padded number of requests.
78
+ xla_device: The XLA device.
79
+ generate_params_if_all_greedy: If True, generate sampling parameters
80
+ even if all requests are greedy. this is useful for cases where
81
+ we want to pre-compile a graph with sampling parameters, even if
82
+ they are not strictly needed for greedy decoding.
82
83
"""
84
+ # Early return to avoid unnecessary cpu to tpu copy
85
+ if (input_batch .all_greedy is True
86
+ and generate_params_if_all_greedy is False ):
87
+ return cls (all_greedy = True )
88
+
83
89
num_reqs = input_batch .num_reqs
84
- padded_num_reqs = len (indices_do_sample )
85
90
86
- def copy_slice (cpu_tensor : torch .Tensor , tpu_tensor : torch .Tensor ,
87
- fill_val ) -> torch .Tensor :
88
- # Copy slice from CPU to corresponding TPU pre-allocated tensor.
91
+ def fill_slice (cpu_tensor : torch .Tensor , fill_val ) -> torch .Tensor :
89
92
# Pad value is the default one.
90
93
cpu_tensor [num_reqs :padded_num_reqs ] = fill_val
91
- # Subtle compilation: len(tpu_tensor) must be >= `padded_num_reqs`
92
- tpu_tensor [:padded_num_reqs ] = cpu_tensor [:padded_num_reqs ]
93
94
94
- # NOTE NickLucche The sync CPU-TPU graph we produce here must be
95
- # consistent. We can't have flags to skip copies or we'll end up
96
- # recompiling.
97
- copy_slice (input_batch .temperature_cpu_tensor , input_batch .temperature ,
95
+ fill_slice (input_batch .temperature_cpu_tensor ,
98
96
DEFAULT_SAMPLING_PARAMS ["temperature" ])
99
97
# TODO Temporarily disabled until sampling options are enabled
100
- # copy_slice (input_batch.top_p_cpu_tensor, input_batch.top_p )
101
- # copy_slice (input_batch.top_k_cpu_tensor, input_batch.top_k )
102
- copy_slice (input_batch .min_p_cpu_tensor , input_batch . min_p ,
98
+ # fill_slice (input_batch.top_p_cpu_tensor)
99
+ # fill_slice (input_batch.top_k_cpu_tensor)
100
+ fill_slice (input_batch .min_p_cpu_tensor ,
103
101
DEFAULT_SAMPLING_PARAMS ["min_p" ])
104
102
105
- xm .mark_step ()
106
- xm .wait_device_ops ()
107
-
108
103
# Slice persistent device tensors to a fixed pre-compiled padded shape.
109
104
return cls (
110
- temperature = input_batch .temperature [:padded_num_reqs ],
111
- # Scalar tensor for xla-friendly tracing.
112
- all_greedy = torch .tensor (input_batch .all_greedy ,
113
- dtype = torch .bool ,
114
- device = input_batch .device ),
105
+ temperature = input_batch .temperature_cpu_tensor [:padded_num_reqs ].
106
+ to (xla_device ),
107
+ all_greedy = input_batch .all_greedy ,
115
108
# TODO enable more and avoid returning None values
116
109
top_p = None , # input_batch.top_p[:padded_num_reqs],
117
110
top_k = None , # input_batch.top_k[:padded_num_reqs],
118
- min_p = input_batch .min_p [:padded_num_reqs ],
119
- generators = input_batch . generators ,
120
- indices_do_sample = indices_do_sample )
111
+ min_p = input_batch .min_p_cpu_tensor [:padded_num_reqs ]. to (
112
+ xla_device ) ,
113
+ generators = input_batch . generators )
0 commit comments