1
1
from dataclasses import dataclass
2
2
from functools import lru_cache
3
- from heapq import nsmallest
4
3
from itertools import groupby
5
4
from typing import (Any , Callable , Generic , List , Mapping , NamedTuple ,
6
5
Optional , TypeVar , Union , final )
@@ -147,6 +146,9 @@ def _encode(
147
146
return tokenizer .encode (text , add_special_tokens = add_special_tokens )
148
147
149
148
149
+ _cached_encode = lru_cache (_encode )
150
+
151
+
150
152
@lru_cache
151
153
def _max_vocab_token_len (tokenizer : AnyTokenizer ) -> int :
152
154
return max (len (token_text ) for token_text in tokenizer .get_vocab ())
@@ -157,7 +159,10 @@ class _TokenMatch(NamedTuple):
157
159
end_idx : int
158
160
159
161
160
- def find_token_match (token_ids : List [int ], match_ids : List [int ]):
162
+ def find_token_match (
163
+ token_ids : List [int ],
164
+ match_ids : List [int ],
165
+ ) -> Optional [_TokenMatch ]:
161
166
"""
162
167
Find the first occurrence of :code:`match_ids` in :code:`token_ids`.
163
168
"""
@@ -171,25 +176,49 @@ def find_token_match(token_ids: List[int], match_ids: List[int]):
171
176
return None
172
177
173
178
174
- class _Candidate (NamedTuple ):
179
+ class _TokenMatchFromTextCandidate (NamedTuple ):
175
180
start_idx : int
176
181
end_idx : int
177
- distance : int
182
+
183
+ match_text_prefix : str
184
+ match_text_suffix : str
185
+
186
+ @property
187
+ def distance (self ) -> int :
188
+ return len (self .match_text_prefix ) + len (self .match_text_suffix )
189
+
190
+
191
+ class _TokenMatchFromText (NamedTuple ):
192
+ start_idx : int
193
+ end_idx : int
194
+
195
+ match_prefix : List [int ]
196
+ match_suffix : List [int ]
197
+
198
+ match_text_prefix : str
199
+ match_text_suffix : str
178
200
179
201
180
202
def find_token_match_by_text (
181
203
tokenizer : AnyTokenizer ,
182
204
token_ids : List [int ],
183
205
token_text : str ,
184
206
match_text : str ,
185
- ):
207
+ ) -> Optional [ _TokenMatchFromText ] :
186
208
"""
187
209
Find the first occurrence of the tokenized :code:`match_text` in
188
210
:code:`token_ids`.
189
211
"""
190
- match_ids = _encode (tokenizer , match_text , add_special_tokens = False )
212
+ match_ids = _cached_encode (tokenizer , match_text , add_special_tokens = False )
191
213
if (match := find_token_match (token_ids , match_ids )):
192
- return match
214
+ return _TokenMatchFromText (
215
+ match .start_idx ,
216
+ match .end_idx ,
217
+ match_prefix = [],
218
+ match_suffix = [],
219
+ match_text_prefix = "" ,
220
+ match_text_suffix = "" ,
221
+ )
193
222
194
223
# When `match_text` is not mapped to a special token ID,
195
224
# it may be tokenized differently based on the surrounding tokens
@@ -202,37 +231,41 @@ def find_token_match_by_text(
202
231
text_end_idx = text_start_idx + len (match_text )
203
232
204
233
# In case the left/right side of `match_text` is fused with the
205
- # string immediately before/after it during tokenization
234
+ # string immediately before/after it as a single token
206
235
text_buffer = _max_vocab_token_len (tokenizer ) - 1
207
236
left_text = token_text [:max (0 , text_start_idx - text_buffer )]
208
237
right_text = token_text [:text_end_idx + text_buffer ]
209
238
210
239
left_idx = len (_encode (tokenizer , left_text , add_special_tokens = False ))
211
240
right_idx = len (_encode (tokenizer , right_text , add_special_tokens = True ))
212
- avg_idx = (left_idx + right_idx ) // 2
213
241
window_size = len (match_ids )
214
242
215
- valid_candidates = list [_Candidate ]()
216
- for start_idx in sorted (range (left_idx , right_idx - window_size + 1 ),
217
- key = lambda x : abs (x - avg_idx )):
243
+ best_distance = len (token_text )
244
+ best_candidate = None
245
+
246
+ for start_idx in range (left_idx , right_idx - window_size + 1 ):
218
247
end_idx = start_idx + window_size
219
248
candidate_text = tokenizer .decode (
220
249
token_ids [start_idx :end_idx ],
250
+ # In case match_text is a special token
221
251
skip_special_tokens = False ,
222
252
)
223
253
224
254
if match_text in candidate_text :
225
- candidate = _Candidate (
226
- start_idx = start_idx ,
227
- end_idx = end_idx ,
228
- distance = len ( candidate_text ) - len (match_text ),
255
+ candidate = _TokenMatchFromTextCandidate (
256
+ start_idx ,
257
+ end_idx ,
258
+ * candidate_text . split (match_text , 1 ),
229
259
)
230
- valid_candidates .append (candidate )
231
260
232
- if candidate .distance == 0 :
261
+ if candidate .distance < best_distance :
262
+ best_candidate = candidate
263
+ best_distance = candidate .distance
264
+
265
+ if best_distance == 0 :
233
266
break
234
267
235
- assert len ( valid_candidates ) > 0 , dict (
268
+ assert best_candidate is not None , dict (
236
269
# To facilitate debugging
237
270
token_ids = token_ids ,
238
271
match_ids = match_ids ,
@@ -242,8 +275,25 @@ def find_token_match_by_text(
242
275
right_idx = right_idx ,
243
276
)
244
277
245
- best_candidate , = nsmallest (1 , valid_candidates , key = lambda x : x .distance )
246
- return best_candidate .start_idx , best_candidate .end_idx
278
+ match_token_prefix = _cached_encode (
279
+ tokenizer ,
280
+ best_candidate .match_text_prefix ,
281
+ add_special_tokens = False ,
282
+ )
283
+ match_token_suffix = _cached_encode (
284
+ tokenizer ,
285
+ best_candidate .match_text_suffix ,
286
+ add_special_tokens = False ,
287
+ )
288
+
289
+ return _TokenMatchFromText (
290
+ start_idx = best_candidate .start_idx ,
291
+ end_idx = best_candidate .end_idx ,
292
+ match_prefix = match_token_prefix ,
293
+ match_suffix = match_token_suffix ,
294
+ match_text_prefix = best_candidate .match_text_prefix ,
295
+ match_text_suffix = best_candidate .match_text_suffix ,
296
+ )
247
297
248
298
249
299
def apply_placeholders (
@@ -253,7 +303,7 @@ def apply_placeholders(
253
303
match_text : str ,
254
304
replacement_id : int ,
255
305
replacement_count : int ,
256
- ) -> Optional [PlaceholderRange ]:
306
+ ) -> tuple [ List [ int ], str , Optional [PlaceholderRange ] ]:
257
307
"""
258
308
Find the first occurrence of the tokenized :code:`match_text` in
259
309
:code:`token_ids`, and replace it with
@@ -269,13 +319,25 @@ def apply_placeholders(
269
319
)
270
320
271
321
if match is None :
272
- return None
322
+ return token_ids , token_text , None
323
+
324
+ start_idx , end_idx , prefix_ids , suffix_ids , prefix_str , suffix_str = match
273
325
274
- # TODO(youkaichao): Don't update new_token_ids
275
- start_idx , end_idx = match
276
- token_ids [start_idx :end_idx ] = [replacement_id ] * replacement_count
326
+ replacement_ids = (prefix_ids + [replacement_id ] * replacement_count +
327
+ suffix_ids )
328
+ replacement_text = tokenizer .decode (
329
+ replacement_ids ,
330
+ # In case match_text is a special token
331
+ skip_special_tokens = False ,
332
+ )
333
+
334
+ token_ids [start_idx :end_idx ] = replacement_ids
335
+ token_text = token_text .replace (prefix_str + match_text + suffix_str ,
336
+ replacement_text , 1 )
277
337
278
- return PlaceholderRange (offset = start_idx , length = replacement_count )
338
+ return (token_ids , token_text ,
339
+ PlaceholderRange (offset = start_idx + len (prefix_ids ),
340
+ length = replacement_count ))
279
341
280
342
281
343
class MultiModalProcessor :
@@ -318,6 +380,7 @@ def apply(
318
380
new_token_ids , = processed_inputs .pop ("input_ids" ).tolist ()
319
381
mm_kwargs = MultiModalKwargs (processed_inputs )
320
382
383
+ new_prompt = prompt
321
384
mm_placeholders : Mapping [str , List [PlaceholderRange ]] = {}
322
385
323
386
for modality , orig_inputs in to_multi_format (mm_data ).items ():
@@ -337,8 +400,9 @@ def apply(
337
400
if new_token_id in repl_token_ids :
338
401
modality_placeholders .append (run_info )
339
402
340
- # Otherwise, we insert them ourselves
341
- if not modality_placeholders :
403
+ if modality_placeholders :
404
+ new_prompt = tokenizer .decode (new_token_ids )
405
+ else : # Otherwise, we insert them ourselves
342
406
for item_idx , orig_item in enumerate (orig_inputs ):
343
407
for match_str , replacement in placeholder_repls .items ():
344
408
replacement_count = replacement ["count" ]
@@ -349,10 +413,14 @@ def apply(
349
413
item_idx ,
350
414
)
351
415
352
- placeholders = apply_placeholders (
416
+ (
417
+ new_token_ids ,
418
+ new_prompt ,
419
+ placeholders ,
420
+ ) = apply_placeholders (
353
421
tokenizer ,
354
422
new_token_ids ,
355
- prompt ,
423
+ new_prompt ,
356
424
match_str ,
357
425
replacement ["token_id" ],
358
426
replacement_count ,
@@ -365,7 +433,7 @@ def apply(
365
433
366
434
return MultiModalInputsV2 (
367
435
type = "multimodal" ,
368
- prompt = prompt ,
436
+ prompt = new_prompt ,
369
437
prompt_token_ids = new_token_ids ,
370
438
mm_kwargs = mm_kwargs ,
371
439
mm_placeholders = mm_placeholders ,
0 commit comments