diff --git a/.dockerignore b/.dockerignore
new file mode 100644
index 0000000..037f15c
--- /dev/null
+++ b/.dockerignore
@@ -0,0 +1,50 @@
+# Git
+.git
+.gitignore
+
+# Python
+__pycache__/
+*.py[cod]
+*$py.class
+*.so
+.Python
+env/
+build/
+develop-eggs/
+dist/
+downloads/
+eggs/
+.eggs/
+lib/
+lib64/
+parts/
+sdist/
+var/
+wheels/
+*.egg-info/
+.installed.cfg
+*.egg
+
+# Virtual Environment
+venv/
+ENV/
+
+# IDE
+.idea/
+.vscode/
+*.swp
+*.swo
+
+# Environment variables
+.env
+.env.*
+
+# Docker
+.docker/
+
+# Logs
+*.log
+
+# Local development
+.DS_Store
+Thumbs.db
\ No newline at end of file
diff --git a/Dockerfile b/Dockerfile
new file mode 100644
index 0000000..ace5e3a
--- /dev/null
+++ b/Dockerfile
@@ -0,0 +1,33 @@
+# Use an official Python runtime as a parent image
+FROM python:3.10-slim
+
+# Set the working directory in the container
+WORKDIR /app
+
+# Install curl for DBFS API calls
+RUN apt-get update && apt-get install -y curl && rm -rf /var/lib/apt/lists/*
+
+# Copy requirements first to leverage Docker cache
+COPY requirements.txt .
+
+# Install any needed packages specified in requirements.txt
+# Use --no-cache-dir to reduce image size
+RUN pip install --no-cache-dir -r requirements.txt
+
+# Optional: Clean up build dependencies to reduce image size
+# RUN apt-get purge -y --auto-remove build-essential
+
+# Copy the rest of the application code into the container at /app
+COPY . .
+
+# Make port 8000 available to the world outside this container
+# (MCP servers typically run on port 8000 by default)
+EXPOSE 8000
+
+# Define environment variables (these will be overridden by docker run -e flags)
+ENV DATABRICKS_HOST=""
+ENV DATABRICKS_TOKEN=""
+ENV DATABRICKS_HTTP_PATH=""
+
+# Run main.py when the container launches
+CMD ["python", "main.py"]
\ No newline at end of file
diff --git a/databricks_mcp_tools.md b/databricks_mcp_tools.md
new file mode 100644
index 0000000..1ecf57f
--- /dev/null
+++ b/databricks_mcp_tools.md
@@ -0,0 +1,157 @@
+# Databricks MCP Server Tools Guide
+
+This guide outlines the available tools and resources provided by the Databricks MCP server.
+
+## Server Configuration
+
+The server is configured in the MCP settings file with:
+```json
+{
+ "mcpServers": {
+ "databricks-server": {
+ "command": "python",
+ "args": ["main.py"],
+ "disabled": false,
+ "alwaysAllow": ["list_jobs"],
+ "env": {},
+ "cwd": "/Users/maheidem/Documents/dev/mcp-databricks-server"
+ }
+ }
+}
+```
+
+## Available Tools
+
+### 1. run_sql_query
+Execute SQL queries on Databricks SQL warehouse.
+
+**Parameters:**
+- sql (string): SQL query to execute
+
+**Example:**
+```python
+
+databricks-server
+run_sql_query
+
+{
+ "sql": "SELECT * FROM my_database.my_table LIMIT 10"
+}
+
+
+```
+
+**Returns:** Results in markdown table format
+
+### 2. list_jobs
+List all Databricks jobs. This tool is in alwaysAllow list.
+
+**Parameters:** None
+
+**Example:**
+```python
+
+databricks-server
+list_jobs
+
+{}
+
+
+```
+
+**Returns:** Job list in markdown table format with columns:
+- Job ID
+- Job Name
+- Created By
+
+### 3. get_job_status
+Get the status of a specific Databricks job.
+
+**Parameters:**
+- job_id (integer): ID of the job to get status for
+
+**Example:**
+```python
+
+databricks-server
+get_job_status
+
+{
+ "job_id": 123
+}
+
+
+```
+
+**Returns:** Job runs in markdown table format with columns:
+- Run ID
+- State
+- Start Time
+- End Time
+- Duration
+
+### 4. get_job_details
+Get detailed information about a specific Databricks job.
+
+**Parameters:**
+- job_id (integer): ID of the job to get details for
+
+**Example:**
+```python
+
+databricks-server
+get_job_details
+
+{
+ "job_id": 123
+}
+
+
+```
+
+**Returns:** Detailed job information in markdown format including:
+- Job Name
+- Job ID
+- Created Time
+- Creator
+- Tasks (if any)
+
+## Available Resources
+
+### schema://tables
+Lists available tables in the Databricks SQL warehouse.
+
+**Example:**
+```python
+
+databricks-server
+schema://tables
+
+```
+
+**Returns:** List of tables with their database and schema information.
+
+## Usage Flow
+
+```mermaid
+graph TD
+ A[Start Server] --> B[MCP System Auto-connects]
+ B --> C[Tools Available]
+ C --> D[Use Tools]
+
+ subgraph "Tool Usage Flow"
+ D --> E[List Jobs]
+ E --> F[Get Job Details]
+ F --> G[Get Job Status]
+ G --> H[Run SQL Query]
+ end
+```
+
+## Requirements
+
+The server requires the following environment variables to be set:
+- DATABRICKS_TOKEN
+- DATABRICKS_HOST
+- DATABRICKS_HTTP_PATH
+
+These are already configured in the .env file.
\ No newline at end of file
diff --git a/main.py b/main.py
index b82d0fb..c673d3c 100644
--- a/main.py
+++ b/main.py
@@ -1,11 +1,13 @@
import os
-from typing import List, Dict, Any, Optional
+from typing import Dict
from dotenv import load_dotenv
from databricks.sql import connect
from databricks.sql.client import Connection
from mcp.server.fastmcp import FastMCP
import requests
import json
+import urllib.parse
+import subprocess
# Load environment variables
load_dotenv()
@@ -45,7 +47,7 @@ def databricks_api_request(endpoint: str, method: str = "GET", data: Dict = None
url = f"https://{DATABRICKS_HOST}/api/2.0/{endpoint}"
if method.upper() == "GET":
- response = requests.get(url, headers=headers)
+ response = requests.get(url, headers=headers, params=data)
elif method.upper() == "POST":
response = requests.post(url, headers=headers, json=data)
else:
@@ -76,9 +78,7 @@ def get_schema() -> str:
@mcp.tool()
def run_sql_query(sql: str) -> str:
"""Execute SQL queries on Databricks SQL warehouse"""
- print(sql)
conn = get_databricks_connection()
- print("connected")
try:
cursor = conn.cursor()
@@ -135,6 +135,85 @@ def list_jobs() -> str:
except Exception as e:
return f"Error listing jobs: {str(e)}"
+@mcp.tool()
+def create_job(
+ name: str,
+ cluster_id: str = None,
+ new_cluster: Dict = None,
+ notebook_path: str = None,
+ spark_jar_task: Dict = None,
+ spark_python_task: Dict = None,
+ schedule: Dict = None,
+ timeout_seconds: int = None,
+ max_concurrent_runs: int = None,
+ max_retries: int = None
+) -> str:
+ """Create a new Databricks job
+
+ Args:
+ name: The name of the job
+ cluster_id: ID of existing cluster to run the job on (specify either cluster_id or new_cluster)
+ new_cluster: Specification for a new cluster (specify either cluster_id or new_cluster)
+ notebook_path: Path to the notebook to run (specify a task type)
+ spark_jar_task: Specification for a JAR task (specify a task type)
+ spark_python_task: Specification for a Python task (specify a task type)
+ schedule: Job schedule configuration
+ timeout_seconds: Timeout in seconds for each run
+ max_concurrent_runs: Maximum number of concurrent runs
+ max_retries: Maximum number of retries per run
+ """
+ try:
+ # Prepare the job settings
+ settings = {
+ "name": name
+ }
+
+ # Set either existing cluster or new cluster
+ if cluster_id:
+ settings["existing_cluster_id"] = cluster_id
+ elif new_cluster:
+ settings["new_cluster"] = new_cluster
+ else:
+ return "Error: Either cluster_id or new_cluster must be specified."
+
+ # Set task definition - only one type of task can be set
+ task_set = False
+ if notebook_path:
+ settings["notebook_task"] = {"notebook_path": notebook_path}
+ task_set = True
+ if spark_jar_task and not task_set:
+ settings["spark_jar_task"] = spark_jar_task
+ task_set = True
+ if spark_python_task and not task_set:
+ settings["spark_python_task"] = spark_python_task
+ task_set = True
+
+ if not task_set:
+ return "Error: You must specify a task type (notebook_path, spark_jar_task, or spark_python_task)."
+
+ # Add optional parameters
+ if schedule:
+ settings["schedule"] = schedule
+ if timeout_seconds:
+ settings["timeout_seconds"] = timeout_seconds
+ if max_concurrent_runs:
+ settings["max_concurrent_runs"] = max_concurrent_runs
+ if max_retries:
+ settings["max_retries"] = max_retries
+
+ # Make the API request to create the job
+ response = databricks_api_request("jobs/create", method="POST", data={"name": name, "settings": settings})
+
+ # The response contains the job_id of the newly created job
+ job_id = response.get("job_id")
+
+ if job_id:
+ return f"Job created successfully with job ID: {job_id}\n\nYou can view job details with: get_job_details({job_id})"
+ else:
+ return "Job creation failed. No job ID returned."
+ except Exception as e:
+ return f"Error creating job: {str(e)}"
+
@mcp.tool()
def get_job_status(job_id: int) -> str:
"""Get the status of a specific Databricks job"""
@@ -212,10 +291,653 @@ def get_job_details(job_id: int) -> str:
except Exception as e:
return f"Error getting job details: {str(e)}"
+@mcp.tool()
+def list_job_runs(
+ job_id: int = None,
+ active_only: bool = None,
+ completed_only: bool = None,
+ limit: int = None,
+ run_type: str = None,
+ expand_tasks: bool = None,
+ start_time_from: int = None,
+ start_time_to: int = None,
+ page_token: str = None
+) -> str:
+ """List job runs with optional filtering parameters
+
+ Args:
+ job_id: Optional job ID to filter runs for a specific job
+ active_only: If true, only active runs are included
+ completed_only: If true, only completed runs are included
+ limit: Number of runs to return (1-25, default 20)
+ run_type: Type of runs (JOB_RUN, WORKFLOW_RUN, SUBMIT_RUN)
+ expand_tasks: Whether to include task details
+ start_time_from: Show runs that started at or after this UTC timestamp in milliseconds
+ start_time_to: Show runs that started at or before this UTC timestamp in milliseconds
+ page_token: Token for pagination
+ """
+ try:
+ # Build query parameters
+ params = {}
+
+ if job_id is not None:
+ params["job_id"] = job_id
+ if active_only is not None:
+ params["active_only"] = active_only
+ if completed_only is not None:
+ params["completed_only"] = completed_only
+ if limit is not None:
+ params["limit"] = limit
+ if run_type is not None:
+ params["run_type"] = run_type
+ if expand_tasks is not None:
+ params["expand_tasks"] = expand_tasks
+ if start_time_from is not None:
+ params["start_time_from"] = start_time_from
+ if start_time_to is not None:
+ params["start_time_to"] = start_time_to
+ if page_token is not None:
+ params["page_token"] = page_token
+
+ # Make API request
+ response = databricks_api_request("jobs/runs/list", method="GET", data=params)
+
+ if not response.get("runs"):
+ if job_id:
+ return f"No runs found for job ID {job_id}."
+ else:
+ return "No job runs found."
+
+ runs = response.get("runs", [])
+
+ # Format as markdown table
+ table = "| Run ID | Job ID | State | Creator | Start Time | End Time | Duration |\n"
+ table += "| ------ | ------ | ----- | ------- | ---------- | -------- | -------- |\n"
+
+ import datetime
+
+ for run in runs:
+ run_id = run.get("run_id", "N/A")
+ run_job_id = run.get("job_id", "N/A")
+ run_creator = run.get("creator_user_name", "N/A")
+
+ # Get state information
+ life_cycle_state = run.get("state", {}).get("life_cycle_state", "N/A")
+ result_state = run.get("state", {}).get("result_state", "")
+ state = f"{life_cycle_state}" if not result_state else f"{life_cycle_state} ({result_state})"
+
+ # Process timestamps
+ start_time = run.get("start_time", 0)
+ end_time = run.get("end_time", 0)
+
+ start_time_str = datetime.datetime.fromtimestamp(start_time / 1000).strftime('%Y-%m-%d %H:%M:%S') if start_time else "N/A"
+ end_time_str = datetime.datetime.fromtimestamp(end_time / 1000).strftime('%Y-%m-%d %H:%M:%S') if end_time else "N/A"
+
+ # Calculate duration
+ if start_time and end_time:
+ duration = f"{(end_time - start_time) / 1000:.2f}s"
+ else:
+ duration = "N/A"
+
+ table += f"| {run_id} | {run_job_id} | {state} | {run_creator} | {start_time_str} | {end_time_str} | {duration} |\n"
+
+ # Add pagination information
+ result = table
+
+ next_page = response.get("next_page_token")
+ prev_page = response.get("prev_page_token")
+
+ if next_page or prev_page:
+ result += "\n### Pagination\n"
+ if next_page:
+ result += f"More results available. Use page_token='{next_page}' to view the next page.\n"
+ if prev_page:
+ result += f"Use page_token='{prev_page}' to view the previous page.\n"
+
+ return result
+
+ except Exception as e:
+ return f"Error listing job runs: {str(e)}"
+
+# Clusters API tools
+@mcp.tool()
+def list_clusters() -> str:
+ """List all Databricks clusters"""
+ try:
+ response = databricks_api_request("clusters/list")
+
+ if not response.get("clusters"):
+ return "No clusters found."
+
+ clusters = response.get("clusters", [])
+
+ # Format as markdown table
+ table = "| Cluster ID | Cluster Name | State | Autoscale | Workers |\n"
+ table += "| ---------- | ------------ | ----- | --------- | ------- |\n"
+
+ for cluster in clusters:
+ cluster_id = cluster.get("cluster_id", "N/A")
+ cluster_name = cluster.get("cluster_name", "N/A")
+ state = cluster.get("state", "N/A")
+
+ autoscale = "Yes" if cluster.get("autoscale", {}).get("min_workers") else "No"
+ workers = cluster.get("num_workers", "N/A")
+
+ table += f"| {cluster_id} | {cluster_name} | {state} | {autoscale} | {workers} |\n"
+
+ return table
+ except Exception as e:
+ return f"Error listing clusters: {str(e)}"
+
+@mcp.tool()
+def get_cluster_details(cluster_id: str) -> str:
+ """Get detailed information about a specific Databricks cluster"""
+ try:
+ response = databricks_api_request(f"clusters/get?cluster_id={cluster_id}")
+
+ result = f"## Cluster Details\n\n"
+ result += f"- **Cluster ID:** {cluster_id}\n"
+ result += f"- **Cluster Name:** {response.get('cluster_name', 'N/A')}\n"
+ result += f"- **State:** {response.get('state', 'N/A')}\n"
+ result += f"- **Spark Version:** {response.get('spark_version', 'N/A')}\n"
+
+ # Add node type info
+ node_type_id = response.get('node_type_id', 'N/A')
+ result += f"- **Node Type ID:** {node_type_id}\n"
+
+ # Add autoscaling info if applicable
+ if 'autoscale' in response:
+ min_workers = response['autoscale'].get('min_workers', 'N/A')
+ max_workers = response['autoscale'].get('max_workers', 'N/A')
+ result += f"- **Autoscale:** {min_workers} to {max_workers} workers\n"
+ else:
+ result += f"- **Num Workers:** {response.get('num_workers', 'N/A')}\n"
+
+ return result
+ except Exception as e:
+ return f"Error getting cluster details: {str(e)}"
+
+# Instance Pools API tools
+@mcp.tool()
+def list_instance_pools() -> str:
+ """List all instance pools in the workspace"""
+ try:
+ response = databricks_api_request("instance-pools/list")
+
+ if not response.get("instance_pools"):
+ return "No instance pools found."
+
+ pools = response.get("instance_pools", [])
+
+ # Format as markdown table
+ table = "| Pool ID | Pool Name | Instance Type | Min Idle | Max Capacity | State |\n"
+ table += "| ------- | --------- | ------------- | -------- | ------------ | ----- |\n"
+
+ for pool in pools:
+ pool_id = pool.get("instance_pool_id", "N/A")
+ pool_name = pool.get("instance_pool_name", "N/A")
+ instance_type = pool.get("node_type_id", "N/A")
+ min_idle = pool.get("min_idle_instances", 0)
+ max_capacity = pool.get("max_capacity", 0)
+ state = pool.get("state", "N/A")
+
+ table += f"| {pool_id} | {pool_name} | {instance_type} | {min_idle} | {max_capacity} | {state} |\n"
+
+ return table
+ except Exception as e:
+ return f"Error listing instance pools: {str(e)}"
+
+@mcp.tool()
+def create_instance_pool(
+ name: str,
+ node_type_id: str,
+ min_idle_instances: int = 0,
+ max_capacity: int = 10,
+ idle_instance_autotermination_minutes: int = 60,
+ preload_spark_versions: List[str] = None
+) -> str:
+ """Create a new instance pool
+
+ Args:
+ name: The name of the instance pool
+ node_type_id: The node type for the instances in the pool (e.g., "Standard_DS3_v2")
+ min_idle_instances: Minimum number of idle instances to keep in the pool
+ max_capacity: Maximum number of instances the pool can contain
+ idle_instance_autotermination_minutes: Number of minutes that idle instances are maintained
+ preload_spark_versions: Spark versions to preload on the instances (optional)
+ """
+ try:
+ # Build the request data
+ data = {
+ "instance_pool_name": name,
+ "node_type_id": node_type_id,
+ "min_idle_instances": min_idle_instances,
+ "max_capacity": max_capacity,
+ "idle_instance_autotermination_minutes": idle_instance_autotermination_minutes
+ }
+
+ # Add preloaded Spark versions if specified
+ if preload_spark_versions:
+ data["preloaded_spark_versions"] = preload_spark_versions
+
+ # Make the API request
+ response = databricks_api_request("instance-pools/create", method="POST", data=data)
+
+ # The response contains the instance_pool_id of the newly created pool
+ pool_id = response.get("instance_pool_id")
+
+ if pool_id:
+ return f"Instance pool created successfully with ID: {pool_id}"
+ else:
+ return "Instance pool creation failed. No pool ID returned."
+ except Exception as e:
+ return f"Error creating instance pool: {str(e)}"
+
+# DBFS API tools
+@mcp.tool()
+def list_dbfs_files(path: str = "/") -> str:
+ """List files in a DBFS directory"""
+ try:
+ if not all([DATABRICKS_HOST, DATABRICKS_TOKEN]):
+ raise ValueError("Missing required Databricks API credentials in .env file")
+
+ # Direct API call to match the exact curl command that works
+ import subprocess
+ import json
+
+ # Execute the curl command that we know works
+ cmd = [
+ "curl", "-s",
+ "-H", f"Authorization: Bearer {DATABRICKS_TOKEN}",
+ "-H", "Content-Type: application/json",
+ f"https://{DATABRICKS_HOST}/api/2.0/dbfs/list?path={path}"
+ ]
+
+ result = subprocess.run(cmd, capture_output=True, text=True)
+
+ if result.returncode != 0:
+ raise ValueError(f"Command failed with exit code {result.returncode}: {result.stderr}")
+
+ try:
+ data = json.loads(result.stdout)
+ except json.JSONDecodeError:
+ raise ValueError(f"Failed to parse response as JSON: {result.stdout}")
+
+ if not data.get("files"):
+ return f"No files found in {path}."
+
+ files = data.get("files", [])
+
+ # Format as markdown table
+ table = "| File Path | Size | Is Directory |\n"
+ table += "| --------- | ---- | ------------ |\n"
+
+ for file in files:
+ file_path = file.get("path", "N/A")
+ size = file.get("file_size", 0)
+ is_dir = "Yes" if file.get("is_dir", False) else "No"
+
+ table += f"| {file_path} | {size} | {is_dir} |\n"
+
+ return table
+ except Exception as e:
+ print(f"Error details: {str(e)}")
+ return f"Error listing DBFS files: {str(e)}"
+
+@mcp.tool()
+def read_dbfs_file(path: str, length: int = 1000) -> str:
+ """Read the contents of a file from DBFS (limited to 1MB)"""
+ try:
+ if not all([DATABRICKS_HOST, DATABRICKS_TOKEN]):
+ raise ValueError("Missing required Databricks API credentials in .env file")
+
+ headers = {
+ "Authorization": f"Bearer {DATABRICKS_TOKEN}",
+ "Content-Type": "application/json"
+ }
+
+ # Build URL with query parameters directly
+ encoded_path = urllib.parse.quote(path)
+ url = f"https://{DATABRICKS_HOST}/api/2.0/dbfs/read?path={encoded_path}&length={length}"
+
+ print(f"Requesting URL: {url}")
+ response = requests.get(url, headers=headers)
+ print(f"Status code: {response.status_code}")
+ print(f"Response: {response.text[:100]}...") # Print just the beginning to avoid flooding logs
+
+ response.raise_for_status()
+ data = response.json()
+
+ if "data" not in data:
+ return f"No data found in file {path}."
+
+ import base64
+ file_content = base64.b64decode(data["data"]).decode("utf-8")
+
+ return f"## File Contents: {path}\n\n```\n{file_content}\n```"
+ except Exception as e:
+ print(f"Error details: {str(e)}")
+ return f"Error reading DBFS file: {str(e)}"
+
+@mcp.tool()
+def upload_to_dbfs(file_content: str, dbfs_path: str, overwrite: bool = False) -> str:
+ """Upload content to a file in DBFS
+
+ Args:
+ file_content: The text content to upload to DBFS
+ dbfs_path: The destination path in DBFS
+ overwrite: Whether to overwrite an existing file
+ """
+ try:
+ if not all([DATABRICKS_HOST, DATABRICKS_TOKEN]):
+ raise ValueError("Missing required Databricks API credentials in .env file")
+
+ import base64
+
+ headers = {
+ "Authorization": f"Bearer {DATABRICKS_TOKEN}",
+ "Content-Type": "application/json"
+ }
+
+ # DBFS file upload requires three steps:
+ # 1. Create a handle
+ create_url = f"https://{DATABRICKS_HOST}/api/2.0/dbfs/create"
+ create_data = {
+ "path": dbfs_path,
+ "overwrite": overwrite
+ }
+
+ create_response = requests.post(create_url, headers=headers, json=create_data)
+ create_response.raise_for_status()
+
+ handle = create_response.json().get("handle")
+ if not handle:
+ return f"Error: Failed to create a handle for the file upload."
+
+ # 2. Add blocks of data
+ # Convert string content to bytes and then encode as base64
+ content_bytes = file_content.encode('utf-8')
+ encoded_content = base64.b64encode(content_bytes).decode('utf-8')
+
+ add_block_url = f"https://{DATABRICKS_HOST}/api/2.0/dbfs/add-block"
+ add_block_data = {
+ "handle": handle,
+ "data": encoded_content
+ }
+
+ add_block_response = requests.post(add_block_url, headers=headers, json=add_block_data)
+ add_block_response.raise_for_status()
+
+ # 3. Close the handle
+ close_url = f"https://{DATABRICKS_HOST}/api/2.0/dbfs/close"
+ close_data = {
+ "handle": handle
+ }
+
+ close_response = requests.post(close_url, headers=headers, json=close_data)
+ close_response.raise_for_status()
+
+ return f"Successfully uploaded content to {dbfs_path}.\n\nYou can view the file with: read_dbfs_file('{dbfs_path}')"
+ except Exception as e:
+ print(f"Error details: {str(e)}")
+ return f"Error uploading to DBFS: {str(e)}"
+
+# Workspace API tools
+@mcp.tool()
+def list_workspace(path: str = "/") -> str:
+ """List notebooks and directories in a workspace"""
+ try:
+ if not all([DATABRICKS_HOST, DATABRICKS_TOKEN]):
+ raise ValueError("Missing required Databricks API credentials in .env file")
+
+ headers = {
+ "Authorization": f"Bearer {DATABRICKS_TOKEN}",
+ "Content-Type": "application/json"
+ }
+
+ # Build URL with query parameters directly
+ encoded_path = urllib.parse.quote(path)
+ url = f"https://{DATABRICKS_HOST}/api/2.0/workspace/list?path={encoded_path}"
+
+ print(f"Requesting URL: {url}")
+ response = requests.get(url, headers=headers)
+ print(f"Status code: {response.status_code}")
+ print(f"Response: {response.text}")
+
+ response.raise_for_status()
+ data = response.json()
+
+ if not data.get("objects"):
+ return f"No objects found in {path}."
+
+ objects = data.get("objects", [])
+
+ # Format as markdown table
+ table = "| Path | Object Type | Language |\n"
+ table += "| ---- | ----------- | -------- |\n"
+
+ for obj in objects:
+ obj_path = obj.get("path", "N/A")
+ obj_type = obj.get("object_type", "N/A")
+ language = obj.get("language", "N/A") if obj_type == "NOTEBOOK" else "N/A"
+
+ table += f"| {obj_path} | {obj_type} | {language} |\n"
+
+ return table
+ except Exception as e:
+ print(f"Error details: {str(e)}")
+ return f"Error listing workspace: {str(e)}"
+
+@mcp.tool()
+def export_notebook(path: str, format: str = "SOURCE") -> str:
+ """Export a notebook from the workspace
+
+ Args:
+ path: Path to the notebook in the workspace
+ format: Export format (SOURCE, HTML, JUPYTER, DBC, R_MARKDOWN)
+ """
+ try:
+ if not all([DATABRICKS_HOST, DATABRICKS_TOKEN]):
+ raise ValueError("Missing required Databricks API credentials in .env file")
+
+ headers = {
+ "Authorization": f"Bearer {DATABRICKS_TOKEN}",
+ "Content-Type": "application/json"
+ }
+
+ # Build request
+ export_url = f"https://{DATABRICKS_HOST}/api/2.0/workspace/export"
+ export_data = {
+ "path": path,
+ "format": format
+ }
+
+ response = requests.get(export_url, headers=headers, params=export_data)
+ response.raise_for_status()
+ data = response.json()
+
+ if "content" not in data:
+ return f"No content found for notebook at {path}."
+
+ import base64
+ notebook_content = base64.b64decode(data["content"]).decode("utf-8")
+
+ # Determine file extension based on format
+ extension = ""
+ if format == "SOURCE":
+ # Try to determine language
+ language_url = f"https://{DATABRICKS_HOST}/api/2.0/workspace/get-status"
+ language_response = requests.get(language_url, headers=headers, params={"path": path})
+
+ if language_response.status_code == 200:
+ language_data = language_response.json()
+ language = language_data.get("language", "").lower()
+
+ if language == "python":
+ extension = ".py"
+ elif language == "scala":
+ extension = ".scala"
+ elif language == "r":
+ extension = ".r"
+ elif language == "sql":
+ extension = ".sql"
+
+ elif format == "HTML":
+ extension = ".html"
+ elif format == "JUPYTER":
+ extension = ".ipynb"
+ elif format == "DBC":
+ extension = ".dbc"
+ elif format == "R_MARKDOWN":
+ extension = ".Rmd"
+
+ # Return the notebook content
+ filename = path.split("/")[-1] + extension
+ return f"## Exported Notebook: {filename}\n\n```\n{notebook_content}\n```"
+ except Exception as e:
+ print(f"Error details: {str(e)}")
+ return f"Error exporting notebook: {str(e)}"
+
+@mcp.tool()
+def import_notebook(content: str, path: str, language: str = "PYTHON", format: str = "SOURCE", overwrite: bool = False) -> str:
+ """Import a notebook into the workspace
+
+ Args:
+ content: The notebook content
+ path: Destination path in the workspace
+ language: Notebook language (PYTHON, SCALA, SQL, R)
+ format: Import format (SOURCE, HTML, JUPYTER, DBC, R_MARKDOWN)
+ overwrite: Whether to overwrite existing notebook
+ """
+ try:
+ if not all([DATABRICKS_HOST, DATABRICKS_TOKEN]):
+ raise ValueError("Missing required Databricks API credentials in .env file")
+
+ headers = {
+ "Authorization": f"Bearer {DATABRICKS_TOKEN}",
+ "Content-Type": "application/json"
+ }
+
+ # Encode the content as base64
+ import base64
+ content_bytes = content.encode("utf-8")
+ encoded_content = base64.b64encode(content_bytes).decode("utf-8")
+
+ # Build request
+ import_url = f"https://{DATABRICKS_HOST}/api/2.0/workspace/import"
+ import_data = {
+ "path": path,
+ "content": encoded_content,
+ "language": language,
+ "format": format,
+ "overwrite": overwrite
+ }
+
+ response = requests.post(import_url, headers=headers, json=import_data)
+ response.raise_for_status()
+
+ return f"Notebook successfully imported to {path}"
+ except Exception as e:
+ print(f"Error details: {str(e)}")
+ return f"Error importing notebook: {str(e)}"
+
+# Unity Catalog API tools
+@mcp.tool()
+def list_catalogs() -> str:
+ """List all catalogs in Unity Catalog"""
+ try:
+ response = databricks_api_request("unity-catalog/catalogs")
+
+ if not response.get("catalogs"):
+ return "No catalogs found."
+
+ catalogs = response.get("catalogs", [])
+
+ # Format as markdown table
+ table = "| Catalog Name | Comment | Provider | Properties |\n"
+ table += "| ------------ | ------- | -------- | ---------- |\n"
+
+ for catalog in catalogs:
+ name = catalog.get("name", "N/A")
+ comment = catalog.get("comment", "")
+ provider = catalog.get("provider", {}).get("name", "N/A")
+ properties = ", ".join([f"{k}: {v}" for k, v in catalog.get("properties", {}).items()]) if catalog.get("properties") else "None"
+
+ table += f"| {name} | {comment} | {provider} | {properties} |\n"
+
+ return table
+ except Exception as e:
+ return f"Error listing catalogs: {str(e)}"
+
+@mcp.tool()
+def create_catalog(name: str, comment: str = None, properties: Dict = None) -> str:
+ """Create a new catalog in Unity Catalog
+
+ Args:
+ name: Name of the catalog
+ comment: Description of the catalog (optional)
+ properties: Key-value properties for the catalog (optional)
+ """
+ try:
+ # Prepare the request data
+ data = {
+ "name": name
+ }
+
+ if comment:
+ data["comment"] = comment
+
+ if properties:
+ data["properties"] = properties
+
+ # Make the API request
+ response = databricks_api_request("unity-catalog/catalogs", method="POST", data=data)
+
+ # Check if the catalog was created successfully
+ if response.get("name") == name:
+ return f"Catalog '{name}' was created successfully."
+ else:
+ return f"Catalog creation might have failed. Please check with list_catalogs()."
+ except Exception as e:
+ return f"Error creating catalog: {str(e)}"
+
+@mcp.tool()
+def create_schema(catalog_name: str, schema_name: str, comment: str = None, properties: Dict = None) -> str:
+ """Create a new schema in Unity Catalog
+
+ Args:
+ catalog_name: Name of the parent catalog
+ schema_name: Name of the schema
+ comment: Description of the schema (optional)
+ properties: Key-value properties for the schema (optional)
+ """
+ try:
+ # Prepare the request data
+ data = {
+ "name": schema_name,
+ "catalog_name": catalog_name
+ }
+
+ if comment:
+ data["comment"] = comment
+
+ if properties:
+ data["properties"] = properties
+
+ # Make the API request
+ response = databricks_api_request("unity-catalog/schemas", method="POST", data=data)
+
+ # Check if the schema was created successfully
+ if response.get("name") == schema_name and response.get("catalog_name") == catalog_name:
+ return f"Schema '{catalog_name}.{schema_name}' was created successfully."
+ else:
+ return f"Schema creation might have failed. Please check using another Unity Catalog API call."
+ except Exception as e:
+ return f"Error creating schema: {str(e)}"
+
# if __name__ == "__main__":
# import uvicorn
# uvicorn.run("main:app", host="0.0.0.0", port=8000, reload=True)
if __name__ == "__main__":
- #run_sql_query("SELECT * FROM dev.dev_test.income_survey_dataset LIMIT 10;")
mcp.run()
\ No newline at end of file
diff --git a/requirements.txt b/requirements.txt
index 6d3033d..e3c9001 100644
--- a/requirements.txt
+++ b/requirements.txt
@@ -1,8 +1,9 @@
-fastapi>=0.95.0
-uvicorn>=0.22.0
-databricks-sql-connector>=2.4.0
-python-dotenv>=1.0.0
-pydantic>=2.0.0
+fastapi>=0.68.0
+uvicorn>=0.15.0
+python-dotenv>=0.19.0
+requests>=2.26.0
+databricks-sql-connector>=2.0.0
mcp>=0.1.0
-pyarrow>=14.0.1
-requests>=2.31.0
\ No newline at end of file
+python-multipart>=0.0.5
+pydantic>=1.8.0
+packaging>=23.0
\ No newline at end of file
diff --git a/test_connection.py b/test_connection.py
index 5a971ab..347da2d 100644
--- a/test_connection.py
+++ b/test_connection.py
@@ -8,6 +8,8 @@
import os
from dotenv import load_dotenv
import sys
+import datetime
+import base64
# Load environment variables
load_dotenv()
@@ -62,6 +64,512 @@ def test_databricks_api():
print(f"❌ Error connecting to Databricks API: {str(e)}")
return False
+def test_clusters_api():
+ """Test Clusters API functionality"""
+ import requests
+
+ print("\nTesting Clusters API...")
+
+ try:
+ headers = {
+ "Authorization": f"Bearer {DATABRICKS_TOKEN}",
+ "Content-Type": "application/json"
+ }
+
+ # Test clusters/list endpoint
+ url = f"https://{DATABRICKS_HOST}/api/2.0/clusters/list"
+ response = requests.get(url, headers=headers)
+
+ if response.status_code == 200:
+ data = response.json()
+ if "clusters" in data:
+ print("✅ Successfully accessed Clusters API (list)")
+
+ # Also test get_cluster_details if clusters exist
+ if data.get("clusters"):
+ cluster_id = data["clusters"][0]["cluster_id"]
+ detail_url = f"https://{DATABRICKS_HOST}/api/2.0/clusters/get?cluster_id={cluster_id}"
+ detail_response = requests.get(detail_url, headers=headers)
+
+ if detail_response.status_code == 200:
+ print("✅ Successfully accessed Clusters API (details)")
+ else:
+ print(f"❌ Failed to access Clusters API (details): {detail_response.status_code}")
+
+ return True
+ else:
+ print("❌ Response missing expected 'clusters' field")
+ return False
+ else:
+ print(f"❌ Failed to access Clusters API: {response.status_code} - {response.text}")
+ return False
+ except Exception as e:
+ print(f"❌ Error testing Clusters API: {str(e)}")
+ return False
+
+def test_instance_pools_api():
+ """Test Instance Pools API functionality"""
+ import requests
+
+ print("\nTesting Instance Pools API...")
+
+ try:
+ headers = {
+ "Authorization": f"Bearer {DATABRICKS_TOKEN}",
+ "Content-Type": "application/json"
+ }
+
+ # Test instance-pools/list endpoint
+ url = f"https://{DATABRICKS_HOST}/api/2.0/instance-pools/list"
+ response = requests.get(url, headers=headers)
+
+ if response.status_code == 200:
+ data = response.json()
+ if "instance_pools" in data:
+ print("✅ Successfully accessed Instance Pools API (list)")
+
+ # Get available node types to use for creating a pool
+ node_types_url = f"https://{DATABRICKS_HOST}/api/2.0/clusters/list-node-types"
+ node_types_response = requests.get(node_types_url, headers=headers)
+
+ if node_types_response.status_code == 200 and "node_types" in node_types_response.json():
+ node_types = node_types_response.json()["node_types"]
+ if node_types:
+ # Get the first node type ID
+ node_type_id = node_types[0]["node_type_id"]
+
+ # Test create instance pool
+ pool_name = f"Test Pool {datetime.datetime.now().strftime('%Y%m%d%H%M%S')}"
+ create_data = {
+ "instance_pool_name": pool_name,
+ "node_type_id": node_type_id,
+ "min_idle_instances": 0,
+ "max_capacity": 2,
+ "idle_instance_autotermination_minutes": 10
+ }
+
+ create_url = f"https://{DATABRICKS_HOST}/api/2.0/instance-pools/create"
+ create_response = requests.post(create_url, headers=headers, json=create_data)
+
+ if create_response.status_code == 200 and "instance_pool_id" in create_response.json():
+ print("✅ Successfully accessed Instance Pools API (create)")
+
+ # Cleanup: Delete the test pool we just created
+ pool_id = create_response.json()["instance_pool_id"]
+ delete_url = f"https://{DATABRICKS_HOST}/api/2.0/instance-pools/delete"
+ delete_data = {"instance_pool_id": pool_id}
+ delete_response = requests.post(delete_url, headers=headers, json=delete_data)
+
+ if delete_response.status_code == 200:
+ print("✅ Successfully cleaned up test instance pool")
+ else:
+ print(f"⚠️ Could not delete test instance pool: {delete_response.status_code}")
+ else:
+ print(f"❌ Failed to access Instance Pools API (create): {create_response.status_code} - {create_response.text}")
+ else:
+ print("⚠️ No node types available for testing instance pool creation")
+ else:
+ print(f"⚠️ Could not get node types for testing: {node_types_response.status_code}")
+
+ return True
+ else:
+ print("❌ Response missing expected 'instance_pools' field")
+ return False
+ else:
+ print(f"❌ Failed to access Instance Pools API: {response.status_code} - {response.text}")
+ return False
+ except Exception as e:
+ print(f"❌ Error testing Instance Pools API: {str(e)}")
+ return False
+
+def test_unity_catalog_api():
+ """Test Unity Catalog API functionality"""
+ import requests
+
+ print("\nTesting Unity Catalog API...")
+
+ try:
+ headers = {
+ "Authorization": f"Bearer {DATABRICKS_TOKEN}",
+ "Content-Type": "application/json"
+ }
+
+ # Test catalogs list endpoint
+ url = f"https://{DATABRICKS_HOST}/api/2.0/unity-catalog/catalogs"
+ response = requests.get(url, headers=headers)
+
+ # Check if workspace supports Unity Catalog
+ if response.status_code == 404 or (response.status_code == 400 and "Unity Catalog is not enabled" in response.text):
+ print("⚠️ Unity Catalog is not enabled in this workspace - skipping tests")
+ return True # Skip test but don't mark as failure
+
+ if response.status_code == 200:
+ # Note: This may return no catalogs, but it should have a "catalogs" key
+ data = response.json()
+ if "catalogs" in data:
+ print("✅ Successfully accessed Unity Catalog API (list catalogs)")
+
+ # Test catalog creation (only if user has permission)
+ catalog_name = f"test_catalog_{datetime.datetime.now().strftime('%Y%m%d%H%M%S')}"
+ create_data = {
+ "name": catalog_name,
+ "comment": "Test catalog for API testing - will be deleted"
+ }
+
+ create_url = f"https://{DATABRICKS_HOST}/api/2.0/unity-catalog/catalogs"
+ create_response = requests.post(create_url, headers=headers, json=create_data)
+
+ if create_response.status_code == 200:
+ print("✅ Successfully accessed Unity Catalog API (create catalog)")
+
+ # Test schema creation in this catalog
+ schema_name = f"test_schema_{datetime.datetime.now().strftime('%Y%m%d%H%M%S')}"
+ schema_data = {
+ "name": schema_name,
+ "catalog_name": catalog_name,
+ "comment": "Test schema for API testing - will be deleted"
+ }
+
+ schema_url = f"https://{DATABRICKS_HOST}/api/2.0/unity-catalog/schemas"
+ schema_response = requests.post(schema_url, headers=headers, json=schema_data)
+
+ if schema_response.status_code == 200:
+ print("✅ Successfully accessed Unity Catalog API (create schema)")
+ else:
+ # Not all users can create schemas, so this is not a critical failure
+ print(f"⚠️ Could not create schema (permissions?): {schema_response.status_code}")
+
+ # Cleanup: Delete the catalog (automatically deletes schemas too)
+ delete_url = f"https://{DATABRICKS_HOST}/api/2.0/unity-catalog/catalogs/{catalog_name}"
+ delete_response = requests.delete(delete_url, headers=headers)
+
+ if delete_response.status_code in [200, 204]:
+ print("✅ Successfully cleaned up test catalog")
+ else:
+ print(f"⚠️ Could not delete test catalog: {delete_response.status_code}")
+ elif create_response.status_code in [403, 401]:
+ # User likely doesn't have permission to create catalogs
+ print("⚠️ Insufficient permissions to create catalogs - skipping create test")
+ else:
+ print(f"❌ Failed to access Unity Catalog API (create): {create_response.status_code} - {create_response.text}")
+
+ return True
+ else:
+ print("❌ Response missing expected 'catalogs' field")
+ return False
+ else:
+ print(f"❌ Failed to access Unity Catalog API: {response.status_code} - {response.text}")
+ return False
+ except Exception as e:
+ print(f"❌ Error testing Unity Catalog API: {str(e)}")
+ return False
+
+def test_jobs_api():
+ """Test Jobs API functionality"""
+ import requests
+
+ print("\nTesting Jobs API...")
+
+ try:
+ headers = {
+ "Authorization": f"Bearer {DATABRICKS_TOKEN}",
+ "Content-Type": "application/json"
+ }
+
+ # Test jobs/list endpoint
+ url = f"https://{DATABRICKS_HOST}/api/2.0/jobs/list"
+ response = requests.get(url, headers=headers)
+
+ if response.status_code == 200:
+ data = response.json()
+ if "jobs" in data:
+ print("✅ Successfully accessed Jobs API (list)")
+
+ # Also test job details if jobs exist
+ if data.get("jobs"):
+ job_id = data["jobs"][0]["job_id"]
+
+ # Test get_job_details
+ detail_url = f"https://{DATABRICKS_HOST}/api/2.0/jobs/get?job_id={job_id}"
+ detail_response = requests.get(detail_url, headers=headers)
+
+ if detail_response.status_code == 200:
+ print("✅ Successfully accessed Jobs API (details)")
+ else:
+ print(f"❌ Failed to access Jobs API (details): {detail_response.status_code}")
+
+ # Test get_job_status
+ status_url = f"https://{DATABRICKS_HOST}/api/2.0/jobs/runs/list"
+ status_response = requests.get(status_url, headers=headers, params={"job_id": job_id})
+
+ if status_response.status_code == 200 and "runs" in status_response.json():
+ print("✅ Successfully accessed Jobs API (status)")
+ else:
+ print(f"❌ Failed to access Jobs API (status): {status_response.status_code}")
+
+ # Test list_job_runs
+ runs_url = f"https://{DATABRICKS_HOST}/api/2.2/jobs/runs/list"
+ runs_response = requests.get(runs_url, headers=headers)
+
+ if runs_response.status_code == 200 and "runs" in runs_response.json():
+ print("✅ Successfully accessed Jobs API (runs list)")
+ else:
+ print(f"❌ Failed to access Jobs API (runs list): {runs_response.status_code}")
+
+ # Test job creation API
+ # Find an available cluster to use
+ cluster_url = f"https://{DATABRICKS_HOST}/api/2.0/clusters/list"
+ cluster_response = requests.get(cluster_url, headers=headers)
+
+ if cluster_response.status_code == 200 and cluster_response.json().get("clusters"):
+ # Find the first RUNNING cluster to use
+ clusters = cluster_response.json()["clusters"]
+ running_clusters = [c for c in clusters if c.get("state") == "RUNNING"]
+
+ if running_clusters:
+ test_cluster_id = running_clusters[0]["cluster_id"]
+
+ # Create a test job
+ create_url = f"https://{DATABRICKS_HOST}/api/2.0/jobs/create"
+ job_name = f"Test Job {datetime.datetime.now().strftime('%Y%m%d%H%M%S')}"
+
+ job_data = {
+ "name": job_name,
+ "settings": {
+ "name": job_name,
+ "existing_cluster_id": test_cluster_id,
+ "notebook_task": {
+ "notebook_path": "/Shared/test-notebook" if os.path.exists("/Shared/test-notebook") else "/Users/test-notebook"
+ },
+ "max_retries": 0,
+ "timeout_seconds": 300
+ }
+ }
+
+ create_response = requests.post(create_url, headers=headers, json=job_data)
+
+ if create_response.status_code == 200 and "job_id" in create_response.json():
+ print("✅ Successfully accessed Jobs API (create)")
+ else:
+ print(f"❌ Failed to access Jobs API (create): {create_response.status_code} - {create_response.text}")
+ else:
+ print("⚠️ Skipping job creation test: No running clusters found")
+ else:
+ print("⚠️ Skipping job creation test: Couldn't list clusters")
+
+ return True
+ else:
+ print("❌ Response missing expected 'jobs' field")
+ return False
+ else:
+ print(f"❌ Failed to access Jobs API: {response.status_code} - {response.text}")
+ return False
+ except Exception as e:
+ print(f"❌ Error testing Jobs API: {str(e)}")
+ return False
+
+def test_dbfs_api():
+ """Test DBFS API functionality"""
+ import requests
+ import json
+ import subprocess
+
+ print("\nTesting DBFS API...")
+
+ try:
+ # Method 1: Direct API call with query parameters (matches our implementation)
+ headers = {
+ "Authorization": f"Bearer {DATABRICKS_TOKEN}",
+ "Content-Type": "application/json"
+ }
+
+ # The DBFS list endpoint requires a path parameter
+ url = f"https://{DATABRICKS_HOST}/api/2.0/dbfs/list?path=/"
+ response = requests.get(url, headers=headers)
+
+ if response.status_code == 200:
+ data = response.json()
+ if "files" in data:
+ print("✅ Successfully accessed DBFS API (list)")
+
+ # Also test read_dbfs_file if README.md exists
+ if any(file.get("path") == "/databricks-datasets/README.md" for file in data.get("files", [])):
+ read_url = f"https://{DATABRICKS_HOST}/api/2.0/dbfs/read?path=/databricks-datasets/README.md&length=1000"
+ read_response = requests.get(read_url, headers=headers)
+
+ if read_response.status_code == 200 and "data" in read_response.json():
+ print("✅ Successfully accessed DBFS API (read)")
+ else:
+ print(f"❌ Failed to access DBFS API (read): {read_response.status_code}")
+
+ # Test file upload to DBFS
+ test_content = "This is a test file uploaded via the DBFS API."
+ test_path = "/FileStore/test-upload.txt"
+
+ # Step 1: Create handle
+ create_url = f"https://{DATABRICKS_HOST}/api/2.0/dbfs/create"
+ create_data = {
+ "path": test_path,
+ "overwrite": True
+ }
+
+ create_response = requests.post(create_url, headers=headers, json=create_data)
+
+ if create_response.status_code == 200 and "handle" in create_response.json():
+ handle = create_response.json()["handle"]
+
+ # Step 2: Add data block
+ add_block_url = f"https://{DATABRICKS_HOST}/api/2.0/dbfs/add-block"
+ encoded_content = base64.b64encode(test_content.encode('utf-8')).decode('utf-8')
+
+ add_block_data = {
+ "handle": handle,
+ "data": encoded_content
+ }
+
+ add_block_response = requests.post(add_block_url, headers=headers, json=add_block_data)
+
+ # Step 3: Close handle
+ if add_block_response.status_code == 200:
+ close_url = f"https://{DATABRICKS_HOST}/api/2.0/dbfs/close"
+ close_data = {
+ "handle": handle
+ }
+
+ close_response = requests.post(close_url, headers=headers, json=close_data)
+
+ if close_response.status_code == 200:
+ print("✅ Successfully accessed DBFS API (upload)")
+
+ # Verify the file was uploaded by reading it back
+ verify_url = f"https://{DATABRICKS_HOST}/api/2.0/dbfs/read?path={test_path}&length=1000"
+ verify_response = requests.get(verify_url, headers=headers)
+
+ if (verify_response.status_code == 200 and
+ "data" in verify_response.json() and
+ test_content == base64.b64decode(verify_response.json()["data"]).decode('utf-8')):
+ print("✅ Successfully verified uploaded content")
+ else:
+ print(f"❌ Failed to verify uploaded content: {verify_response.status_code}")
+ else:
+ print(f"❌ Failed to close handle: {close_response.status_code}")
+ else:
+ print(f"❌ Failed to add data block: {add_block_response.status_code}")
+ else:
+ print(f"❌ Failed to create upload handle: {create_response.status_code}")
+
+ return True
+ else:
+ print("❌ Response missing expected 'files' field")
+ return False
+ else:
+ # Method 2: Try with curl as a fallback (this is our current implementation)
+ print(f"Direct requests failed, trying curl: {response.status_code}")
+
+ cmd = [
+ "curl", "-s",
+ "-H", f"Authorization: Bearer {DATABRICKS_TOKEN}",
+ "-H", "Content-Type: application/json",
+ f"https://{DATABRICKS_HOST}/api/2.0/dbfs/list?path=/"
+ ]
+
+ result = subprocess.run(cmd, capture_output=True, text=True)
+
+ if result.returncode == 0:
+ try:
+ data = json.loads(result.stdout)
+ if "files" in data:
+ print("✅ Successfully accessed DBFS API using curl")
+ return True
+ else:
+ print("❌ Response missing expected 'files' field (curl)")
+ return False
+ except json.JSONDecodeError:
+ print(f"❌ Failed to parse curl response as JSON")
+ return False
+ else:
+ print(f"❌ Curl command failed: {result.stderr}")
+ return False
+ except Exception as e:
+ print(f"❌ Error testing DBFS API: {str(e)}")
+ return False
+
+def test_workspace_api():
+ """Test Workspace API functionality"""
+ import requests
+ import base64
+
+ print("\nTesting Workspace API...")
+
+ try:
+ headers = {
+ "Authorization": f"Bearer {DATABRICKS_TOKEN}",
+ "Content-Type": "application/json"
+ }
+
+ # The workspace list endpoint requires a path parameter
+ url = f"https://{DATABRICKS_HOST}/api/2.0/workspace/list"
+ data = {"path": "/"}
+ response = requests.get(url, headers=headers, params=data)
+
+ if response.status_code == 200:
+ data = response.json()
+ if "objects" in data:
+ print("✅ Successfully accessed Workspace API (list)")
+
+ # Test notebook import
+ test_notebook_path = "/Users/test-import-notebook"
+ test_notebook_content = "# Test Notebook\n\nprint('Hello, Databricks!')"
+
+ # Encode content
+ encoded_content = base64.b64encode(test_notebook_content.encode('utf-8')).decode('utf-8')
+
+ import_url = f"https://{DATABRICKS_HOST}/api/2.0/workspace/import"
+ import_data = {
+ "path": test_notebook_path,
+ "content": encoded_content,
+ "language": "PYTHON",
+ "format": "SOURCE",
+ "overwrite": True
+ }
+
+ import_response = requests.post(import_url, headers=headers, json=import_data)
+
+ if import_response.status_code == 200:
+ print("✅ Successfully accessed Workspace API (import)")
+
+ # Test notebook export
+ export_url = f"https://{DATABRICKS_HOST}/api/2.0/workspace/export"
+ export_data = {
+ "path": test_notebook_path,
+ "format": "SOURCE"
+ }
+
+ export_response = requests.get(export_url, headers=headers, params=export_data)
+
+ if export_response.status_code == 200 and "content" in export_response.json():
+ exported_content = base64.b64decode(export_response.json()["content"]).decode('utf-8')
+
+ if test_notebook_content in exported_content:
+ print("✅ Successfully accessed Workspace API (export)")
+ else:
+ print("❌ Exported content does not match imported content")
+ else:
+ print(f"❌ Failed to access Workspace API (export): {export_response.status_code}")
+ else:
+ print(f"❌ Failed to access Workspace API (import): {import_response.status_code}")
+
+ return True
+ else:
+ print("❌ Response missing expected 'objects' field")
+ return False
+ else:
+ print(f"❌ Failed to access Workspace API: {response.status_code} - {response.text}")
+ return False
+ except Exception as e:
+ print(f"❌ Error testing Workspace API: {str(e)}")
+ return False
+
def test_sql_connection():
"""Test connection to Databricks SQL warehouse"""
print("\nTesting Databricks SQL warehouse connection...")
@@ -112,6 +620,12 @@ def test_sql_connection():
api_ok = test_databricks_api()
sql_ok = test_sql_connection()
+ clusters_ok = test_clusters_api()
+ jobs_ok = test_jobs_api()
+ dbfs_ok = test_dbfs_api()
+ workspace_ok = test_workspace_api()
+ instance_pools_ok = test_instance_pools_api()
+ unity_catalog_ok = test_unity_catalog_api()
# Summary
print("\nTest Summary")
@@ -119,8 +633,14 @@ def test_sql_connection():
print(f"Environment Variables: {'✅ OK' if env_ok else '❌ Failed'}")
print(f"Databricks API: {'✅ OK' if api_ok else '❌ Failed'}")
print(f"Databricks SQL: {'✅ OK' if sql_ok else '❌ Failed'}")
+ print(f"Clusters API: {'✅ OK' if clusters_ok else '❌ Failed'}")
+ print(f"Jobs API: {'✅ OK' if jobs_ok else '❌ Failed'}")
+ print(f"DBFS API: {'✅ OK' if dbfs_ok else '❌ Failed'}")
+ print(f"Workspace API: {'✅ OK' if workspace_ok else '❌ Failed'}")
+ print(f"Instance Pools API: {'✅ OK' if instance_pools_ok else '❌ Failed'}")
+ print(f"Unity Catalog API: {'✅ OK' if unity_catalog_ok else '❌ Failed'}")
- if env_ok and api_ok and sql_ok:
+ if env_ok and api_ok and sql_ok and clusters_ok and jobs_ok and dbfs_ok and workspace_ok and instance_pools_ok and unity_catalog_ok:
print("\n✅ All tests passed! Your Databricks MCP server should work correctly.")
sys.exit(0)
else: