|
1 | 1 | """Unit tests for DocsSummarizer class."""
|
2 | 2 |
|
| 3 | +import logging |
3 | 4 | from unittest.mock import ANY, patch
|
4 | 5 |
|
5 | 6 | import pytest
|
|
13 | 14 | config.ols_config.authentication_config.module = "k8s"
|
14 | 15 |
|
15 | 16 |
|
| 17 | +from ols.app.models.config import LoggingConfig # noqa:E402 |
16 | 18 | from ols.src.query_helpers.docs_summarizer import ( # noqa:E402
|
17 | 19 | DocsSummarizer,
|
18 | 20 | QueryHelper,
|
19 | 21 | )
|
20 | 22 | from ols.utils import suid # noqa:E402
|
| 23 | +from ols.utils.logging_configurator import configure_logging # noqa:E402 |
21 | 24 | from tests import constants # noqa:E402
|
22 | 25 | from tests.mock_classes.mock_langchain_interface import ( # noqa:E402
|
23 | 26 | mock_langchain_interface,
|
@@ -145,6 +148,29 @@ def test_summarize_no_reference_content():
|
145 | 148 | assert not summary.history_truncated
|
146 | 149 |
|
147 | 150 |
|
| 151 | +def test_summarize_reranker(caplog): |
| 152 | + """Basic test to make sure the reranker is called as expected.""" |
| 153 | + logging_config = LoggingConfig(app_log_level="debug") |
| 154 | + |
| 155 | + configure_logging(logging_config) |
| 156 | + logger = logging.getLogger("ols") |
| 157 | + logger.handlers = [caplog.handler] # add caplog handler to logger |
| 158 | + |
| 159 | + with ( |
| 160 | + patch("ols.utils.token_handler.RAG_SIMILARITY_CUTOFF", 0.4), |
| 161 | + patch("ols.utils.token_handler.MINIMUM_CONTEXT_TOKEN_LIMIT", 3), |
| 162 | + ): |
| 163 | + summarizer = DocsSummarizer(llm_loader=mock_llm_loader(None)) |
| 164 | + question = "What's the ultimate question with answer 42?" |
| 165 | + rag_index = MockLlamaIndex() |
| 166 | + # no history is passed into create_response() method |
| 167 | + summary = summarizer.create_response(question, rag_index) |
| 168 | + check_summary_result(summary, question) |
| 169 | + |
| 170 | + # Check captured log text to see if reranker was called. |
| 171 | + assert "reranker.rerank() is called with 1 result(s)." in caplog.text |
| 172 | + |
| 173 | + |
148 | 174 | @pytest.mark.asyncio
|
149 | 175 | async def test_response_generator():
|
150 | 176 | """Test response generator method."""
|
|
0 commit comments