Skip to content

Commit 8360979

Browse files
authored
[Model] Add Qwen2 PRM model support (#12202)
Signed-off-by: Isotr0py <[email protected]>
1 parent 0974c9b commit 8360979

File tree

5 files changed

+45
-13
lines changed

5 files changed

+45
-13
lines changed

docs/source/models/supported_models.md

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -470,6 +470,11 @@ of the whole prompt are extracted from the normalized hidden state corresponding
470470
- `Qwen/Qwen2.5-Math-RM-72B`, etc.
471471
- ✅︎
472472
- ✅︎
473+
* - `Qwen2ForProcessRewardModel`
474+
- Qwen2-based
475+
- `Qwen/Qwen2.5-Math-PRM-7B`, `Qwen/Qwen2.5-Math-PRM-72B`, etc.
476+
- ✅︎
477+
- ✅︎
473478
```
474479

475480
If your model is not in the above list, we will try to automatically convert the model using

tests/models/embedding/language/test_embedding.py

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -17,14 +17,15 @@
1717
marks=[pytest.mark.core_model, pytest.mark.cpu_model]),
1818
pytest.param("sentence-transformers/all-MiniLM-L12-v2"),
1919
pytest.param("intfloat/multilingual-e5-large"),
20-
# [Encoder-decoder]
21-
pytest.param("intfloat/e5-mistral-7b-instruct",
22-
marks=[pytest.mark.core_model, pytest.mark.cpu_model]),
20+
# [Decoder-only]
2321
pytest.param("BAAI/bge-multilingual-gemma2",
2422
marks=[pytest.mark.core_model]),
25-
pytest.param("ssmits/Qwen2-7B-Instruct-embed-base"),
23+
pytest.param("intfloat/e5-mistral-7b-instruct",
24+
marks=[pytest.mark.core_model, pytest.mark.cpu_model]),
2625
pytest.param("Alibaba-NLP/gte-Qwen2-1.5B-instruct"),
2726
pytest.param("Alibaba-NLP/gte-Qwen2-7B-instruct"),
27+
pytest.param("ssmits/Qwen2-7B-Instruct-embed-base"),
28+
# [Encoder-decoder]
2829
pytest.param("sentence-transformers/stsb-roberta-base-v2"),
2930
],
3031
)

tests/models/registry.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -155,6 +155,7 @@ class _HfExamplesInfo:
155155
"MistralModel": _HfExamplesInfo("intfloat/e5-mistral-7b-instruct"),
156156
"Qwen2Model": _HfExamplesInfo("ssmits/Qwen2-7B-Instruct-embed-base"),
157157
"Qwen2ForRewardModel": _HfExamplesInfo("Qwen/Qwen2.5-Math-RM-72B"),
158+
"Qwen2ForProcessRewardModel": _HfExamplesInfo("Qwen/Qwen2.5-Math-PRM-7B"),
158159
"Qwen2ForSequenceClassification": _HfExamplesInfo("jason9693/Qwen2.5-1.5B-apeach"), # noqa: E501
159160
"RobertaModel": _HfExamplesInfo("sentence-transformers/stsb-roberta-base-v2"), # noqa: E501
160161
"RobertaForMaskedLM": _HfExamplesInfo("sentence-transformers/all-roberta-large-v1"), # noqa: E501

vllm/model_executor/models/qwen2_rm.py

Lines changed: 33 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@
1212
from vllm.config import VllmConfig
1313
from vllm.model_executor.layers.linear import (ColumnParallelLinear,
1414
RowParallelLinear)
15-
from vllm.model_executor.layers.pooler import Pooler, PoolingType
15+
from vllm.model_executor.layers.pooler import Pooler, PoolingType, SimplePooler
1616
from vllm.model_executor.pooling_metadata import PoolingMetadata
1717
from vllm.sequence import IntermediateTensors, PoolerOutput
1818

@@ -32,7 +32,7 @@ def forward(self, input):
3232
return self.activation(input)
3333

3434

35-
class Qwen2ForRewardModel(nn.Module, SupportsLoRA, SupportsPP):
35+
class Qwen2RewardBaseModel(nn.Module, SupportsLoRA, SupportsPP):
3636
packed_modules_mapping = {
3737
"qkv_proj": [
3838
"q_proj",
@@ -60,7 +60,6 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
6060
config = vllm_config.model_config.hf_config
6161
quant_config = vllm_config.quant_config
6262
lora_config = vllm_config.lora_config
63-
pooler_config = vllm_config.model_config.pooler_config
6463

6564
self.config = config
6665
self.lora_config = lora_config
@@ -74,14 +73,11 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
7473
config.hidden_size,
7574
quant_config=quant_config),
7675
ReLU(),
77-
RowParallelLinear(config.hidden_size, 1,
76+
RowParallelLinear(config.hidden_size,
77+
config.num_labels,
7878
quant_config=quant_config),
7979
)
80-
self._pooler = Pooler.from_config_with_defaults(
81-
pooler_config,
82-
pooling_type=PoolingType.ALL,
83-
normalize=False,
84-
softmax=False)
80+
self._pooler: SimplePooler
8581
self.make_empty_intermediate_tensors = (
8682
self.model.make_empty_intermediate_tensors)
8783

@@ -115,3 +111,31 @@ def load_weights(self, weights: Iterable[Tuple[str,
115111
loader = AutoWeightsLoader(self,
116112
ignore_unexpected_prefixes=["lm_head."])
117113
return loader.load_weights(weights)
114+
115+
116+
class Qwen2ForRewardModel(Qwen2RewardBaseModel):
117+
118+
def __init__(self, *, vllm_config, prefix=""):
119+
vllm_config.model_config.hf_config.num_labels = 1
120+
super().__init__(vllm_config=vllm_config, prefix=prefix)
121+
pooler_config = vllm_config.model_config.pooler_config
122+
self._pooler = Pooler.from_config_with_defaults(
123+
pooler_config,
124+
pooling_type=PoolingType.ALL,
125+
normalize=False,
126+
softmax=False)
127+
128+
129+
class Qwen2ForProcessRewardModel(Qwen2RewardBaseModel):
130+
131+
def __init__(self, *, vllm_config, prefix=""):
132+
vllm_config.model_config.hf_config.num_labels = 2
133+
super().__init__(vllm_config=vllm_config, prefix=prefix)
134+
pooler_config = vllm_config.model_config.pooler_config
135+
self._pooler = Pooler.from_config_with_defaults(
136+
pooler_config,
137+
pooling_type=PoolingType.STEP,
138+
normalize=False,
139+
softmax=True,
140+
step_tag_id=151651,
141+
)

vllm/model_executor/models/registry.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -127,6 +127,7 @@
127127
"Qwen2Model": ("qwen2", "Qwen2EmbeddingModel"),
128128
"Qwen2ForCausalLM": ("qwen2", "Qwen2ForCausalLM"),
129129
"Qwen2ForRewardModel": ("qwen2_rm", "Qwen2ForRewardModel"),
130+
"Qwen2ForProcessRewardModel": ("qwen2_rm", "Qwen2ForProcessRewardModel"),
130131
"TeleChat2ForCausalLM": ("telechat2", "TeleChat2ForCausalLM"),
131132
# [Multimodal]
132133
"LlavaNextForConditionalGeneration": ("llava_next", "LlavaNextForConditionalGeneration"), # noqa: E501

0 commit comments

Comments
 (0)