Skip to content

[Model][MiniMaxText01] Support MiniMaxText01 model inference #13454

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 108 commits into from
Apr 1, 2025

Conversation

ZZBoom
Copy link
Contributor

@ZZBoom ZZBoom commented Feb 18, 2025

Purpose

This PR is intended to support the MiniMaxText01 model inference.
It can run on a single machine with 8xH800 and 8xH20, where a single H800 machine can handle a maximum context input of 2 million tokens, and a single H20 machine can handle a maximum context input of 5 million tokens.

Modifications

  1. Add the MiniMaxText01 model inference implementation, and a separate cache manager specifically for linear attention.
  2. Adapt to the input consistent with the mamba model, including request_ids_to_seq_ids and finished_requests_ids.
  3. Temporary Fix for the finished_requests_ids Issue in Consecutive Multi-Batch Inferences: This is a temporary solution for a specific problem, which likely involves state management during multi-batch inferences.

Deployment

Default Parameter Startup

python3 -m vllm.entrypoints.api_server \
--model ${MiniMaxText01-Molde-Path} \
--tensor-parallel-size 8 \
--trust-remote-code \
--quantization experts_int8  \
--max_model_len 1000000 \
--dtype bfloat16

H800 TP8, maximum context length 2 million

python3 -m vllm.entrypoints.api_server \
--model ${MiniMax-Text-01-Path} \
--tensor-parallel-size 8 \
--trust-remote-code \
--quantization experts_int8  \
--max_model_len 2048000 \
--gpu_memory_utilization 0.95 \
--max_num_seqs 1 \
--dtype bfloat16

H20 TP8, maximum context length 5 million

python -m vllm.entrypoints.api_server \
--model MiniMaxAI/MiniMax-Text-01 \
--tensor-parallel-size 8 \
--trust-remote-code \
--quantization experts_int8 \
--max_model_len 5120000 \
--gpu_memory_utilization 0.95 \
--max_num_seqs 1 \
--dtype bfloat16

Copy link

👋 Hi! Thank you for contributing to the vLLM project.

💬 Join our developer Slack at https://slack.vllm.ai to discuss your PR in #pr-reviews, coordinate on features in #feat- channels, or join special interest groups in #sig- channels.

Just a reminder: PRs would not trigger full CI run by default. Instead, it would only run fastcheck CI which starts running only a small and essential subset of CI tests to quickly catch errors. You can run other CI tests on top of those by going to your fastcheck build on Buildkite UI (linked in the PR checks section) and unblock them. If you do not have permission to unblock, ping simon-mo or khluu to add you in our Buildkite org.

Once the PR is approved and ready to go, your PR reviewer(s) can run CI to test the changes comprehensively before merging.

To run CI, PR reviewers can either: Add ready label to the PR or enable auto-merge.

🚀

@heheda12345
Copy link
Collaborator

Why do you introduce minimax_cache.py instead of reusing mamba_cache.py?

vllm/config.py Outdated
Comment on lines 860 to 863
# Handle minimax model
if hasattr(self.hf_config, "attn_type_list"):
# 1 represents flash attention and 0 represents linear attention
return sum(t == 1 for t in self.hf_config.attn_type_list)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This should be handled in the hybrid model case a few lines down

return hidden_states


class MiniMaxText01ForCausalLM(nn.Module, HasInnerState): # IsHybrid后续加
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

IIUC, this should be:
class MiniMaxText01ForCausalLM(nn.Module, HasInnerState, IsHybrid):

Did you hit some issue when adding the IsHybrid interface?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, there were some issues with the earlier vllm.

Thanks for your above suggestion! We add the logic in the hybrid model case, and it works!
Please review the new commit 530d99a.

@ZZBoom
Copy link
Contributor Author

ZZBoom commented Feb 21, 2025

Why do you introduce minimax_cache.py instead of reusing mamba_cache.py?

Because the internal data structure self.mamba_cache in mamba_cache.py is not suitable for the cache of MiniMaxText01 Model Linear Attn, and this parameter is coupled within the current_run_tensors method.

@zwc163
Copy link

zwc163 commented Feb 24, 2025

Could you please support the MiniMax VL model as well? I would greatly appreciate it

@zifengdexiatian
Copy link

Sorry, this may be a silly question, but is the model used int8 quantized to achieve 2 million contexts using H800 TP8 inference?

Copy link

mergify bot commented Feb 25, 2025

This pull request has merge conflicts that must be resolved before it can be
merged. Please rebase the PR, @ZZBoom.

https://docs.github.com/en/pull-requests/collaborating-with-pull-requests/working-with-forks/syncing-a-fork

@mergify mergify bot added the needs-rebase label Feb 25, 2025
@ZZBoom
Copy link
Contributor Author

ZZBoom commented Feb 25, 2025

Could you please support the MiniMax VL model as well? I would greatly appreciate it

@zwc163
Thank you for your attention. We do not have such a plan in the near future.

@ZZBoom
Copy link
Contributor Author

ZZBoom commented Feb 25, 2025

Sorry, this may be a silly question, but is the model used int8 quantized to achieve 2 million contexts using H800 TP8 inference?

@zifengdexiatian
Two million tokens is not the goal. If you want to run this model on a single machine with 8xH800, you can only use int8 weight-only quantization or lower precision, and this two million is the maximum limit for running in this environment.

@tlrmchlsmth tlrmchlsmth self-assigned this Feb 25, 2025
@zifengdexiatian
Copy link

Sorry, this may be a silly question, but is the model used int8 quantized to achieve 2 million contexts using H800 TP8 inference?

@zifengdexiatian

Two million tokens is not the goal. If you want to run this model on a single machine with 8xH800, you can only use int8 weight-only quantization or lower precision, and this two million is the maximum limit for running in this environment.

Thanks for the answer, I understand that a single machine can only run the quantitative version, and can run a maximum of 2 million tokens at the same time.

@tlrmchlsmth
Copy link
Collaborator

@ZZBoom just checking - are there any blockers on this PR? I plan to review it but it's still marked as draft

@shuxiaobo
Copy link

Is there any progress?

@tugot17
Copy link

tugot17 commented Mar 7, 2025

Can you merge this please?

@mergify mergify bot added documentation Improvements or additions to documentation ci/build frontend multi-modality Related to multi-modality (#4194) structured-output speculative-decoding v1 labels Mar 13, 2025
@qscqesze qscqesze force-pushed the qinggangying/vllm branch from e863d81 to 1bd32bc Compare March 13, 2025 03:41
@mergify mergify bot removed the needs-rebase label Mar 13, 2025
Copy link
Collaborator

@tlrmchlsmth tlrmchlsmth left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I had a couple more small questions and comments, but overall I think the PR is looking pretty good and ready to land once those are addressed.

Will there be a followup to simplify the weight loading?

@tlrmchlsmth tlrmchlsmth added the ready ONLY add when PR is ready to merge/full CI is needed label Mar 30, 2025
@tlrmchlsmth
Copy link
Collaborator

Adding ready to see how the mamba and hybrid integration tests do

@qscqesze
Copy link
Contributor

I had a couple more small questions and comments, but overall I think the PR is looking pretty good and ready to land once those are addressed.

Will there be a followup to simplify the weight loading?

Yes. We will simplify the weight loading in following work.

- Removed redundant loops for tensor value assignments in the tests, enhancing readability and maintainability.
- Streamlined the initialization of key-value caches and input tensors, focusing on essential configurations for clarity.

Signed-off-by: qscqesze <[email protected]>
@qscqesze qscqesze requested a review from DarkLight1337 as a code owner March 31, 2025 02:56
…itespace

- Eliminated trailing whitespace in the test file to enhance code cleanliness and maintain consistency in formatting.
- This minor adjustment contributes to overall code quality without affecting functionality.

Signed-off-by: qscqesze <[email protected]>
… functionality

- Removed unused parameter from current_run_tensors method in ConstantSizeCache to simplify its interface.
- Updated slope_rate calculation in MiniMaxText01 to handle single-layer scenarios more clearly, enhancing readability.
- Adjusted calls to current_run_tensors in MiniMaxText01Model to reflect the updated method signature.

Signed-off-by: qscqesze <[email protected]>
@qscqesze
Copy link
Contributor

Some gsm8k evals on my end. Do these look good to you @qscqesze and @ZZBoom? (Using experts_int8 to fit on a single 8xA100 machine)

Running the following:

vllm serve MiniMaxAI/MiniMax-Text-01 \
--tensor-parallel-size 8 \
--trust-remote-code \
--quantization experts_int8  \
--max_model_len 1000000 \
--dtype bfloat16

lm_eval --model local-completions --tasks gsm8k --model_args model=MiniMaxAI/MiniMax-Text-01,base_url=http://127.0.0.1:8000/v1/completions --limit 100
|Tasks|Version|     Filter     |n-shot|  Metric   |   |Value|   |Stderr|
|-----|------:|----------------|-----:|-----------|---|----:|---|-----:|
|gsm8k|      3|flexible-extract|     5|exact_match|↑  | 0.94|±  |0.0239|
|     |       |strict-match    |     5|exact_match|↑  | 0.94|±  |0.0239|

GSM8K results reported in https://huggingface.co/MiniMaxAI/MiniMax-Text-01#3-evaluation are 0.948, so this looks good to me, especially we'll be dropping accuracy a bit from quantization

Yeah, this looks good to me and aligns with expectations.

@qscqesze
Copy link
Contributor

@tlrmchlsmth Hi! I believe our code passes all the tests except for [buildkite/ci/pr/v1-test], which failed due to a torch.OutOfMemoryError: CUDA out of memory. This issue doesn’t seem related to our code. Could you take a look and see if it’s ready to be merged?

Comment on lines 136 to 146
q = torch.zeros(batch_size, num_heads, 1, head_size, dtype=dtype)
k = torch.zeros(batch_size, num_heads, 1, head_size, dtype=dtype)
v = torch.zeros(batch_size, num_heads, 1, head_size, dtype=dtype)

kv_caches = torch.zeros(batch_size,
num_heads,
head_size,
head_size,
dtype=dtype,
device="cuda")

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Now that you've removed the old initialization code, these should all be torch.randn instead of torch.zeros. Since these tensors are initialized to all zeros, we're not testing anything.

Ditto for the other unit tests.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

thanks. Fixed it.

qscqesze added 4 commits April 1, 2025 10:21
…om values

- Changed tensor initialization from zeros to random values in the lightning attention test cases to better simulate realistic input scenarios.
- This adjustment enhances the robustness of the tests by ensuring varied input distributions.

Signed-off-by: qscqesze <[email protected]>
… remove scale factor

- Changed the initialization of the key-value cache tensor from random values to zeros for consistency in test scenarios.
- Removed the scale factor from the KV outer product calculation to simplify the implementation and enhance clarity.

Signed-off-by: qscqesze <[email protected]>
…aled random values

- Updated the initialization of query, key, and value tensors in the lightning attention tests to use a base scale factor for random values, enhancing consistency across test scenarios.
- Adjusted the initialization of key-value caches to align with the new scaling approach, improving the robustness of the tests.

Signed-off-by: qscqesze <[email protected]>
…lity

- Adjusted the indentation and formatting of tensor initialization in the lightning attention test cases to enhance code clarity and maintain consistency.
- This change focuses on improving the overall structure of the tests without altering their functionality.

Signed-off-by: qscqesze <[email protected]>
@qscqesze
Copy link
Contributor

qscqesze commented Apr 1, 2025

Hi @tlrmchlsmth .
I’ve fixed the comments—thank you for the feedback! However, the test failed due to a missing image. Would you mind helping to restart the test? When you have a moment, could you also take another look at the code to see if it’s ready to be merged?
Thanks again!

@tlrmchlsmth
Copy link
Collaborator

Hi @tlrmchlsmth . I’ve fixed the comments—thank you for the feedback! However, the test failed due to a missing image. Would you mind helping to restart the test? When you have a moment, could you also take another look at the code to see if it’s ready to be merged? Thanks again!

I'll take another look at the code tomorrow morning! In the meantime I think you need to merge in main for the failing docker-build-image test (related to #14549)

@qscqesze
Copy link
Contributor

qscqesze commented Apr 1, 2025

Hi @tlrmchlsmth . I’ve fixed the comments—thank you for the feedback! However, the test failed due to a missing image. Would you mind helping to restart the test? When you have a moment, could you also take another look at the code to see if it’s ready to be merged? Thanks again!

I'll take another look at the code tomorrow morning! In the meantime I think you need to merge in main for the failing docker-build-image test (related to #14549)

Thanks. I updated the branch already.

Copy link
Collaborator

@tlrmchlsmth tlrmchlsmth left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looks good to me now! Thank you for the contribution!

Running one more sanity check on my end and then ready to merge

@tlrmchlsmth tlrmchlsmth merged commit 9ef98d5 into vllm-project:main Apr 1, 2025
41 checks passed
Alex4210987 pushed a commit to LeiWang1999/vllm-bitblas that referenced this pull request Apr 5, 2025
…oject#13454)

Signed-off-by: qscqesze <[email protected]>
Co-authored-by: qingjun <[email protected]>
Co-authored-by: qscqesze <[email protected]>
Signed-off-by: xinyuxiao <[email protected]>
lulmer pushed a commit to lulmer/vllm that referenced this pull request Apr 7, 2025
…oject#13454)

Signed-off-by: qscqesze <[email protected]>
Co-authored-by: qingjun <[email protected]>
Co-authored-by: qscqesze <[email protected]>
Signed-off-by: Louis Ulmer <[email protected]>
nishith-fujitsu pushed a commit to nishith-fujitsu/vllm that referenced this pull request Apr 9, 2025
lk-chen pushed a commit to lk-chen/vllm that referenced this pull request Apr 29, 2025
shreyankg pushed a commit to shreyankg/vllm that referenced this pull request May 3, 2025
RichardoMrMu pushed a commit to RichardoMrMu/vllm that referenced this pull request May 12, 2025
…oject#13454)

Signed-off-by: qscqesze <[email protected]>
Co-authored-by: qingjun <[email protected]>
Co-authored-by: qscqesze <[email protected]>
Signed-off-by: Mu Huai <[email protected]>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
ci/build documentation Improvements or additions to documentation frontend multi-modality Related to multi-modality (#4194) ready ONLY add when PR is ready to merge/full CI is needed speculative-decoding structured-output v1
Projects
None yet
Development

Successfully merging this pull request may close these issues.

9 participants