Skip to content

Commit 1f735e8

Browse files
Merge pull request #2402 from tisnik/reranker-as-customized-part
OLS-1615: Reranker as customized part
2 parents 1130ea4 + 90421e4 commit 1f735e8

File tree

4 files changed

+45
-1
lines changed

4 files changed

+45
-1
lines changed

ols/customize/__init__.py

+1
Original file line numberDiff line numberDiff line change
@@ -6,3 +6,4 @@
66
project = os.getenv("PROJECT", "ols")
77
prompts = importlib.import_module(f"ols.customize.{project}.prompts")
88
keywords = importlib.import_module(f"ols.customize.{project}.keywords")
9+
reranker = importlib.import_module(f"ols.customize.{project}.reranker")

ols/customize/ols/reranker.py

+14
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,14 @@
1+
"""Reranker for post-processing the Vector DB search results."""
2+
3+
import logging
4+
5+
from llama_index.core.schema import NodeWithScore
6+
7+
logger = logging.getLogger(__name__)
8+
9+
10+
def rerank(retrieved_nodes: list[NodeWithScore]) -> list[NodeWithScore]:
11+
"""Rerank Vector DB search results."""
12+
message = f"reranker.rerank() is called with {len(retrieved_nodes)} result(s)."
13+
logger.debug(message)
14+
return retrieved_nodes

ols/src/query_helpers/docs_summarizer.py

+4-1
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@
1313
from ols.app.metrics import TokenMetricUpdater
1414
from ols.app.models.models import RagChunk, SummarizerResponse, TokenCounter, ToolCall
1515
from ols.constants import MAX_ITERATIONS, RAG_CONTENT_LIMIT, GenericLLMParameters
16+
from ols.customize import reranker
1617
from ols.src.prompts.prompt_generator import GeneratePrompt
1718
from ols.src.query_helpers.query_helper import QueryHelper
1819
from ols.src.tools.oc_cli import token_works_for_oc
@@ -97,8 +98,10 @@ def _prepare_prompt(
9798
# Retrieve RAG content
9899
if vector_index:
99100
retriever = vector_index.as_retriever(similarity_top_k=RAG_CONTENT_LIMIT)
101+
retrieved_nodes = retriever.retrieve(query)
102+
retrieved_nodes = reranker.rerank(retrieved_nodes)
100103
rag_chunks, available_tokens = token_handler.truncate_rag_context(
101-
retriever.retrieve(query), available_tokens
104+
retrieved_nodes, available_tokens
102105
)
103106
else:
104107
logger.warning("Proceeding without RAG content. Check start up messages.")

tests/unit/query_helpers/test_docs_summarizer.py

+26
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
"""Unit tests for DocsSummarizer class."""
22

3+
import logging
34
from unittest.mock import ANY, patch
45

56
import pytest
@@ -13,11 +14,13 @@
1314
config.ols_config.authentication_config.module = "k8s"
1415

1516

17+
from ols.app.models.config import LoggingConfig # noqa:E402
1618
from ols.src.query_helpers.docs_summarizer import ( # noqa:E402
1719
DocsSummarizer,
1820
QueryHelper,
1921
)
2022
from ols.utils import suid # noqa:E402
23+
from ols.utils.logging_configurator import configure_logging # noqa:E402
2124
from tests import constants # noqa:E402
2225
from tests.mock_classes.mock_langchain_interface import ( # noqa:E402
2326
mock_langchain_interface,
@@ -145,6 +148,29 @@ def test_summarize_no_reference_content():
145148
assert not summary.history_truncated
146149

147150

151+
def test_summarize_reranker(caplog):
152+
"""Basic test to make sure the reranker is called as expected."""
153+
logging_config = LoggingConfig(app_log_level="debug")
154+
155+
configure_logging(logging_config)
156+
logger = logging.getLogger("ols")
157+
logger.handlers = [caplog.handler] # add caplog handler to logger
158+
159+
with (
160+
patch("ols.utils.token_handler.RAG_SIMILARITY_CUTOFF", 0.4),
161+
patch("ols.utils.token_handler.MINIMUM_CONTEXT_TOKEN_LIMIT", 3),
162+
):
163+
summarizer = DocsSummarizer(llm_loader=mock_llm_loader(None))
164+
question = "What's the ultimate question with answer 42?"
165+
rag_index = MockLlamaIndex()
166+
# no history is passed into create_response() method
167+
summary = summarizer.create_response(question, rag_index)
168+
check_summary_result(summary, question)
169+
170+
# Check captured log text to see if reranker was called.
171+
assert "reranker.rerank() is called with 1 result(s)." in caplog.text
172+
173+
148174
@pytest.mark.asyncio
149175
async def test_response_generator():
150176
"""Test response generator method."""

0 commit comments

Comments
 (0)