Skip to content

[RFC]: Reward Modelling in OpenAI Compatible Server #8967

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
1 task done
noamgat opened this issue Sep 30, 2024 · 30 comments · Fixed by #9759
Closed
1 task done

[RFC]: Reward Modelling in OpenAI Compatible Server #8967

noamgat opened this issue Sep 30, 2024 · 30 comments · Fixed by #9759
Labels

Comments

@noamgat
Copy link
Contributor

noamgat commented Sep 30, 2024

Motivation.

Reward models are an important tool in NLP / AI workflows, especially in agentic flows which use them to verify quality of intermediate outputs, or rank between several attempts at performing a single task.

vLLM just added support for a reward model in #8896 (comment) .
This requires a workaround in order to work with the OpenAI Compatible Server - it piggybacks the embedding endpoint.
The workaround requires the client to know which tokenizer is being used by the server, apply the chat template to the conversation, and send the resulting string to the embedding endpoint. This isn't ideal and it breaks the decoupling between the client and server.

A short discussion in the same issue led to the creation of the RFC.

Proposed Change.

The reason that no endpoint currently matches the needs of the reward model is as follows:

  • The embedding endpoint receives a string as the input, not a conversation
  • The chat endpoint returns a string, not a series of numbers. Even if you ask for logprobs, they are after softmax was applied, which is not a reversable process.

I see several ways to more elegantly support reward models in the OpenAI compatible server, and this RFC will hopefully be the discussion point for them.

Option 1:
Add a conversation object (List[Dict[str, str]]) as a potential input to EmbeddingRequest class. It already supports a variety of options:
input: Union[List[int], List[List[int]], str, List[str]]
Upon detecting that a conversation object was given, the OpenAI Compatible server will apply the chat template using the tokenizer, and proceed as if it received str input.

Option 2:
Add a way to get output logits instead of output logprobs from the chat endpoint. This can be either a new per-request parameter (similar to top_logprobs) or a server-side flag to override the behavior of the data returned in the field (--return_logits_instead_of_logprobs flag to the OpenAI server for example).

Option 3:
Add a dedicated endpoint to vLLM.

Option 4:
Do nothing. Since there is a /tokenize endpoint that also accepts a conversation, the sample code in #8896 (comment) could be changed to use the tokenize endpoint, receieve the tokens list and send that the embeddings endpoint, which addresses the coupling problem.

I personally support Option 1, as it feels the least hacky of the bunch, and also does not require a whole lot of new code.

What do you think?

Feedback Period.

Not my decision to say when a conclusion was reached here, but I don't think it should take more than a couple of weeks.

CC List.

@simon-mo
@DarkLight1337
@zhuzilin

Any Other Things.

vLLM is awesome!

Before submitting a new issue...

  • Make sure you already searched for relevant issues, and asked the chatbot living at the bottom right corner of the documentation page, which can answer lots of frequently asked questions.
@noamgat noamgat added the RFC label Sep 30, 2024
@DarkLight1337
Copy link
Member

DarkLight1337 commented Sep 30, 2024

To keep the semantics consistent (where Chat Completions is like Completions API with chat conversation), I prefer having a separate Chat Embeddings API (Embeddings API with chat conversation). So our endpoint map would be something like

  • v1/completions (Completions API)
  • v1/chat/completions (Chat Completions API)
  • v1/embeddings (Embeddings API)
  • v1/chat/embeddings (Chat Embeddings API) [new]

Since most of the logic related to chat conversation parsing is already in chat_utils.py, it should not take that much effort to add this.

@simon-mo
Copy link
Collaborator

Thank you for the RFC! Adding @youkaichao as stakeholder and please feel free to add others

@natolambert
Copy link

I'm not a VLLM contributor (at least heavily, I may have had a PR I don't remember), but I'm a heavy reward model user and a heavy infrastructure builder (you can see my basic pipelines for most public reward models on HuggingFace in RewardBench).

I do not think getting reward models right will be easy, but it is worthwhile and a sign of a maturing ecosystem. Some things to keep in mind, and I think eventually a dedicated architecture will be worthwhile. This goes in line with a few common use-cases for RMs.

  1. LLM-as-a-judge / evals: This is likely the biggest use at first. Rating responses for filtering. In this case, you normally will not be using the text at token level. Outputting a score is all you need. Hacky solutions are fine (iirc option 1 above).
  2. RLHF training (synchronous, e.g. PPO): Here, the reward model scores candidate samples and this value is used to update the LM loss. The weights of the RM are held constant. Though, this likely just works on tokens and re-applying the chat template is slow. Though, being able to switch chat templates easily is very nice for the open community, who may be training Llama 3.1 but using the Qwen RM (different tokenizer).
  3. RLHF training (asynch, e.g. rejection sampling): Here, the RM likely can be used as an OpenAI server. We just need to pass a big list of texts through the reward model. This is very similar to LLM as a judge substitutes, but the description is different :)

Some comments:

  • Passing a "messages" List[Dict[str, str]] I suspect would be used less than just List[str] as people have normally formatted their messages before passing into the RM, but that may just be me.

Can someone provide more examples on how the embedding API and the existing Qwen model work? I didn't see that much in the PR.

THANKS!

@youkaichao
Copy link
Member

cc @zhuzilin

@zhuzilin
Copy link
Contributor

zhuzilin commented Oct 1, 2024

@natolambert please take a look at the pr description: #8896

@noamgat
Copy link
Contributor Author

noamgat commented Oct 1, 2024

I'm not a VLLM contributor (at least heavily, I may have had a PR I don't remember), but I'm a heavy reward model user and a heavy infrastructure builder (you can see my basic pipelines for most public reward models on HuggingFace in RewardBench).

I do not think getting reward models right will be easy, but it is worthwhile and a sign of a maturing ecosystem. Some things to keep in mind, and I think eventually a dedicated architecture will be worthwhile. This goes in line with a few common use-cases for RMs.

1. LLM-as-a-judge / evals: This is likely the biggest use at first. Rating responses for filtering. In this case, you normally will not be using the text at token level. Outputting a score is all you need. Hacky solutions are fine (iirc option 1 above).

2. RLHF training (synchronous, e.g. PPO): Here, the reward model scores candidate samples and this value is used to update the LM loss. The weights of the RM are held constant. Though, this likely just works on tokens and re-applying the chat template is slow. Though, being able to switch chat templates easily is very nice for the open community, who may be training Llama 3.1 but using the Qwen RM (different tokenizer).

3. RLHF training (asynch, e.g. rejection sampling): Here, the RM likely can be used as an OpenAI server. We just need to pass a big list of texts through the reward model. This is very similar to LLM as a judge substitutes, but the description is different :)

Some comments:

* Passing a "messages" `List[Dict[str, str]]` I suspect would be used less than just `List[str]` as people have normally formatted their messages before passing into the RM, but that may just be me.

Can someone provide more examples on how the embedding API and the existing Qwen model work? I didn't see that much in the PR.

THANKS!

Thanks for chipping in! I really appreciate your work on RewardBench!

  • I agree with your division into the three main usecases. I think all three usecases can be covered by all three options I listed, so its a matter of elegance, simplicity and maintainability IMO.
  • I think List[dict] is better than List[str] to support system turns at the API level (is the first message a user or system message?) It also follows the current API patterns better.

@natolambert
Copy link

@zhuzilin I think the initial implementation is good at a quick pass. It covers the biggest things.
(mostly acknowledging that I did, but without using it I am unlikely to uncover weird corner cases)

@noamgat
Copy link
Contributor Author

noamgat commented Oct 2, 2024

@zhuzilin @youkaichao - which of the approaches sound best to you?

Note that I also added a fourth option, do nothing, and guide the clients to use tokenize(conversation) endpoint and later embeddings endpoint.

@zankner
Copy link

zankner commented Oct 13, 2024

Not sure how useful this is, but one thought is that reward models will eventually be generative. I did a work on this along with some others (https://arxiv.org/abs/2408.11791, https://arxiv.org/abs/2408.15240). Might be worthwhile to scope out doing both generation and scoring from a single interface.

@noamgai21
Copy link

Not sure how useful this is, but one thought is that reward models will eventually be generative. I did a work on this along with some others (https://arxiv.org/abs/2408.11791, https://arxiv.org/abs/2408.15240). Might be worthwhile to scope out doing both generation and scoring from a single interface.

Thanks for pitching in! From looking at the paper, that kind of model can be served with today's chat interface, as text generation + logprobs is enough (From what I see) to use the trained model. Am I wrong?

@zankner
Copy link

zankner commented Oct 15, 2024

Thats true for the second paper (https://arxiv.org/abs/2408.15240). For the first paper its actually a second linear head that gets called on the hidden state of the eos token generated by the reward model, so can't use logprobs sadly.

@arthrod
Copy link

arthrod commented Oct 21, 2024

Not sure how useful this is, but one thought is that reward models will eventually be generative. I did a work on this along with some others (https://arxiv.org/abs/2408.11791, https://arxiv.org/abs/2408.15240). Might be worthwhile to scope out doing both generation and scoring from a single interface.

nvidia/Llama-3.1-Nemotron-70B-Reward-HF's architecture is LlamaForCausalLM. I was able to use torch to deploy and the inference is working.

@DarkLight1337
Copy link
Member

To keep the semantics consistent (where Chat Completions is like Completions API with chat conversation), I prefer having a separate Chat Embeddings API (Embeddings API with chat conversation). So our endpoint map would be something like

  • v1/completions (Completions API)
  • v1/chat/completions (Chat Completions API)
  • v1/embeddings (Embeddings API)
  • v1/chat/embeddings (Chat Embeddings API) [new]

Since most of the logic related to chat conversation parsing is already in chat_utils.py, it should not take that much effort to add this.

We will add a Chat Embeddings API soon in order to support multi-modal embeddings in online inference. This will also provide support for embeddings from text-only conversations.

@Went-Liang
Copy link
Contributor

Not sure how useful this is, but one thought is that reward models will eventually be generative. I did a work on this along with some others (https://arxiv.org/abs/2408.11791, https://arxiv.org/abs/2408.15240). Might be worthwhile to scope out doing both generation and scoring from a single interface.

nvidia/Llama-3.1-Nemotron-70B-Reward-HF's architecture is LlamaForCausalLM. I was able to use torch to deploy and the inference is working.

@arthrod Excuse me, would it be convenient for you to share the script? I encounter an error when testing Llama-3.1-Nemotron-70B-Reward-HF with --task embedding.

@hrdxwandg
Copy link

Not sure how useful this is, but one thought is that reward models will eventually be generative. I did a work on this along with some others (https://arxiv.org/abs/2408.11791, https://arxiv.org/abs/2408.15240). Might be worthwhile to scope out doing both generation and scoring from a single interface.

nvidia/Llama-3.1-Nemotron-70B-Reward-HF's architecture is LlamaForCausalLM. I was able to use torch to deploy and the inference is working.

@arthrod Excuse me, would it be convenient for you to share the script? I encounter an error when testing Llama-3.1-Nemotron-70B-Reward-HF with --task embedding.

I also meet error: 500 Internal Server Error. TypeError: object of type 'NoneType' has no len()
do you solve it?

@DarkLight1337
Copy link
Member

Can you open a separate issue for this and describe it in more detail?

@hrdxwandg
Copy link

Can you open a separate issue for this and describe it in more detail?

thanks. #10444

@xs1997zju
Copy link

xs1997zju commented Dec 24, 2024

#10444

@arthrod how do u get the final reward score? I set task=embedding, and use llm.encode('xxx'), but got a output[0].outputs.data a tensor of size [8192]

@DarkLight1337
Copy link
Member

You should now use task=reward for reward models.

@xs1997zju
Copy link

You should now use task=reward for reward models.
@DarkLight1337
just got the same output, here is my infer scripts:
`from vllm import LLM, SamplingParams
from transformers import AutoModelForCausalLM, AutoTokenizer
model_path = 'nvidia/Llama-3.1-Nemotron-70B-Reward-HF'
model = LLM(model=model_path,
tensor_parallel_size=8,
dtype="float16",
enable_chunked_prefill=False,
task="embed",
)

tokenizer = AutoTokenizer.from_pretrained(model_path)
prompt = "What is 1+1?"
good_response = "1+1=2"
bad_response = "1+1=3"

messages = [
{'role': "user", "content": prompt},
{'role': "assistant", "content": good_response}
]

prompt = tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=False)

output = model.encode(prompt)

output[0].output.data.size(): [8192]`

but i want get the reward score, which should be: reward for good_response = -3.28125, according to the HF example
image

@xs1997zju
Copy link

You should now use task=reward for reward models.
@DarkLight1337
just got the same output, here is my infer scripts:
`from vllm import LLM, SamplingParams
from transformers import AutoModelForCausalLM, AutoTokenizer
model_path = 'nvidia/Llama-3.1-Nemotron-70B-Reward-HF'
model = LLM(model=model_path,
tensor_parallel_size=8,
dtype="float16",
enable_chunked_prefill=False,
task="embed",
)

tokenizer = AutoTokenizer.from_pretrained(model_path) prompt = "What is 1+1?" good_response = "1+1=2" bad_response = "1+1=3"

messages = [ {'role': "user", "content": prompt}, {'role': "assistant", "content": good_response} ]

prompt = tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=False)

output = model.encode(prompt)

output[0].output.data.size(): [8192]`

but i want get the reward score, which should be: reward for good_response = -3.28125, according to the HF example image

@zhuzilin Any Suggestion for this?

@DarkLight1337
Copy link
Member

Please see LLM.encode() example: https://docs.vllm.ai/en/latest/models/pooling_models.html#llm-encode

The output data should be equivalent to response_token_ids['scores'].

@xs1997zju
Copy link

Please see LLM.encode() example: https://docs.vllm.ai/en/latest/models/pooling_models.html#llm-encode

The output data should be equivalent to response_token_ids['scores'].

@DarkLight1337 Yes, I just follow the encode example, but got the output data as a 8192 size tensor, not a single score

@xs1997zju
Copy link

Does the vllm now can not support this reward model? nvidia/Llama-3.1-Nemotron-70B-Reward-HF

@DarkLight1337
Copy link
Member

DarkLight1337 commented Dec 24, 2024

As I said before, you should pass task="reward", not task="embed".

@xs1997zju
Copy link

xs1997zju commented Dec 24, 2024

As I said before, you should pass task="reward", not task="embed".
@DarkLight1337
`model = LLM(model=model_path,
tensor_parallel_size=8,
dtype="float16",
enable_chunked_prefill=False,
task="reward",
)

tokenizer = AutoTokenizer.from_pretrained(model_path)

prompt = "What is 1+1?"
good_response = "1+1=2"
bad_response = "1+1=3"

messages = [
{'role': "user", "content": prompt},
{'role': "assistant", "content": good_response}
]

prompt = tokenizer.apply_chat_template(messages, tokenize=True, add_generation_prompt=False)

output = model.encode(prompt)`

Yes, I have set the task as reward, but get the same output, vllm version is 0.6.5

@DarkLight1337
Copy link
Member

DarkLight1337 commented Dec 24, 2024

OK I think I see the problem now. You should also set --override-pooler-config '{"pooling_type": "ALL"}' in this case. I'll update vLLM to default to this if --task reward is provided.

@xs1997zju
Copy link

@DarkLight1337
model = LLM(model=model_path, tensor_parallel_size=8, task="reward", override_pooler_config=PoolerConfig(pooling_type='ALL') )
By setting like this, now the output data is a tensor of shape [len_of_input_token, hidden_dim], for example 8x8192

@DarkLight1337
Copy link
Member

I think this should be correct now, since in HF the extra dimension is the batch dimension (i.e. number of prompts).

@git-xp
Copy link

git-xp commented Mar 13, 2025

Hi @DarkLight1337 I think llm.encode() is just pulling the hidden states of the last layer even we set task="reward". Here is an example snippet that replicates the issue

from vllm import LLM

llm = LLM(model="Skywork/Skywork-Reward-Llama-3.1-8B-v0.2", task="reward")
(output,) = llm.encode("Hello, my name is")

data = output.outputs.data
print(f"Data: {data!r}")

this will give you

Processed prompts: 100%|██████████| 1/1 [00:00<00:00, 42.76it/s, est. speed input: 0.00 toks/s, output: 0.00 toks/s]
Data: tensor([[ 1.2734,  2.6875,  1.6484,  ..., -1.8984,  2.7500,  1.8984],
        [-0.3281, -1.9141, -0.1025,  ...,  0.6562, -4.1875, -0.8789],
        [ 0.8711, -3.2500, -0.4277,  ..., -2.2031, -2.4062,  1.2891],
        [ 2.0781, -4.8125, -1.1797,  ..., -0.0869, -2.7500,  2.9688],
        [-2.9375, -1.3828, -0.7852,  ...,  1.4062, -2.5469, -0.5547],
        [-1.5938, -1.1406, -3.9219,  ...,  0.9062, -1.1562, -0.6484]],
       dtype=torch.float32)

I also posted in #12791

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
Projects
None yet
Development

Successfully merging a pull request may close this issue.