diff --git a/examples/offline_inference/prithvi_geospatial_mae.py b/examples/offline_inference/prithvi_geospatial_mae.py new file mode 100644 index 00000000000..298f0801900 --- /dev/null +++ b/examples/offline_inference/prithvi_geospatial_mae.py @@ -0,0 +1,530 @@ +# SPDX-License-Identifier: Apache-2.0 +""" +This is a demo script showing how to use the +PrithviGeospatialMAE model with vLLM +This script is based on: https://huggingface.co/ibm-nasa-geospatial/Prithvi-EO-2.0-300M-TL-Sen1Floods11/blob/main/inference.py # noqa + +Target model weights: https://huggingface.co/ibm-nasa-geospatial/Prithvi-EO-2.0-300M-TL-Sen1Floods11/resolve/main/Prithvi-EO-V2-300M-TL-Sen1Floods11.pt # noqa + +The requirements for running this script are: +- Installing [terratorch, albumentations, rasterio] in your python environment +- downloading the model weights in a 'model' folder local to the script + (temporary measure until the proper config.json file is uploaded to HF) +- download an input example image (India_900498_S2Hand.tif) and place it in + the same folder with the script (or specify with the --data_file argument) + +Run the example: +python prithvi_geospatial_mae.py + +""" # noqa: E501 +import argparse +import datetime +import os +import re +from typing import List, Union + +import albumentations +import numpy as np +import rasterio +import torch +from einops import rearrange +from terratorch.datamodules import Sen1Floods11NonGeoDataModule + +from vllm import LLM + +NO_DATA = -9999 +NO_DATA_FLOAT = 0.0001 +OFFSET = 0 +PERCENTILE = 99 + +model_config = """{ + "architectures": ["PrithviGeoSpatialMAE"], + "num_classes": 0, + "pretrained_cfg": { + "task_args": { + "task": "SemanticSegmentationTask", + "model_factory": "EncoderDecoderFactory", + "loss": "ce", + "ignore_index": -1, + "lr": 0.001, + "freeze_backbone": false, + "freeze_decoder": false, + "plot_on_val": 10, + "optimizer": "AdamW", + "scheduler": "CosineAnnealingLR" + }, + "model_args": { + "backbone_pretrained": false, + "backbone": "prithvi_eo_v2_300_tl", + "decoder": "UperNetDecoder", + "decoder_channels": 256, + "decoder_scale_modules": true, + "num_classes": 2, + "rescale": true, + "backbone_bands": [ + "BLUE", + "GREEN", + "RED", + "NIR_NARROW", + "SWIR_1", + "SWIR_2" + ], + "head_dropout": 0.1, + "necks": [ + { + "name": "SelectIndices", + "indices": [ + 5, + 11, + 17, + 23 + ] + }, + { + "name": "ReshapeTokensToImage" + } + ] + }, + "optimizer_params" : { + "lr": 5.0e-05, + "betas": [0.9, 0.999], + "eps": [1.0e-08], + "weight_decay": 0.05, + "amsgrad": false, + "maximize": false, + "capturable": false, + "differentiable": false + }, + "scheduler_params" : { + "T_max": 50, + "eta_min": 0, + "last_epoch": -1, + "verbose": "deprecated" + } + }, + + + "torch_dtype": "float32" +} +""" + +# Temporarily creating the "config.json" for the model. +# This is going to disappear once the correct config.json is available on HF +with open(os.path.join(os.path.dirname(__file__), "./model/config.json"), + 'w') as config_file: + config_file.write(model_config) + +datamodule_config = { + 'bands': ['BLUE', 'GREEN', 'RED', 'NIR_NARROW', 'SWIR_1', 'SWIR_2'], + 'batch_size': + 16, + 'constant_scale': + 0.0001, + 'data_root': + '/dccstor/geofm-finetuning/datasets/sen1floods11', + 'drop_last': + True, + 'no_data_replace': + 0.0, + 'no_label_replace': + -1, + 'num_workers': + 8, + 'test_transform': [ + albumentations.Resize(always_apply=False, + height=448, + interpolation=1, + p=1, + width=448), + albumentations.pytorch.ToTensorV2(transpose_mask=False, + always_apply=True, + p=1.0) + ], +} + + +class PrithviMAE: + + def __init__(self): + print("Initializing PrithviMAE model") + self.model = LLM(model=os.path.join(os.path.dirname(__file__), + "./model"), + skip_tokenizer_init=True, + dtype="float32") + + def run(self, input_data, location_coords): + print("################ Running inference on vLLM ##############") + # merge the inputs into one data structure + mm_data = { + "pixel_values": + torch.empty(0) if input_data is None else input_data, + "location_coords": + torch.empty(0) if location_coords is None else location_coords + } + + prompt = {"prompt_token_ids": [1], "multi_modal_data": mm_data} + + outputs = self.model.encode(prompt, use_tqdm=False) + print( + "################ Inference done (it took seconds) ##############" + ) + + return outputs[0].outputs.data + + +def generate_datamodule(): + datamodule = Sen1Floods11NonGeoDataModule( + data_root=datamodule_config['data_root'], + batch_size=datamodule_config["batch_size"], + num_workers=datamodule_config["num_workers"], + bands=datamodule_config["bands"], + drop_last=datamodule_config["drop_last"], + test_transform=datamodule_config["test_transform" + ""]) + + return datamodule + + +def process_channel_group(orig_img, channels): + """ + Args: + orig_img: torch.Tensor representing original image (reference) + with shape = (bands, H, W). + channels: list of indices representing RGB channels. + + Returns: + torch.Tensor with shape (num_channels, height, width) for original image + """ + + orig_img = orig_img[channels, ...] + valid_mask = torch.ones_like(orig_img, dtype=torch.bool) + valid_mask[orig_img == NO_DATA_FLOAT] = False + + # Rescale (enhancing contrast) + max_value = max(3000, np.percentile(orig_img[valid_mask], PERCENTILE)) + min_value = OFFSET + + orig_img = torch.clamp((orig_img - min_value) / (max_value - min_value), 0, + 1) + + # No data as zeros + orig_img[~valid_mask] = 0 + + return orig_img + + +def read_geotiff(file_path: str): + """Read all bands from *file_path* and return image + meta info. + + Args: + file_path: path to image file. + + Returns: + np.ndarray with shape (bands, height, width) + meta info dict + """ + + with rasterio.open(file_path) as src: + img = src.read() + meta = src.meta + try: + coords = src.lnglat() + except Exception: + # Cannot read coords + coords = None + + return img, meta, coords + + +def save_geotiff(image, output_path: str, meta: dict): + """Save multi-band image in Geotiff file. + + Args: + image: np.ndarray with shape (bands, height, width) + output_path: path where to save the image + meta: dict with meta info. + """ + + with rasterio.open(output_path, "w", **meta) as dest: + for i in range(image.shape[0]): + dest.write(image[i, :, :], i + 1) + + return + + +def _convert_np_uint8(float_image: torch.Tensor): + image = float_image.numpy() * 255.0 + image = image.astype(dtype=np.uint8) + + return image + + +def load_example( + file_paths: List[str], + mean: List[float] = None, + std: List[float] = None, + indices: Union[list[int], None] = None, +): + """Build an input example by loading images in *file_paths*. + + Args: + file_paths: list of file paths . + mean: list containing mean values for each band in the images + in *file_paths*. + std: list containing std values for each band in the images + in *file_paths*. + + Returns: + np.array containing created example + list of meta info for each image in *file_paths* + """ + + imgs = [] + metas = [] + temporal_coords = [] + location_coords = [] + + for file in file_paths: + img, meta, coords = read_geotiff(file) + + # Rescaling (don't normalize on nodata) + img = np.moveaxis(img, 0, -1) # channels last for rescaling + if indices is not None: + img = img[..., indices] + if mean is not None and std is not None: + img = np.where(img == NO_DATA, NO_DATA_FLOAT, (img - mean) / std) + + imgs.append(img) + metas.append(meta) + if coords is not None: + location_coords.append(coords) + + try: + match = re.search(r'(\d{7,8}T\d{6})', file) + if match: + year = int(match.group(1)[:4]) + julian_day = match.group(1).split('T')[0][4:] + if len(julian_day) == 3: + julian_day = int(julian_day) + else: + julian_day = datetime.datetime.strptime( + julian_day, '%m%d').timetuple().tm_yday + temporal_coords.append([year, julian_day]) + except Exception as e: + print(f'Could not extract timestamp for {file} ({e})') + + imgs = np.stack(imgs, axis=0) # num_frames, H, W, C + imgs = np.moveaxis(imgs, -1, 0).astype("float32") + imgs = np.expand_dims(imgs, axis=0) # add batch di + + return imgs, temporal_coords, location_coords, metas + + +def run_model(input_data, + temporal_coords, + location_coords, + model, + datamodule, + img_size, + lightning_model=None): + # Reflect pad if not divisible by img_size + original_h, original_w = input_data.shape[-2:] + pad_h = (img_size - (original_h % img_size)) % img_size + pad_w = (img_size - (original_w % img_size)) % img_size + input_data = np.pad(input_data, + ((0, 0), (0, 0), (0, 0), (0, pad_h), (0, pad_w)), + mode="reflect") + + # Build sliding window + batch_size = 1 + batch = torch.tensor(input_data, device="cpu") + windows = (batch.unfold(3, img_size, + img_size).unfold(4, img_size, img_size)) + h1, w1 = windows.shape[3:5] + windows = rearrange(windows, + "b c t h1 w1 h w -> (b h1 w1) c t h w", + h=img_size, + w=img_size) + + # Split into batches if number of windows > batch_size + num_batches = windows.shape[0] // batch_size if windows.shape[ + 0] > batch_size else 1 + windows = torch.tensor_split(windows, num_batches, dim=0) + + if torch.cuda.is_available(): + device = torch.device('cuda') + else: + device = torch.device('cpu') + + if temporal_coords: + temporal_coords = torch.tensor(temporal_coords, + device=device).unsqueeze(0) + else: + temporal_coords = None + if location_coords: + location_coords = torch.tensor(location_coords[0], + device=device).unsqueeze(0) + else: + location_coords = None + + # Run model + pred_imgs = [] + for x in windows: + # Apply standardization + x = datamodule.test_transform( + image=x.squeeze().numpy().transpose(1, 2, 0)) + x = datamodule.aug(x)['image'] + + with torch.no_grad(): + x = x.to(device) + pred = model.run(x, location_coords=location_coords) + if lightning_model: + pred_lightning = lightning_model( + x, + temporal_coords=temporal_coords, + location_coords=location_coords) + pred_lightning = pred_lightning.output.detach().cpu() + if not torch.equal(pred, pred_lightning): + print("Inference output is not equal") + y_hat = pred.argmax(dim=1) + + y_hat = torch.nn.functional.interpolate(y_hat.unsqueeze(1).float(), + size=img_size, + mode="nearest") + + pred_imgs.append(y_hat) + + pred_imgs = torch.concat(pred_imgs, dim=0) + + # Build images from patches + pred_imgs = rearrange( + pred_imgs, + "(b h1 w1) c h w -> b c (h1 h) (w1 w)", + h=img_size, + w=img_size, + b=1, + c=1, + h1=h1, + w1=w1, + ) + + # Cut padded area back to original size + pred_imgs = pred_imgs[..., :original_h, :original_w] + + # Squeeze (batch size 1) + pred_imgs = pred_imgs[0] + + return pred_imgs + + +def main( + data_file: str, + output_dir: str, + rgb_outputs: bool, + input_indices: list[int] = None, +): + os.makedirs(output_dir, exist_ok=True) + + # Load model --------------------------------------------------------------- + + model_obj = PrithviMAE() + datamodule = generate_datamodule() + img_size = 256 # Size of Sen1Floods11 + + # Loading data ------------------------------------------------------------- + + input_data, temporal_coords, location_coords, meta_data = load_example( + file_paths=[data_file], + indices=input_indices, + ) + + meta_data = meta_data[0] # only one image + + if input_data.mean() > 1: + input_data = input_data / 10000 # Convert to range 0-1 + + # Running model ------------------------------------------------------------ + + channels = [ + datamodule_config['bands'].index(b) for b in ["RED", "GREEN", "BLUE"] + ] # BGR -> RGB + + pred = run_model(input_data, temporal_coords, location_coords, model_obj, + datamodule, img_size) + + # Save pred + meta_data.update(count=1, dtype="uint8", compress="lzw", nodata=0) + pred_file = os.path.join( + output_dir, + f"pred_{os.path.splitext(os.path.basename(data_file))[0]}.tiff") + save_geotiff(_convert_np_uint8(pred), pred_file, meta_data) + + # Save image + pred + meta_data.update(count=3, dtype="uint8", compress="lzw", nodata=0) + + if input_data.mean() < 1: + input_data = input_data * 10000 # Scale to 0-10000 + + rgb_orig = process_channel_group( + orig_img=torch.Tensor(input_data[0, :, 0, ...]), + channels=channels, + ) + + pred[pred == 0.] = np.nan + img_pred = rgb_orig * 0.7 + pred * 0.3 + img_pred[img_pred.isnan()] = rgb_orig[img_pred.isnan()] + + img_pred_file = os.path.join( + output_dir, + f"rgb_pred_{os.path.splitext(os.path.basename(data_file))[0]}.tiff") + save_geotiff( + image=_convert_np_uint8(img_pred), + output_path=img_pred_file, + meta=meta_data, + ) + + # Save image rgb + if rgb_outputs: + rgb_file = os.path.join( + output_dir, "original_rgb_" + f"{os.path.splitext(os.path.basename(data_file))[0]}.tiff") + save_geotiff( + image=_convert_np_uint8(rgb_orig), + output_path=rgb_file, + meta=meta_data, + ) + + +if __name__ == "__main__": + parser = argparse.ArgumentParser("MAE run inference", add_help=False) + + parser.add_argument( + "--data_file", + type=str, + default="./India_900498_S2Hand.tif", + help="Path to the file.", + ) + parser.add_argument( + "--output_dir", + type=str, + default="output", + help="Path to the directory where to save outputs.", + ) + parser.add_argument( + "--input_indices", + default=[1, 2, 3, 8, 11, 12], + type=int, + nargs="+", + help= + "0-based indices of the six Prithvi channels to be selected from the " + "input. By default selects [1,2,3,8,11,12] for S2L1C data.", + ) + parser.add_argument( + "--rgb_outputs", + action="store_true", + help="If present, output files will only contain RGB channels. " + "Otherwise, all bands will be saved.", + ) + args = parser.parse_args() + + main(**vars(args)) diff --git a/tests/models/registry.py b/tests/models/registry.py index 20787fe008a..48abfaddc46 100644 --- a/tests/models/registry.py +++ b/tests/models/registry.py @@ -213,6 +213,10 @@ def check_available_online( "Phi3VForCausalLM": _HfExamplesInfo("TIGER-Lab/VLM2Vec-Full", trust_remote_code=True), "Qwen2VLForConditionalGeneration": _HfExamplesInfo("MrLight/dse-qwen2-2b-mrl-v1"), # noqa: E501 + # The model on Huggingface is currently being updated, + # hence I temporarily mark it as not available online + "PrithviGeoSpatialMAE": _HfExamplesInfo("ibm-nasa-geospatial/Prithvi-EO-2.0-300M-TL-Sen1Floods11", # noqa: E501 + is_available_online=False), } _CROSS_ENCODER_EXAMPLE_MODELS = { diff --git a/vllm/attention/backends/placeholder_attn.py b/vllm/attention/backends/placeholder_attn.py index 9f6e731afd1..d15e0d01917 100644 --- a/vllm/attention/backends/placeholder_attn.py +++ b/vllm/attention/backends/placeholder_attn.py @@ -316,9 +316,14 @@ def build(self, seq_lens: List[int], query_lens: List[int], -1 if cuda graph is not used. batch_size: The maybe padded batch size. """ - for inter_data in self.input_builder.inter_data_list: - self._add_seq_group(inter_data, - self.input_builder.chunked_prefill_enabled) + + # Some input builders such as ModelInputForCPUBuilder do not have the + # "inter_data_list" attribute. + # Let's check inter_data_list exists before we reference it. + if hasattr(self.input_builder, "inter_data_list"): + for inter_data in self.input_builder.inter_data_list: + self._add_seq_group(inter_data, + self.input_builder.chunked_prefill_enabled) device = self.runner.device use_captured_graph = cuda_graph_pad_size != -1 diff --git a/vllm/inputs/preprocess.py b/vllm/inputs/preprocess.py index 4d8f28cb041..656f2f2b766 100644 --- a/vllm/inputs/preprocess.py +++ b/vllm/inputs/preprocess.py @@ -254,14 +254,18 @@ def _process_multimodal( Apply the model's multi-modal processor to a multi-modal prompt, returning the corresponding token IDs and metadata. """ - tokenizer_group = self.get_tokenizer_group() - tokenizer = tokenizer_group.get_lora_tokenizer(lora_request) + # At the moment on model (PrithviGeoSpatialMAE) requires to be + # initialized without a tokenizer while using also multi-modal + # input. + if not self.tokenizer: + tokenizer = None + else: + tokenizer_group = self.get_tokenizer_group() + tokenizer = tokenizer_group.get_lora_tokenizer(lora_request) mm_processor = self.mm_registry.create_processor( self.model_config, tokenizer) - if isinstance(prompt, list): - prompt = tokenizer.decode(prompt) if mm_processor_kwargs is None: mm_processor_kwargs = {} @@ -275,9 +279,15 @@ async def _process_multimodal_async( lora_request: Optional[LoRARequest], ) -> MultiModalInputs: """Async version of :meth:`_process_multimodal`.""" - tokenizer_group = self.get_tokenizer_group() - tokenizer = await tokenizer_group.get_lora_tokenizer_async(lora_request - ) + # At the moment on model (PrithviGeoSpatialMAE) requires to be + # initialized without a tokenizer while using also multi-modal + # input. + if not self.tokenizer: + tokenizer = None + else: + tokenizer_group = self.get_tokenizer_group() + tokenizer = await tokenizer_group.get_lora_tokenizer_async( + lora_request) mm_processor = self.mm_registry.create_processor( self.model_config, tokenizer) diff --git a/vllm/model_executor/models/prithvi_geospatial_mae.py b/vllm/model_executor/models/prithvi_geospatial_mae.py new file mode 100644 index 00000000000..9383cbae11b --- /dev/null +++ b/vllm/model_executor/models/prithvi_geospatial_mae.py @@ -0,0 +1,238 @@ +# SPDX-License-Identifier: Apache-2.0 + +# Copyright 2025 The vLLM team. +# Copyright 2025 IBM. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Inference-only IBM/NASA Prithvi Geospatial model.""" +from typing import Iterable, List, Mapping, Optional, Set, Tuple, Union + +import torch +import torch.nn as nn +from transformers import BatchFeature + +from vllm.attention import AttentionMetadata +from vllm.config import VllmConfig +from vllm.model_executor.model_loader.weight_utils import default_weight_loader +from vllm.model_executor.models.interfaces import (IsAttentionFree, + SupportsMultiModal) +from vllm.model_executor.models.utils import AutoWeightsLoader +from vllm.model_executor.pooling_metadata import PoolingMetadata +from vllm.multimodal import MULTIMODAL_REGISTRY +from vllm.multimodal.inputs import (MultiModalDataDict, MultiModalFieldConfig, + MultiModalInputs, MultiModalKwargs) +from vllm.multimodal.parse import MultiModalDataItems +from vllm.multimodal.processing import (BaseMultiModalProcessor, + BaseProcessingInfo, PromptReplacement) +from vllm.multimodal.profiling import BaseDummyInputsBuilder, ProcessorInputs +from vllm.sequence import (IntermediateTensors, PoolerOutput, + PoolingSequenceGroupOutput) + + +class PrithviGeoSpatialMAEProcessingInfo(BaseProcessingInfo): + + def get_supported_mm_limits(self) -> Mapping[str, Optional[int]]: + return {"image": None} + + def get_mm_max_tokens_per_item(self, seq_len: int) -> Mapping[str, int]: + pass + + +class PrithviGeoSpatialMAEInputBuilder( + BaseDummyInputsBuilder[PrithviGeoSpatialMAEProcessingInfo]): + + def get_dummy_processor_inputs( + self, + seq_len: int, + mm_counts: Mapping[str, int], + ) -> ProcessorInputs: + return ProcessorInputs( + prompt_text="", + # This model input is fixed and is in the form of a torch Tensor. + # The size of pixel_values might change in the cases where we resize + # the input but never exceeds the dimensions below. + mm_data={ + "pixel_values": torch.full((1, 6, 512, 512), 1.0), + "location_coords": torch.full((1, 2), 1.0) + }) + + +class PrithviGeoSpatialMAEMultiModalProcessor(BaseMultiModalProcessor): + + def _get_mm_fields_config( + self, + hf_inputs: BatchFeature, + hf_processor_mm_kwargs: Mapping[str, object], + ) -> Mapping[str, MultiModalFieldConfig]: + return dict( + pixel_values=MultiModalFieldConfig.batched("image"), + location_coords=MultiModalFieldConfig.batched("image"), + ) + + def _get_prompt_replacements( + self, + mm_items: MultiModalDataItems, + hf_processor_mm_kwargs: Mapping[str, object], + out_mm_kwargs: MultiModalKwargs, + ) -> list[PromptReplacement]: + pass + + def _get_mm_fields_config( + self, + hf_inputs: BatchFeature, + hf_processor_mm_kwargs: Mapping[str, object], + ) -> Mapping[str, MultiModalFieldConfig]: + pass + + def apply( + self, + prompt: Union[str, list[int]], + mm_data: MultiModalDataDict, + hf_processor_mm_kwargs: Mapping[str, object], + ) -> MultiModalInputs: + mm_kwargs = {} + + for k, v in mm_data.items(): + mm_kwargs[k] = v + + return MultiModalInputs( + type="multimodal", + prompt=prompt, + prompt_token_ids=[1], + mm_kwargs=MultiModalKwargs(mm_kwargs), + mm_placeholders={}, + ) + + +@MULTIMODAL_REGISTRY.register_processor( + PrithviGeoSpatialMAEMultiModalProcessor, + info=PrithviGeoSpatialMAEProcessingInfo, + dummy_inputs=PrithviGeoSpatialMAEInputBuilder) +class PrithviGeoSpatialMAE(nn.Module, IsAttentionFree, SupportsMultiModal): + """ Prithvi Masked Autoencoder""" + + def _instantiate_model(self, config: dict) -> nn.Module | None: + + # We might be able/need to support different tasks with this same model + if config["task_args"]["task"] == "SemanticSegmentationTask": + from terratorch.cli_tools import SemanticSegmentationTask + task = SemanticSegmentationTask( + config["model_args"], + config["task_args"]["model_factory"], + loss=config["task_args"]["loss"], + lr=config["task_args"]["lr"], + ignore_index=config["task_args"]["ignore_index"], + optimizer=config["task_args"]["optimizer"], + optimizer_hparams=config["optimizer_params"], + scheduler=config["task_args"]["scheduler"], + scheduler_hparams=config["scheduler_params"], + plot_on_val=config["task_args"]["plot_on_val"], + freeze_decoder=config["task_args"]["freeze_decoder"], + freeze_backbone=config["task_args"]["freeze_backbone"]) + + return task.model + else: + return None + + def __init__(self, vllm_config: VllmConfig, prefix: str = ""): + super().__init__() + + # the actual model is dynamically instantiated using terratorch + # allowing us to perform changes to the model architecture + # at startup time (e.g., change the model decoder class.) + self.model = self._instantiate_model( + vllm_config.model_config.hf_config.to_dict()["pretrained_cfg"]) + if self.model is None: + raise ValueError( + "Unsupported task." + "Only SemanticSegmentationTask is supported for now" + "by PrithviGeospatialMAE.") + + def _parse_and_validate_multimodal_data( + self, **kwargs) -> Tuple[torch.Tensor, torch.Tensor | None]: + + pixel_values = kwargs.pop("pixel_values", None) + if not isinstance(pixel_values, torch.Tensor): + raise ValueError(f"Incorrect type of pixel_values. " + f"Got type: {type(pixel_values)}") + pixel_values = torch.unbind(pixel_values, dim=0)[0] + + location_coords = kwargs.pop("location_coords", None) + if not isinstance(location_coords, torch.Tensor): + raise ValueError(f"Incorrect type of location_coords. " + f"Got type: {type(location_coords)}") + location_coords = torch.unbind(location_coords, dim=0)[0] + if location_coords.shape == torch.Size([0]): + location_coords = None + + return pixel_values, location_coords + + def forward( + self, + input_ids: Optional[torch.Tensor], + positions: torch.Tensor, + kv_caches: List[torch.Tensor], + attn_metadata: AttentionMetadata, + intermediate_tensors: Optional[IntermediateTensors] = None, + inputs_embeds: Optional[torch.Tensor] = None, + **kwargs: object, + ): + + pixel_values, location_coords = ( + self._parse_and_validate_multimodal_data(**kwargs)) + model_output = self.model(pixel_values, + location_coords=location_coords) + + return model_output.output + + def pooler( + self, + hidden_states: torch.Tensor, + pooling_metadata: PoolingMetadata, + ) -> Optional[PoolerOutput]: + return PoolerOutput([PoolingSequenceGroupOutput(hidden_states)]) + + def load_weights(self, weights: Iterable[Tuple[str, + torch.Tensor]]) -> Set[str]: + params_list = [] + model_buffers = dict(self.named_buffers()) + loaded_buffers = [] + for key, value in weights: + if key == "state_dict": + weights_to_parse = value + for name, weight in weights_to_parse.items(): + if "pos_embed" in name: + continue + + if "_timm_module." in name: + name = name.replace("_timm_module.", "") + + # this model requires a couple of buffers to be loaded + # that are not loadable with the AutoWeightsLoader + if name in model_buffers: + if "_timm_module." in name: + name = name.replace("_timm_module.", "") + buffer = model_buffers[name] + weight_loader = getattr(buffer, "weight_loader", + default_weight_loader) + weight_loader(buffer, weight) + loaded_buffers.append(name) + else: + params_list.append((name, weight)) + break + + # Load the remaining model parameters + loader = AutoWeightsLoader(self) + autoloaded_weights = loader.load_weights(params_list) + + return autoloaded_weights.union(set(loaded_buffers)) diff --git a/vllm/model_executor/models/registry.py b/vllm/model_executor/models/registry.py index 3b2a7069efc..13b8e90020c 100644 --- a/vllm/model_executor/models/registry.py +++ b/vllm/model_executor/models/registry.py @@ -136,6 +136,10 @@ "Qwen2VLForConditionalGeneration": ("qwen2_vl", "Qwen2VLForConditionalGeneration"), # noqa: E501 # [Auto-converted (see adapters.py)] "Qwen2ForSequenceClassification": ("qwen2", "Qwen2ForCausalLM"), + # Technically PrithviGeoSpatialMAE is a model that works on images, both in + # input and output. I am adding it here because it piggy-backs on embedding + # models for the time being. + "PrithviGeoSpatialMAE": ("prithvi_geospatial_mae", "PrithviGeoSpatialMAE"), } _CROSS_ENCODER_MODELS = { diff --git a/vllm/worker/pooling_model_runner.py b/vllm/worker/pooling_model_runner.py index f43085b0e96..4cbe5db4453 100644 --- a/vllm/worker/pooling_model_runner.py +++ b/vllm/worker/pooling_model_runner.py @@ -74,7 +74,16 @@ def execute_model( prefill_meta = model_input.attn_metadata.prefill_metadata decode_meta = model_input.attn_metadata.decode_metadata virtual_engine = model_input.virtual_engine - if prefill_meta is None and decode_meta.use_cuda_graph: + # Pooling models are (ab-)used also to integrate non text models that + # are not autoregressive (PrithviGeosaptialMAE). + # These model might not use attention and do not really have a prefill + # and decode phase. The model input is processed in one shot and both + # decode_metadata and prefill_metadata would be None for such models. + # See the PlaceholderAttentionMetadata class. + # TODO: Figure out if cuda_graph is of any use for these models and + # explore how to leverage it. + if (prefill_meta is None and decode_meta is not None + and decode_meta.use_cuda_graph): assert model_input.input_tokens is not None graph_batch_size = model_input.input_tokens.shape[0] model_executable = self.graph_runners[virtual_engine][