-
-
Notifications
You must be signed in to change notification settings - Fork 7.3k
RoPE should be applied with float32 #863
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Comments
Hi @imoneoi, thanks for pointing this out and submitting a PR to fix it. To my understanding, the data type used in RoPE is different in Meta's original LLaMA implementation (you attached here) and HF's. If I understand correctly, in HF Transformers, the cos and sin embeddings are converted to the input q, k data type while q, k tensors keep their original data type. I've checked that our current kernel implementation passes the unit test (https://github.com/vllm-project/vllm/blob/main/tests/kernels/test_pos_encoding.py) while your implementation in #870 does not. In #870, could we make it optional to cast the intermediate tensors to float32? As I believe most people expect vLLM to be compatible with HF transformers, I'd like to keep the current behavior as default. However, it'd be nice to have the float32 casting as an option for advanced users who care about the original implementation. |
Yes, I understand your concern. I'm mostly thinking about the latest code llama. When I can implement the cast option and add an extra kernel for full-precision RoPE. By the way, this PR is editable and I welcome your changes. |
For the unit test, I wonder if it's because some differences in type casting? Is Looks like only 1 item in the whole array have little rounding errors |
Closing as this now appears to be resolved. |
It seems that RoPE(sin,cos) should be stored and applied in fp32 and then casted back to fp16/bf16
vllm/vllm/model_executor/layers/attention.py
Line 273 in 791d79d
Reference implementation from Llama 2 / Code Llama:
The text was updated successfully, but these errors were encountered: