Skip to content

Commit de8305a

Browse files
noooopnishith-fujitsu
authored andcommitted
[New Model]: jinaai/jina-embeddings-v3 (vllm-project#16120)
1 parent cca9ca7 commit de8305a

File tree

6 files changed

+297
-86
lines changed

6 files changed

+297
-86
lines changed
Lines changed: 50 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,50 @@
1+
# SPDX-License-Identifier: Apache-2.0
2+
3+
from argparse import Namespace
4+
5+
from vllm import LLM, EngineArgs
6+
from vllm.utils import FlexibleArgumentParser
7+
8+
9+
def main(args: Namespace):
10+
# Sample prompts.
11+
prompts = [
12+
"Follow the white rabbit.", # English
13+
"Sigue al conejo blanco.", # Spanish
14+
"Suis le lapin blanc.", # French
15+
"跟着白兔走。", # Chinese
16+
"اتبع الأرنب الأبيض.", # Arabic
17+
"Folge dem weißen Kaninchen.", # German
18+
]
19+
20+
# Create an LLM.
21+
# You should pass task="embed" for embedding models
22+
model = LLM(**vars(args))
23+
24+
# Generate embedding. The output is a list of EmbeddingRequestOutputs.
25+
# Only text matching task is supported for now. See #16120
26+
outputs = model.embed(prompts)
27+
28+
# Print the outputs.
29+
print("\nGenerated Outputs:")
30+
print("Only text matching task is supported for now. See #16120")
31+
print("-" * 60)
32+
for prompt, output in zip(prompts, outputs):
33+
embeds = output.outputs.embedding
34+
embeds_trimmed = ((str(embeds[:16])[:-1] +
35+
", ...]") if len(embeds) > 16 else embeds)
36+
print(f"Prompt: {prompt!r} \n"
37+
f"Embeddings for text matching: {embeds_trimmed} "
38+
f"(size={len(embeds)})")
39+
print("-" * 60)
40+
41+
42+
if __name__ == "__main__":
43+
parser = FlexibleArgumentParser()
44+
parser = EngineArgs.add_cli_args(parser)
45+
# Set example specific arguments
46+
parser.set_defaults(model="jinaai/jina-embeddings-v3",
47+
task="embed",
48+
trust_remote_code=True)
49+
args = parser.parse_args()
50+
main(args)

tests/conftest.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -671,8 +671,9 @@ def generate_encoder_decoder_greedy_logprobs_limit(
671671
return [(output_ids, output_str, output_logprobs)
672672
for output_ids, output_str, output_logprobs in outputs]
673673

674-
def encode(self, prompts: list[str]) -> list[list[torch.Tensor]]:
675-
return self.model.encode(prompts)
674+
def encode(self, prompts: list[str], *args,
675+
**kwargs) -> list[list[torch.Tensor]]:
676+
return self.model.encode(prompts, *args, **kwargs)
676677

677678
def predict(self, prompts: list[list[str]]) -> torch.Tensor:
678679
return self.model.predict(prompts, convert_to_tensor=True)

tests/models/embedding/language/test_jina_reranker_v2.py renamed to tests/models/embedding/language/test_jina.py

Lines changed: 61 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -2,13 +2,15 @@
22
# ruff: noqa: E501
33
"""Compare the scoring outputs of HF and vLLM models.
44
5-
Run `pytest tests/models/embedding/language/test_jina_reranker_v2.py`.
5+
Run `pytest tests/models/embedding/language/test_jina.py`.
66
"""
77
import math
88

99
import pytest
1010

11-
MODELS = [
11+
from tests.models.embedding.utils import check_embeddings_close
12+
13+
SCORING_MODELS = [
1214
"jinaai/jina-reranker-v2-base-multilingual", # Roberta
1315
]
1416

@@ -27,8 +29,21 @@
2729
"新しいメイクのトレンドは鮮やかな色と革新的な技術に焦点を当てています",
2830
]
2931

32+
EMBEDDING_MODELS = [
33+
"jinaai/jina-embeddings-v3",
34+
]
35+
36+
EMBEDDING_PROMPTS = [
37+
"Follow the white rabbit.", # English
38+
"Sigue al conejo blanco.", # Spanish
39+
"Suis le lapin blanc.", # French
40+
"跟着白兔走。", # Chinese
41+
"اتبع الأرنب الأبيض.", # Arabic
42+
"Folge dem weißen Kaninchen.", # German
43+
]
44+
3045

31-
@pytest.fixture(scope="module", params=MODELS)
46+
@pytest.fixture(scope="module", params=SCORING_MODELS)
3247
def model_name(request):
3348
yield request.param
3449

@@ -68,3 +83,46 @@ def test_llm_1_to_N(vllm_runner, hf_runner, model_name, dtype: str):
6883

6984
assert math.isclose(hf_outputs[0], vllm_outputs[0], rel_tol=0.01)
7085
assert math.isclose(hf_outputs[1], vllm_outputs[1], rel_tol=0.01)
86+
87+
88+
@pytest.fixture(scope="module", params=EMBEDDING_MODELS)
89+
def emb_model_name(request):
90+
yield request.param
91+
92+
93+
def test_is_matryoshka(vllm_runner, emb_model_name):
94+
with vllm_runner(emb_model_name, task="embed",
95+
max_model_len=None) as vllm_model:
96+
assert vllm_model.model.llm_engine.model_config.is_matryoshka
97+
98+
99+
@pytest.mark.parametrize("model", EMBEDDING_MODELS)
100+
@pytest.mark.parametrize("dtype", ["half"])
101+
def test_embeddings(
102+
hf_runner,
103+
vllm_runner,
104+
model,
105+
dtype: str,
106+
monkeypatch,
107+
) -> None:
108+
109+
example_prompts = EMBEDDING_PROMPTS
110+
111+
with hf_runner(
112+
model,
113+
dtype=dtype,
114+
is_sentence_transformer=True,
115+
) as hf_model:
116+
hf_outputs = hf_model.encode(example_prompts, task="text-matching")
117+
118+
with vllm_runner(model, task="embed", dtype=dtype,
119+
max_model_len=None) as vllm_model:
120+
vllm_outputs = vllm_model.encode(example_prompts)
121+
122+
check_embeddings_close(
123+
embeddings_0_lst=hf_outputs,
124+
embeddings_1_lst=vllm_outputs,
125+
name_0="hf",
126+
name_1="vllm",
127+
tol=1e-2,
128+
)

vllm/config.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1130,6 +1130,11 @@ def is_v1_compatible(self) -> bool:
11301130
architectures = getattr(self.hf_config, "architectures", [])
11311131
return ModelRegistry.is_v1_compatible(architectures)
11321132

1133+
@property
1134+
def is_matryoshka(self) -> bool:
1135+
return (hasattr(self.hf_config, "matryoshka_dimensions")
1136+
or getattr(self.hf_config, "is_matryoshka", False))
1137+
11331138

11341139
class CacheConfig:
11351140
"""Configuration for the KV cache.

vllm/model_executor/models/bert.py

Lines changed: 49 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@
1818
from vllm.model_executor.layers.pooler import (CrossEncodingPooler, Pooler,
1919
PoolingType)
2020
from vllm.model_executor.layers.quantization import QuantizationConfig
21+
from vllm.model_executor.layers.rotary_embedding import get_rope
2122
from vllm.model_executor.layers.vocab_parallel_embedding import (
2223
VocabParallelEmbedding)
2324
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
@@ -38,19 +39,24 @@ def __init__(self, config: BertConfig):
3839
self.size = config.hidden_size
3940
self.word_embeddings = VocabParallelEmbedding(config.vocab_size,
4041
config.hidden_size)
41-
self.position_embeddings = VocabParallelEmbedding(
42-
config.max_position_embeddings, config.hidden_size)
42+
4343
self.token_type_embeddings = VocabParallelEmbedding(
4444
config.type_vocab_size, config.hidden_size)
4545
self.LayerNorm = nn.LayerNorm(config.hidden_size,
4646
eps=config.layer_norm_eps)
47-
self.position_ids = nn.Parameter(
48-
torch.empty((1, config.max_position_embeddings)), )
4947

5048
self.position_embedding_type = config.position_embedding_type
51-
if self.position_embedding_type != "absolute":
52-
raise ValueError("Only 'absolute' position_embedding_type" +
53-
" is supported")
49+
if self.position_embedding_type == "absolute":
50+
self.position_embeddings = VocabParallelEmbedding(
51+
config.max_position_embeddings, config.hidden_size)
52+
self.position_ids = nn.Parameter(
53+
torch.empty((1, config.max_position_embeddings)), )
54+
elif self.position_embedding_type == "rotary":
55+
self.position_embeddings = None
56+
self.position_ids = None
57+
else:
58+
raise ValueError("Only 'absolute' and 'rotary' " +
59+
"position_embedding_type is supported")
5460

5561
def forward(
5662
self,
@@ -64,17 +70,19 @@ def forward(
6470
# Input embeddings.
6571
inputs_embeds = self.word_embeddings(input_ids)
6672

67-
# Position embeddings.
68-
position_embeddings = self.position_embeddings(position_ids)
69-
7073
if token_type_ids is None:
7174
token_type_ids = torch.zeros(input_shape,
7275
dtype=torch.long,
7376
device=inputs_embeds.device)
7477

7578
token_type_embeddings = self.token_type_embeddings(token_type_ids)
7679

77-
embeddings = inputs_embeds + token_type_embeddings + position_embeddings
80+
embeddings = inputs_embeds + token_type_embeddings
81+
82+
if self.position_embedding_type == "absolute":
83+
position_embeddings = self.position_embeddings(position_ids)
84+
embeddings += position_embeddings
85+
7886
embeddings = self.LayerNorm(embeddings)
7987
return embeddings
8088

@@ -98,7 +106,10 @@ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
98106
@support_torch_compile
99107
class BertEncoder(nn.Module):
100108

101-
def __init__(self, vllm_config: VllmConfig, prefix: str = ""):
109+
def __init__(self,
110+
vllm_config: VllmConfig,
111+
rotary_kwargs: Optional[dict] = None,
112+
prefix: str = ""):
102113
super().__init__()
103114
config = vllm_config.model_config.hf_config
104115
cache_config = vllm_config.cache_config
@@ -107,16 +118,18 @@ def __init__(self, vllm_config: VllmConfig, prefix: str = ""):
107118
BertLayer(config=config,
108119
cache_config=cache_config,
109120
quant_config=quant_config,
121+
rotary_kwargs=rotary_kwargs,
110122
prefix=f"{prefix}.layer.{layer_idx}")
111123
for layer_idx in range(config.num_hidden_layers)
112124
])
113125

114126
def forward(
115127
self,
128+
positions: torch.Tensor,
116129
hidden_states: torch.Tensor,
117130
) -> torch.Tensor:
118131
for layer in self.layer:
119-
hidden_states = layer(hidden_states)
132+
hidden_states = layer(positions, hidden_states)
120133
return hidden_states
121134

122135

@@ -126,6 +139,7 @@ def __init__(self,
126139
config: BertConfig,
127140
cache_config: Optional[CacheConfig] = None,
128141
quant_config: Optional[QuantizationConfig] = None,
142+
rotary_kwargs: Optional[dict] = None,
129143
prefix: str = ""):
130144
super().__init__()
131145

@@ -135,6 +149,7 @@ def __init__(self,
135149
layer_norm_eps=config.layer_norm_eps,
136150
cache_config=cache_config,
137151
quant_config=quant_config,
152+
rotary_kwargs=rotary_kwargs,
138153
prefix=f"{prefix}.attention")
139154

140155
self.intermediate = BertIntermediate(
@@ -150,8 +165,8 @@ def __init__(self,
150165
quant_config=quant_config,
151166
prefix=f"{prefix}.output")
152167

153-
def forward(self, hidden_states: torch.Tensor):
154-
attn_output = self.attention(hidden_states)
168+
def forward(self, positions: torch.Tensor, hidden_states: torch.Tensor):
169+
attn_output = self.attention(positions, hidden_states)
155170
intermediate_output = self.intermediate(attn_output)
156171
output = self.output(intermediate_output, attn_output)
157172
return output
@@ -166,6 +181,7 @@ def __init__(
166181
layer_norm_eps: float,
167182
cache_config: Optional[CacheConfig] = None,
168183
quant_config: Optional[QuantizationConfig] = None,
184+
rotary_kwargs: Optional[dict] = None,
169185
prefix: str = "",
170186
):
171187
super().__init__()
@@ -174,6 +190,7 @@ def __init__(
174190
num_attention_heads=num_attention_heads,
175191
cache_config=cache_config,
176192
quant_config=quant_config,
193+
rotary_kwargs=rotary_kwargs,
177194
prefix=f"{prefix}.output")
178195

179196
self.output = BertSelfOutput(hidden_size=hidden_size,
@@ -183,9 +200,10 @@ def __init__(
183200

184201
def forward(
185202
self,
203+
positions: torch.Tensor,
186204
hidden_states: torch.Tensor,
187205
) -> torch.Tensor:
188-
self_output = self.self(hidden_states)
206+
self_output = self.self(positions, hidden_states)
189207
return self.output(self_output, hidden_states)
190208

191209

@@ -197,6 +215,7 @@ def __init__(
197215
num_attention_heads: int,
198216
cache_config: Optional[CacheConfig] = None,
199217
quant_config: Optional[QuantizationConfig] = None,
218+
rotary_kwargs: Optional[dict] = None,
200219
prefix: str = "",
201220
):
202221
super().__init__()
@@ -225,6 +244,11 @@ def __init__(
225244
quant_config=quant_config,
226245
prefix=f"{prefix}.qkv_proj")
227246

247+
if rotary_kwargs:
248+
self.rotary_emb = get_rope(**rotary_kwargs)
249+
else:
250+
self.rotary_emb = None
251+
228252
self.attn = Attention(num_heads=self.num_heads,
229253
head_size=self.head_dim,
230254
scale=self.scaling,
@@ -236,10 +260,15 @@ def __init__(
236260

237261
def forward(
238262
self,
263+
positions: torch.Tensor,
239264
hidden_states: torch.Tensor,
240265
) -> torch.Tensor:
241266
qkv, _ = self.qkv_proj(hidden_states)
242267
q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1)
268+
269+
if self.rotary_emb:
270+
q, k = self.rotary_emb(positions, q, k)
271+
243272
output = self.attn(q, k, v)
244273
return output
245274

@@ -321,11 +350,13 @@ def __init__(self,
321350
vllm_config: VllmConfig,
322351
prefix: str = "",
323352
embedding_class: type = BertEmbedding,
353+
rotary_kwargs: Optional[dict] = None,
324354
add_pooling_layer: bool = False):
325355
super().__init__()
326356
config = vllm_config.model_config.hf_config
327357
self.embeddings = embedding_class(config)
328358
self.encoder = BertEncoder(vllm_config=vllm_config,
359+
rotary_kwargs=rotary_kwargs,
329360
prefix=f"{prefix}.encoder")
330361
self.pooler = BertPooler(config) if add_pooling_layer else None
331362

@@ -347,7 +378,7 @@ def forward(
347378
seq_lens=attn_metadata.seq_lens_tensor,
348379
position_ids=position_ids,
349380
token_type_ids=token_type_ids)
350-
return self.encoder(hidden_states)
381+
return self.encoder(position_ids, hidden_states)
351382

352383
def load_weights(self, weights: Iterable[Tuple[str,
353384
torch.Tensor]]) -> Set[str]:
@@ -401,6 +432,7 @@ class BertEmbeddingModel(nn.Module, SupportsV0Only, SupportsQuant):
401432
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
402433
super().__init__()
403434
pooler_config = vllm_config.model_config.pooler_config
435+
self.config = vllm_config.model_config.hf_config
404436
self.model = self._build_model(vllm_config=vllm_config,
405437
prefix=maybe_prefix(prefix, "model"))
406438
self._pooler = self._build_pooler(pooler_config)

0 commit comments

Comments
 (0)