-
Notifications
You must be signed in to change notification settings - Fork 5
/
Copy pathmain.py
221 lines (172 loc) · 7.73 KB
/
main.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
import os
from typing import List, Dict, Any, Optional
from dotenv import load_dotenv
from databricks.sql import connect
from databricks.sql.client import Connection
from mcp.server.fastmcp import FastMCP
import requests
import json
# Load environment variables
load_dotenv()
# Get Databricks credentials from environment variables
DATABRICKS_HOST = os.getenv("DATABRICKS_HOST")
DATABRICKS_TOKEN = os.getenv("DATABRICKS_TOKEN")
DATABRICKS_HTTP_PATH = os.getenv("DATABRICKS_HTTP_PATH")
# Set up the MCP server
mcp = FastMCP("Databricks API Explorer")
# Helper function to get a Databricks SQL connection
def get_databricks_connection() -> Connection:
"""Create and return a Databricks SQL connection"""
if not all([DATABRICKS_HOST, DATABRICKS_TOKEN, DATABRICKS_HTTP_PATH]):
raise ValueError("Missing required Databricks connection details in .env file")
return connect(
server_hostname=DATABRICKS_HOST,
http_path=DATABRICKS_HTTP_PATH,
access_token=DATABRICKS_TOKEN
)
# Helper function for Databricks REST API requests
def databricks_api_request(endpoint: str, method: str = "GET", data: Dict = None) -> Dict:
"""Make a request to the Databricks REST API"""
if not all([DATABRICKS_HOST, DATABRICKS_TOKEN]):
raise ValueError("Missing required Databricks API credentials in .env file")
headers = {
"Authorization": f"Bearer {DATABRICKS_TOKEN}",
"Content-Type": "application/json"
}
url = f"https://{DATABRICKS_HOST}/api/2.0/{endpoint}"
if method.upper() == "GET":
response = requests.get(url, headers=headers)
elif method.upper() == "POST":
response = requests.post(url, headers=headers, json=data)
else:
raise ValueError(f"Unsupported HTTP method: {method}")
response.raise_for_status()
return response.json()
@mcp.resource("schema://tables")
def get_schema() -> str:
"""Provide the list of tables in the Databricks SQL warehouse as a resource"""
conn = get_databricks_connection()
try:
cursor = conn.cursor()
tables = cursor.tables().fetchall()
table_info = []
for table in tables:
table_info.append(f"Database: {table.TABLE_CAT}, Schema: {table.TABLE_SCHEM}, Table: {table.TABLE_NAME}")
return "\n".join(table_info)
except Exception as e:
return f"Error retrieving tables: {str(e)}"
finally:
if 'conn' in locals():
conn.close()
@mcp.tool()
def run_sql_query(sql: str) -> str:
"""Execute SQL queries on Databricks SQL warehouse"""
print(sql)
conn = get_databricks_connection()
print("connected")
try:
cursor = conn.cursor()
result = cursor.execute(sql)
if result.description:
# Get column names
columns = [col[0] for col in result.description]
# Format the result as a table
rows = result.fetchall()
if not rows:
return "Query executed successfully. No results returned."
# Format as markdown table
table = "| " + " | ".join(columns) + " |\n"
table += "| " + " | ".join(["---" for _ in columns]) + " |\n"
for row in rows:
table += "| " + " | ".join([str(cell) for cell in row]) + " |\n"
return table
else:
return "Query executed successfully. No results returned."
except Exception as e:
return f"Error executing query: {str(e)}"
finally:
if 'conn' in locals():
conn.close()
@mcp.tool()
def list_jobs() -> str:
"""List all Databricks jobs"""
try:
response = databricks_api_request("jobs/list")
if not response.get("jobs"):
return "No jobs found."
jobs = response.get("jobs", [])
# Format as markdown table
table = "| Job ID | Job Name | Created By |\n"
table += "| ------ | -------- | ---------- |\n"
for job in jobs:
job_id = job.get("job_id", "N/A")
job_name = job.get("settings", {}).get("name", "N/A")
created_by = job.get("created_by", "N/A")
table += f"| {job_id} | {job_name} | {created_by} |\n"
return table
except Exception as e:
return f"Error listing jobs: {str(e)}"
@mcp.tool()
def get_job_status(job_id: int) -> str:
"""Get the status of a specific Databricks job"""
try:
response = databricks_api_request("jobs/runs/list", data={"job_id": job_id})
if not response.get("runs"):
return f"No runs found for job ID {job_id}."
runs = response.get("runs", [])
# Format as markdown table
table = "| Run ID | State | Start Time | End Time | Duration |\n"
table += "| ------ | ----- | ---------- | -------- | -------- |\n"
for run in runs:
run_id = run.get("run_id", "N/A")
state = run.get("state", {}).get("result_state", "N/A")
# Convert timestamps to readable format if they exist
start_time = run.get("start_time", 0)
end_time = run.get("end_time", 0)
if start_time and end_time:
duration = f"{(end_time - start_time) / 1000:.2f}s"
else:
duration = "N/A"
# Format timestamps
import datetime
start_time_str = datetime.datetime.fromtimestamp(start_time / 1000).strftime('%Y-%m-%d %H:%M:%S') if start_time else "N/A"
end_time_str = datetime.datetime.fromtimestamp(end_time / 1000).strftime('%Y-%m-%d %H:%M:%S') if end_time else "N/A"
table += f"| {run_id} | {state} | {start_time_str} | {end_time_str} | {duration} |\n"
return table
except Exception as e:
return f"Error getting job status: {str(e)}"
@mcp.tool()
def get_job_details(job_id: int) -> str:
"""Get detailed information about a specific Databricks job"""
try:
response = databricks_api_request(f"jobs/get?job_id={job_id}", method="GET")
# Format the job details
job_name = response.get("settings", {}).get("name", "N/A")
created_time = response.get("created_time", 0)
# Convert timestamp to readable format
import datetime
created_time_str = datetime.datetime.fromtimestamp(created_time / 1000).strftime('%Y-%m-%d %H:%M:%S') if created_time else "N/A"
# Get job tasks
tasks = response.get("settings", {}).get("tasks", [])
result = f"## Job Details: {job_name}\n\n"
result += f"- **Job ID:** {job_id}\n"
result += f"- **Created:** {created_time_str}\n"
result += f"- **Creator:** {response.get('creator_user_name', 'N/A')}\n\n"
if tasks:
result += "### Tasks:\n\n"
result += "| Task Key | Task Type | Description |\n"
result += "| -------- | --------- | ----------- |\n"
for task in tasks:
task_key = task.get("task_key", "N/A")
task_type = next(iter([k for k in task.keys() if k.endswith("_task")]), "N/A")
description = task.get("description", "N/A")
result += f"| {task_key} | {task_type} | {description} |\n"
return result
except Exception as e:
return f"Error getting job details: {str(e)}"
# if __name__ == "__main__":
# import uvicorn
# uvicorn.run("main:app", host="0.0.0.0", port=8000, reload=True)
if __name__ == "__main__":
#run_sql_query("SELECT * FROM dev.dev_test.income_survey_dataset LIMIT 10;")
mcp.run()