diff --git a/google/cloud/firestore_v1/base_collection.py b/google/cloud/firestore_v1/base_collection.py index b113da827..0e5ae6ed1 100644 --- a/google/cloud/firestore_v1/base_collection.py +++ b/google/cloud/firestore_v1/base_collection.py @@ -25,9 +25,10 @@ Generator, Generic, Iterable, - Optional, + Sequence, Tuple, Union, + Optional, ) from google.api_core import retry as retries @@ -555,7 +556,7 @@ def avg(self, field_ref: str | FieldPath, alias=None): def find_nearest( self, vector_field: str, - query_vector: Vector, + query_vector: Union[Vector, Sequence[float]], limit: int, distance_measure: DistanceMeasure, *, @@ -568,7 +569,7 @@ def find_nearest( Args: vector_field (str): An indexed vector field to search upon. Only documents which contain vectors whose dimensionality match the query_vector can be returned. - query_vector (Vector): The query vector that we are searching on. Must be a vector of no more + query_vector(Union[Vector, Sequence[float]]): The query vector that we are searching on. Must be a vector of no more than 2048 dimensions. limit (int): The number of nearest neighbors to return. Must be a positive integer of no more than 1000. distance_measure (:class:`DistanceMeasure`): The Distance Measure to use. diff --git a/google/cloud/firestore_v1/base_query.py b/google/cloud/firestore_v1/base_query.py index 5a9efaf78..2fb9bd895 100644 --- a/google/cloud/firestore_v1/base_query.py +++ b/google/cloud/firestore_v1/base_query.py @@ -32,6 +32,7 @@ Iterable, List, Optional, + Sequence, Tuple, Type, Union, @@ -1000,7 +1001,7 @@ def _to_protobuf(self) -> StructuredQuery: def find_nearest( self, vector_field: str, - query_vector: Vector, + query_vector: Union[Vector, Sequence[float]], limit: int, distance_measure: DistanceMeasure, *, diff --git a/google/cloud/firestore_v1/base_vector_query.py b/google/cloud/firestore_v1/base_vector_query.py index f5a4403c8..88e40635f 100644 --- a/google/cloud/firestore_v1/base_vector_query.py +++ b/google/cloud/firestore_v1/base_vector_query.py @@ -19,13 +19,14 @@ import abc from abc import ABC from enum import Enum -from typing import TYPE_CHECKING, Any, Coroutine, Optional, Tuple, Union +from typing import TYPE_CHECKING, Any, Coroutine, Optional, Sequence, Tuple, Union from google.api_core import gapic_v1 from google.api_core import retry as retries from google.cloud.firestore_v1 import _helpers from google.cloud.firestore_v1.types import query +from google.cloud.firestore_v1.vector import Vector if TYPE_CHECKING: # pragma: NO COVER from google.cloud.firestore_v1.async_stream_generator import AsyncStreamGenerator @@ -33,7 +34,6 @@ from google.cloud.firestore_v1.query_profile import ExplainOptions from google.cloud.firestore_v1.query_results import QueryResultsList from google.cloud.firestore_v1.stream_generator import StreamGenerator - from google.cloud.firestore_v1.vector import Vector class DistanceMeasure(Enum): @@ -137,7 +137,7 @@ def get( def find_nearest( self, vector_field: str, - query_vector: Vector, + query_vector: Union[Vector, Sequence[float]], limit: int, distance_measure: DistanceMeasure, *, @@ -145,8 +145,11 @@ def find_nearest( distance_threshold: Optional[float] = None, ): """Finds the closest vector embeddings to the given query vector.""" + if not isinstance(query_vector, Vector): + self._query_vector = Vector(query_vector) + else: + self._query_vector = query_vector self._vector_field = vector_field - self._query_vector = query_vector self._limit = limit self._distance_measure = distance_measure self._distance_result_field = distance_result_field diff --git a/google/cloud/firestore_v1/query.py b/google/cloud/firestore_v1/query.py index 0b52afc83..a8b821bdc 100644 --- a/google/cloud/firestore_v1/query.py +++ b/google/cloud/firestore_v1/query.py @@ -20,7 +20,17 @@ """ from __future__ import annotations -from typing import TYPE_CHECKING, Any, Callable, Generator, List, Optional, Type +from typing import ( + TYPE_CHECKING, + Any, + Callable, + Generator, + List, + Optional, + Sequence, + Type, + Union, +) from google.api_core import exceptions, gapic_v1 from google.api_core import retry as retries @@ -269,7 +279,7 @@ def _retry_query_after_exception(self, exc, retry, transaction): def find_nearest( self, vector_field: str, - query_vector: Vector, + query_vector: Union[Vector, Sequence[float]], limit: int, distance_measure: DistanceMeasure, *, @@ -282,7 +292,7 @@ def find_nearest( Args: vector_field (str): An indexed vector field to search upon. Only documents which contain vectors whose dimensionality match the query_vector can be returned. - query_vector (Vector): The query vector that we are searching on. Must be a vector of no more + query_vector(Vector | Sequence[float]): The query vector that we are searching on. Must be a vector of no more than 2048 dimensions. limit (int): The number of nearest neighbors to return. Must be a positive integer of no more than 1000. distance_measure (:class:`DistanceMeasure`): The Distance Measure to use. diff --git a/tests/unit/v1/test_vector.py b/tests/unit/v1/test_vector.py index a28a05525..d850fc1cf 100644 --- a/tests/unit/v1/test_vector.py +++ b/tests/unit/v1/test_vector.py @@ -25,7 +25,7 @@ from google.cloud.firestore_v1.vector import Vector -def _make_commit_repsonse(): +def _make_commit_response(): response = mock.create_autospec(firestore.CommitResponse) response.write_results = [mock.sentinel.write_result] response.commit_time = mock.sentinel.commit_time @@ -35,7 +35,7 @@ def _make_commit_repsonse(): def _make_firestore_api(): firestore_api = mock.Mock() firestore_api.commit.mock_add_spec(spec=["commit"]) - firestore_api.commit.return_value = _make_commit_repsonse() + firestore_api.commit.return_value = _make_commit_response() return firestore_api diff --git a/tests/unit/v1/test_vector_query.py b/tests/unit/v1/test_vector_query.py index eb5328ace..ad88478c8 100644 --- a/tests/unit/v1/test_vector_query.py +++ b/tests/unit/v1/test_vector_query.py @@ -533,6 +533,68 @@ def test_vector_query_collection_group(distance_measure, expected_distance): ) +def test_vector_query_list_as_query_vector(): + # Create a minimal fake GAPIC. + firestore_api = mock.Mock(spec=["run_query"]) + client = make_client() + client._firestore_api_internal = firestore_api + + # Make a **real** collection reference as parent. + parent = client.collection("dee") + query = make_query(parent) + parent_path, expected_prefix = parent._parent_info() + + data = {"snooze": 10, "embedding": Vector([1.0, 2.0, 3.0])} + response_pb1 = _make_query_response( + name="{}/test_doc".format(expected_prefix), data=data + ) + response_pb2 = _make_query_response( + name="{}/test_doc".format(expected_prefix), data=data + ) + + kwargs = make_retry_timeout_kwargs(retry=None, timeout=None) + + # Execute the vector query and check the response. + firestore_api.run_query.return_value = iter([response_pb1, response_pb2]) + + vector_query = query.where("snooze", "==", 10).find_nearest( + vector_field="embedding", + query_vector=[1.0, 2.0, 3.0], + distance_measure=DistanceMeasure.EUCLIDEAN, + limit=5, + ) + + returned = vector_query.get(transaction=_transaction(client), **kwargs) + assert isinstance(returned, list) + assert len(returned) == 2 + assert returned[0].to_dict() == data + + expected_pb = _expected_pb( + parent=parent, + vector_field="embedding", + vector=Vector([1.0, 2.0, 3.0]), + distance_type=StructuredQuery.FindNearest.DistanceMeasure.EUCLIDEAN, + limit=5, + ) + expected_pb.where = StructuredQuery.Filter( + field_filter=StructuredQuery.FieldFilter( + field=StructuredQuery.FieldReference(field_path="snooze"), + op=StructuredQuery.FieldFilter.Operator.EQUAL, + value=encode_value(10), + ) + ) + + firestore_api.run_query.assert_called_once_with( + request={ + "parent": parent_path, + "structured_query": expected_pb, + "transaction": _TXN_ID, + }, + metadata=client._rpc_metadata, + **kwargs, + ) + + def test_query_stream_multiple_empty_response_in_stream(): from google.cloud.firestore_v1 import stream_generator