Skip to content

feat: Support Sequence[float] as query_vector in FindNearest #908

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 12 commits into from
May 23, 2025
7 changes: 4 additions & 3 deletions google/cloud/firestore_v1/base_collection.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,9 +25,10 @@
Generator,
Generic,
Iterable,
Optional,
Sequence,
Tuple,
Union,
Optional,
)

from google.api_core import retry as retries
Expand Down Expand Up @@ -555,7 +556,7 @@ def avg(self, field_ref: str | FieldPath, alias=None):
def find_nearest(
self,
vector_field: str,
query_vector: Vector,
query_vector: Union[Vector, Sequence[float]],
limit: int,
distance_measure: DistanceMeasure,
*,
Expand All @@ -568,7 +569,7 @@ def find_nearest(
Args:
vector_field (str): An indexed vector field to search upon. Only documents which contain
vectors whose dimensionality match the query_vector can be returned.
query_vector (Vector): The query vector that we are searching on. Must be a vector of no more
query_vector(Union[Vector, Sequence[float]]): The query vector that we are searching on. Must be a vector of no more
than 2048 dimensions.
limit (int): The number of nearest neighbors to return. Must be a positive integer of no more than 1000.
distance_measure (:class:`DistanceMeasure`): The Distance Measure to use.
Expand Down
3 changes: 2 additions & 1 deletion google/cloud/firestore_v1/base_query.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@
Iterable,
List,
Optional,
Sequence,
Tuple,
Type,
Union,
Expand Down Expand Up @@ -1000,7 +1001,7 @@ def _to_protobuf(self) -> StructuredQuery:
def find_nearest(
self,
vector_field: str,
query_vector: Vector,
query_vector: Union[Vector, Sequence[float]],
limit: int,
distance_measure: DistanceMeasure,
*,
Expand Down
11 changes: 7 additions & 4 deletions google/cloud/firestore_v1/base_vector_query.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,21 +19,21 @@
import abc
from abc import ABC
from enum import Enum
from typing import TYPE_CHECKING, Any, Coroutine, Optional, Tuple, Union
from typing import TYPE_CHECKING, Any, Coroutine, Optional, Sequence, Tuple, Union

from google.api_core import gapic_v1
from google.api_core import retry as retries

from google.cloud.firestore_v1 import _helpers
from google.cloud.firestore_v1.types import query
from google.cloud.firestore_v1.vector import Vector

if TYPE_CHECKING: # pragma: NO COVER
from google.cloud.firestore_v1.async_stream_generator import AsyncStreamGenerator
from google.cloud.firestore_v1.base_document import DocumentSnapshot
from google.cloud.firestore_v1.query_profile import ExplainOptions
from google.cloud.firestore_v1.query_results import QueryResultsList
from google.cloud.firestore_v1.stream_generator import StreamGenerator
from google.cloud.firestore_v1.vector import Vector


class DistanceMeasure(Enum):
Expand Down Expand Up @@ -137,16 +137,19 @@ def get(
def find_nearest(
self,
vector_field: str,
query_vector: Vector,
query_vector: Union[Vector, Sequence[float]],
limit: int,
distance_measure: DistanceMeasure,
*,
distance_result_field: Optional[str] = None,
distance_threshold: Optional[float] = None,
):
"""Finds the closest vector embeddings to the given query vector."""
if not isinstance(query_vector, Vector):
self._query_vector = Vector(query_vector)
else:
self._query_vector = query_vector
self._vector_field = vector_field
self._query_vector = query_vector
self._limit = limit
self._distance_measure = distance_measure
self._distance_result_field = distance_result_field
Expand Down
16 changes: 13 additions & 3 deletions google/cloud/firestore_v1/query.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,17 @@
"""
from __future__ import annotations

from typing import TYPE_CHECKING, Any, Callable, Generator, List, Optional, Type
from typing import (
TYPE_CHECKING,
Any,
Callable,
Generator,
List,
Optional,
Sequence,
Type,
Union,
)

from google.api_core import exceptions, gapic_v1
from google.api_core import retry as retries
Expand Down Expand Up @@ -269,7 +279,7 @@ def _retry_query_after_exception(self, exc, retry, transaction):
def find_nearest(
self,
vector_field: str,
query_vector: Vector,
query_vector: Union[Vector, Sequence[float]],
limit: int,
distance_measure: DistanceMeasure,
*,
Expand All @@ -282,7 +292,7 @@ def find_nearest(
Args:
vector_field (str): An indexed vector field to search upon. Only documents which contain
vectors whose dimensionality match the query_vector can be returned.
query_vector (Vector): The query vector that we are searching on. Must be a vector of no more
query_vector(Vector | Sequence[float]): The query vector that we are searching on. Must be a vector of no more
than 2048 dimensions.
limit (int): The number of nearest neighbors to return. Must be a positive integer of no more than 1000.
distance_measure (:class:`DistanceMeasure`): The Distance Measure to use.
Expand Down
4 changes: 2 additions & 2 deletions tests/unit/v1/test_vector.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@
from google.cloud.firestore_v1.vector import Vector


def _make_commit_repsonse():
def _make_commit_response():
response = mock.create_autospec(firestore.CommitResponse)
response.write_results = [mock.sentinel.write_result]
response.commit_time = mock.sentinel.commit_time
Expand All @@ -35,7 +35,7 @@ def _make_commit_repsonse():
def _make_firestore_api():
firestore_api = mock.Mock()
firestore_api.commit.mock_add_spec(spec=["commit"])
firestore_api.commit.return_value = _make_commit_repsonse()
firestore_api.commit.return_value = _make_commit_response()
return firestore_api


Expand Down
62 changes: 62 additions & 0 deletions tests/unit/v1/test_vector_query.py
Original file line number Diff line number Diff line change
Expand Up @@ -533,6 +533,68 @@ def test_vector_query_collection_group(distance_measure, expected_distance):
)


def test_vector_query_list_as_query_vector():
# Create a minimal fake GAPIC.
firestore_api = mock.Mock(spec=["run_query"])
client = make_client()
client._firestore_api_internal = firestore_api

# Make a **real** collection reference as parent.
parent = client.collection("dee")
query = make_query(parent)
parent_path, expected_prefix = parent._parent_info()

data = {"snooze": 10, "embedding": Vector([1.0, 2.0, 3.0])}
response_pb1 = _make_query_response(
name="{}/test_doc".format(expected_prefix), data=data
)
response_pb2 = _make_query_response(
name="{}/test_doc".format(expected_prefix), data=data
)

kwargs = make_retry_timeout_kwargs(retry=None, timeout=None)

# Execute the vector query and check the response.
firestore_api.run_query.return_value = iter([response_pb1, response_pb2])

vector_query = query.where("snooze", "==", 10).find_nearest(
vector_field="embedding",
query_vector=[1.0, 2.0, 3.0],
distance_measure=DistanceMeasure.EUCLIDEAN,
limit=5,
)

returned = vector_query.get(transaction=_transaction(client), **kwargs)
assert isinstance(returned, list)
assert len(returned) == 2
assert returned[0].to_dict() == data

expected_pb = _expected_pb(
parent=parent,
vector_field="embedding",
vector=Vector([1.0, 2.0, 3.0]),
distance_type=StructuredQuery.FindNearest.DistanceMeasure.EUCLIDEAN,
limit=5,
)
expected_pb.where = StructuredQuery.Filter(
field_filter=StructuredQuery.FieldFilter(
field=StructuredQuery.FieldReference(field_path="snooze"),
op=StructuredQuery.FieldFilter.Operator.EQUAL,
value=encode_value(10),
)
)

firestore_api.run_query.assert_called_once_with(
request={
"parent": parent_path,
"structured_query": expected_pb,
"transaction": _TXN_ID,
},
metadata=client._rpc_metadata,
**kwargs,
)


def test_query_stream_multiple_empty_response_in_stream():
from google.cloud.firestore_v1 import stream_generator

Expand Down
Loading