Skip to content

[Bugfix] Adjust mllama to regional compilation #15112

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 2 commits into from
Mar 19, 2025

Conversation

jkaniecki
Copy link
Contributor

@jkaniecki jkaniecki commented Mar 19, 2025

When trying to perform regional compilation with t.compile (compiling layers separately instead of calling t.compile on whole model) on mllama model with Gaudi devices, such an error occures:

ValueError: Unknown decoder layer type <class 'torch._dynamo.eval_frame.OptimizedModule'>
Regional compilation for Gaudi devices has been added with #13213

Cause for this issue is checking layer classes by their names inside mllama code e.g.:
if isinstance(decoder_layer, MllamaCrossAttentionDecoderLayer):

Torch.compile wraps module after compilation with torch._dynamo.eval_frame.OptimizedModule name, that's wht we see mismatch in isinstance function. To resolve that we can distinguish layers basing on self.cross_attention_layers ids - and so do proposed changes. We don't also need raise ValueError in layer instance checking as there is no option for decoder layers to be of different types than desired ones.

Copy link

👋 Hi! Thank you for contributing to the vLLM project.

💬 Join our developer Slack at https://slack.vllm.ai to discuss your PR in #pr-reviews, coordinate on features in #feat- channels, or join special interest groups in #sig- channels.

Just a reminder: PRs would not trigger full CI run by default. Instead, it would only run fastcheck CI which starts running only a small and essential subset of CI tests to quickly catch errors. You can run other CI tests on top of those by going to your fastcheck build on Buildkite UI (linked in the PR checks section) and unblock them. If you do not have permission to unblock, ping simon-mo or khluu to add you in our Buildkite org.

Once the PR is approved and ready to go, your PR reviewer(s) can run CI to test the changes comprehensively before merging.

To run CI, PR reviewers can either: Add ready label to the PR or enable auto-merge.

🚀

Signed-off-by: Jan Kaniecki <[email protected]>
@jkaniecki jkaniecki force-pushed the mllama_regional_compilation branch from 2f5b299 to 49cd230 Compare March 19, 2025 10:28
Copy link
Collaborator

@heheda12345 heheda12345 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM! Thanks for your contribution.

@DarkLight1337 DarkLight1337 enabled auto-merge (squash) March 19, 2025 13:00
@github-actions github-actions bot added the ready ONLY add when PR is ready to merge/full CI is needed label Mar 19, 2025
@vllm-bot vllm-bot merged commit 8363cd0 into vllm-project:main Mar 19, 2025
41 of 43 checks passed
gmarinho2 pushed a commit to gmarinho2/vllm that referenced this pull request Apr 1, 2025
jkaniecki added a commit to HabanaAI/vllm-fork that referenced this pull request Apr 2, 2025
afierka-intel pushed a commit to HabanaAI/vllm-fork that referenced this pull request Apr 3, 2025
This PR involves cherry-pick of
vllm-project#15112 from the upstream and a
fix for cos_sin preparation in emb layers to match regional compilation.

---------

Signed-off-by: Jan Kaniecki <[email protected]>
lulmer pushed a commit to lulmer/vllm that referenced this pull request Apr 7, 2025
imangohari1 pushed a commit to imangohari1/vllm-fork that referenced this pull request Apr 8, 2025
This PR involves cherry-pick of
vllm-project#15112 from the upstream and a
fix for cos_sin preparation in emb layers to match regional compilation.

---------

Signed-off-by: Jan Kaniecki <[email protected]>
tvoas pushed a commit to tvoas/vllm-fork that referenced this pull request Apr 9, 2025
This PR involves cherry-pick of
vllm-project#15112 from the upstream and a
fix for cos_sin preparation in emb layers to match regional compilation.

---------

Signed-off-by: Jan Kaniecki <[email protected]>
nishith-fujitsu pushed a commit to nishith-fujitsu/vllm that referenced this pull request Apr 9, 2025
shreyankg pushed a commit to shreyankg/vllm that referenced this pull request May 3, 2025
RichardoMrMu pushed a commit to RichardoMrMu/vllm that referenced this pull request May 12, 2025
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
ready ONLY add when PR is ready to merge/full CI is needed
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants