@@ -21,7 +21,7 @@ class Sampler(nn.Module):
21
21
1. Discard the hidden states that are not used for sampling (i.e., all
22
22
tokens except the final one in each prompt).
23
23
2. Compute the logits for the next tokens.
24
- 3. Apply presence and frequency penalties.
24
+ 3. Apply presence, frequency and repetition penalties.
25
25
4. Apply temperature scaling.
26
26
5. Apply top-p and top-k truncation.
27
27
6. Sample the next tokens.
@@ -50,14 +50,12 @@ def forward(
50
50
# Apply logits processors (if any).
51
51
logits = _apply_logits_processors (logits , input_metadata )
52
52
# Apply presence and frequency penalties.
53
- output_tokens = _get_output_tokens (input_metadata )
54
- assert len (output_tokens ) == logits .shape [0 ]
55
53
presence_penalties , frequency_penalties , repetition_penalties = (
56
54
_get_penalties (input_metadata ))
57
55
assert len (presence_penalties ) == logits .shape [0 ]
58
56
assert len (frequency_penalties ) == logits .shape [0 ]
59
57
assert len (repetition_penalties ) == logits .shape [0 ]
60
- logits = _apply_penalties (logits , output_tokens , presence_penalties ,
58
+ logits = _apply_penalties (logits , input_metadata , presence_penalties ,
61
59
frequency_penalties , repetition_penalties )
62
60
63
61
# Apply temperature scaling.
@@ -146,7 +144,10 @@ def _get_penalties(
146
144
return presence_penalties , frequency_penalties , repetition_penalties
147
145
148
146
149
- def _get_output_tokens (input_metadata : InputMetadata ) -> List [List [int ]]:
147
+ def _get_prompt_and_output_tokens (
148
+ input_metadata : InputMetadata
149
+ ) -> Tuple [List [List [int ]], List [List [int ]]]:
150
+ prompt_tokens : List [List [int ]] = []
150
151
output_tokens : List [List [int ]] = []
151
152
for i , seq_group in enumerate (input_metadata .seq_groups ):
152
153
seq_ids , sampling_params = seq_group
@@ -155,11 +156,39 @@ def _get_output_tokens(input_metadata: InputMetadata) -> List[List[int]]:
155
156
# NOTE: prompt token positions do not need output tokens to
156
157
# compute penalties.
157
158
prompt_len = input_metadata .prompt_lens [i ]
159
+ prompt_tokens .extend ([] for _ in range (prompt_len - 1 ))
158
160
output_tokens .extend ([] for _ in range (prompt_len - 1 ))
159
161
for seq_id in seq_ids :
160
162
seq_data = input_metadata .seq_data [seq_id ]
163
+ prompt_tokens .append (seq_data .prompt_token_ids )
161
164
output_tokens .append (seq_data .output_token_ids )
162
- return output_tokens
165
+ return prompt_tokens , output_tokens
166
+
167
+
168
+ def _get_bin_counts_and_mask (
169
+ logits : torch .Tensor ,
170
+ tokens : List [List [int ]],
171
+ vocab_size : int ,
172
+ num_seqs : int ,
173
+ ) -> Tuple [torch .Tensor , torch .Tensor ]:
174
+ max_len = max (len (tokens ) for tokens in tokens )
175
+ padded_tokens = [
176
+ tokens + [vocab_size ] * (max_len - len (tokens )) for tokens in tokens
177
+ ]
178
+ tokens_tensor = torch .tensor (padded_tokens ,
179
+ dtype = torch .long ,
180
+ device = logits .device )
181
+
182
+ # Compute the bin counts for the tokens.
183
+ # vocab_size + 1 for padding.
184
+ bin_counts = torch .zeros ((num_seqs , vocab_size + 1 ),
185
+ dtype = torch .long ,
186
+ device = logits .device )
187
+ bin_counts .scatter_add_ (1 , tokens_tensor , torch .ones_like (tokens_tensor ))
188
+ bin_counts = bin_counts [:, :vocab_size ]
189
+ mask = bin_counts > 0
190
+
191
+ return bin_counts , mask
163
192
164
193
165
194
def _apply_logits_processors (logits : torch .Tensor ,
@@ -186,15 +215,13 @@ def _apply_logits_processors(logits: torch.Tensor,
186
215
187
216
def _apply_penalties (
188
217
logits : torch .Tensor ,
189
- output_tokens : List [ List [ int ]] ,
218
+ input_metadata : InputMetadata ,
190
219
presence_penalties : List [float ],
191
220
frequency_penalties : List [float ],
192
221
repetition_penalties : List [float ],
193
222
) -> torch .Tensor :
194
223
num_seqs , vocab_size = logits .shape
195
224
for i in range (num_seqs ):
196
- if not output_tokens [i ]:
197
- continue
198
225
p = presence_penalties [i ]
199
226
f = frequency_penalties [i ]
200
227
r = repetition_penalties [i ]
@@ -206,24 +233,15 @@ def _apply_penalties(
206
233
# Return early if all sequences have zero penalties.
207
234
return logits
208
235
209
- max_output_len = max (len (tokens ) for tokens in output_tokens )
210
- padded_output_tokens = [
211
- tokens + [vocab_size ] * (max_output_len - len (tokens ))
212
- for tokens in output_tokens
213
- ]
214
- output_tokens_tensor = torch .tensor (padded_output_tokens ,
215
- dtype = torch .long ,
216
- device = logits .device )
236
+ prompt_tokens , output_tokens = (
237
+ _get_prompt_and_output_tokens (input_metadata ))
238
+ assert len (prompt_tokens ) == logits .shape [0 ]
239
+ assert len (output_tokens ) == logits .shape [0 ]
217
240
218
- # Compute the bin counts for the output tokens.
219
- # vocab_size + 1 for padding.
220
- bin_counts = torch .zeros ((num_seqs , vocab_size + 1 ),
221
- dtype = torch .long ,
222
- device = logits .device )
223
- bin_counts .scatter_add_ (1 , output_tokens_tensor ,
224
- torch .ones_like (output_tokens_tensor ))
225
- bin_counts = bin_counts [:, :vocab_size ] # Remove the padding bin.
226
- mask = bin_counts > 0
241
+ prompt_bin_counts , prompt_mask = _get_bin_counts_and_mask (
242
+ logits , prompt_tokens , vocab_size , num_seqs )
243
+ output_bin_counts , output_mask = _get_bin_counts_and_mask (
244
+ logits , output_tokens , vocab_size , num_seqs )
227
245
228
246
repetition_penalties = torch .tensor (repetition_penalties ,
229
247
dtype = logits .dtype ,
@@ -236,14 +254,14 @@ def _apply_penalties(
236
254
device = logits .device )
237
255
238
256
repetition_penalties = repetition_penalties [:, None ].repeat (1 , vocab_size )
239
- repetition_penalties [~ mask ] = 1.0
257
+ repetition_penalties [~ ( prompt_mask | output_mask ) ] = 1.0
240
258
logits = torch .where (logits > 0 , logits / repetition_penalties ,
241
259
logits * repetition_penalties )
242
260
243
261
# We follow the definition in OpenAI API.
244
262
# Refer to https://platform.openai.com/docs/api-reference/parameter-details
245
- logits -= frequency_penalties .unsqueeze (dim = 1 ) * bin_counts
246
- logits -= presence_penalties .unsqueeze (dim = 1 ) * mask
263
+ logits -= frequency_penalties .unsqueeze (dim = 1 ) * output_bin_counts
264
+ logits -= presence_penalties .unsqueeze (dim = 1 ) * output_mask
247
265
return logits
248
266
249
267
0 commit comments