Skip to content

[Bugfix][Model] fix mllama multi-image #14883

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 4 commits into from
Apr 1, 2025
Merged

Conversation

yma11
Copy link
Contributor

@yma11 yma11 commented Mar 16, 2025

FIX #14551

Copy link

👋 Hi! Thank you for contributing to the vLLM project.

💬 Join our developer Slack at https://slack.vllm.ai to discuss your PR in #pr-reviews, coordinate on features in #feat- channels, or join special interest groups in #sig- channels.

Just a reminder: PRs would not trigger full CI run by default. Instead, it would only run fastcheck CI which starts running only a small and essential subset of CI tests to quickly catch errors. You can run other CI tests on top of those by going to your fastcheck build on Buildkite UI (linked in the PR checks section) and unblock them. If you do not have permission to unblock, ping simon-mo or khluu to add you in our Buildkite org.

Once the PR is approved and ready to go, your PR reviewer(s) can run CI to test the changes comprehensively before merging.

To run CI, PR reviewers can either: Add ready label to the PR or enable auto-merge.

🚀

img = pixel_values_unpacked[b][i]
out_images[b, i, :img.shape[0]] = img
return out_images

def _parse_and_validate_image_input(self, **kwargs: object):
# tensor with the same shape will be batched together by
# MultiModalKwargs.batch, so pixel_values here can be:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@DarkLight1337 @ywang96 I think it is a common problem that images have different sizes and we need to pad them from list of tensors with different shape to one tensor. (See these code comments for details). Is there any utility functions for this in vLLM?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is usually done as part of the HF processor. If the HF processor doesn't do this, you can apply it manually like in Pixtral-HF.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks! Given that, I think it is OK to implement the unpacking in mllama.py.

Copy link
Collaborator

@heheda12345 heheda12345 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for the fix! I left some comments. And do you know why test_models_single_leading_image can pass without unpacking the data given that it contains a test with different number of tiles?

        # Multi-size, batched, including text only
        [(512, 512), (1024, 512), (1536, 512), (2048, 512), (512, 1024),
         (1024, 1024), (512, 1536), (512, 2028), None],)

img = pixel_values_unpacked[b][i]
out_images[b, i, :img.shape[0]] = img
return out_images

def _parse_and_validate_image_input(self, **kwargs: object):
# tensor with the same shape will be batched together by
# MultiModalKwargs.batch, so pixel_values here can be:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks! Given that, I think it is OK to implement the unpacking in mllama.py.

elif isinstance(image_data[0], torch.Tensor):
bsz = len(image_data)
# List[torch.Tensor]
if image_data[0].ndim == 1:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Some questions:

  1. Can you merge the code path of image_data[0].ndim == 1 and image_data[0].ndim == 2? I think it can be achieved by something like:
def unpack_data(tensor_list, padding_value=0):
    batch_size = len(tensor_list)
    max_length = max(t.size(0) for t in tensor_list)
    trailing_dims = tensor_list[0].shape[1:]
    padded_tensor = torch.full((batch_size, max_length, *trailing_dims),
                               padding_value,
                               dtype=tensor_list[0].dtype,
                               device=tensor_list[0].device)
    for i, t in enumerate(tensor_list):
        padded_tensor[i, :t.size(0)] = t
    return padded_tensor
  1. image_data is with type Union[List[List[torch.Tensor]], List[torch.Tensor], torch.Tensor]. But seems that only List[torch.Tensor] and torch.Tensor are handled.
  2. Can you add the return type annotation to the function signature?

@yma11
Copy link
Contributor Author

yma11 commented Mar 21, 2025

Thanks for the fix! I left some comments. And do you know why test_models_single_leading_image can pass without unpacking the data given that it contains a test with different number of tiles?

        # Multi-size, batched, including text only
        [(512, 512), (1024, 512), (1536, 512), (2048, 512), (512, 1024),
         (1024, 1024), (512, 1536), (512, 2028), None],)

Seems the data is already in a tensor shape, like:

aspect_ratio_ids: tensor([[6, 6],
        [1, 6]], device='cuda:0')
aspect_ratio_mask: tensor([[[1, 1, 1, 1],
         [1, 1, 1, 1]],

        [[1, 0, 0, 0],
         [1, 1, 1, 1]]], device='cuda:0')

aspect_ratio_ids: tensor([[6, 7, 5]], device='cuda:0')
aspect_ratio_mask: tensor([[[1, 1, 1, 1],
         [1, 1, 1, 0],
         [1, 1, 0, 0]]], device='cuda:0')

So maybe we can assume there will be no type like List[List[torch.Tensor]] for these datas? At least, I never triggered this path so have no idea what kind of padding should be correct.

@heheda12345
Copy link
Collaborator

https://github.com/huggingface/transformers/blob/c9d1e5238a752813ba91a8751a638a09b5efbb73/src/transformers/models/mllama/image_processing_mllama.py#L767-L770
I just notice that num_tiles is always padded to 4 regardless of the real max_num_tiles of images inside the request (see the above code). Therefore, the code path for List[List[torch.Tensor]] should never be triggered.

@yma11 Can you help to do the following things?

  1. Simplify the code to only handle torch.Tensor & List[torch.Tensor] and add assert for it is not List[List[torch.Tensor]]
  2. merge unpack_data and unpack_pixel_values to one function, which should be possible now because we no longer need to do padding over the num_tiles dimension.
  3. for List[torch.Tensor] cases, verify that the trailing_dims are the same for all Tensors.

@yma11
Copy link
Contributor Author

yma11 commented Mar 24, 2025

https://github.com/huggingface/transformers/blob/c9d1e5238a752813ba91a8751a638a09b5efbb73/src/transformers/models/mllama/image_processing_mllama.py#L767-L770 I just notice that num_tiles is always padded to 4 regardless of the real max_num_tiles of images inside the request (see the above code). Therefore, the code path for List[List[torch.Tensor]] should never be triggered.

@yma11 Can you help to do the following things?

  1. Simplify the code to only handle torch.Tensor & List[torch.Tensor] and add assert for it is not List[List[torch.Tensor]]
  2. merge unpack_data and unpack_pixel_values to one function, which should be possible now because we no longer need to do padding over the num_tiles dimension.
  3. for List[torch.Tensor] cases, verify that the trailing_dims are the same for all Tensors.

Updated.

@heheda12345
Copy link
Collaborator

The code is quite clean now! Can you fix the unit tests in tests/models/encoder_decoder/vision_language/test_mllama.py?

@yma11
Copy link
Contributor Author

yma11 commented Mar 25, 2025

The code is quite clean now! Can you fix the unit tests in tests/models/encoder_decoder/vision_language/test_mllama.py?

do you mean unit test failure? I didn't observe it, any link?. Or you mean revert changes in the UT?
image

@heheda12345
Copy link
Collaborator

The unit tests in this file fails on my local environment.
image

@yma11
Copy link
Contributor Author

yma11 commented Mar 26, 2025

The unit tests in this file fails on my local environment. image

Oh see, I can reproduce this issue in my local env but really have no clue about this. with export CUDA_LAUNCH_BLOCKING=1, I got log like following:

_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _

self = <vllm.worker.enc_dec_model_runner.EncoderDecoderModelRunner object at 0x7f19f88f04c0>
model_input = <[RuntimeError('CUDA error: an illegal memory access was encountered\nCompile with `TORCH_USE_CUDA_DSA` to enable device-side assertions.\n') raised in repr()] EncoderDecoderModelInput object at 0x7f19cd9ef220>
kv_caches = [<[RuntimeError('CUDA error: an illegal memory access was encountered\nCompile with `TORCH_USE_CUDA_DSA` to enable dev...ith `TORCH_USE_CUDA_DSA` to enable device-side assertions.\n') raised in repr()] Tensor object at 0x7f18abffb180>, ...]
intermediate_tensors = None, num_steps = 1

    @torch.inference_mode()
    def execute_model(
        self,
        model_input: EncoderDecoderModelInput,
        kv_caches: List[torch.Tensor],
        intermediate_tensors: Optional[IntermediateTensors] = None,
        num_steps: int = 1,
    ) -> Optional[List[PoolerOutput]]:
        if num_steps > 1:
            raise ValueError("num_steps > 1 is not supported in "
                             "EncoderDecoderModelRunner")

        if (model_input.attn_metadata is not None
                and model_input.attn_metadata.prefill_metadata is None
                and model_input.attn_metadata.decode_metadata.use_cuda_graph):
            assert model_input.input_tokens is not None
            graph_batch_size = model_input.input_tokens.shape[0]
            model_executable = self.graph_runners[
                model_input.virtual_engine][graph_batch_size]
        else:
            model_executable = self.model

        seqlen_agnostic_kwargs = {
            "finished_requests_ids": model_input.finished_requests_ids,
            "request_ids_to_seq_ids": model_input.request_ids_to_seq_ids,
        } if self.has_inner_state else {}

        multi_modal_kwargs = model_input.multi_modal_kwargs or {}
        with set_forward_context(model_input.attn_metadata, self.vllm_config,
                                 model_input.virtual_engine):
>           hidden_or_intermediate_states = model_executable(
                input_ids=model_input.input_tokens,
                positions=model_input.input_positions,
                encoder_input_ids=model_input.encoder_input_tokens,
                encoder_positions=model_input.encoder_input_positions,
                intermediate_tensors=intermediate_tensors,
                **MultiModalKwargs.as_kwargs(multi_modal_kwargs,
                                             device=self.device),
                **seqlen_agnostic_kwargs)

Do you have any insights?

@heheda12345
Copy link
Collaborator

I think you can run with enforce_eager=True to disable cuda graph. Then, export CUDA_LAUNCH_BLOCKING=1 should help you to find the line that crashes.

@yma11
Copy link
Contributor Author

yma11 commented Mar 29, 2025

I think you can run with enforce_eager=True to disable cuda graph. Then, export CUDA_LAUNCH_BLOCKING=1 should help you to find the line that crashes.

@heheda12345 With such settings, it hints the error happens at operation on q and passing a continuous Tensor like this will work as verified in my env. But curious why memory of these tensors become inaccessible after passing to method _attention_with_mask? Please help merge this PR if you think the change is okay.

@heheda12345
Copy link
Collaborator

Thanks for your information. After some debugging, I find that the crash comes from torch.ops._C_cache_ops.reshape_and_cache_flash and PagedAttention.write_to_paged_cache. The reason is that kv_range_for_decode and attn_metadata do not match. This bug has been fixed by #15564. Can you wait unit that PR is merged into main branch and rebase yours? After manually merging this PR, I can pass all test without the added .contiguous() in my local environment.

@yma11
Copy link
Contributor Author

yma11 commented Mar 31, 2025

verified and rebased.

Copy link
Collaborator

@heheda12345 heheda12345 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM! All tests are passed in my local environment and the code is very clean now.

@heheda12345
Copy link
Collaborator

@yma11 Can you fix the DCO failure?

@yma11
Copy link
Contributor Author

yma11 commented Mar 31, 2025

@yma11 Can you fix the DCO failure?

@heheda12345 Done. But seems some OOM in V1 test not related with this PR.

@heheda12345 heheda12345 enabled auto-merge (squash) March 31, 2025 05:33
@github-actions github-actions bot added the ready ONLY add when PR is ready to merge/full CI is needed label Mar 31, 2025
yma11 added 4 commits March 31, 2025 16:29
Signed-off-by: yan ma <[email protected]>
Signed-off-by: yan ma <[email protected]>
This reverts commit 8f9a1ce.

Signed-off-by: yan ma <[email protected]>
@vllm-bot vllm-bot merged commit ff64739 into vllm-project:main Apr 1, 2025
41 of 43 checks passed
kylesayrs pushed a commit to neuralmagic/vllm that referenced this pull request Apr 2, 2025
Alex4210987 pushed a commit to LeiWang1999/vllm-bitblas that referenced this pull request Apr 5, 2025
lulmer pushed a commit to lulmer/vllm that referenced this pull request Apr 7, 2025
nishith-fujitsu pushed a commit to nishith-fujitsu/vllm that referenced this pull request Apr 9, 2025
zhouyu5 pushed a commit to HabanaAI/vllm-fork that referenced this pull request Apr 10, 2025
bnellnm pushed a commit to neuralmagic/vllm that referenced this pull request Apr 29, 2025
lk-chen pushed a commit to lk-chen/vllm that referenced this pull request Apr 29, 2025
bnellnm pushed a commit to neuralmagic/vllm that referenced this pull request Apr 30, 2025
shreyankg pushed a commit to shreyankg/vllm that referenced this pull request May 3, 2025
bnellnm pushed a commit to neuralmagic/vllm that referenced this pull request May 7, 2025
bnellnm pushed a commit to neuralmagic/vllm that referenced this pull request May 7, 2025
bnellnm pushed a commit to neuralmagic/vllm that referenced this pull request May 9, 2025
RichardoMrMu pushed a commit to RichardoMrMu/vllm that referenced this pull request May 12, 2025
bnellnm pushed a commit to neuralmagic/vllm that referenced this pull request May 12, 2025
bnellnm pushed a commit to neuralmagic/vllm that referenced this pull request May 13, 2025
bnellnm pushed a commit to neuralmagic/vllm that referenced this pull request May 13, 2025
bnellnm pushed a commit to neuralmagic/vllm that referenced this pull request May 14, 2025
@yma11 yma11 deleted the mllama-fix branch May 27, 2025 02:06
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
ready ONLY add when PR is ready to merge/full CI is needed
Projects
None yet
Development

Successfully merging this pull request may close these issues.

[Bug]: Requests with different num_images can't be proceeded by Llama3.2
4 participants