Skip to content

Commit 540be49

Browse files
new artifact output scheme
1 parent 9ecf30c commit 540be49

File tree

1 file changed

+74
-36
lines changed

1 file changed

+74
-36
lines changed

src/mcp_snowflake_server/server.py

+74-36
Original file line numberDiff line numberDiff line change
@@ -1,16 +1,17 @@
1+
import importlib.metadata
2+
import json
13
import logging
2-
from functools import wraps
3-
from typing import Any, Callable
44
import os
5-
import uuid
6-
import yaml
7-
import importlib.metadata
85
import time
6+
import uuid
7+
from functools import wraps
8+
from typing import Any, Callable
99

10-
from mcp.server.models import InitializationOptions
10+
import mcp.server.stdio
1111
import mcp.types as types
12+
import yaml
1213
from mcp.server import NotificationOptions, Server
13-
import mcp.server.stdio
14+
from mcp.server.models import InitializationOptions
1415
from pydantic import AnyUrl, BaseModel
1516
from snowflake.snowpark import Session
1617

@@ -113,8 +114,22 @@ async def handle_list_tables(arguments, db, *_):
113114
FROM {db.connection_config['database']}.information_schema.tables
114115
WHERE table_schema = '{db.connection_config['schema'].upper()}'
115116
"""
116-
results, data_id = db.execute_query(query)
117-
return [types.TextContent(type="text", text=data_to_yaml(results), artifacts=[{"type": "dataframe", "data": results}])]
117+
data, data_id = db.execute_query(query)
118+
119+
output = {
120+
"type": "data",
121+
"data_id": data_id,
122+
"data": data,
123+
}
124+
yaml_output = data_to_yaml(output)
125+
json_output = json.dumps(output)
126+
return [
127+
types.TextContent(type="text", text=yaml_output),
128+
types.EmbeddedResource(
129+
type="resource",
130+
resource=types.TextResourceContents(uri=f"data://{data_id}", text=json_output, mimeType="application/json"),
131+
),
132+
]
118133

119134

120135
async def handle_describe_table(arguments, db, *_):
@@ -131,30 +146,41 @@ async def handle_describe_table(arguments, db, *_):
131146
FROM {database_name}.information_schema.columns
132147
WHERE table_schema = '{schema_name}' AND table_name = '{table_name}'
133148
"""
134-
results, data_id = db.execute_query(query)
149+
data, data_id = db.execute_query(query)
150+
151+
output = {
152+
"type": "data",
153+
"data_id": data_id,
154+
"data": data,
155+
}
156+
yaml_output = data_to_yaml(output)
157+
json_output = json.dumps(output)
135158
return [
136-
types.TextContent(
137-
type="text", text=data_to_yaml(results), artifacts=[{"type": "dataframe", "data": results, "data_id": data_id}]
138-
)
159+
types.TextContent(type="text", text=yaml_output),
160+
types.EmbeddedResource(
161+
type="resource",
162+
resource=types.TextResourceContents(uri=f"data://{data_id}", text=json_output, mimeType="application/json"),
163+
),
139164
]
140165

141166

142167
async def handle_read_query(arguments, db, write_detector, *_):
143-
MAX_RESULTS = 50
144168
if write_detector.analyze_query(arguments["query"])["contains_write"]:
145169
raise ValueError("Calls to read_query should not contain write operations")
146-
147-
results, data_id = db.execute_query(arguments["query"])
148-
truncate = len(results) > MAX_RESULTS
149-
results_text = data_to_yaml(results[:MAX_RESULTS])
150-
if truncate:
151-
results_text += f"\nResults of query have been truncated. There are {len(results) - MAX_RESULTS} more rows."
152-
results_text += f"\ndata_id = {data_id}"
153-
170+
data, data_id = db.execute_query(arguments["query"])
171+
output = {
172+
"type": "data",
173+
"data_id": data_id,
174+
"data": data,
175+
}
176+
yaml_output = data_to_yaml(output)
177+
json_output = json.dumps(output)
154178
return [
155-
types.TextContent(
156-
type="text", text=results_text, artifacts=[{"type": "dataframe", "data": results, "data_id": data_id}]
157-
)
179+
types.TextContent(type="text", text=yaml_output),
180+
types.EmbeddedResource(
181+
type="resource",
182+
resource=types.TextResourceContents(uri=f"data://{data_id}", text=json_output, mimeType="application/json"),
183+
),
158184
]
159185

160186

@@ -208,9 +234,11 @@ async def prefetch_tables(db: SnowflakeDB, credentials: dict) -> str:
208234
tables_brief[row["TABLE_NAME"]] = {**row, "COLUMNS": {}}
209235

210236
for row in column_results:
211-
tables_brief[row["TABLE_NAME"]]["COLUMNS"][row["COLUMN_NAME"]] = row
237+
row_without_table_name = row.copy()
238+
del row_without_table_name["TABLE_NAME"]
239+
tables_brief[row["TABLE_NAME"]]["COLUMNS"][row["COLUMN_NAME"]] = row_without_table_name
212240

213-
return data_to_yaml(tables_brief)
241+
return tables_brief
214242

215243
except Exception as e:
216244
logger.error(f"Error prefetching table descriptions: {e}")
@@ -238,7 +266,8 @@ async def main(
238266
server = Server("snowflake-manager")
239267
write_detector = SQLWriteDetector()
240268

241-
tables_brief = await prefetch_tables(db, credentials) if prefetch else ""
269+
tables_info = await prefetch_tables(db, credentials)
270+
tables_brief = data_to_yaml(tables_info) if prefetch else ""
242271

243272
all_tools = [
244273
Tool(
@@ -319,27 +348,36 @@ async def main(
319348
# Register handlers
320349
@server.list_resources()
321350
async def handle_list_resources() -> list[types.Resource]:
322-
return [
351+
resources = [
323352
types.Resource(
324353
uri=AnyUrl("memo://insights"),
325354
name="Data Insights Memo",
326355
description="A living document of discovered data insights",
327356
mimeType="text/plain",
328-
),
357+
)
358+
]
359+
table_brief_resources = [
329360
types.Resource(
330-
uri=AnyUrl("context://tables"),
331-
name="Tables",
332-
description="Description of tables and columns in the database",
361+
uri=AnyUrl(f"context://table/{table_name}"),
362+
name=f"{table_name} table",
363+
description=f"Description of the {table_name} table",
333364
mimeType="text/plain",
334-
),
365+
)
366+
for table_name in tables_info.keys()
335367
]
368+
resources += table_brief_resources
369+
return resources
336370

337371
@server.read_resource()
338372
async def handle_read_resource(uri: AnyUrl) -> str:
339373
if str(uri) == "memo://insights":
340374
return db.get_memo()
341-
elif str(uri) == "context://tables":
342-
return tables_brief
375+
elif str(uri).startswith("context://table"):
376+
table_name = str(uri).split("/")[-1]
377+
if table_name in tables_info:
378+
return data_to_yaml(tables_info[table_name])
379+
else:
380+
raise ValueError(f"Unknown table: {table_name}")
343381
else:
344382
raise ValueError(f"Unknown resource: {uri}")
345383

0 commit comments

Comments
 (0)