Skip to content

Commit c503c8c

Browse files
author
reidliu41
committed
Merge remote-tracking branch 'upstream/main' into fix-broken-links
2 parents 10c72be + 44073a7 commit c503c8c

File tree

21 files changed

+643
-111
lines changed

21 files changed

+643
-111
lines changed

docs/models/supported_models.md

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -527,7 +527,7 @@ Specified using `--task generate`.
527527
| `GraniteSpeechForConditionalGeneration` | Granite Speech | T + A | `ibm-granite/granite-speech-3.3-8b` | ✅︎ | ✅︎ | ✅︎ |
528528
| `H2OVLChatModel` | H2OVL | T + I<sup>E+</sup> | `h2oai/h2ovl-mississippi-800m`, `h2oai/h2ovl-mississippi-2b`, etc. | ✅︎ | ✅︎\* | |
529529
| `Idefics3ForConditionalGeneration` | Idefics3 | T + I | `HuggingFaceM4/Idefics3-8B-Llama3` etc. | ✅︎ | ✅︎ | |
530-
| `InternVLChatModel` | InternVL 3.0, InternVideo 2.5, InternVL 2.5, Mono-InternVL, InternVL 2.0 | T + I<sup>E+</sup> | `OpenGVLab/InternVL3-9B`, `OpenGVLab/InternVideo2_5_Chat_8B`, `OpenGVLab/InternVL2_5-4B`, `OpenGVLab/Mono-InternVL-2B`, `OpenGVLab/InternVL2-4B`, etc. | ✅︎ | ✅︎ | |
530+
| `InternVLChatModel` | InternVL 3.0, InternVideo 2.5, InternVL 2.5, Mono-InternVL, InternVL 2.0 | T + I<sup>E+</sup> + (V<sup>E+</sup>) | `OpenGVLab/InternVL3-9B`, `OpenGVLab/InternVideo2_5_Chat_8B`, `OpenGVLab/InternVL2_5-4B`, `OpenGVLab/Mono-InternVL-2B`, `OpenGVLab/InternVL2-4B`, etc. | ✅︎ | ✅︎ | |
531531
| `KimiVLForConditionalGeneration` | Kimi-VL-A3B-Instruct, Kimi-VL-A3B-Thinking | T + I<sup>+</sup> | `moonshotai/Kimi-VL-A3B-Instruct`, `moonshotai/Kimi-VL-A3B-Thinking` | ✅︎ | | |
532532
| `Llama4ForConditionalGeneration` | Llama 4 | T + I<sup>+</sup> | `meta-llama/Llama-4-Scout-17B-16E-Instruct`, `meta-llama/Llama-4-Maverick-17B-128E-Instruct-FP8`, `meta-llama/Llama-4-Maverick-17B-128E-Instruct`, etc. | ✅︎ | ✅︎ | |
533533
| `LlavaForConditionalGeneration` | LLaVA-1.5 | T + I<sup>E+</sup> | `llava-hf/llava-1.5-7b-hf`, `TIGER-Lab/Mantis-8B-siglip-llama3` (see note), etc. | ✅︎ | ✅︎ | |
@@ -577,6 +577,9 @@ Specified using `--task generate`.
577577

578578
This limitation exists because the model's mixed attention pattern (bidirectional for images, causal otherwise) is not yet supported by vLLM's attention backends.
579579

580+
!!! note
581+
Only `InternVLChatModel` with Qwen2.5 text backbone (`OpenGVLab/InternVL3-2B`, `OpenGVLab/InternVL2.5-1B` etc) has video inputs support currently.
582+
580583
!!! note
581584
`h2oai/h2ovl-mississippi-2b` will be available in V1 once we support head size 80.
582585

examples/offline_inference/vision_language.py

Lines changed: 11 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -330,22 +330,26 @@ def run_smolvlm(questions: list[str], modality: str) -> ModelRequestData:
330330

331331
# InternVL
332332
def run_internvl(questions: list[str], modality: str) -> ModelRequestData:
333-
assert modality == "image"
334333

335-
model_name = "OpenGVLab/InternVL2-2B"
334+
model_name = "OpenGVLab/InternVL3-2B"
336335

337336
engine_args = EngineArgs(
338337
model=model_name,
339338
trust_remote_code=True,
340-
max_model_len=4096,
339+
max_model_len=8192,
341340
limit_mm_per_prompt={modality: 1},
342341
)
343342

343+
if modality == "image":
344+
placeholder = "<image>"
345+
elif modality == "video":
346+
placeholder = "<video>"
347+
344348
tokenizer = AutoTokenizer.from_pretrained(model_name,
345349
trust_remote_code=True)
346350
messages = [[{
347351
'role': 'user',
348-
'content': f"<image>\n{question}"
352+
'content': f"{placeholder}\n{question}"
349353
}] for question in questions]
350354
prompts = tokenizer.apply_chat_template(messages,
351355
tokenize=False,
@@ -357,6 +361,9 @@ def run_internvl(questions: list[str], modality: str) -> ModelRequestData:
357361
# https://huggingface.co/OpenGVLab/InternVL2-2B/blob/main/conversation.py
358362
stop_tokens = ["<|endoftext|>", "<|im_start|>", "<|im_end|>", "<|end|>"]
359363
stop_token_ids = [tokenizer.convert_tokens_to_ids(i) for i in stop_tokens]
364+
stop_token_ids = [
365+
token_id for token_id in stop_token_ids if token_id is not None
366+
]
360367

361368
return ModelRequestData(
362369
engine_args=engine_args,

tests/kernels/quantization/test_block_fp8.py

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -36,16 +36,16 @@
3636

3737
# Test configurations
3838
DTYPES = [torch.bfloat16] # [torch.half, torch.bfloat16, torch.float32]
39-
NUM_TOKENS = [7, 83, 2048]
39+
NUM_TOKENS = [7, 2050]
4040
D = [512, 4096, 5120, 13824]
41-
GROUP_SIZE = [64, 128, 256, 512]
42-
M = [1, 7, 8, 83, 84, 512, 2048, 4096]
43-
N = [128, 512, 1024, 4096, 7168, 7748, 13824]
44-
K = [256, 4096, 5120, 3884, 13824, 16384]
41+
GROUP_SIZE = [64, 128, 512]
42+
M = [1, 7, 8, 83, 84, 4096]
43+
N = [128, 512, 7168, 7748, 13824]
44+
K = [256, 3884, 4096, 13824, 16384]
4545
# Deepseek-V3's intermediate size 18432, so N is 18432*2/8=4608 at TP8
4646
# and its hidden size is 7168.
47-
M_moe = [1, 2, 7, 83, 128, 512, 2048]
48-
M_moe_dg = [128, 192, 512, 1335, 2048]
47+
M_moe = [1, 2, 7, 83, 128, 2048]
48+
M_moe_dg = [128, 192, 1335, 2048]
4949
N_moe = [128, 256, 1024, 4608] # [13824]
5050
K_moe = [256, 512, 7168] # [13824]
5151
BLOCK_SIZE = [[128, 128]]

tests/kernels/quantization/test_gguf.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -35,11 +35,11 @@ def get_gguf_MoE_tensors(
3535
return GGUFReader(sample_file).tensors
3636

3737

38-
DTYPES = [torch.half, torch.bfloat16, torch.float32]
38+
DTYPES = [torch.bfloat16] # [torch.half, torch.bfloat16, torch.float32]
3939
# Hidden_size for testing, must match the sample file in HF repo,
4040
# we have `hidden_size = 256, 1024` for test in HF repo currently.
4141
HIDDEN_SIZES = [256, 1024]
42-
NUM_TOKENS = [7, 83, 128, 2048] # Arbitrary values for testing
42+
NUM_TOKENS = [7, 2050] # Arbitrary values for testing
4343
SEEDS = [0]
4444
QUANT_TYPES = [
4545
# i-matrix

tests/kernels/quantization/test_triton_scaled_mm.py

Lines changed: 8 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -13,8 +13,13 @@
1313

1414
device = "cuda"
1515

16+
triton_scaled_mm_module = importlib.import_module(
17+
"vllm.model_executor.layers.quantization.compressed_tensors."
18+
"triton_scaled_mm")
19+
triton_scaled_mm = triton_scaled_mm_module.triton_scaled_mm
1620

17-
def scaled_mm_torch(a: torch.Tensor,
21+
22+
def torch_scaled_mm(a: torch.Tensor,
1823
b: torch.Tensor,
1924
scale_a: torch.Tensor,
2025
scale_b: torch.Tensor,
@@ -101,21 +106,8 @@ def test_scaled_mm(M, N, K, in_dtype, out_dtype, use_scalar_scale_a,
101106
if use_bias:
102107
bias = torch.rand((N, ), device=device, dtype=out_dtype)
103108

104-
triton_scaled_mm_module = importlib.import_module(
105-
"vllm.model_executor.layers.quantization.compressed_tensors."
106-
"triton_scaled_mm")
107-
triton_scaled_mm = triton_scaled_mm_module.triton_scaled_mm
108-
109109
c_check = triton_scaled_mm(a, b, scale_a, scale_b, out_dtype, bias)
110110

111-
a_cpu = a.cpu()
112-
b_cpu = b.cpu()
113-
scale_a_cpu = scale_a.cpu()
114-
scale_b_cpu = scale_b.cpu()
115-
bias_cpu = None if bias is None else bias.cpu()
116-
117-
c_actual = scaled_mm_torch(a_cpu, b_cpu, scale_a_cpu, scale_b_cpu,
118-
out_dtype, bias_cpu)
111+
c_actual = torch_scaled_mm(a, b, scale_a, scale_b, out_dtype, bias)
119112

120-
c_check_cpu = c_check.cpu()
121-
torch.testing.assert_close(c_check_cpu, c_actual, rtol=1e-1, atol=1e-1)
113+
torch.testing.assert_close(c_check, c_actual, rtol=1e-1, atol=1e-1)

tests/models/multimodal/generation/test_common.py

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -349,6 +349,17 @@
349349
use_tokenizer_eos=True,
350350
patch_hf_runner=model_utils.internvl_patch_hf_runner,
351351
),
352+
"intern_vl-video": VLMTestInfo(
353+
models=[
354+
"OpenGVLab/InternVL3-1B",
355+
],
356+
test_type=VLMTestType.VIDEO,
357+
prompt_formatter=lambda img_prompt: f"<|im_start|>User\n{img_prompt}<|im_end|>\n<|im_start|>Assistant\n", # noqa: E501
358+
video_idx_to_prompt=lambda idx: "<video>",
359+
max_model_len=8192,
360+
use_tokenizer_eos=True,
361+
patch_hf_runner=model_utils.internvl_patch_hf_runner,
362+
),
352363
"kimi_vl": VLMTestInfo(
353364
models=["moonshotai/Kimi-VL-A3B-Instruct"],
354365
test_type=(VLMTestType.IMAGE, VLMTestType.MULTI_IMAGE),

tests/models/multimodal/generation/vlm_utils/model_utils.py

Lines changed: 66 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,8 @@
77
from pathlib import PosixPath
88
from typing import Optional, Union
99

10+
import numpy as np
11+
import numpy.typing as npt
1012
import regex as re
1113
import torch
1214
from PIL.Image import Image
@@ -495,30 +497,74 @@ def __init__(self, hf_runner: HfRunner):
495497
self.max_num = self.config.max_dynamic_patch
496498
self.image_size = self.vision_config.image_size
497499

498-
def __call__(self, text: str, images: Union[Image, list[Image]],
499-
**kwargs):
500+
def __call__(
501+
self,
502+
text: str,
503+
images: Union[Image, list[Image]] = None,
504+
videos: Union[npt.NDArray, list[npt.NDArray]] = None,
505+
**kwargs,
506+
):
500507
from vllm.model_executor.models.internvl import (
501508
IMG_CONTEXT, IMG_END, IMG_START,
502-
image_to_pixel_values_internvl)
509+
image_to_pixel_values_internvl, video_to_pixel_values_internvl)
503510
images = [images] if isinstance(images, Image) else images
504-
pixel_values = [
505-
image_to_pixel_values_internvl(
506-
image,
507-
input_size=self.image_size,
508-
min_num=self.min_num,
509-
max_num=self.max_num,
510-
use_thumbnail=self.use_thumbnail,
511-
) for image in images
512-
]
513-
num_patches_list = [
514-
pixel_value.shape[0] for pixel_value in pixel_values
515-
]
511+
videos = [videos] if isinstance(videos, np.ndarray) else videos
512+
if images is not None:
513+
pixel_values_images = [
514+
image_to_pixel_values_internvl(
515+
image,
516+
input_size=self.image_size,
517+
min_num=self.min_num,
518+
max_num=self.max_num,
519+
use_thumbnail=self.use_thumbnail,
520+
) for image in images
521+
]
522+
num_patches_images = [
523+
pixel_value.shape[0] for pixel_value in pixel_values_images
524+
]
525+
else:
526+
pixel_values_images, num_patches_images = [], []
527+
528+
if videos is not None:
529+
pixel_values_videos = [
530+
video_to_pixel_values_internvl(
531+
video,
532+
input_size=self.image_size,
533+
min_num=1,
534+
max_num=1,
535+
use_thumbnail=False,
536+
) for video in videos
537+
]
538+
num_patches_videos = [
539+
pixel_value.shape[0] for pixel_value in pixel_values_videos
540+
]
541+
else:
542+
pixel_values_videos, num_patches_videos = [], []
543+
544+
pixel_values = []
545+
while ("<image>" in text) or ("<video>" in text):
546+
image_index = text.find("<image>")
547+
video_index = text.find("<video>")
548+
if image_index == -1 or (video_index > -1
549+
and video_index < image_index):
550+
num_patches = num_patches_videos.pop(0)
551+
pixel_values.append(pixel_values_videos.pop(0))
552+
context_tokens = IMG_START + \
553+
IMG_CONTEXT * self.num_image_token + IMG_END
554+
video_tokens = ''.join([
555+
f'Frame{i+1}: {context_tokens}'
556+
for i in range(num_patches)
557+
])
558+
text = text.replace('<video>', video_tokens, 1)
559+
else:
560+
num_patches = num_patches_images.pop(0)
561+
pixel_values.append(pixel_values_images.pop(0))
562+
context_tokens = IMG_CONTEXT * self.num_image_token \
563+
* num_patches
564+
image_tokens = IMG_START + context_tokens + IMG_END
565+
text = text.replace('<image>', image_tokens, 1)
516566
pixel_values = torch.cat(pixel_values, dim=0)
517-
for num_patches in num_patches_list:
518-
context_tokens = IMG_CONTEXT * self.num_image_token \
519-
* num_patches
520-
image_tokens = IMG_START + context_tokens + IMG_END
521-
text = text.replace('<image>', image_tokens, 1)
567+
522568
prompt = self.tokenizer(text, return_tensors="pt")
523569
prompt.update({"pixel_values": pixel_values})
524570
return prompt

tests/models/multimodal/processing/test_common.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -258,6 +258,7 @@ def _test_processing_correctness_mistral(
258258
"ibm-granite/granite-speech-3.3-8b",
259259
"h2oai/h2ovl-mississippi-800m",
260260
"OpenGVLab/InternVL2-1B",
261+
"OpenGVLab/InternVL3-1B",
261262
"HuggingFaceM4/Idefics3-8B-Llama3",
262263
"HuggingFaceTB/SmolVLM2-2.2B-Instruct",
263264
"moonshotai/Kimi-VL-A3B-Instruct",

tests/models/registry.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -334,7 +334,8 @@ def check_available_online(
334334
max_transformers_version="4.48", # noqa: E501
335335
transformers_version_reason="HF model is not compatible."), # noqa: E501
336336
"InternVLChatModel": _HfExamplesInfo("OpenGVLab/InternVL2-1B",
337-
extras={"2B": "OpenGVLab/InternVL2-2B"}, # noqa: E501
337+
extras={"2B": "OpenGVLab/InternVL2-2B",
338+
"3.0": "OpenGVLab/InternVL3-1B"}, # noqa: E501
338339
trust_remote_code=True),
339340
"Idefics3ForConditionalGeneration": _HfExamplesInfo("HuggingFaceM4/Idefics3-8B-Llama3", # noqa: E501
340341
{"tiny": "HuggingFaceTB/SmolVLM-256M-Instruct"}), # noqa: E501

tests/test_regression.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -60,6 +60,9 @@ def test_model_from_modelscope(monkeypatch: pytest.MonkeyPatch):
6060
# model: https://modelscope.cn/models/qwen/Qwen1.5-0.5B-Chat/summary
6161
with monkeypatch.context() as m:
6262
m.setenv("VLLM_USE_MODELSCOPE", "True")
63+
# Don't use HF_TOKEN for ModelScope repos, otherwise it will fail
64+
# with 400 Client Error: Bad Request.
65+
m.setenv("HF_TOKEN", "")
6366
llm = LLM(model="qwen/Qwen1.5-0.5B-Chat")
6467

6568
prompts = [

vllm/entrypoints/chat_utils.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -556,6 +556,8 @@ def _placeholder_str(self, modality: ModalityStr,
556556
return "(<audio>./</audio>)"
557557
raise TypeError(f"Unknown model type: {model_type}")
558558
elif modality == "video":
559+
if model_type == "internvl_chat":
560+
return "<video>"
559561
if model_type in ("qwen2_vl", "qwen2_5_vl"):
560562
return "<|vision_start|><|video_pad|><|vision_end|>"
561563
if model_type == "qwen2_5_omni":

vllm/executor/ray_utils.py

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -87,9 +87,8 @@ def execute_model_spmd(
8787
# TODO(swang): This is needed right now because Ray Compiled Graph
8888
# executes on a background thread, so we need to reset torch's
8989
# current device.
90-
import torch
9190
if not self.compiled_dag_cuda_device_set:
92-
torch.cuda.set_device(self.worker.device)
91+
current_platform.set_device(self.worker.device)
9392
self.compiled_dag_cuda_device_set = True
9493

9594
output = self.worker._execute_model_spmd(execute_model_req,
@@ -113,8 +112,7 @@ def setup_device_if_necessary(self):
113112
# Not needed
114113
pass
115114
else:
116-
import torch
117-
torch.cuda.set_device(self.worker.device)
115+
current_platform.set_device(self.worker.device)
118116

119117
self.compiled_dag_cuda_device_set = True
120118

vllm/model_executor/model_loader/default_loader.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -11,8 +11,8 @@
1111
from torch import nn
1212
from transformers.utils import SAFE_WEIGHTS_INDEX_NAME
1313

14+
from vllm import envs
1415
from vllm.config import LoadConfig, LoadFormat, ModelConfig, VllmConfig
15-
from vllm.envs import VLLM_USE_MODELSCOPE
1616
from vllm.logger import init_logger
1717
from vllm.model_executor.model_loader.base_loader import BaseModelLoader
1818
from vllm.model_executor.model_loader.utils import (
@@ -64,7 +64,7 @@ def _maybe_download_from_modelscope(
6464
6565
Returns the path to the downloaded model, or None if the model is not
6666
downloaded from ModelScope."""
67-
if VLLM_USE_MODELSCOPE:
67+
if envs.VLLM_USE_MODELSCOPE:
6868
# download model from ModelScope hub,
6969
# lazy import so that modelscope is not required for normal use.
7070
# pylint: disable=C.

vllm/model_executor/model_loader/weight_utils.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -319,6 +319,7 @@ def download_safetensors_index_file_from_hf(
319319
320320
Args:
321321
model_name_or_path (str): The model name or path.
322+
index_file (str): The safetensors index file name
322323
cache_dir (Optional[str]): The cache directory to store the model
323324
weights. If None, will use HF defaults.
324325
revision (Optional[str]): The revision of the model.
@@ -337,10 +338,10 @@ def download_safetensors_index_file_from_hf(
337338
)
338339
# If file not found on remote or locally, we should not fail since
339340
# only some models will have index_file.
340-
except huggingface_hub.utils.EntryNotFoundError:
341-
logger.info("No %s found in remote.", index_file)
342341
except huggingface_hub.utils.LocalEntryNotFoundError:
343342
logger.info("No %s found in local cache.", index_file)
343+
except huggingface_hub.utils.EntryNotFoundError:
344+
logger.info("No %s found in remote.", index_file)
344345

345346

346347
# For models like Mistral-7B-v0.3, there are both sharded
@@ -634,7 +635,7 @@ def row_parallel_weight_loader(param: torch.Tensor,
634635
return default_weight_loader(param, loaded_weight)
635636

636637

637-
LoaderFunction = Callable[[torch.Tensor, torch.Tensor], torch.Tensor]
638+
LoaderFunction = Callable[[torch.Tensor, torch.Tensor], None]
638639

639640

640641
def sharded_weight_loader(shard_axis: int) -> LoaderFunction:

vllm/model_executor/models/h2ovl.py

Lines changed: 6 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -25,9 +25,10 @@
2525

2626
from .intern_vit import InternVisionModel
2727
from .internvl import (IMG_CONTEXT, IMG_END, IMG_START,
28+
BaseInternVLDummyInputsBuilder,
29+
BaseInternVLMultiModalProcessor,
2830
BaseInternVLProcessingInfo, BaseInternVLProcessor,
29-
InternVLChatModel, InternVLDummyInputsBuilder,
30-
InternVLMultiModalProcessor, build_transform,
31+
InternVLChatModel, build_transform,
3132
find_closest_aspect_ratio, get_internvl_target_ratios)
3233

3334

@@ -430,8 +431,8 @@ def get_num_image_tokens(
430431
)
431432

432433

433-
class H2OVLMultiModalProcessor(InternVLMultiModalProcessor[H2OVLProcessingInfo]
434-
):
434+
class H2OVLMultiModalProcessor(
435+
BaseInternVLMultiModalProcessor[H2OVLProcessingInfo]):
435436

436437
def _get_prompt_updates(
437438
self,
@@ -514,7 +515,7 @@ def _cached_apply_hf_processor(
514515
@MULTIMODAL_REGISTRY.register_processor(
515516
H2OVLMultiModalProcessor,
516517
info=H2OVLProcessingInfo,
517-
dummy_inputs=InternVLDummyInputsBuilder)
518+
dummy_inputs=BaseInternVLDummyInputsBuilder)
518519
class H2OVLChatModel(InternVLChatModel):
519520

520521
def _init_vision_model(

0 commit comments

Comments
 (0)