|
| 1 | +"""Vectorize LangChain retrievers.""" |
| 2 | + |
| 3 | +from __future__ import annotations |
| 4 | + |
| 5 | +from typing import TYPE_CHECKING, Any, Optional |
| 6 | + |
| 7 | +import vectorize_client |
| 8 | +from langchain_core.documents import Document |
| 9 | +from langchain_core.retrievers import BaseRetriever |
| 10 | +from typing_extensions import override |
| 11 | +from vectorize_client import ( |
| 12 | + ApiClient, |
| 13 | + Configuration, |
| 14 | + PipelinesApi, |
| 15 | + RetrieveDocumentsRequest, |
| 16 | +) |
| 17 | + |
| 18 | +if TYPE_CHECKING: |
| 19 | + from langchain_core.callbacks import CallbackManagerForRetrieverRun |
| 20 | + from langchain_core.runnables import RunnableConfig |
| 21 | + |
| 22 | +_METADATA_FIELDS = { |
| 23 | + "relevancy", |
| 24 | + "chunk_id", |
| 25 | + "total_chunks", |
| 26 | + "origin", |
| 27 | + "origin_id", |
| 28 | + "similarity", |
| 29 | + "source", |
| 30 | + "unique_source", |
| 31 | + "source_display_name", |
| 32 | + "pipeline_id", |
| 33 | + "org_id", |
| 34 | +} |
| 35 | +_NOT_SET = object() |
| 36 | + |
| 37 | + |
| 38 | +class VectorizeRetriever(BaseRetriever): |
| 39 | + """Vectorize retriever.""" |
| 40 | + |
| 41 | + api_token: str |
| 42 | + """The Vectorize API token.""" |
| 43 | + organization: Optional[str] = None # noqa: UP007 |
| 44 | + """The Vectorize organization ID.""" |
| 45 | + pipeline_id: Optional[str] = None # noqa: UP007 |
| 46 | + """The Vectorize pipeline ID.""" |
| 47 | + num_results: int = 5 |
| 48 | + """The number of documents to return.""" |
| 49 | + rerank: bool = False |
| 50 | + """Whether to rerank the results.""" |
| 51 | + metadata_filters: list[dict[str, Any]] = [] |
| 52 | + """The metadata filters to apply when retrieving the documents.""" |
| 53 | + |
| 54 | + _pipelines: PipelinesApi | None = None |
| 55 | + |
| 56 | + @override |
| 57 | + def model_post_init(self, /, context: Any) -> None: |
| 58 | + api = ApiClient(Configuration(access_token=self.api_token)) |
| 59 | + self._pipelines = PipelinesApi(api) |
| 60 | + |
| 61 | + @staticmethod |
| 62 | + def _convert_document(document: vectorize_client.models.Document) -> Document: |
| 63 | + metadata = {field: getattr(document, field) for field in _METADATA_FIELDS} |
| 64 | + return Document(id=document.id, page_content=document.text, metadata=metadata) |
| 65 | + |
| 66 | + @override |
| 67 | + def _get_relevant_documents( |
| 68 | + self, |
| 69 | + query: str, |
| 70 | + *, |
| 71 | + run_manager: CallbackManagerForRetrieverRun, |
| 72 | + organization: str | None = None, |
| 73 | + pipeline_id: str | None = None, |
| 74 | + num_results: int | None = None, |
| 75 | + rerank: bool | None = None, |
| 76 | + metadata_filters: list[dict[str, Any]] | None = None, |
| 77 | + ) -> list[Document]: |
| 78 | + request = RetrieveDocumentsRequest( |
| 79 | + question=query, |
| 80 | + num_results=num_results or self.num_results, |
| 81 | + rerank=rerank or self.rerank, |
| 82 | + metadata_filters=metadata_filters or self.metadata_filters, |
| 83 | + ) |
| 84 | + response = self._pipelines.retrieve_documents( |
| 85 | + organization or self.organization, pipeline_id or self.pipeline_id, request |
| 86 | + ) |
| 87 | + return [self._convert_document(doc) for doc in response.documents] |
| 88 | + |
| 89 | + @override |
| 90 | + def invoke( |
| 91 | + self, |
| 92 | + input: str, |
| 93 | + config: RunnableConfig | None = None, |
| 94 | + *, |
| 95 | + organization: str = "", |
| 96 | + pipeline_id: str = "", |
| 97 | + num_results: int = _NOT_SET, |
| 98 | + rerank: bool = _NOT_SET, |
| 99 | + metadata_filters: list[dict[str, Any]] = _NOT_SET, |
| 100 | + ) -> list[Document]: |
| 101 | + """Invoke the retriever to get relevant documents. |
| 102 | +
|
| 103 | + Main entry point for retriever invocations. |
| 104 | +
|
| 105 | + Args: |
| 106 | + input: The query string. |
| 107 | + config: Configuration for the retriever. Defaults to None. |
| 108 | + organization: The organization to retrieve documents from. |
| 109 | + If set, overrides the organization set at the initialization of the |
| 110 | + retriever. |
| 111 | + pipeline_id: The pipeline ID to retrieve documents from. |
| 112 | + If set, overrides the pipeline ID set at the initialization of the |
| 113 | + retriever. |
| 114 | + num_results: The number of results to retrieve. |
| 115 | + If set, overrides the number of results set at the initialization of |
| 116 | + the retriever. |
| 117 | + rerank: Whether to rerank the retrieved documents. |
| 118 | + If set, overrides the reranking set at the initialization of the |
| 119 | + retriever. |
| 120 | + metadata_filters: The metadata filters to apply when retrieving documents. |
| 121 | + If set, overrides the metadata filters set at the initialization of the |
| 122 | + retriever. |
| 123 | +
|
| 124 | + Returns: |
| 125 | + List of relevant documents. |
| 126 | +
|
| 127 | + Examples: |
| 128 | +
|
| 129 | + .. code-block:: python |
| 130 | +
|
| 131 | + retriever.invoke("query") |
| 132 | + """ |
| 133 | + kwargs = {} |
| 134 | + if organization: |
| 135 | + kwargs["organization"] = organization |
| 136 | + if pipeline_id: |
| 137 | + kwargs["pipeline_id"] = pipeline_id |
| 138 | + if num_results is not _NOT_SET: |
| 139 | + kwargs["num_results"] = num_results |
| 140 | + if rerank is not _NOT_SET: |
| 141 | + kwargs["rerank"] = rerank |
| 142 | + if metadata_filters is not _NOT_SET: |
| 143 | + kwargs["metadata_filters"] = metadata_filters |
| 144 | + |
| 145 | + return super().invoke(input, config, **kwargs) |
0 commit comments