forked from ClickHouse/mcp-clickhouse
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathmcp_server.py
166 lines (139 loc) · 5.83 KB
/
mcp_server.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
import logging
from typing import Sequence
import concurrent.futures
import atexit
import clickhouse_connect
from clickhouse_connect.driver.binding import quote_identifier, format_query_value
from dotenv import load_dotenv
from mcp.server.fastmcp import FastMCP
from mcp_clickhouse.mcp_env import config
MCP_SERVER_NAME = "mcp-clickhouse"
# Configure logging
logging.basicConfig(
level=logging.INFO, format="%(asctime)s - %(name)s - %(levelname)s - %(message)s"
)
logger = logging.getLogger(MCP_SERVER_NAME)
QUERY_EXECUTOR = concurrent.futures.ThreadPoolExecutor(max_workers=10)
atexit.register(lambda: QUERY_EXECUTOR.shutdown(wait=True))
SELECT_QUERY_TIMEOUT_SECS = 30
load_dotenv()
deps = [
"clickhouse-connect",
"python-dotenv",
"uvicorn",
"pip-system-certs",
]
mcp = FastMCP(MCP_SERVER_NAME, dependencies=deps)
@mcp.tool()
def list_databases():
"""List available ClickHouse databases"""
logger.info("Listing all databases")
client = create_clickhouse_client()
result = client.command("SHOW DATABASES")
logger.info(f"Found {len(result) if isinstance(result, list) else 1} databases")
return result
@mcp.tool()
def list_tables(database: str, like: str = None):
"""List available ClickHouse tables in a database"""
logger.info(f"Listing tables in database '{database}'")
client = create_clickhouse_client()
query = f"SHOW TABLES FROM {quote_identifier(database)}"
if like:
query += f" LIKE {format_query_value(like)}"
result = client.command(query)
# Get all table comments in one query
table_comments_query = f"SELECT name, comment FROM system.tables WHERE database = {format_query_value(database)}"
table_comments_result = client.query(table_comments_query)
table_comments = {row[0]: row[1] for row in table_comments_result.result_rows}
# Get all column comments in one query
column_comments_query = f"SELECT table, name, comment FROM system.columns WHERE database = {format_query_value(database)}"
column_comments_result = client.query(column_comments_query)
column_comments = {}
for row in column_comments_result.result_rows:
table, col_name, comment = row
if table not in column_comments:
column_comments[table] = {}
column_comments[table][col_name] = comment
def get_table_info(table):
logger.info(f"Getting schema info for table {database}.{table}")
schema_query = f"DESCRIBE TABLE {quote_identifier(database)}.{quote_identifier(table)}"
schema_result = client.query(schema_query)
columns = []
column_names = schema_result.column_names
for row in schema_result.result_rows:
column_dict = {}
for i, col_name in enumerate(column_names):
column_dict[col_name] = row[i]
# Add comment from our pre-fetched comments
if table in column_comments and column_dict['name'] in column_comments[table]:
column_dict['comment'] = column_comments[table][column_dict['name']]
else:
column_dict['comment'] = None
columns.append(column_dict)
create_table_query = f"SHOW CREATE TABLE {database}.`{table}`"
create_table_result = client.command(create_table_query)
return {
"database": database,
"name": table,
"comment": table_comments.get(table),
"columns": columns,
"create_table_query": create_table_result,
}
tables = []
if isinstance(result, str):
# Single table result
for table in (t.strip() for t in result.split()):
if table:
tables.append(get_table_info(table))
elif isinstance(result, Sequence):
# Multiple table results
for table in result:
tables.append(get_table_info(table))
logger.info(f"Found {len(tables)} tables")
return tables
def execute_query(query: str):
client = create_clickhouse_client()
try:
res = client.query(query, settings={"readonly": 1})
column_names = res.column_names
rows = []
for row in res.result_rows:
row_dict = {}
for i, col_name in enumerate(column_names):
row_dict[col_name] = row[i]
rows.append(row_dict)
logger.info(f"Query returned {len(rows)} rows")
return rows
except Exception as err:
logger.error(f"Error executing query: {err}")
return f"error running query: {err}"
@mcp.tool()
def run_select_query(query: str):
"""Run a SELECT query in a ClickHouse database"""
logger.info(f"Executing SELECT query: {query}")
future = QUERY_EXECUTOR.submit(execute_query, query)
try:
result = future.result(timeout=SELECT_QUERY_TIMEOUT_SECS)
return result
except concurrent.futures.TimeoutError:
logger.warning(f"Query timed out after {SELECT_QUERY_TIMEOUT_SECS} seconds: {query}")
future.cancel()
return f"Queries taking longer than {SELECT_QUERY_TIMEOUT_SECS} seconds are currently not supported."
def create_clickhouse_client():
client_config = config.get_client_config()
logger.info(
f"Creating ClickHouse client connection to {client_config['host']}:{client_config['port']} "
f"as {client_config['username']} "
f"(secure={client_config['secure']}, verify={client_config['verify']}, "
f"connect_timeout={client_config['connect_timeout']}s, "
f"send_receive_timeout={client_config['send_receive_timeout']}s)"
)
try:
client = clickhouse_connect.get_client(**client_config)
# Test the connection
version = client.server_version
logger.info(f"Successfully connected to ClickHouse server version {version}")
return client
except Exception as e:
logger.error(f"Failed to connect to ClickHouse: {str(e)}")
raise