Skip to content

RoPE should be applied with float32 #863

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
imoneoi opened this issue Aug 25, 2023 · 4 comments
Closed

RoPE should be applied with float32 #863

imoneoi opened this issue Aug 25, 2023 · 4 comments
Labels
feature request New feature or request

Comments

@imoneoi
Copy link
Contributor

imoneoi commented Aug 25, 2023

It seems that RoPE(sin,cos) should be stored and applied in fp32 and then casted back to fp16/bf16

cache = cache.to(torch_dtype)

Reference implementation from Llama 2 / Code Llama:

def precompute_freqs_cis(dim: int, end: int, theta: float = 10000.0):
    freqs = 1.0 / (theta ** (torch.arange(0, dim, 2)[: (dim // 2)].float() / dim))
    t = torch.arange(end, device=freqs.device, dtype=torch.float32)  # type: ignore
    freqs = torch.outer(t, freqs)  # type: ignore
    freqs_cis = torch.polar(torch.ones_like(freqs), freqs)  # complex64
    return freqs_cis

def apply_rotary_emb(
    xq: torch.Tensor,
    xk: torch.Tensor,
    freqs_cis: torch.Tensor,
) -> Tuple[torch.Tensor, torch.Tensor]:
    xq_ = torch.view_as_complex(xq.float().reshape(*xq.shape[:-1], -1, 2))
    xk_ = torch.view_as_complex(xk.float().reshape(*xk.shape[:-1], -1, 2))
    freqs_cis = reshape_for_broadcast(freqs_cis, xq_)
    xq_out = torch.view_as_real(xq_ * freqs_cis).flatten(3)
    xk_out = torch.view_as_real(xk_ * freqs_cis).flatten(3)
    return xq_out.type_as(xq), xk_out.type_as(xk)
@WoosukKwon
Copy link
Collaborator

Hi @imoneoi, thanks for pointing this out and submitting a PR to fix it. To my understanding, the data type used in RoPE is different in Meta's original LLaMA implementation (you attached here) and HF's. If I understand correctly, in HF Transformers, the cos and sin embeddings are converted to the input q, k data type while q, k tensors keep their original data type. I've checked that our current kernel implementation passes the unit test (https://github.com/vllm-project/vllm/blob/main/tests/kernels/test_pos_encoding.py) while your implementation in #870 does not.

In #870, could we make it optional to cast the intermediate tensors to float32? As I believe most people expect vLLM to be compatible with HF transformers, I'd like to keep the current behavior as default. However, it'd be nice to have the float32 casting as an option for advanced users who care about the original implementation.

@imoneoi
Copy link
Contributor Author

imoneoi commented Aug 25, 2023

Yes, I understand your concern. I'm mostly thinking about the latest code llama. When theta is large, the accuracy problem may be exacerbated, so we may need full precision RoPE.

I can implement the cast option and add an extra kernel for full-precision RoPE. By the way, this PR is editable and I welcome your changes.

@imoneoi
Copy link
Contributor Author

imoneoi commented Aug 25, 2023

For the unit test, I wonder if it's because some differences in type casting? Is Tensor.to the same with static_cast<>?

Looks like only 1 item in the whole array have little rounding errors

@hmellor
Copy link
Member

hmellor commented Mar 8, 2024

Closing as this now appears to be resolved.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
feature request New feature or request
Projects
None yet
Development

Successfully merging a pull request may close this issue.

3 participants