Skip to content

Add check for transformers version to fix some CodeLlama-based models' behaviors #998

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 2 commits into from

Conversation

HermitSun
Copy link
Contributor

Supports for CodeLlama have been added in transformers 4.33.0 (huggingface/transformers#25740). This pr introduces the rope_theta param for RoPE scaling.

However, some CodeLLaMA-based models weights are released before this pr merged. That is to say, those models did not take the rope_theta into considerations. And if we directly use the rope_theta from their configs like this, those models will give gebbrish outputs.

Fortunately, there is a transformers_version attribute in hf models' config.json. So we can add a check to solve this problem: if we find the model use the legacy version of transformers, we simply use the default rope_theta value.

And after applying this patch, I think at least the following models will work normally (tested on A100 with CUDA 11.8):

  1. WizardLM/WizardCoder-Python-13B-V1.0
  2. WizardLM/WizardCoder-Python-34B-V1.0

@WoosukKwon WoosukKwon self-requested a review September 9, 2023 22:53
@WoosukKwon
Copy link
Collaborator

Hi @HermitSun, is setting rope_theta as 10000 a correct implementation of the model? According to the discussion in #883, it seems 1000000 (the value in config.json) is actually correct, while it requires Rope in FP32 precision (#870).

@HermitSun
Copy link
Contributor Author

HermitSun commented Sep 10, 2023

Hi @HermitSun, is setting rope_theta as 10000 a correct implementation of the model? According to the discussion in #883, it seems 1000000 (the value in config.json) is actually correct, while it requires Rope in FP32 precision (#870).

Thank you for pointing this out! It seems that this problem is actually caused by fp16 precision overflow in some cases and I missed it😭. I will double check it in my case.

This version mismatch check is useful for me, because in my work I will try many different models with different transformers versions. Since transformers before 4.33.0 does not really support CodeLlama, models using a legacy version are more likely to use a default rope_theta value, and this check will work as a reminder.

And what do you think about adding this version check when using CodeLlama-based models? I mean, if you do not think my current change is so necessary, I think I can change it to simply printing a warning when a version mismatch is detected, and then I will try to working on this precision problem in another new pr.

@viktor-ferenczi
Copy link
Contributor

viktor-ferenczi commented Sep 10, 2023

Related: I'm trying to run WizardLM/WizardCoder-Python-13B-V1.0 (A Code Llama fine-tune). It gives gibberish output (I used the right prompt template), while the original Code Llama 13B works. I guess the config.json is different. Another fine-tune OpenAssistant/codellama-13b-oasst-sft-v10 works fine.

llm = LLM(model='WizardLM/WizardCoder-Python-13B-V1.0', tokenizer='hf-internal-testing/llama-tokenizer', tensor_parallel_size=2, seed=42)
sampling_params = SamplingParams(n=10, max_tokens=2000, temperature=0.3, top_p=0.95, stop=['</s>'])

@HermitSun
Copy link
Contributor Author

Related: I'm trying to run WizardLM/WizardCoder-Python-13B-V1.0 (A Code Llama fine-tune). It gives gibberish output (I used the right prompt template), while the original Code Llama 13B works. I guess the config.json is different. Another fine-tune OpenAssistant/codellama-13b-oasst-sft-v10 works fine.

llm = LLM(model='WizardLM/WizardCoder-Python-13B-V1.0', tokenizer='hf-internal-testing/llama-tokenizer', tensor_parallel_size=2, seed=42)
sampling_params = SamplingParams(n=10, max_tokens=2000, temperature=0.3, top_p=0.95, stop=['</s>'])

I noticed that this code uses a hf-internal-testing/llama-tokenizer. What if this code had used the model's tokenizer?

@WoosukKwon
Copy link
Collaborator

Hi @HermitSun

I will try to working on this precision problem in another new pr.

No worries. @imoneoi and I made a PR #1004 to fix this issue.

Since transformers before 4.33.0 does not really support CodeLlama, models using a legacy version are more likely to use a default rope_theta value, and this check will work as a reminder.

Actually, I don't understand. Isn't rope_theta: 1000000 correct for the WizardLM/WizardCoder-Python-13B-V1.0 and WizardLM/WizardCoder-Python-34B-V1.0 models? Do you need to reset this to 10000?

@esmeetu
Copy link
Member

esmeetu commented Sep 10, 2023

imk,wizardcoder-34b was trained using transformers==4.31.0, which doesn't support rope theta, change to base 10000 will be ok.
@HermitSun @WoosukKwon

@HermitSun
Copy link
Contributor Author

HermitSun commented Sep 10, 2023

Hi @WoosukKwon, after applying the changes in PR #1004, I find the model works well (At least it looks well...Maybe I need more tests.).

It seems that we do not have to change the rope_theta if the precision problem is solved. But as @esmeetu says, I think it's ok to use 10000 as default, since transformers had not supported this param when the model was trained.

And I originally opened this pr to check the transformers version mismatch. If the precision problem is fixed, maybe we can give a warning instead of changing rope_theta. Or maybe we can have a better way to check the version instead if just leave a comment in the code.

@viktor-ferenczi
Copy link
Contributor

viktor-ferenczi commented Sep 10, 2023

The HF repos of the WizardLM/WizardCoder-Python-13B-V1.0 and WizardLM/WizardCoder-Python-34B-V1.0 models have 1000000 in their config.json. If I change the value to 10000 in the downloaded model folder vLLM starts to work properly with the model (tested only the 13B variant, because the 34B is too much for my hardware). So a default of 10000 seems to be correct, indeed.

@HermitSun
Copy link
Contributor Author

HermitSun commented Sep 10, 2023

I additionally did some simple experiments with WizardLM/WizardCoder-Python-34B-V1.0:

rope_theta RoPE precision pass@1
10000 fixed (PR #1004) 0.6646
10000 not fixed 0.6585
1000000 fixed (PR #1004) 0.6037
1000000 not fixed 0.0

Because the model will give gibberish outputs when rope_theta=1000000 and RoPE precision is not fixed, pass@1 of this case is 0.0. As the result shows:

  1. rope_theta=10000 gives a better result on the human-eval benchmark.
  2. The precision problem does affect the result.

And I noticed that WizardLM/WizardCoder-Python-13B-V1.0 and WizardLM/WizardCoder-Python-34B-V1.0 both require vllm==0.1.4. At that time vLLM does not support read rope_theta from config.json.

@esmeetu
Copy link
Member

esmeetu commented Sep 10, 2023

@HermitSun What is interesting is that adding rope theta slightly (ex. 16384) might improve the score.

@HermitSun
Copy link
Contributor Author

@HermitSun What is interesting is that adding rope theta slightly (ex. 16384) might improve the score.

Amazing! I will update my results below, and it seems to prove that rope_theta around 10000 performs better:

rope_theta RoPE precision pass@1
16384 not fixed 0.7012
16384 fixed (PR #1004) 0.6829
10000 fixed (PR #1004) 0.6646
10000 not fixed 0.6585
1000000 fixed (PR #1004) 0.6037
1000000 not fixed 0.0

And I wonder if this conclusion (i.e. slightly increase of rope_theta will improve the model's performance) suitable for all CodeLlama-based models?

@WoosukKwon
Copy link
Collaborator

@HermitSun Thanks a lot for the evaluation. Could you please share the script you used for it?

And, if possible, could you measure the accuracy on the rope-fp32 branch as well? It uses FP32 not only for initialization, but also for calculating the RoPE (without changing the base value). It requires installing vLLM from source:

git checkout rope-fp32
pip install -e .

@viktor-ferenczi
Copy link
Contributor

The same issue is discussed at llama.cpp: ggml-org/llama.cpp#3090

@HermitSun
Copy link
Contributor Author

HermitSun commented Sep 11, 2023

@HermitSun Thanks a lot for the evaluation. Could you please share the script you used for it?

And, if possible, could you measure the accuracy on the rope-fp32 branch as well? It uses FP32 not only for initialization, but also for calculating the RoPE (without changing the base value). It requires installing vLLM from source:

git checkout rope-fp32
pip install -e .

I used the evaluation scripts in WizardLM's official repo and followed their environment setup steps: https://github.com/nlpxucan/WizardLM/tree/main/WizardCoder.

But I think my environment has two minor differences with the authors', and that maybe why I reported a result lower than the authors claimed:

  1. I used torch 2.0.1 and cuda 11.8 instead of torch 1.12 and cuda 11.3.
  2. I used a fork of vLLM 0.1.4 (slightly modified some impls) instead of vLLM 0.1.4 release.

I will use latest code (instead of my fork) and apply fp32 precision patch (and compare the two patches in PR #1004 & branch rope-fp32) then update my test results again. And I think we'd better further discuss the precision problem in PR #1004? We can solve the rope_theta setting problem in this PR.

@HermitSun
Copy link
Contributor Author

HermitSun commented Sep 11, 2023

@WoosukKwon I used the latest code and applied fp32 precision patches, the updated results are here:

rope_theta PR #1004 branch rope-fp32 baseline
10000 0.6646 0.6585 0.6646
16384 0.6890 0.6829 0.6890
1000000 0.6037 0.6098 N/A

It seems that PR #1004 performs better when using default rope_theta value 10000, and branch rope-fp32 performs better when rope_theta is larger.

@WoosukKwon
Copy link
Collaborator

WoosukKwon commented Sep 11, 2023

Hi @HermitSun, thanks for the benchmark results. It seems #1004 is enough for the precision problem.

Then, I believe it should be the users who are responsible for setting the right rope_theta for their models. It's out of vLLM's scope. As the value can be easily configured by modifying the model's config.json file, I think there's nothing we need to do about it here.

I think I can change it to simply printing a warning when a version mismatch is detected

While it makes sense to me, I'm wondering what the exact logic would be. 99% of the models in the HF model hub were trained with transformers < v4.33.1. How can we tell when to warn? If I understand correctly, your PR will warn every llama model trained before 4.33.1, including LLaMA V1 and V2. How can we minimize such false positives?

@HermitSun
Copy link
Contributor Author

HermitSun commented Sep 11, 2023

Hi @WoosukKwon, I think you're right, maybe we should not introduce extra workload for an inference framework.

As for the version mismatch detection problem, I think we can add some heuristic rules to check the _name_or_path of config.json. If CodeLlama or some other keywords appear in _name_or_path, then we can raise a warning, and vice versa. But maybe we should not do this in the framework, this is JUST a requirement of my case rather than a generic one.

If you feel this pr (including replacing the rope_theta check with rule-based alerts) is less necessary, I'll close it.

@WoosukKwon WoosukKwon mentioned this pull request Sep 16, 2023
@WoosukKwon WoosukKwon closed this Mar 13, 2024
pi314ever pushed a commit to pi314ever/vllm that referenced this pull request Apr 3, 2025
It adds performance benchmark to jenkins to early catch regression for
torch.compile. One benchmark run produce to results (2 separate
testcases). One compares throughput against given threshold, second
warmup time
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants