Skip to content

[New Model]: jinaai/jina-embeddings-v3 #16120

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 9 commits into from
Apr 8, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
50 changes: 50 additions & 0 deletions examples/offline_inference/embed_jina_embeddings_v3.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,50 @@
# SPDX-License-Identifier: Apache-2.0

from argparse import Namespace

from vllm import LLM, EngineArgs
from vllm.utils import FlexibleArgumentParser


def main(args: Namespace):
# Sample prompts.
prompts = [
"Follow the white rabbit.", # English
"Sigue al conejo blanco.", # Spanish
"Suis le lapin blanc.", # French
"跟着白兔走。", # Chinese
"اتبع الأرنب الأبيض.", # Arabic
"Folge dem weißen Kaninchen.", # German
]

# Create an LLM.
# You should pass task="embed" for embedding models
model = LLM(**vars(args))

# Generate embedding. The output is a list of EmbeddingRequestOutputs.
# Only text matching task is supported for now. See #16120
outputs = model.embed(prompts)

# Print the outputs.
print("\nGenerated Outputs:")
print("Only text matching task is supported for now. See #16120")
print("-" * 60)
for prompt, output in zip(prompts, outputs):
embeds = output.outputs.embedding
embeds_trimmed = ((str(embeds[:16])[:-1] +
", ...]") if len(embeds) > 16 else embeds)
print(f"Prompt: {prompt!r} \n"
f"Embeddings for text matching: {embeds_trimmed} "
f"(size={len(embeds)})")
print("-" * 60)


if __name__ == "__main__":
parser = FlexibleArgumentParser()
parser = EngineArgs.add_cli_args(parser)
# Set example specific arguments
parser.set_defaults(model="jinaai/jina-embeddings-v3",
task="embed",
trust_remote_code=True)
args = parser.parse_args()
main(args)
5 changes: 3 additions & 2 deletions tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -671,8 +671,9 @@ def generate_encoder_decoder_greedy_logprobs_limit(
return [(output_ids, output_str, output_logprobs)
for output_ids, output_str, output_logprobs in outputs]

def encode(self, prompts: list[str]) -> list[list[torch.Tensor]]:
return self.model.encode(prompts)
def encode(self, prompts: list[str], *args,
**kwargs) -> list[list[torch.Tensor]]:
return self.model.encode(prompts, *args, **kwargs)

def predict(self, prompts: list[list[str]]) -> torch.Tensor:
return self.model.predict(prompts, convert_to_tensor=True)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2,13 +2,15 @@
# ruff: noqa: E501
"""Compare the scoring outputs of HF and vLLM models.

Run `pytest tests/models/embedding/language/test_jina_reranker_v2.py`.
Run `pytest tests/models/embedding/language/test_jina.py`.
"""
import math

import pytest

MODELS = [
from tests.models.embedding.utils import check_embeddings_close

SCORING_MODELS = [
"jinaai/jina-reranker-v2-base-multilingual", # Roberta
]

Expand All @@ -27,8 +29,21 @@
"新しいメイクのトレンドは鮮やかな色と革新的な技術に焦点を当てています",
]

EMBEDDING_MODELS = [
"jinaai/jina-embeddings-v3",
]

EMBEDDING_PROMPTS = [
"Follow the white rabbit.", # English
"Sigue al conejo blanco.", # Spanish
"Suis le lapin blanc.", # French
"跟着白兔走。", # Chinese
"اتبع الأرنب الأبيض.", # Arabic
"Folge dem weißen Kaninchen.", # German
]


@pytest.fixture(scope="module", params=MODELS)
@pytest.fixture(scope="module", params=SCORING_MODELS)
def model_name(request):
yield request.param

Expand Down Expand Up @@ -68,3 +83,46 @@ def test_llm_1_to_N(vllm_runner, hf_runner, model_name, dtype: str):

assert math.isclose(hf_outputs[0], vllm_outputs[0], rel_tol=0.01)
assert math.isclose(hf_outputs[1], vllm_outputs[1], rel_tol=0.01)


@pytest.fixture(scope="module", params=EMBEDDING_MODELS)
def emb_model_name(request):
yield request.param


def test_is_matryoshka(vllm_runner, emb_model_name):
with vllm_runner(emb_model_name, task="embed",
max_model_len=None) as vllm_model:
assert vllm_model.model.llm_engine.model_config.is_matryoshka


@pytest.mark.parametrize("model", EMBEDDING_MODELS)
@pytest.mark.parametrize("dtype", ["half"])
def test_embeddings(
hf_runner,
vllm_runner,
model,
dtype: str,
monkeypatch,
) -> None:

example_prompts = EMBEDDING_PROMPTS

with hf_runner(
model,
dtype=dtype,
is_sentence_transformer=True,
) as hf_model:
hf_outputs = hf_model.encode(example_prompts, task="text-matching")

with vllm_runner(model, task="embed", dtype=dtype,
max_model_len=None) as vllm_model:
vllm_outputs = vllm_model.encode(example_prompts)

check_embeddings_close(
embeddings_0_lst=hf_outputs,
embeddings_1_lst=vllm_outputs,
name_0="hf",
name_1="vllm",
tol=1e-2,
)
5 changes: 5 additions & 0 deletions vllm/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -1130,6 +1130,11 @@ def is_v1_compatible(self) -> bool:
architectures = getattr(self.hf_config, "architectures", [])
return ModelRegistry.is_v1_compatible(architectures)

@property
def is_matryoshka(self) -> bool:
return (hasattr(self.hf_config, "matryoshka_dimensions")
or getattr(self.hf_config, "is_matryoshka", False))


class CacheConfig:
"""Configuration for the KV cache.
Expand Down
66 changes: 49 additions & 17 deletions vllm/model_executor/models/bert.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
from vllm.model_executor.layers.pooler import (CrossEncodingPooler, Pooler,
PoolingType)
from vllm.model_executor.layers.quantization import QuantizationConfig
from vllm.model_executor.layers.rotary_embedding import get_rope
from vllm.model_executor.layers.vocab_parallel_embedding import (
VocabParallelEmbedding)
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
Expand All @@ -38,19 +39,24 @@ def __init__(self, config: BertConfig):
self.size = config.hidden_size
self.word_embeddings = VocabParallelEmbedding(config.vocab_size,
config.hidden_size)
self.position_embeddings = VocabParallelEmbedding(
config.max_position_embeddings, config.hidden_size)

self.token_type_embeddings = VocabParallelEmbedding(
config.type_vocab_size, config.hidden_size)
self.LayerNorm = nn.LayerNorm(config.hidden_size,
eps=config.layer_norm_eps)
self.position_ids = nn.Parameter(
torch.empty((1, config.max_position_embeddings)), )

self.position_embedding_type = config.position_embedding_type
if self.position_embedding_type != "absolute":
raise ValueError("Only 'absolute' position_embedding_type" +
" is supported")
if self.position_embedding_type == "absolute":
self.position_embeddings = VocabParallelEmbedding(
config.max_position_embeddings, config.hidden_size)
self.position_ids = nn.Parameter(
torch.empty((1, config.max_position_embeddings)), )
elif self.position_embedding_type == "rotary":
self.position_embeddings = None
self.position_ids = None
else:
raise ValueError("Only 'absolute' and 'rotary' " +
"position_embedding_type is supported")

def forward(
self,
Expand All @@ -64,17 +70,19 @@ def forward(
# Input embeddings.
inputs_embeds = self.word_embeddings(input_ids)

# Position embeddings.
position_embeddings = self.position_embeddings(position_ids)

if token_type_ids is None:
token_type_ids = torch.zeros(input_shape,
dtype=torch.long,
device=inputs_embeds.device)

token_type_embeddings = self.token_type_embeddings(token_type_ids)

embeddings = inputs_embeds + token_type_embeddings + position_embeddings
embeddings = inputs_embeds + token_type_embeddings

if self.position_embedding_type == "absolute":
position_embeddings = self.position_embeddings(position_ids)
embeddings += position_embeddings

embeddings = self.LayerNorm(embeddings)
return embeddings

Expand All @@ -98,7 +106,10 @@ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
@support_torch_compile
class BertEncoder(nn.Module):

def __init__(self, vllm_config: VllmConfig, prefix: str = ""):
def __init__(self,
vllm_config: VllmConfig,
rotary_kwargs: Optional[dict] = None,
prefix: str = ""):
super().__init__()
config = vllm_config.model_config.hf_config
cache_config = vllm_config.cache_config
Expand All @@ -107,16 +118,18 @@ def __init__(self, vllm_config: VllmConfig, prefix: str = ""):
BertLayer(config=config,
cache_config=cache_config,
quant_config=quant_config,
rotary_kwargs=rotary_kwargs,
prefix=f"{prefix}.layer.{layer_idx}")
for layer_idx in range(config.num_hidden_layers)
])

def forward(
self,
positions: torch.Tensor,
hidden_states: torch.Tensor,
) -> torch.Tensor:
for layer in self.layer:
hidden_states = layer(hidden_states)
hidden_states = layer(positions, hidden_states)
return hidden_states


Expand All @@ -126,6 +139,7 @@ def __init__(self,
config: BertConfig,
cache_config: Optional[CacheConfig] = None,
quant_config: Optional[QuantizationConfig] = None,
rotary_kwargs: Optional[dict] = None,
prefix: str = ""):
super().__init__()

Expand All @@ -135,6 +149,7 @@ def __init__(self,
layer_norm_eps=config.layer_norm_eps,
cache_config=cache_config,
quant_config=quant_config,
rotary_kwargs=rotary_kwargs,
prefix=f"{prefix}.attention")

self.intermediate = BertIntermediate(
Expand All @@ -150,8 +165,8 @@ def __init__(self,
quant_config=quant_config,
prefix=f"{prefix}.output")

def forward(self, hidden_states: torch.Tensor):
attn_output = self.attention(hidden_states)
def forward(self, positions: torch.Tensor, hidden_states: torch.Tensor):
attn_output = self.attention(positions, hidden_states)
intermediate_output = self.intermediate(attn_output)
output = self.output(intermediate_output, attn_output)
return output
Expand All @@ -166,6 +181,7 @@ def __init__(
layer_norm_eps: float,
cache_config: Optional[CacheConfig] = None,
quant_config: Optional[QuantizationConfig] = None,
rotary_kwargs: Optional[dict] = None,
prefix: str = "",
):
super().__init__()
Expand All @@ -174,6 +190,7 @@ def __init__(
num_attention_heads=num_attention_heads,
cache_config=cache_config,
quant_config=quant_config,
rotary_kwargs=rotary_kwargs,
prefix=f"{prefix}.output")

self.output = BertSelfOutput(hidden_size=hidden_size,
Expand All @@ -183,9 +200,10 @@ def __init__(

def forward(
self,
positions: torch.Tensor,
hidden_states: torch.Tensor,
) -> torch.Tensor:
self_output = self.self(hidden_states)
self_output = self.self(positions, hidden_states)
return self.output(self_output, hidden_states)


Expand All @@ -197,6 +215,7 @@ def __init__(
num_attention_heads: int,
cache_config: Optional[CacheConfig] = None,
quant_config: Optional[QuantizationConfig] = None,
rotary_kwargs: Optional[dict] = None,
prefix: str = "",
):
super().__init__()
Expand Down Expand Up @@ -225,6 +244,11 @@ def __init__(
quant_config=quant_config,
prefix=f"{prefix}.qkv_proj")

if rotary_kwargs:
self.rotary_emb = get_rope(**rotary_kwargs)
else:
self.rotary_emb = None

self.attn = Attention(num_heads=self.num_heads,
head_size=self.head_dim,
scale=self.scaling,
Expand All @@ -236,10 +260,15 @@ def __init__(

def forward(
self,
positions: torch.Tensor,
hidden_states: torch.Tensor,
) -> torch.Tensor:
qkv, _ = self.qkv_proj(hidden_states)
q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1)

if self.rotary_emb:
q, k = self.rotary_emb(positions, q, k)

output = self.attn(q, k, v)
return output

Expand Down Expand Up @@ -321,11 +350,13 @@ def __init__(self,
vllm_config: VllmConfig,
prefix: str = "",
embedding_class: type = BertEmbedding,
rotary_kwargs: Optional[dict] = None,
add_pooling_layer: bool = False):
super().__init__()
config = vllm_config.model_config.hf_config
self.embeddings = embedding_class(config)
self.encoder = BertEncoder(vllm_config=vllm_config,
rotary_kwargs=rotary_kwargs,
prefix=f"{prefix}.encoder")
self.pooler = BertPooler(config) if add_pooling_layer else None

Expand All @@ -347,7 +378,7 @@ def forward(
seq_lens=attn_metadata.seq_lens_tensor,
position_ids=position_ids,
token_type_ids=token_type_ids)
return self.encoder(hidden_states)
return self.encoder(position_ids, hidden_states)

def load_weights(self, weights: Iterable[Tuple[str,
torch.Tensor]]) -> Set[str]:
Expand Down Expand Up @@ -401,6 +432,7 @@ class BertEmbeddingModel(nn.Module, SupportsV0Only, SupportsQuant):
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
super().__init__()
pooler_config = vllm_config.model_config.pooler_config
self.config = vllm_config.model_config.hf_config
self.model = self._build_model(vllm_config=vllm_config,
prefix=maybe_prefix(prefix, "model"))
self._pooler = self._build_pooler(pooler_config)
Expand Down
Loading