10
10
from vllm .config import CompilationLevel , VllmConfig
11
11
from vllm .distributed .parallel_state import graph_capture
12
12
from vllm .forward_context import set_forward_context
13
- from vllm .inputs import INPUT_REGISTRY , InputRegistry
13
+ from vllm .inputs import INPUT_REGISTRY
14
14
from vllm .logger import init_logger
15
15
from vllm .model_executor .model_loader import get_model
16
- from vllm .multimodal import MultiModalKwargs
16
+ from vllm .multimodal import MULTIMODAL_REGISTRY , MultiModalKwargs
17
17
from vllm .sampling_params import SamplingType
18
18
from vllm .utils import (STR_DTYPE_TO_TORCH_DTYPE , DeviceMemoryProfiler ,
19
19
LayerBlockType , cdiv , is_pin_memory_available )
20
20
from vllm .v1 .attention .backends .flash_attn import (FlashAttentionBackend ,
21
21
FlashAttentionMetadata )
22
+ from vllm .v1 .engine .mm_input_mapper import MMInputMapperClient
22
23
from vllm .v1 .outputs import ModelRunnerOutput
23
24
from vllm .v1 .sample .metadata import SamplingMetadata
24
25
from vllm .v1 .worker .gpu_input_batch import CachedRequestState , InputBatch
@@ -35,7 +36,6 @@ def __init__(
35
36
self ,
36
37
vllm_config : VllmConfig ,
37
38
device : torch .device ,
38
- input_registry : InputRegistry = INPUT_REGISTRY ,
39
39
):
40
40
self .vllm_config = vllm_config
41
41
self .model_config = vllm_config .model_config
@@ -77,7 +77,12 @@ def __init__(
77
77
self .hidden_size = model_config .get_hidden_size ()
78
78
79
79
# Multi-modal data support
80
- self .input_registry = input_registry
80
+ self .input_registry = INPUT_REGISTRY
81
+ self .mm_registry = MULTIMODAL_REGISTRY
82
+ # NOTE: mm_input_mapper is only used for memory profiling.
83
+ self .mm_input_mapper = MMInputMapperClient (self .model_config )
84
+ self .max_num_encoder_input_tokens = self .scheduler_config .max_num_encoder_input_tokens # noqa: E501
85
+ self .encoder_cache_size = self .scheduler_config .encoder_cache_size
81
86
82
87
# Lazy initialization
83
88
# self.model: nn.Module # Set after load_model
@@ -599,8 +604,6 @@ def _dummy_run(
599
604
return hidden_states
600
605
601
606
def profile_run (self ) -> None :
602
- # TODO(woosuk): Profile the max memory usage of the encoder and
603
- # the encoder cache.
604
607
# use an empty tensor instead of `None`` to force Dynamo to pass
605
608
# it by reference, rather by specializing on the value `None`.
606
609
# the `dtype` argument does not matter, and we use `float32` as
@@ -612,6 +615,57 @@ def profile_run(self) -> None:
612
615
torch .tensor ([], dtype = torch .float32 , device = self .device )
613
616
for _ in range (self .num_attn_layers )
614
617
]
618
+
619
+ # Profile with multimodal encoder & encoder cache.
620
+ # TODO (ywang96): generalize this beyond image modality since
621
+ # mm_input_mapper only supports image inputs.
622
+ if self .is_multimodal_model :
623
+
624
+ # Create dummy batch of multimodal inputs.
625
+ dummy_request_data = self .input_registry .dummy_data_for_profiling (
626
+ model_config = self .model_config ,
627
+ seq_len = self .max_num_tokens ,
628
+ mm_registry = self .mm_registry ,
629
+ )
630
+ dummy_mm_data = dummy_request_data .multi_modal_data
631
+ dummy_mm_kwargs , _ = self .mm_input_mapper .process_inputs (
632
+ mm_data = dummy_mm_data ,
633
+ mm_hashes = None ,
634
+ mm_processor_kwargs = None ,
635
+ precomputed_mm_inputs = None )
636
+
637
+ # NOTE: Currently model is profiled with a single non-text
638
+ # modality even when it supports multiple.
639
+ max_tokens_per_mm_item = max (
640
+ self .mm_registry .get_max_tokens_per_item_by_modality (
641
+ self .model_config ).values ())
642
+
643
+ max_num_mm_items = min (
644
+ self .max_num_encoder_input_tokens ,
645
+ self .encoder_cache_size ) // max_tokens_per_mm_item
646
+
647
+ # Dummy data definition in V0 may contain multiple multimodal items
648
+ # (e.g, multiple images) for a single request, therefore here we
649
+ # always replicate first item by max_num_mm_items times since in V1
650
+ # they are scheduled to be processed separately.
651
+ batched_dummy_mm_inputs = MultiModalKwargs .batch (
652
+ [dummy_mm_kwargs [0 ]] * max_num_mm_items )
653
+ batched_dummy_mm_inputs = MultiModalKwargs .as_kwargs (
654
+ batched_dummy_mm_inputs , device = self .device )
655
+
656
+ # Run multimodal encoder.
657
+ dummy_encoder_outputs = self .model .get_multimodal_embeddings (
658
+ ** batched_dummy_mm_inputs )
659
+ assert len (dummy_encoder_outputs ) == max_num_mm_items , (
660
+ "Expected dimension 0 of encoder outputs to match the number "
661
+ f"of multimodal data items: { max_num_mm_items } , got "
662
+ f"{ len (dummy_encoder_outputs )= } instead. This is most likely "
663
+ "due to the 'get_multimodal_embeddings' method of the model "
664
+ "not implemented correctly." )
665
+
666
+ # Cache the dummy encoder outputs.
667
+ self .encoder_cache ["tmp" ] = dict (enumerate (dummy_encoder_outputs ))
668
+
615
669
# Trigger compilation for general shape.
616
670
hidden_states = self ._dummy_run (self .model , self .max_num_tokens ,
617
671
dummy_kv_caches )
@@ -620,6 +674,7 @@ def profile_run(self) -> None:
620
674
# TODO(woosuk): Consider the memory usage of the sampler.
621
675
torch .cuda .synchronize ()
622
676
del hidden_states , logits
677
+ self .encoder_cache .clear ()
623
678
gc .collect ()
624
679
625
680
def capture_model (self ) -> None :
0 commit comments