Skip to content

Commit fcb3f4b

Browse files
committed
[Model] Support Skywork-R1V
Signed-off-by: jiacai.liu <[email protected]>
1 parent cec8c7d commit fcb3f4b

File tree

12 files changed

+1194
-7
lines changed

12 files changed

+1194
-7
lines changed

docs/source/models/supported_models.md

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -921,6 +921,13 @@ See [this page](#generative-models) for more information on how to use generativ
921921
* ✅︎
922922
* ✅︎
923923
* ✅︎
924+
- * `SkyworkR1VChatModel`
925+
* Skywork-R1V-38B
926+
* T + I
927+
* `Skywork/Skywork-R1V-38B`
928+
*
929+
* ✅︎
930+
* ✅︎
924931
- * `UltravoxModel`
925932
* Ultravox
926933
* T + A<sup>E+</sup>

examples/offline_inference/vision_language.py

Lines changed: 36 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -804,6 +804,41 @@ def run_qwen2_5_vl(questions: list[str], modality: str) -> ModelRequestData:
804804
)
805805

806806

807+
# SkyworkR1V
808+
def run_skyworkr1v(questions: list[str], modality: str) -> ModelRequestData:
809+
assert modality == "image"
810+
811+
model_name = "Skywork/Skywork-R1V-38B"
812+
813+
engine_args = EngineArgs(
814+
model=model_name,
815+
trust_remote_code=True,
816+
max_model_len=4096,
817+
disable_mm_preprocessor_cache=args.disable_mm_preprocessor_cache,
818+
)
819+
820+
tokenizer = AutoTokenizer.from_pretrained(model_name,
821+
trust_remote_code=True)
822+
messages = [[{
823+
'role': 'user',
824+
'content': f"<image>\n{question}"
825+
}] for question in questions]
826+
prompts = tokenizer.apply_chat_template(messages,
827+
tokenize=False,
828+
add_generation_prompt=True)
829+
830+
# Stop tokens for SkyworkR1V
831+
# https://huggingface.co/Skywork/Skywork-R1V-38B/blob/main/conversation.py
832+
stop_tokens = ["<|end▁of▁sentence|>", "<|endoftext|>"]
833+
stop_token_ids = [tokenizer.convert_tokens_to_ids(i) for i in stop_tokens]
834+
835+
return ModelRequestData(
836+
engine_args=engine_args,
837+
prompts=prompts,
838+
stop_token_ids=stop_token_ids,
839+
)
840+
841+
807842
model_example_map = {
808843
"aria": run_aria,
809844
"blip-2": run_blip2,
@@ -834,6 +869,7 @@ def run_qwen2_5_vl(questions: list[str], modality: str) -> ModelRequestData:
834869
"qwen_vl": run_qwen_vl,
835870
"qwen2_vl": run_qwen2_vl,
836871
"qwen2_5_vl": run_qwen2_5_vl,
872+
"skywork_chat": run_skyworkr1v,
837873
}
838874

839875

tests/models/decoder_only/vision_language/test_models.py

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -474,6 +474,20 @@
474474
vllm_output_post_proc=model_utils.qwen_vllm_to_hf_output,
475475
prompt_path_encoder=model_utils.qwen_prompt_path_encoder,
476476
),
477+
"skywork_r1v": VLMTestInfo(
478+
models=["Skywork/Skywork-R1V-38B"],
479+
test_type=(VLMTestType.IMAGE, VLMTestType.MULTI_IMAGE),
480+
prompt_formatter=lambda img_prompt: f"<|begin▁of▁sentence|><|User|>\n{img_prompt}<|Assistant|><think>\n", # noqa: E501
481+
single_image_prompts=IMAGE_ASSETS.prompts({
482+
"stop_sign": "<image>\nWhat's the content in the center of the image?", # noqa: E501
483+
"cherry_blossom": "<image>\nWhat is the season?",
484+
}),
485+
multi_image_prompt="<image>\n<image>\nDescribe the two images in short.", # noqa: E501
486+
max_model_len=4096,
487+
use_tokenizer_eos=True,
488+
patch_hf_runner=model_utils.skyworkr1v_patch_hf_runner,
489+
marks=[large_gpu_mark(min_gb=80)],
490+
),
477491
### Tensor parallel / multi-gpu broadcast tests
478492
"chameleon-broadcast": VLMTestInfo(
479493
models=["facebook/chameleon-7b"],

tests/models/decoder_only/vision_language/vlm_utils/model_utils.py

Lines changed: 57 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -376,6 +376,63 @@ def __call__(self, text: str, images: Union[Image, list[Image]],
376376
return hf_model
377377

378378

379+
def skyworkr1v_patch_hf_runner(hf_model: HfRunner) -> HfRunner:
380+
"""Patches and returns an instance of the HfRunner to use for SkyworkR1V."""
381+
382+
class SkyworkR1VProcessor:
383+
"""A simple processor for SkyworkR1V."""
384+
385+
def __init__(self, hf_runner: HfRunner):
386+
self.num_image_token = hf_runner.model.num_image_token
387+
self.tokenizer = hf_runner.tokenizer
388+
389+
self.config = AutoConfig.from_pretrained(hf_runner.model_name,
390+
trust_remote_code=True)
391+
self.vision_config = self.config.vision_config
392+
self.use_thumbnail = self.config.use_thumbnail
393+
self.min_num = self.config.min_dynamic_patch
394+
self.max_num = self.config.max_dynamic_patch
395+
self.image_size = self.vision_config.image_size
396+
397+
def __call__(self, text: str, images: Union[Image, list[Image]],
398+
**kwargs):
399+
from vllm.model_executor.models.skyworkr1v import (
400+
IMG_CONTEXT, IMG_END, IMG_START,
401+
image_to_pixel_values_skyworkr1v)
402+
images = [images] if isinstance(images, Image) else images
403+
pixel_values = [
404+
image_to_pixel_values_skyworkr1v(
405+
image,
406+
input_size=self.image_size,
407+
min_num=self.min_num,
408+
max_num=self.max_num,
409+
use_thumbnail=self.use_thumbnail,
410+
) for image in images
411+
]
412+
num_patches_list = [
413+
pixel_value.shape[0] for pixel_value in pixel_values
414+
]
415+
pixel_values = torch.cat(pixel_values, dim=0)
416+
for num_patches in num_patches_list:
417+
context_tokens = IMG_CONTEXT * self.num_image_token \
418+
* num_patches
419+
image_tokens = IMG_START + context_tokens + IMG_END
420+
text = text.replace('<image>', image_tokens, 1)
421+
prompt = self.tokenizer(text, return_tensors="pt")
422+
prompt.update({"pixel_values": pixel_values})
423+
return prompt
424+
425+
img_context_token_id = hf_model.tokenizer.convert_tokens_to_ids(
426+
"<IMG_CONTEXT>")
427+
hf_model.model.img_context_token_id = img_context_token_id
428+
hf_model.processor = SkyworkR1VProcessor(hf_model)
429+
hf_model.model.get_output_embeddings = lambda: \
430+
hf_model.model.language_model.get_output_embeddings()
431+
hf_model.model.generate = types.MethodType(_internvl_generate,
432+
hf_model.model)
433+
return hf_model
434+
435+
379436
def internvl_patch_hf_runner(hf_model: HfRunner) -> HfRunner:
380437
"""Patches and returns an instance of the HfRunner to use for InternVL."""
381438

tests/models/multimodal/processing/test_common.py

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -262,22 +262,23 @@ def _test_processing_correctness_mistral(
262262
"llava-hf/llava-onevision-qwen2-0.5b-ov-hf",
263263
"meta-llama/Llama-3.2-11B-Vision-Instruct",
264264
"TIGER-Lab/Mantis-8B-siglip-llama3",
265-
"mistralai/Pixtral-12B-2409",
266-
"mistral-community/pixtral-12b",
267265
"openbmb/MiniCPM-Llama3-V-2_5",
268266
"openbmb/MiniCPM-o-2_6",
269267
"openbmb/MiniCPM-V-2_6",
270268
"allenai/Molmo-7B-D-0924",
271269
"allenai/Molmo-7B-O-0924",
272270
"nvidia/NVLM-D-72B",
271+
"google/paligemma-3b-mix-224",
272+
"google/paligemma2-3b-ft-docci-448",
273+
"mistralai/Pixtral-12B-2409",
274+
"mistral-community/pixtral-12b",
273275
"Qwen/Qwen-VL-Chat",
274276
"Qwen/Qwen2-VL-2B-Instruct",
275277
"Qwen/Qwen2.5-VL-3B-Instruct",
276278
"Qwen/Qwen2-Audio-7B-Instruct",
279+
"Skywork/Skywork-R1V-38B",
277280
"fixie-ai/ultravox-v0_5-llama-3_2-1b",
278281
"openai/whisper-large-v3",
279-
"google/paligemma-3b-mix-224",
280-
"google/paligemma2-3b-ft-docci-448",
281282
])
282283
@pytest.mark.parametrize("hit_rate", [0.3, 0.5, 1.0])
283284
@pytest.mark.parametrize("num_batches", [32])

tests/models/registry.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -294,6 +294,7 @@ def check_available_online(
294294
"Qwen2VLForConditionalGeneration": _HfExamplesInfo("Qwen/Qwen2-VL-2B-Instruct"), # noqa: E501
295295
"Qwen2_5_VLForConditionalGeneration": _HfExamplesInfo("Qwen/Qwen2.5-VL-3B-Instruct", # noqa: E501
296296
min_transformers_version="4.49"), # noqa: E501
297+
"SkyworkR1VChatModel": _HfExamplesInfo("Skywork/Skywork-R1V-38B"),
297298
"UltravoxModel": _HfExamplesInfo("fixie-ai/ultravox-v0_5-llama-3_2-1b", # noqa: E501
298299
trust_remote_code=True),
299300
# [Encoder-decoder]

vllm/entrypoints/chat_utils.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -496,7 +496,7 @@ def _placeholder_str(self, modality: ModalityStr,
496496
return self._cached_token_str(self._tokenizer,
497497
hf_config.image_token_index)
498498
if model_type in ("chameleon", "deepseek_vl_v2", "internvl_chat",
499-
"NVLM_D", "h2ovl_chat"):
499+
"skywork_chat", "NVLM_D", "h2ovl_chat"):
500500
return "<image>"
501501
if model_type == "mllama":
502502
return "<|image|>"

vllm/model_executor/models/registry.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -190,6 +190,7 @@
190190
# [Encoder-decoder]
191191
"Florence2ForConditionalGeneration": ("florence2", "Florence2ForConditionalGeneration"), # noqa: E501
192192
"MllamaForConditionalGeneration": ("mllama", "MllamaForConditionalGeneration"), # noqa: E501
193+
"SkyworkR1VChatModel": ("skyworkr1v", "SkyworkR1VChatModel"),
193194
"WhisperForConditionalGeneration": ("whisper", "WhisperForConditionalGeneration"), # noqa: E501
194195
}
195196

0 commit comments

Comments
 (0)