3
3
import torch .nn as nn
4
4
from torch .nn .utils .rnn import pad_sequence
5
5
6
+ from vllm import envs
6
7
from vllm .logger import init_logger
8
+ from vllm .platforms import current_platform
7
9
from vllm .v1 .outputs import SamplerOutput
8
10
from vllm .v1 .sample .metadata import SamplingMetadata
9
11
19
21
20
22
class RejectionSampler (nn .Module ):
21
23
24
+ def __init__ (self ):
25
+ super ().__init__ ()
26
+ if current_platform .is_cuda :
27
+ if is_flashinfer_available :
28
+ if envs .VLLM_USE_FLASHINFER_SAMPLER is not False :
29
+ # NOTE(woosuk): The V0 sampler doesn't use FlashInfer for
30
+ # sampling unless VLLM_USE_FLASHINFER_SAMPLER=1 (i.e., by
31
+ # default it is unused). For backward compatibility, we set
32
+ # `VLLM_USE_FLASHINFER_SAMPLER` as None by default and
33
+ # interpret it differently in V0 and V1 samplers: In V0,
34
+ # None means False, while in V1, None means True. This is
35
+ # why we use the condition
36
+ # `envs.VLLM_USE_FLASHINFER_SAMPLER is not False` here.
37
+ logger .info ("Using FlashInfer for rejection sampling." )
38
+ self .forward_method = self .flashinfer_sample
39
+ else :
40
+ logger .warning (
41
+ "FlashInfer is available, but it is not enabled. "
42
+ "Falling back to the PyTorch-native implementation of "
43
+ "rejection sampling. For the best performance, "
44
+ "please set VLLM_USE_FLASHINFER_SAMPLER=1." )
45
+ self .forward_method = self .forward_native
46
+ else :
47
+ logger .warning (
48
+ "FlashInfer is not available. Falling back to the PyTorch-"
49
+ "native implementation of rejection sampling. For the "
50
+ "best performance, please install FlashInfer." )
51
+ self .forward_method = self .forward_native
52
+ else :
53
+ self .forward_method = self .forward_native
54
+
22
55
def forward (self , logits : torch .Tensor ,
23
56
sampling_metadata : SamplingMetadata ) -> SamplerOutput :
24
57
if not sampling_metadata .all_greedy :
25
58
raise NotImplementedError (
26
- "Only greedy sampling is supported by rejection sampler." )
59
+ "Currently, only greedy sampling is supported by "
60
+ "rejection sampler." )
61
+ return self .forward_method (logits , sampling_metadata )
27
62
28
- if is_flashinfer_available :
29
- logger .info ("User FlashInfer for rejection sampling." )
30
- return RejectionSampler .flashinfer_sample (logits ,
31
- sampling_metadata )
32
- else :
33
- logger .warning (
34
- "FlashInfer is not available. Falling back to the PyTorch-"
35
- "native implementation of rejection sampling." )
36
- return RejectionSampler .greedy_sample_native (
37
- logits , sampling_metadata )
38
-
39
- @staticmethod
40
63
def flashinfer_sample (
41
- logits : torch .Tensor ,
42
- sampling_metadata : SamplingMetadata ) -> SamplerOutput :
64
+ self ,
65
+ logits : torch .Tensor ,
66
+ sampling_metadata : SamplingMetadata ,
67
+ ) -> SamplerOutput :
43
68
# NOTE: The following input preparationg can be moved
44
69
# to the model runner with a persistent manner for better
45
70
# performance.
@@ -71,10 +96,10 @@ def flashinfer_sample(
71
96
vocab_size = logits .size (- 1 )
72
97
# NOTE: CPU <-> GPU synchronization happens here.
73
98
draft_token_ids = draft_token_ids .to (logits .device )
74
- draft_probs = RejectionSampler . _create_greedy_token_probs (
75
- draft_token_ids , vocab_size , logits .device )
76
- target_probs = RejectionSampler . _create_greedy_token_probs (
77
- target_token_ids , vocab_size , logits .device )
99
+ draft_probs = _create_greedy_token_probs (draft_token_ids , vocab_size ,
100
+ logits .device )
101
+ target_probs = _create_greedy_token_probs (target_token_ids , vocab_size ,
102
+ logits .device )
78
103
uniform_samples = torch .zeros (batch_size ,
79
104
max_spec_len + 1 ,
80
105
device = logits .device )
@@ -89,10 +114,11 @@ def flashinfer_sample(
89
114
logprobs_tensors = None )
90
115
91
116
# TODO: The following method can be optimized for better performance.
92
- @staticmethod
93
- def greedy_sample_native (
94
- logits : torch .Tensor ,
95
- sampling_metadata : SamplingMetadata ) -> SamplerOutput :
117
+ def forward_native (
118
+ self ,
119
+ logits : torch .Tensor ,
120
+ sampling_metadata : SamplingMetadata ,
121
+ ) -> SamplerOutput :
96
122
spec_lens = [len (x ) for x in sampling_metadata .spec_token_ids ]
97
123
# Add 1 to include the 'bonus' token.
98
124
sample_lens = [x + 1 for x in spec_lens ]
@@ -137,24 +163,27 @@ def greedy_sample_native(
137
163
return SamplerOutput (sampled_token_ids = output_token_ids ,
138
164
logprobs_tensors = None )
139
165
140
- @staticmethod
141
- def _create_greedy_token_probs (token_ids : torch .Tensor , vocab_size : int ,
142
- out_device : torch .device ) -> torch .Tensor :
143
- batch_size , num_tokens = token_ids .shape
144
166
145
- token_probs = torch .zeros (batch_size ,
146
- num_tokens ,
147
- vocab_size ,
148
- dtype = torch .float ,
149
- device = out_device )
167
+ def _create_greedy_token_probs (
168
+ token_ids : torch .Tensor ,
169
+ vocab_size : int ,
170
+ out_device : torch .device ,
171
+ ) -> torch .Tensor :
172
+ batch_size , num_tokens = token_ids .shape
173
+
174
+ token_probs = torch .zeros (batch_size ,
175
+ num_tokens ,
176
+ vocab_size ,
177
+ dtype = torch .float ,
178
+ device = out_device )
150
179
151
- # Ignore INVALID_TOKEN_ID.
152
- valid_mask = (token_ids != INVALID_TOKEN_ID )
153
- valid_indices = token_ids .clone ()
154
- valid_indices [~ valid_mask ] = 0
180
+ # Ignore INVALID_TOKEN_ID.
181
+ valid_mask = (token_ids != INVALID_TOKEN_ID )
182
+ valid_indices = token_ids .clone ()
183
+ valid_indices [~ valid_mask ] = 0
155
184
156
- token_probs .scatter_ (dim = 2 ,
157
- index = valid_indices .unsqueeze (- 1 ),
158
- src = valid_mask .unsqueeze (- 1 ).float ())
185
+ token_probs .scatter_ (dim = 2 ,
186
+ index = valid_indices .unsqueeze (- 1 ),
187
+ src = valid_mask .unsqueeze (- 1 ).float ())
159
188
160
- return token_probs
189
+ return token_probs
0 commit comments