1
1
import re
2
2
from abc import ABC , abstractmethod
3
+ from collections .abc import Callable , ItemsView , Iterable , Mapping , Sequence
3
4
from dataclasses import dataclass
4
5
from functools import lru_cache
5
6
from itertools import groupby
6
- from typing import (Any , Callable , Generic , Iterable , Mapping , NamedTuple ,
7
- Optional , Sequence , TypeVar , Union )
7
+ from typing import Any , Generic , NamedTuple , Optional , Protocol , TypeVar , Union
8
8
9
9
import numpy as np
10
10
from transformers import BatchFeature
18
18
MultiModalInputsV2 , MultiModalKwargs , PlaceholderRange ,
19
19
VideoItem )
20
20
21
-
22
- def _encode (
23
- tokenizer : AnyTokenizer ,
24
- text : str ,
25
- * ,
26
- add_special_tokens : bool = False ,
27
- ) -> list [int ]:
28
- """
29
- Backend-agnostic equivalent of HF's
30
- :code:`tokenizer.encode(text, add_special_tokens=...)`.
31
- """
32
- if isinstance (tokenizer , MistralTokenizer ):
33
- return tokenizer .tokenizer .encode (text ,
34
- bos = add_special_tokens ,
35
- eos = add_special_tokens )
36
-
37
- return tokenizer .encode (text , add_special_tokens = add_special_tokens )
38
-
39
-
40
- @lru_cache (maxsize = 2048 )
41
- def _cached_encode (
42
- tokenizer : AnyTokenizer ,
43
- text : str ,
44
- * ,
45
- add_special_tokens : bool = False ,
46
- ) -> list [int ]:
47
- return _encode (tokenizer , text , add_special_tokens = add_special_tokens )
48
-
49
-
50
- def _decode (
51
- tokenizer : AnyTokenizer ,
52
- token_ids : list [int ],
53
- * ,
54
- skip_special_tokens : bool = False ,
55
- ) -> str :
56
- """
57
- Backend-agnostic equivalent of HF's
58
- :code:`tokenizer.decode(token_ids, skip_special_tokens=...)`.
59
- """
60
- return tokenizer .decode (token_ids , skip_special_tokens = skip_special_tokens )
61
-
62
-
63
- @lru_cache (maxsize = 2048 )
64
- def _cached_decode (
65
- tokenizer : AnyTokenizer ,
66
- token_ids : tuple [int , ...],
67
- * ,
68
- skip_special_tokens : bool = False ,
69
- ) -> str :
70
- return _decode (tokenizer ,
71
- list (token_ids ),
72
- skip_special_tokens = skip_special_tokens )
73
-
74
-
75
21
PromptSegment : TypeAlias = Union [str , list [int ]]
76
22
77
23
78
24
def bind_segment (
79
- prompt_segment : PromptSegment ,
25
+ segment : PromptSegment ,
80
26
tokenizer : AnyTokenizer ,
81
27
) -> "_BoundPromptSegment" :
28
+ """
29
+ Bind a text or token prompt to a tokenizer so that it can be
30
+ lazily converted into the other format on demand.
31
+ """
82
32
return _BoundPromptSegment (
83
33
tokenizer = tokenizer ,
84
- _text = prompt_segment if isinstance (prompt_segment , str ) else None ,
85
- _token_ids = prompt_segment
86
- if isinstance (prompt_segment , list ) else None ,
34
+ _text = segment if isinstance (segment , str ) else None ,
35
+ _token_ids = segment if isinstance (segment , list ) else None ,
87
36
)
88
37
89
38
@@ -163,6 +112,78 @@ class MultiModalProcessingMetadataBuiltins(TypedDict, total=False):
163
112
"""
164
113
165
114
115
+ def _encode (
116
+ tokenizer : AnyTokenizer ,
117
+ text : str ,
118
+ * ,
119
+ add_special_tokens : bool = False ,
120
+ ) -> list [int ]:
121
+ """
122
+ Backend-agnostic equivalent of HF's
123
+ :code:`tokenizer.encode(text, add_special_tokens=...)`.
124
+ """
125
+ if isinstance (tokenizer , MistralTokenizer ):
126
+ return tokenizer .tokenizer .encode (text ,
127
+ bos = add_special_tokens ,
128
+ eos = add_special_tokens )
129
+
130
+ return tokenizer .encode (text , add_special_tokens = add_special_tokens )
131
+
132
+
133
+ @lru_cache (maxsize = 2048 )
134
+ def _cached_encode (
135
+ tokenizer : AnyTokenizer ,
136
+ text : str ,
137
+ * ,
138
+ add_special_tokens : bool = False ,
139
+ ) -> list [int ]:
140
+ return _encode (tokenizer , text , add_special_tokens = add_special_tokens )
141
+
142
+
143
+ def _decode (
144
+ tokenizer : AnyTokenizer ,
145
+ token_ids : list [int ],
146
+ * ,
147
+ skip_special_tokens : bool = False ,
148
+ ) -> str :
149
+ """
150
+ Backend-agnostic equivalent of HF's
151
+ :code:`tokenizer.decode(token_ids, skip_special_tokens=...)`.
152
+ """
153
+ return tokenizer .decode (token_ids , skip_special_tokens = skip_special_tokens )
154
+
155
+
156
+ @lru_cache (maxsize = 2048 )
157
+ def _cached_decode (
158
+ tokenizer : AnyTokenizer ,
159
+ token_ids : tuple [int , ...],
160
+ * ,
161
+ skip_special_tokens : bool = False ,
162
+ ) -> str :
163
+ return _decode (tokenizer ,
164
+ list (token_ids ),
165
+ skip_special_tokens = skip_special_tokens )
166
+
167
+
168
+ class _HasModalityAttr (Protocol ):
169
+ modality : str
170
+
171
+
172
+ class _HasModalityProp (Protocol ):
173
+
174
+ @property
175
+ def modality (self ) -> str :
176
+ ...
177
+
178
+
179
+ _M = TypeVar ("_M" , bound = Union [_HasModalityAttr , _HasModalityProp ])
180
+
181
+
182
+ def full_groupby_modality (values : Iterable [_M ]) -> ItemsView [str , list [_M ]]:
183
+ """Convenience function to apply :func:`full_groupby` based on modality."""
184
+ return full_groupby (values , key = lambda x : x .modality )
185
+
186
+
166
187
@dataclass
167
188
class _BoundPromptSegment :
168
189
tokenizer : AnyTokenizer
@@ -237,7 +258,7 @@ def to_multi_format(data: MultiModalDataDict) -> dict[str, list[Any]]:
237
258
return multi_data
238
259
239
260
240
- class _TokenRun (TypedDict ):
261
+ class _TokenRun (NamedTuple ):
241
262
token_id : int
242
263
243
264
start_idx : int
@@ -257,19 +278,20 @@ def iter_token_runs(token_ids: list[int]) -> Iterable[_TokenRun]:
257
278
start_idx += length
258
279
259
280
260
- class _BoundPlaceholderRange (TypedDict ):
281
+ class _BoundPlaceholderRange (NamedTuple ):
261
282
modality : str
262
283
offset : int
263
284
length : int
264
285
265
286
266
287
def iter_placeholders (
267
288
prompt_repls : Sequence [_BoundPromptReplacement [Any ]],
268
- new_token_ids : list [int ],
289
+ token_ids : list [int ],
269
290
* ,
270
291
min_placeholder_count : int ,
271
292
) -> Iterable [_BoundPlaceholderRange ]:
272
- repls_by_modality = full_groupby (prompt_repls , key = lambda x : x .modality )
293
+ """Yield each set of placeholder tokens found in :code:`token_ids`."""
294
+ repls_by_modality = full_groupby_modality (prompt_repls )
273
295
274
296
placeholder_ids_by_modality = {
275
297
modality : {
@@ -280,15 +302,15 @@ def iter_placeholders(
280
302
for modality , repls in repls_by_modality
281
303
}
282
304
283
- for run_info in iter_token_runs (new_token_ids ):
284
- if run_info [ " length" ] > min_placeholder_count :
305
+ for run_info in iter_token_runs (token_ids ):
306
+ if run_info . length > min_placeholder_count :
285
307
for (modality ,
286
308
placeholder_ids ) in placeholder_ids_by_modality .items ():
287
- if run_info [ " token_id" ] in placeholder_ids :
309
+ if run_info . token_id in placeholder_ids :
288
310
yield _BoundPlaceholderRange (
289
311
modality = modality ,
290
- offset = run_info [ " start_idx" ] ,
291
- length = run_info [ " length" ] ,
312
+ offset = run_info . start_idx ,
313
+ length = run_info . length ,
292
314
)
293
315
294
316
@@ -398,6 +420,7 @@ def find_token_matches(
398
420
prompt : list [int ],
399
421
prompt_repls : Sequence [_BoundPromptReplacement [_T ]],
400
422
) -> list [_PromptReplacementTokenMatch [_T ]]:
423
+ """Return each target of :code:`prompt_repls` found in :code:`prompt`."""
401
424
return [
402
425
_PromptReplacementTokenMatch (prompt_repl , match )
403
426
for prompt_repl in prompt_repls
@@ -409,17 +432,22 @@ def find_text_matches(
409
432
prompt : str ,
410
433
prompt_repls : Sequence [_BoundPromptReplacement [_T ]],
411
434
) -> list [_PromptReplacementTextMatch [_T ]]:
435
+ """Return each target of :code:`prompt_repls` found in :code:`prompt`."""
412
436
return [
413
437
_PromptReplacementTextMatch (prompt_repl , match )
414
438
for prompt_repl in prompt_repls
415
439
for match in re .finditer (re .escape (prompt_repl .target .text ), prompt )
416
440
]
417
441
418
442
419
- def unique_sort_matches (
443
+ def resolve_matches (
420
444
prompt : _S ,
421
445
matches : Sequence [_PromptReplacementMatch [_T , _S ]],
422
446
) -> list [_PromptReplacementMatch [_T , _S ]]:
447
+ """
448
+ Resolve :code:`matches` to ensure that there are no overlapping matches,
449
+ and sort them such that earlier matches take priority over later ones.
450
+ """
423
451
num_matches_by_idx = np .zeros (len (prompt ), dtype = int )
424
452
for match in matches :
425
453
num_matches_by_idx [match .start_idx :match .end_idx ] += 1
@@ -443,8 +471,7 @@ def _replace_matches(
443
471
prev_end_idx = 0
444
472
next_idx_by_modality = {modality : 0 for modality in mm_items_by_modality }
445
473
446
- # Earlier matches take priority over later ones
447
- for match in unique_sort_matches (prompt , matches ):
474
+ for match in resolve_matches (prompt , matches ):
448
475
modality = match .modality
449
476
mm_items = mm_items_by_modality [modality ]
450
477
@@ -471,6 +498,7 @@ def replace_token_matches(
471
498
mm_items_by_modality : Mapping [str , list [_T ]],
472
499
hf_inputs : BatchFeature ,
473
500
) -> list [int ]:
501
+ """Apply :code:`prompt_repls` to :code:`prompt`."""
474
502
if not matches :
475
503
return prompt
476
504
@@ -490,6 +518,7 @@ def replace_text_matches(
490
518
mm_items_by_modality : Mapping [str , list [_T ]],
491
519
hf_inputs : BatchFeature ,
492
520
) -> str :
521
+ """Apply :code:`prompt_repls` to :code:`prompt`."""
493
522
if not matches :
494
523
return prompt
495
524
@@ -581,8 +610,7 @@ def _apply_prompt_replacements(
581
610
582
611
if all (
583
612
len (matches ) >= len (mm_data [modality ])
584
- for modality , matches in full_groupby (token_matches ,
585
- key = lambda x : x .modality )
613
+ for modality , matches in full_groupby_modality (token_matches )
586
614
): # yapf: disable
587
615
token_ids = replace_token_matches (
588
616
token_ids ,
@@ -647,11 +675,10 @@ def apply(
647
675
648
676
mm_placeholders = {
649
677
modality : [
650
- PlaceholderRange (offset = item [ " offset" ] , length = item [ " length" ] )
678
+ PlaceholderRange (offset = item . offset , length = item . length )
651
679
for item in items
652
680
]
653
- for modality , items in full_groupby (all_placeholders ,
654
- key = lambda x : x ["modality" ])
681
+ for modality , items in full_groupby_modality (all_placeholders )
655
682
}
656
683
657
684
return MultiModalInputsV2 (
0 commit comments