17
17
from .inputs import MultiModalDataDict , MultiModalKwargs , NestedTensors
18
18
from .processing import (BaseMultiModalProcessor , BaseProcessingInfo ,
19
19
ProcessingCache )
20
- from .profiling import BaseDummyInputsBuilder
20
+ from .profiling import BaseDummyInputsBuilder , MultiModalProfiler
21
21
from .utils import cached_get_tokenizer
22
22
from .video import VideoPlugin
23
23
@@ -282,13 +282,13 @@ def get_max_tokens_per_item_by_nonzero_modality(
282
282
This is currently directly used only in V1 for profiling the memory
283
283
usage of a model.
284
284
"""
285
- limits_per_plugin = self ._limits_by_model [ model_config ]
285
+ mm_limits = self .get_mm_limits_per_prompt ( model_config )
286
286
287
287
return {
288
288
key : max_tokens_per_mm_item
289
289
for key , max_tokens_per_mm_item in
290
290
self .get_max_tokens_per_item_by_modality (model_config ).items ()
291
- if limits_per_plugin [key ] > 0
291
+ if mm_limits [key ] > 0
292
292
}
293
293
294
294
def get_max_tokens_by_modality (
@@ -304,10 +304,10 @@ def get_max_tokens_by_modality(
304
304
Note:
305
305
This should be called after :meth:`init_mm_limits_per_prompt`.
306
306
"""
307
- limits_per_plugin = self ._limits_by_model [ model_config ]
307
+ mm_limits = self .get_mm_limits_per_prompt ( model_config )
308
308
309
309
return {
310
- key : limits_per_plugin [key ] * max_tokens_per_mm_item
310
+ key : mm_limits [key ] * max_tokens_per_mm_item
311
311
for key , max_tokens_per_mm_item in
312
312
self .get_max_tokens_per_item_by_modality (model_config ).items ()
313
313
}
@@ -371,6 +371,15 @@ def get_mm_limits_per_prompt(
371
371
Note:
372
372
This should be called after :meth:`init_mm_limits_per_prompt`.
373
373
"""
374
+ if self .has_processor (model_config ):
375
+ tokenizer = cached_get_tokenizer (
376
+ model_config .tokenizer ,
377
+ trust_remote_code = model_config .trust_remote_code ,
378
+ )
379
+ processor = self .create_processor (model_config , tokenizer )
380
+ profiler = MultiModalProfiler (processor )
381
+ return profiler .get_mm_limits ()
382
+
374
383
return self ._limits_by_model [model_config ]
375
384
376
385
def register_processor (
0 commit comments