Skip to content

FP32 RoPE kernel #1061

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 0 commits into from
Closed

FP32 RoPE kernel #1061

wants to merge 0 commits into from

Conversation

imoneoi
Copy link
Contributor

@imoneoi imoneoi commented Sep 16, 2023

No description provided.

@imoneoi
Copy link
Contributor Author

imoneoi commented Sep 16, 2023

This PR is essential for Code-Llama to generate correct code. Without FP32 RoPE, Code-Llama can't generate the correct indentation, leading to very low HumanEval scores.

Code-Llama 13B HumanEval pass@1 greedy
vLLM 4.9
vLLM + FP32 RoPE 36.0

Example (vLLM):

from typing import List


def has_close_elements(numbers: List[float], threshold: float) -> bool:
    """ Check if in given list of numbers, are any two numbers closer to each other than
    given threshold.
    >>> has_close_elements([1.0, 2.0, 3.0], 0.5)
    False
    >>> has_close_elements([1.0, 2.8, 3.0, 4.0, 5.0, 2.0], 0.3)
    True
    """
   for i in range(len(numbers) - 1):   # wrong indentation in this line
        if abs(numbers[i] - numbers[i + 1]) <= threshold:
            return True
    return False

Example (vLLM + FP32 RoPE):

from typing import List


def has_close_elements(numbers: List[float], threshold: float) -> bool:
    """ Check if in given list of numbers, are any two numbers closer to each other than
    given threshold.
    >>> has_close_elements([1.0, 2.0, 3.0], 0.5)
    False
    >>> has_close_elements([1.0, 2.8, 3.0, 4.0, 5.0, 2.0], 0.3)
    True
    """
    for i in range(len(numbers) - 1):
        if abs(numbers[i] - numbers[i + 1]) <= threshold:
            return True
    return False

@WoosukKwon WoosukKwon self-requested a review September 16, 2023 21:49
@WoosukKwon
Copy link
Collaborator

Hi @imoneoi, thanks for bringing this up. We discussed this issue in #998 (comment), where we observed only using FP32 for initializing RoPE is enough to preserve the accuracy. It seems your evaluation result is not consistent with that. Could you provide a script for evaluation?

@imoneoi
Copy link
Contributor Author

imoneoi commented Sep 17, 2023

Hi @WoosukKwon my results were tested with Code-Llama weights and bfloat16 precision.

vLLM server:

python -m vllm.entrypoints.openai.api_server --model imone/CodeLlama_13B_with_EOT_token --host 127.0.0.1 --port 5000 --max-num-batched-tokens 16384 --worker-use-ray --engine-use-ray

EvalPlus

python -m codegen.generate --model codegen-16b --bs 1 --temperature 0 --greedy --n_samples 1 --root ./data/codellama_13b_fp32_rope
docker run -v $(pwd):/app ganler/evalplus:latest --dataset humaneval --samples ./data/codellama_13b_fp32_rope/humaneval/codegen-16b_temp_0.0/

@WoosukKwon
Copy link
Collaborator

@imoneoi Got it. Could you

  1. Resolve the formatting error? You can use ./format.sh to make your code align with our format requirement.
  2. Fix tests/kernels/test_pos_encoding.py to align with the change?

@imoneoi
Copy link
Contributor Author

imoneoi commented Sep 17, 2023

@imoneoi Got it. Could you

  1. Resolve the formatting error? You can use ./format.sh to make your code align with our format requirement.
  2. Fix tests/kernels/test_pos_encoding.py to align with the change?

@WoosukKwon Do you know the difference in rounding methods between PyTorch .to and CUDA type conversion? It seems that rounding differences caused some of the discrepancies in tests.

@imoneoi
Copy link
Contributor Author

imoneoi commented Sep 17, 2023

We also tested CodeLlama-Instruct 13B. FP32 RoPE has about 1% improvement on HumanEval+.

Before PR:

Base
{'pass@1': 0.42073170731707316}
Base + Extra
{'pass@1': 0.36585365853658536}

After PR:

Base
{'pass@1': 0.42073170731707316}
Base + Extra
{'pass@1': 0.3719512195121951}

@WoosukKwon
Copy link
Collaborator

@imoneoi

@WoosukKwon Do you know the difference in rounding methods between PyTorch .to and CUDA type conversion? It seems that rounding differences caused some of the discrepancies in tests.

While I believe this should not be the case, I also experience the weird precision error when applying this PR...

@WoosukKwon WoosukKwon mentioned this pull request Sep 18, 2023
5 tasks
@WoosukKwon
Copy link
Collaborator

This PR is stuck because we found a weird error in the test. More specifically, in the branch rope-fp32, I've updated tes_pos_encoding according to the precision change, but it didn't pass the test; The difference from the reference implementation was quite large. I have no idea where the difference came from.

@esmeetu
Copy link
Member

esmeetu commented Sep 19, 2023

@imoneoi @WoosukKwon I test on same model: codellama-13b-instruct.

Before(latest main branch) and after PR, results are same.

Base
{'pass@1': 0.5121951219512195}
Base + Extra
{'pass@1': 0.4451219512195122}

@Yard1
Copy link
Collaborator

Yard1 commented Sep 19, 2023

I have spent some time trying out different rounding modes in the kernel and none of them makes all the tests pass. My assumption would be something in the reference implementation is most likely causing it.

Here are the max and min differences between kernel and reference for the first test case (tests/kernels/test_pos_encoding.py::test_rotary_embedding[0-dtype0-None-64-7-2048-True]) from the rope-fp32 branch with different rounding modes:

round to nearest-even (default):
q: tensor(7.6294e-06, device='cuda:0', dtype=torch.float16) tensor(-6.1035e-05, device='cuda:0', dtype=torch.float16)
k: tensor(0.0010, device='cuda:0', dtype=torch.float16) tensor(-0.0005, device='cuda:0', dtype=torch.float16)

round to zero:
q: tensor(0.0039, device='cuda:0', dtype=torch.float16) tensor(-0.0020, device='cuda:0', dtype=torch.float16)
k: tensor(0.0020, device='cuda:0', dtype=torch.float16) tensor(-0.0020, device='cuda:0', dtype=torch.float16)

round down:
q: tensor(0., device='cuda:0', dtype=torch.float16) tensor(-0.0039, device='cuda:0', dtype=torch.float16)
k: tensor(0., device='cuda:0', dtype=torch.float16) tensor(-0.0020, device='cuda:0', dtype=torch.float16)

round up:
q: tensor(0.0039, device='cuda:0', dtype=torch.float16) tensor(0., device='cuda:0', dtype=torch.float16)
k: tensor(0.0020, device='cuda:0', dtype=torch.float16) tensor(0., device='cuda:0', dtype=torch.float16)

@WoosukKwon
Copy link
Collaborator

@Yard1 Yeah, this one is really really weird. I've checked out both reference implementation and our kernel multiple times to find any potential source of the precision error, but totally failed.

@imoneoi
Copy link
Contributor Author

imoneoi commented Sep 20, 2023

@imoneoi @WoosukKwon I test on same model: codellama-13b-instruct.

Before(latest main branch) and after PR, results are same.

Base {'pass@1': 0.5121951219512195} Base + Extra {'pass@1': 0.4451219512195122}

What are your EvalPlus settings? The results seem much higher

@imoneoi
Copy link
Contributor Author

imoneoi commented Sep 20, 2023

@Yard1 I also cannot figure out the precision issue. Is it because of the non-associativity of floating-point computations?

@esmeetu
Copy link
Member

esmeetu commented Sep 21, 2023

@imoneoi I changed the system prompt. And some post processing.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants