diff --git a/mcp_clickhouse/mcp_server.py b/mcp_clickhouse/mcp_server.py index 95875fa..8962cbf 100644 --- a/mcp_clickhouse/mcp_server.py +++ b/mcp_clickhouse/mcp_server.py @@ -1,5 +1,7 @@ import logging from typing import Sequence +import concurrent.futures +import atexit import clickhouse_connect from clickhouse_connect.driver.binding import quote_identifier, format_query_value @@ -16,6 +18,10 @@ ) logger = logging.getLogger(MCP_SERVER_NAME) +QUERY_EXECUTOR = concurrent.futures.ThreadPoolExecutor(max_workers=10) +atexit.register(lambda: QUERY_EXECUTOR.shutdown(wait=True)) +SELECT_QUERY_TIMEOUT_SECS = 30 + load_dotenv() deps = [ @@ -105,9 +111,7 @@ def get_table_info(table): return tables -@mcp.tool() -def run_select_query(query: str): - logger.info(f"Executing SELECT query: {query}") +def execute_query(query: str): client = create_clickhouse_client() try: res = client.query(query, settings={"readonly": 1}) @@ -125,6 +129,19 @@ def run_select_query(query: str): return f"error running query: {err}" +@mcp.tool() +def run_select_query(query: str): + logger.info(f"Executing SELECT query: {query}") + future = QUERY_EXECUTOR.submit(execute_query, query) + try: + result = future.result(timeout=SELECT_QUERY_TIMEOUT_SECS) + return result + except concurrent.futures.TimeoutError: + logger.warning(f"Query timed out after {SELECT_QUERY_TIMEOUT_SECS} seconds: {query}") + future.cancel() + return f"Queries taking longer than {SELECT_QUERY_TIMEOUT_SECS} seconds are currently not supported." + + def create_clickhouse_client(): client_config = config.get_client_config() logger.info(