diff --git a/.dockerignore b/.dockerignore new file mode 100644 index 0000000..037f15c --- /dev/null +++ b/.dockerignore @@ -0,0 +1,50 @@ +# Git +.git +.gitignore + +# Python +__pycache__/ +*.py[cod] +*$py.class +*.so +.Python +env/ +build/ +develop-eggs/ +dist/ +downloads/ +eggs/ +.eggs/ +lib/ +lib64/ +parts/ +sdist/ +var/ +wheels/ +*.egg-info/ +.installed.cfg +*.egg + +# Virtual Environment +venv/ +ENV/ + +# IDE +.idea/ +.vscode/ +*.swp +*.swo + +# Environment variables +.env +.env.* + +# Docker +.docker/ + +# Logs +*.log + +# Local development +.DS_Store +Thumbs.db \ No newline at end of file diff --git a/Dockerfile b/Dockerfile new file mode 100644 index 0000000..ace5e3a --- /dev/null +++ b/Dockerfile @@ -0,0 +1,33 @@ +# Use an official Python runtime as a parent image +FROM python:3.10-slim + +# Set the working directory in the container +WORKDIR /app + +# Install curl for DBFS API calls +RUN apt-get update && apt-get install -y curl && rm -rf /var/lib/apt/lists/* + +# Copy requirements first to leverage Docker cache +COPY requirements.txt . + +# Install any needed packages specified in requirements.txt +# Use --no-cache-dir to reduce image size +RUN pip install --no-cache-dir -r requirements.txt + +# Optional: Clean up build dependencies to reduce image size +# RUN apt-get purge -y --auto-remove build-essential + +# Copy the rest of the application code into the container at /app +COPY . . + +# Make port 8000 available to the world outside this container +# (MCP servers typically run on port 8000 by default) +EXPOSE 8000 + +# Define environment variables (these will be overridden by docker run -e flags) +ENV DATABRICKS_HOST="" +ENV DATABRICKS_TOKEN="" +ENV DATABRICKS_HTTP_PATH="" + +# Run main.py when the container launches +CMD ["python", "main.py"] \ No newline at end of file diff --git a/databricks_mcp_tools.md b/databricks_mcp_tools.md new file mode 100644 index 0000000..1ecf57f --- /dev/null +++ b/databricks_mcp_tools.md @@ -0,0 +1,157 @@ +# Databricks MCP Server Tools Guide + +This guide outlines the available tools and resources provided by the Databricks MCP server. + +## Server Configuration + +The server is configured in the MCP settings file with: +```json +{ + "mcpServers": { + "databricks-server": { + "command": "python", + "args": ["main.py"], + "disabled": false, + "alwaysAllow": ["list_jobs"], + "env": {}, + "cwd": "/Users/maheidem/Documents/dev/mcp-databricks-server" + } + } +} +``` + +## Available Tools + +### 1. run_sql_query +Execute SQL queries on Databricks SQL warehouse. + +**Parameters:** +- sql (string): SQL query to execute + +**Example:** +```python + +databricks-server +run_sql_query + +{ + "sql": "SELECT * FROM my_database.my_table LIMIT 10" +} + + +``` + +**Returns:** Results in markdown table format + +### 2. list_jobs +List all Databricks jobs. This tool is in alwaysAllow list. + +**Parameters:** None + +**Example:** +```python + +databricks-server +list_jobs + +{} + + +``` + +**Returns:** Job list in markdown table format with columns: +- Job ID +- Job Name +- Created By + +### 3. get_job_status +Get the status of a specific Databricks job. + +**Parameters:** +- job_id (integer): ID of the job to get status for + +**Example:** +```python + +databricks-server +get_job_status + +{ + "job_id": 123 +} + + +``` + +**Returns:** Job runs in markdown table format with columns: +- Run ID +- State +- Start Time +- End Time +- Duration + +### 4. get_job_details +Get detailed information about a specific Databricks job. + +**Parameters:** +- job_id (integer): ID of the job to get details for + +**Example:** +```python + +databricks-server +get_job_details + +{ + "job_id": 123 +} + + +``` + +**Returns:** Detailed job information in markdown format including: +- Job Name +- Job ID +- Created Time +- Creator +- Tasks (if any) + +## Available Resources + +### schema://tables +Lists available tables in the Databricks SQL warehouse. + +**Example:** +```python + +databricks-server +schema://tables + +``` + +**Returns:** List of tables with their database and schema information. + +## Usage Flow + +```mermaid +graph TD + A[Start Server] --> B[MCP System Auto-connects] + B --> C[Tools Available] + C --> D[Use Tools] + + subgraph "Tool Usage Flow" + D --> E[List Jobs] + E --> F[Get Job Details] + F --> G[Get Job Status] + G --> H[Run SQL Query] + end +``` + +## Requirements + +The server requires the following environment variables to be set: +- DATABRICKS_TOKEN +- DATABRICKS_HOST +- DATABRICKS_HTTP_PATH + +These are already configured in the .env file. \ No newline at end of file diff --git a/main.py b/main.py index b82d0fb..c673d3c 100644 --- a/main.py +++ b/main.py @@ -1,11 +1,13 @@ import os -from typing import List, Dict, Any, Optional +from typing import Dict from dotenv import load_dotenv from databricks.sql import connect from databricks.sql.client import Connection from mcp.server.fastmcp import FastMCP import requests import json +import urllib.parse +import subprocess # Load environment variables load_dotenv() @@ -45,7 +47,7 @@ def databricks_api_request(endpoint: str, method: str = "GET", data: Dict = None url = f"https://{DATABRICKS_HOST}/api/2.0/{endpoint}" if method.upper() == "GET": - response = requests.get(url, headers=headers) + response = requests.get(url, headers=headers, params=data) elif method.upper() == "POST": response = requests.post(url, headers=headers, json=data) else: @@ -76,9 +78,7 @@ def get_schema() -> str: @mcp.tool() def run_sql_query(sql: str) -> str: """Execute SQL queries on Databricks SQL warehouse""" - print(sql) conn = get_databricks_connection() - print("connected") try: cursor = conn.cursor() @@ -135,6 +135,85 @@ def list_jobs() -> str: except Exception as e: return f"Error listing jobs: {str(e)}" +@mcp.tool() +def create_job( + name: str, + cluster_id: str = None, + new_cluster: Dict = None, + notebook_path: str = None, + spark_jar_task: Dict = None, + spark_python_task: Dict = None, + schedule: Dict = None, + timeout_seconds: int = None, + max_concurrent_runs: int = None, + max_retries: int = None +) -> str: + """Create a new Databricks job + + Args: + name: The name of the job + cluster_id: ID of existing cluster to run the job on (specify either cluster_id or new_cluster) + new_cluster: Specification for a new cluster (specify either cluster_id or new_cluster) + notebook_path: Path to the notebook to run (specify a task type) + spark_jar_task: Specification for a JAR task (specify a task type) + spark_python_task: Specification for a Python task (specify a task type) + schedule: Job schedule configuration + timeout_seconds: Timeout in seconds for each run + max_concurrent_runs: Maximum number of concurrent runs + max_retries: Maximum number of retries per run + """ + try: + # Prepare the job settings + settings = { + "name": name + } + + # Set either existing cluster or new cluster + if cluster_id: + settings["existing_cluster_id"] = cluster_id + elif new_cluster: + settings["new_cluster"] = new_cluster + else: + return "Error: Either cluster_id or new_cluster must be specified." + + # Set task definition - only one type of task can be set + task_set = False + if notebook_path: + settings["notebook_task"] = {"notebook_path": notebook_path} + task_set = True + if spark_jar_task and not task_set: + settings["spark_jar_task"] = spark_jar_task + task_set = True + if spark_python_task and not task_set: + settings["spark_python_task"] = spark_python_task + task_set = True + + if not task_set: + return "Error: You must specify a task type (notebook_path, spark_jar_task, or spark_python_task)." + + # Add optional parameters + if schedule: + settings["schedule"] = schedule + if timeout_seconds: + settings["timeout_seconds"] = timeout_seconds + if max_concurrent_runs: + settings["max_concurrent_runs"] = max_concurrent_runs + if max_retries: + settings["max_retries"] = max_retries + + # Make the API request to create the job + response = databricks_api_request("jobs/create", method="POST", data={"name": name, "settings": settings}) + + # The response contains the job_id of the newly created job + job_id = response.get("job_id") + + if job_id: + return f"Job created successfully with job ID: {job_id}\n\nYou can view job details with: get_job_details({job_id})" + else: + return "Job creation failed. No job ID returned." + except Exception as e: + return f"Error creating job: {str(e)}" + @mcp.tool() def get_job_status(job_id: int) -> str: """Get the status of a specific Databricks job""" @@ -212,10 +291,653 @@ def get_job_details(job_id: int) -> str: except Exception as e: return f"Error getting job details: {str(e)}" +@mcp.tool() +def list_job_runs( + job_id: int = None, + active_only: bool = None, + completed_only: bool = None, + limit: int = None, + run_type: str = None, + expand_tasks: bool = None, + start_time_from: int = None, + start_time_to: int = None, + page_token: str = None +) -> str: + """List job runs with optional filtering parameters + + Args: + job_id: Optional job ID to filter runs for a specific job + active_only: If true, only active runs are included + completed_only: If true, only completed runs are included + limit: Number of runs to return (1-25, default 20) + run_type: Type of runs (JOB_RUN, WORKFLOW_RUN, SUBMIT_RUN) + expand_tasks: Whether to include task details + start_time_from: Show runs that started at or after this UTC timestamp in milliseconds + start_time_to: Show runs that started at or before this UTC timestamp in milliseconds + page_token: Token for pagination + """ + try: + # Build query parameters + params = {} + + if job_id is not None: + params["job_id"] = job_id + if active_only is not None: + params["active_only"] = active_only + if completed_only is not None: + params["completed_only"] = completed_only + if limit is not None: + params["limit"] = limit + if run_type is not None: + params["run_type"] = run_type + if expand_tasks is not None: + params["expand_tasks"] = expand_tasks + if start_time_from is not None: + params["start_time_from"] = start_time_from + if start_time_to is not None: + params["start_time_to"] = start_time_to + if page_token is not None: + params["page_token"] = page_token + + # Make API request + response = databricks_api_request("jobs/runs/list", method="GET", data=params) + + if not response.get("runs"): + if job_id: + return f"No runs found for job ID {job_id}." + else: + return "No job runs found." + + runs = response.get("runs", []) + + # Format as markdown table + table = "| Run ID | Job ID | State | Creator | Start Time | End Time | Duration |\n" + table += "| ------ | ------ | ----- | ------- | ---------- | -------- | -------- |\n" + + import datetime + + for run in runs: + run_id = run.get("run_id", "N/A") + run_job_id = run.get("job_id", "N/A") + run_creator = run.get("creator_user_name", "N/A") + + # Get state information + life_cycle_state = run.get("state", {}).get("life_cycle_state", "N/A") + result_state = run.get("state", {}).get("result_state", "") + state = f"{life_cycle_state}" if not result_state else f"{life_cycle_state} ({result_state})" + + # Process timestamps + start_time = run.get("start_time", 0) + end_time = run.get("end_time", 0) + + start_time_str = datetime.datetime.fromtimestamp(start_time / 1000).strftime('%Y-%m-%d %H:%M:%S') if start_time else "N/A" + end_time_str = datetime.datetime.fromtimestamp(end_time / 1000).strftime('%Y-%m-%d %H:%M:%S') if end_time else "N/A" + + # Calculate duration + if start_time and end_time: + duration = f"{(end_time - start_time) / 1000:.2f}s" + else: + duration = "N/A" + + table += f"| {run_id} | {run_job_id} | {state} | {run_creator} | {start_time_str} | {end_time_str} | {duration} |\n" + + # Add pagination information + result = table + + next_page = response.get("next_page_token") + prev_page = response.get("prev_page_token") + + if next_page or prev_page: + result += "\n### Pagination\n" + if next_page: + result += f"More results available. Use page_token='{next_page}' to view the next page.\n" + if prev_page: + result += f"Use page_token='{prev_page}' to view the previous page.\n" + + return result + + except Exception as e: + return f"Error listing job runs: {str(e)}" + +# Clusters API tools +@mcp.tool() +def list_clusters() -> str: + """List all Databricks clusters""" + try: + response = databricks_api_request("clusters/list") + + if not response.get("clusters"): + return "No clusters found." + + clusters = response.get("clusters", []) + + # Format as markdown table + table = "| Cluster ID | Cluster Name | State | Autoscale | Workers |\n" + table += "| ---------- | ------------ | ----- | --------- | ------- |\n" + + for cluster in clusters: + cluster_id = cluster.get("cluster_id", "N/A") + cluster_name = cluster.get("cluster_name", "N/A") + state = cluster.get("state", "N/A") + + autoscale = "Yes" if cluster.get("autoscale", {}).get("min_workers") else "No" + workers = cluster.get("num_workers", "N/A") + + table += f"| {cluster_id} | {cluster_name} | {state} | {autoscale} | {workers} |\n" + + return table + except Exception as e: + return f"Error listing clusters: {str(e)}" + +@mcp.tool() +def get_cluster_details(cluster_id: str) -> str: + """Get detailed information about a specific Databricks cluster""" + try: + response = databricks_api_request(f"clusters/get?cluster_id={cluster_id}") + + result = f"## Cluster Details\n\n" + result += f"- **Cluster ID:** {cluster_id}\n" + result += f"- **Cluster Name:** {response.get('cluster_name', 'N/A')}\n" + result += f"- **State:** {response.get('state', 'N/A')}\n" + result += f"- **Spark Version:** {response.get('spark_version', 'N/A')}\n" + + # Add node type info + node_type_id = response.get('node_type_id', 'N/A') + result += f"- **Node Type ID:** {node_type_id}\n" + + # Add autoscaling info if applicable + if 'autoscale' in response: + min_workers = response['autoscale'].get('min_workers', 'N/A') + max_workers = response['autoscale'].get('max_workers', 'N/A') + result += f"- **Autoscale:** {min_workers} to {max_workers} workers\n" + else: + result += f"- **Num Workers:** {response.get('num_workers', 'N/A')}\n" + + return result + except Exception as e: + return f"Error getting cluster details: {str(e)}" + +# Instance Pools API tools +@mcp.tool() +def list_instance_pools() -> str: + """List all instance pools in the workspace""" + try: + response = databricks_api_request("instance-pools/list") + + if not response.get("instance_pools"): + return "No instance pools found." + + pools = response.get("instance_pools", []) + + # Format as markdown table + table = "| Pool ID | Pool Name | Instance Type | Min Idle | Max Capacity | State |\n" + table += "| ------- | --------- | ------------- | -------- | ------------ | ----- |\n" + + for pool in pools: + pool_id = pool.get("instance_pool_id", "N/A") + pool_name = pool.get("instance_pool_name", "N/A") + instance_type = pool.get("node_type_id", "N/A") + min_idle = pool.get("min_idle_instances", 0) + max_capacity = pool.get("max_capacity", 0) + state = pool.get("state", "N/A") + + table += f"| {pool_id} | {pool_name} | {instance_type} | {min_idle} | {max_capacity} | {state} |\n" + + return table + except Exception as e: + return f"Error listing instance pools: {str(e)}" + +@mcp.tool() +def create_instance_pool( + name: str, + node_type_id: str, + min_idle_instances: int = 0, + max_capacity: int = 10, + idle_instance_autotermination_minutes: int = 60, + preload_spark_versions: List[str] = None +) -> str: + """Create a new instance pool + + Args: + name: The name of the instance pool + node_type_id: The node type for the instances in the pool (e.g., "Standard_DS3_v2") + min_idle_instances: Minimum number of idle instances to keep in the pool + max_capacity: Maximum number of instances the pool can contain + idle_instance_autotermination_minutes: Number of minutes that idle instances are maintained + preload_spark_versions: Spark versions to preload on the instances (optional) + """ + try: + # Build the request data + data = { + "instance_pool_name": name, + "node_type_id": node_type_id, + "min_idle_instances": min_idle_instances, + "max_capacity": max_capacity, + "idle_instance_autotermination_minutes": idle_instance_autotermination_minutes + } + + # Add preloaded Spark versions if specified + if preload_spark_versions: + data["preloaded_spark_versions"] = preload_spark_versions + + # Make the API request + response = databricks_api_request("instance-pools/create", method="POST", data=data) + + # The response contains the instance_pool_id of the newly created pool + pool_id = response.get("instance_pool_id") + + if pool_id: + return f"Instance pool created successfully with ID: {pool_id}" + else: + return "Instance pool creation failed. No pool ID returned." + except Exception as e: + return f"Error creating instance pool: {str(e)}" + +# DBFS API tools +@mcp.tool() +def list_dbfs_files(path: str = "/") -> str: + """List files in a DBFS directory""" + try: + if not all([DATABRICKS_HOST, DATABRICKS_TOKEN]): + raise ValueError("Missing required Databricks API credentials in .env file") + + # Direct API call to match the exact curl command that works + import subprocess + import json + + # Execute the curl command that we know works + cmd = [ + "curl", "-s", + "-H", f"Authorization: Bearer {DATABRICKS_TOKEN}", + "-H", "Content-Type: application/json", + f"https://{DATABRICKS_HOST}/api/2.0/dbfs/list?path={path}" + ] + + result = subprocess.run(cmd, capture_output=True, text=True) + + if result.returncode != 0: + raise ValueError(f"Command failed with exit code {result.returncode}: {result.stderr}") + + try: + data = json.loads(result.stdout) + except json.JSONDecodeError: + raise ValueError(f"Failed to parse response as JSON: {result.stdout}") + + if not data.get("files"): + return f"No files found in {path}." + + files = data.get("files", []) + + # Format as markdown table + table = "| File Path | Size | Is Directory |\n" + table += "| --------- | ---- | ------------ |\n" + + for file in files: + file_path = file.get("path", "N/A") + size = file.get("file_size", 0) + is_dir = "Yes" if file.get("is_dir", False) else "No" + + table += f"| {file_path} | {size} | {is_dir} |\n" + + return table + except Exception as e: + print(f"Error details: {str(e)}") + return f"Error listing DBFS files: {str(e)}" + +@mcp.tool() +def read_dbfs_file(path: str, length: int = 1000) -> str: + """Read the contents of a file from DBFS (limited to 1MB)""" + try: + if not all([DATABRICKS_HOST, DATABRICKS_TOKEN]): + raise ValueError("Missing required Databricks API credentials in .env file") + + headers = { + "Authorization": f"Bearer {DATABRICKS_TOKEN}", + "Content-Type": "application/json" + } + + # Build URL with query parameters directly + encoded_path = urllib.parse.quote(path) + url = f"https://{DATABRICKS_HOST}/api/2.0/dbfs/read?path={encoded_path}&length={length}" + + print(f"Requesting URL: {url}") + response = requests.get(url, headers=headers) + print(f"Status code: {response.status_code}") + print(f"Response: {response.text[:100]}...") # Print just the beginning to avoid flooding logs + + response.raise_for_status() + data = response.json() + + if "data" not in data: + return f"No data found in file {path}." + + import base64 + file_content = base64.b64decode(data["data"]).decode("utf-8") + + return f"## File Contents: {path}\n\n```\n{file_content}\n```" + except Exception as e: + print(f"Error details: {str(e)}") + return f"Error reading DBFS file: {str(e)}" + +@mcp.tool() +def upload_to_dbfs(file_content: str, dbfs_path: str, overwrite: bool = False) -> str: + """Upload content to a file in DBFS + + Args: + file_content: The text content to upload to DBFS + dbfs_path: The destination path in DBFS + overwrite: Whether to overwrite an existing file + """ + try: + if not all([DATABRICKS_HOST, DATABRICKS_TOKEN]): + raise ValueError("Missing required Databricks API credentials in .env file") + + import base64 + + headers = { + "Authorization": f"Bearer {DATABRICKS_TOKEN}", + "Content-Type": "application/json" + } + + # DBFS file upload requires three steps: + # 1. Create a handle + create_url = f"https://{DATABRICKS_HOST}/api/2.0/dbfs/create" + create_data = { + "path": dbfs_path, + "overwrite": overwrite + } + + create_response = requests.post(create_url, headers=headers, json=create_data) + create_response.raise_for_status() + + handle = create_response.json().get("handle") + if not handle: + return f"Error: Failed to create a handle for the file upload." + + # 2. Add blocks of data + # Convert string content to bytes and then encode as base64 + content_bytes = file_content.encode('utf-8') + encoded_content = base64.b64encode(content_bytes).decode('utf-8') + + add_block_url = f"https://{DATABRICKS_HOST}/api/2.0/dbfs/add-block" + add_block_data = { + "handle": handle, + "data": encoded_content + } + + add_block_response = requests.post(add_block_url, headers=headers, json=add_block_data) + add_block_response.raise_for_status() + + # 3. Close the handle + close_url = f"https://{DATABRICKS_HOST}/api/2.0/dbfs/close" + close_data = { + "handle": handle + } + + close_response = requests.post(close_url, headers=headers, json=close_data) + close_response.raise_for_status() + + return f"Successfully uploaded content to {dbfs_path}.\n\nYou can view the file with: read_dbfs_file('{dbfs_path}')" + except Exception as e: + print(f"Error details: {str(e)}") + return f"Error uploading to DBFS: {str(e)}" + +# Workspace API tools +@mcp.tool() +def list_workspace(path: str = "/") -> str: + """List notebooks and directories in a workspace""" + try: + if not all([DATABRICKS_HOST, DATABRICKS_TOKEN]): + raise ValueError("Missing required Databricks API credentials in .env file") + + headers = { + "Authorization": f"Bearer {DATABRICKS_TOKEN}", + "Content-Type": "application/json" + } + + # Build URL with query parameters directly + encoded_path = urllib.parse.quote(path) + url = f"https://{DATABRICKS_HOST}/api/2.0/workspace/list?path={encoded_path}" + + print(f"Requesting URL: {url}") + response = requests.get(url, headers=headers) + print(f"Status code: {response.status_code}") + print(f"Response: {response.text}") + + response.raise_for_status() + data = response.json() + + if not data.get("objects"): + return f"No objects found in {path}." + + objects = data.get("objects", []) + + # Format as markdown table + table = "| Path | Object Type | Language |\n" + table += "| ---- | ----------- | -------- |\n" + + for obj in objects: + obj_path = obj.get("path", "N/A") + obj_type = obj.get("object_type", "N/A") + language = obj.get("language", "N/A") if obj_type == "NOTEBOOK" else "N/A" + + table += f"| {obj_path} | {obj_type} | {language} |\n" + + return table + except Exception as e: + print(f"Error details: {str(e)}") + return f"Error listing workspace: {str(e)}" + +@mcp.tool() +def export_notebook(path: str, format: str = "SOURCE") -> str: + """Export a notebook from the workspace + + Args: + path: Path to the notebook in the workspace + format: Export format (SOURCE, HTML, JUPYTER, DBC, R_MARKDOWN) + """ + try: + if not all([DATABRICKS_HOST, DATABRICKS_TOKEN]): + raise ValueError("Missing required Databricks API credentials in .env file") + + headers = { + "Authorization": f"Bearer {DATABRICKS_TOKEN}", + "Content-Type": "application/json" + } + + # Build request + export_url = f"https://{DATABRICKS_HOST}/api/2.0/workspace/export" + export_data = { + "path": path, + "format": format + } + + response = requests.get(export_url, headers=headers, params=export_data) + response.raise_for_status() + data = response.json() + + if "content" not in data: + return f"No content found for notebook at {path}." + + import base64 + notebook_content = base64.b64decode(data["content"]).decode("utf-8") + + # Determine file extension based on format + extension = "" + if format == "SOURCE": + # Try to determine language + language_url = f"https://{DATABRICKS_HOST}/api/2.0/workspace/get-status" + language_response = requests.get(language_url, headers=headers, params={"path": path}) + + if language_response.status_code == 200: + language_data = language_response.json() + language = language_data.get("language", "").lower() + + if language == "python": + extension = ".py" + elif language == "scala": + extension = ".scala" + elif language == "r": + extension = ".r" + elif language == "sql": + extension = ".sql" + + elif format == "HTML": + extension = ".html" + elif format == "JUPYTER": + extension = ".ipynb" + elif format == "DBC": + extension = ".dbc" + elif format == "R_MARKDOWN": + extension = ".Rmd" + + # Return the notebook content + filename = path.split("/")[-1] + extension + return f"## Exported Notebook: {filename}\n\n```\n{notebook_content}\n```" + except Exception as e: + print(f"Error details: {str(e)}") + return f"Error exporting notebook: {str(e)}" + +@mcp.tool() +def import_notebook(content: str, path: str, language: str = "PYTHON", format: str = "SOURCE", overwrite: bool = False) -> str: + """Import a notebook into the workspace + + Args: + content: The notebook content + path: Destination path in the workspace + language: Notebook language (PYTHON, SCALA, SQL, R) + format: Import format (SOURCE, HTML, JUPYTER, DBC, R_MARKDOWN) + overwrite: Whether to overwrite existing notebook + """ + try: + if not all([DATABRICKS_HOST, DATABRICKS_TOKEN]): + raise ValueError("Missing required Databricks API credentials in .env file") + + headers = { + "Authorization": f"Bearer {DATABRICKS_TOKEN}", + "Content-Type": "application/json" + } + + # Encode the content as base64 + import base64 + content_bytes = content.encode("utf-8") + encoded_content = base64.b64encode(content_bytes).decode("utf-8") + + # Build request + import_url = f"https://{DATABRICKS_HOST}/api/2.0/workspace/import" + import_data = { + "path": path, + "content": encoded_content, + "language": language, + "format": format, + "overwrite": overwrite + } + + response = requests.post(import_url, headers=headers, json=import_data) + response.raise_for_status() + + return f"Notebook successfully imported to {path}" + except Exception as e: + print(f"Error details: {str(e)}") + return f"Error importing notebook: {str(e)}" + +# Unity Catalog API tools +@mcp.tool() +def list_catalogs() -> str: + """List all catalogs in Unity Catalog""" + try: + response = databricks_api_request("unity-catalog/catalogs") + + if not response.get("catalogs"): + return "No catalogs found." + + catalogs = response.get("catalogs", []) + + # Format as markdown table + table = "| Catalog Name | Comment | Provider | Properties |\n" + table += "| ------------ | ------- | -------- | ---------- |\n" + + for catalog in catalogs: + name = catalog.get("name", "N/A") + comment = catalog.get("comment", "") + provider = catalog.get("provider", {}).get("name", "N/A") + properties = ", ".join([f"{k}: {v}" for k, v in catalog.get("properties", {}).items()]) if catalog.get("properties") else "None" + + table += f"| {name} | {comment} | {provider} | {properties} |\n" + + return table + except Exception as e: + return f"Error listing catalogs: {str(e)}" + +@mcp.tool() +def create_catalog(name: str, comment: str = None, properties: Dict = None) -> str: + """Create a new catalog in Unity Catalog + + Args: + name: Name of the catalog + comment: Description of the catalog (optional) + properties: Key-value properties for the catalog (optional) + """ + try: + # Prepare the request data + data = { + "name": name + } + + if comment: + data["comment"] = comment + + if properties: + data["properties"] = properties + + # Make the API request + response = databricks_api_request("unity-catalog/catalogs", method="POST", data=data) + + # Check if the catalog was created successfully + if response.get("name") == name: + return f"Catalog '{name}' was created successfully." + else: + return f"Catalog creation might have failed. Please check with list_catalogs()." + except Exception as e: + return f"Error creating catalog: {str(e)}" + +@mcp.tool() +def create_schema(catalog_name: str, schema_name: str, comment: str = None, properties: Dict = None) -> str: + """Create a new schema in Unity Catalog + + Args: + catalog_name: Name of the parent catalog + schema_name: Name of the schema + comment: Description of the schema (optional) + properties: Key-value properties for the schema (optional) + """ + try: + # Prepare the request data + data = { + "name": schema_name, + "catalog_name": catalog_name + } + + if comment: + data["comment"] = comment + + if properties: + data["properties"] = properties + + # Make the API request + response = databricks_api_request("unity-catalog/schemas", method="POST", data=data) + + # Check if the schema was created successfully + if response.get("name") == schema_name and response.get("catalog_name") == catalog_name: + return f"Schema '{catalog_name}.{schema_name}' was created successfully." + else: + return f"Schema creation might have failed. Please check using another Unity Catalog API call." + except Exception as e: + return f"Error creating schema: {str(e)}" + # if __name__ == "__main__": # import uvicorn # uvicorn.run("main:app", host="0.0.0.0", port=8000, reload=True) if __name__ == "__main__": - #run_sql_query("SELECT * FROM dev.dev_test.income_survey_dataset LIMIT 10;") mcp.run() \ No newline at end of file diff --git a/requirements.txt b/requirements.txt index 6d3033d..e3c9001 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,8 +1,9 @@ -fastapi>=0.95.0 -uvicorn>=0.22.0 -databricks-sql-connector>=2.4.0 -python-dotenv>=1.0.0 -pydantic>=2.0.0 +fastapi>=0.68.0 +uvicorn>=0.15.0 +python-dotenv>=0.19.0 +requests>=2.26.0 +databricks-sql-connector>=2.0.0 mcp>=0.1.0 -pyarrow>=14.0.1 -requests>=2.31.0 \ No newline at end of file +python-multipart>=0.0.5 +pydantic>=1.8.0 +packaging>=23.0 \ No newline at end of file diff --git a/test_connection.py b/test_connection.py index 5a971ab..347da2d 100644 --- a/test_connection.py +++ b/test_connection.py @@ -8,6 +8,8 @@ import os from dotenv import load_dotenv import sys +import datetime +import base64 # Load environment variables load_dotenv() @@ -62,6 +64,512 @@ def test_databricks_api(): print(f"❌ Error connecting to Databricks API: {str(e)}") return False +def test_clusters_api(): + """Test Clusters API functionality""" + import requests + + print("\nTesting Clusters API...") + + try: + headers = { + "Authorization": f"Bearer {DATABRICKS_TOKEN}", + "Content-Type": "application/json" + } + + # Test clusters/list endpoint + url = f"https://{DATABRICKS_HOST}/api/2.0/clusters/list" + response = requests.get(url, headers=headers) + + if response.status_code == 200: + data = response.json() + if "clusters" in data: + print("✅ Successfully accessed Clusters API (list)") + + # Also test get_cluster_details if clusters exist + if data.get("clusters"): + cluster_id = data["clusters"][0]["cluster_id"] + detail_url = f"https://{DATABRICKS_HOST}/api/2.0/clusters/get?cluster_id={cluster_id}" + detail_response = requests.get(detail_url, headers=headers) + + if detail_response.status_code == 200: + print("✅ Successfully accessed Clusters API (details)") + else: + print(f"❌ Failed to access Clusters API (details): {detail_response.status_code}") + + return True + else: + print("❌ Response missing expected 'clusters' field") + return False + else: + print(f"❌ Failed to access Clusters API: {response.status_code} - {response.text}") + return False + except Exception as e: + print(f"❌ Error testing Clusters API: {str(e)}") + return False + +def test_instance_pools_api(): + """Test Instance Pools API functionality""" + import requests + + print("\nTesting Instance Pools API...") + + try: + headers = { + "Authorization": f"Bearer {DATABRICKS_TOKEN}", + "Content-Type": "application/json" + } + + # Test instance-pools/list endpoint + url = f"https://{DATABRICKS_HOST}/api/2.0/instance-pools/list" + response = requests.get(url, headers=headers) + + if response.status_code == 200: + data = response.json() + if "instance_pools" in data: + print("✅ Successfully accessed Instance Pools API (list)") + + # Get available node types to use for creating a pool + node_types_url = f"https://{DATABRICKS_HOST}/api/2.0/clusters/list-node-types" + node_types_response = requests.get(node_types_url, headers=headers) + + if node_types_response.status_code == 200 and "node_types" in node_types_response.json(): + node_types = node_types_response.json()["node_types"] + if node_types: + # Get the first node type ID + node_type_id = node_types[0]["node_type_id"] + + # Test create instance pool + pool_name = f"Test Pool {datetime.datetime.now().strftime('%Y%m%d%H%M%S')}" + create_data = { + "instance_pool_name": pool_name, + "node_type_id": node_type_id, + "min_idle_instances": 0, + "max_capacity": 2, + "idle_instance_autotermination_minutes": 10 + } + + create_url = f"https://{DATABRICKS_HOST}/api/2.0/instance-pools/create" + create_response = requests.post(create_url, headers=headers, json=create_data) + + if create_response.status_code == 200 and "instance_pool_id" in create_response.json(): + print("✅ Successfully accessed Instance Pools API (create)") + + # Cleanup: Delete the test pool we just created + pool_id = create_response.json()["instance_pool_id"] + delete_url = f"https://{DATABRICKS_HOST}/api/2.0/instance-pools/delete" + delete_data = {"instance_pool_id": pool_id} + delete_response = requests.post(delete_url, headers=headers, json=delete_data) + + if delete_response.status_code == 200: + print("✅ Successfully cleaned up test instance pool") + else: + print(f"⚠️ Could not delete test instance pool: {delete_response.status_code}") + else: + print(f"❌ Failed to access Instance Pools API (create): {create_response.status_code} - {create_response.text}") + else: + print("⚠️ No node types available for testing instance pool creation") + else: + print(f"⚠️ Could not get node types for testing: {node_types_response.status_code}") + + return True + else: + print("❌ Response missing expected 'instance_pools' field") + return False + else: + print(f"❌ Failed to access Instance Pools API: {response.status_code} - {response.text}") + return False + except Exception as e: + print(f"❌ Error testing Instance Pools API: {str(e)}") + return False + +def test_unity_catalog_api(): + """Test Unity Catalog API functionality""" + import requests + + print("\nTesting Unity Catalog API...") + + try: + headers = { + "Authorization": f"Bearer {DATABRICKS_TOKEN}", + "Content-Type": "application/json" + } + + # Test catalogs list endpoint + url = f"https://{DATABRICKS_HOST}/api/2.0/unity-catalog/catalogs" + response = requests.get(url, headers=headers) + + # Check if workspace supports Unity Catalog + if response.status_code == 404 or (response.status_code == 400 and "Unity Catalog is not enabled" in response.text): + print("⚠️ Unity Catalog is not enabled in this workspace - skipping tests") + return True # Skip test but don't mark as failure + + if response.status_code == 200: + # Note: This may return no catalogs, but it should have a "catalogs" key + data = response.json() + if "catalogs" in data: + print("✅ Successfully accessed Unity Catalog API (list catalogs)") + + # Test catalog creation (only if user has permission) + catalog_name = f"test_catalog_{datetime.datetime.now().strftime('%Y%m%d%H%M%S')}" + create_data = { + "name": catalog_name, + "comment": "Test catalog for API testing - will be deleted" + } + + create_url = f"https://{DATABRICKS_HOST}/api/2.0/unity-catalog/catalogs" + create_response = requests.post(create_url, headers=headers, json=create_data) + + if create_response.status_code == 200: + print("✅ Successfully accessed Unity Catalog API (create catalog)") + + # Test schema creation in this catalog + schema_name = f"test_schema_{datetime.datetime.now().strftime('%Y%m%d%H%M%S')}" + schema_data = { + "name": schema_name, + "catalog_name": catalog_name, + "comment": "Test schema for API testing - will be deleted" + } + + schema_url = f"https://{DATABRICKS_HOST}/api/2.0/unity-catalog/schemas" + schema_response = requests.post(schema_url, headers=headers, json=schema_data) + + if schema_response.status_code == 200: + print("✅ Successfully accessed Unity Catalog API (create schema)") + else: + # Not all users can create schemas, so this is not a critical failure + print(f"⚠️ Could not create schema (permissions?): {schema_response.status_code}") + + # Cleanup: Delete the catalog (automatically deletes schemas too) + delete_url = f"https://{DATABRICKS_HOST}/api/2.0/unity-catalog/catalogs/{catalog_name}" + delete_response = requests.delete(delete_url, headers=headers) + + if delete_response.status_code in [200, 204]: + print("✅ Successfully cleaned up test catalog") + else: + print(f"⚠️ Could not delete test catalog: {delete_response.status_code}") + elif create_response.status_code in [403, 401]: + # User likely doesn't have permission to create catalogs + print("⚠️ Insufficient permissions to create catalogs - skipping create test") + else: + print(f"❌ Failed to access Unity Catalog API (create): {create_response.status_code} - {create_response.text}") + + return True + else: + print("❌ Response missing expected 'catalogs' field") + return False + else: + print(f"❌ Failed to access Unity Catalog API: {response.status_code} - {response.text}") + return False + except Exception as e: + print(f"❌ Error testing Unity Catalog API: {str(e)}") + return False + +def test_jobs_api(): + """Test Jobs API functionality""" + import requests + + print("\nTesting Jobs API...") + + try: + headers = { + "Authorization": f"Bearer {DATABRICKS_TOKEN}", + "Content-Type": "application/json" + } + + # Test jobs/list endpoint + url = f"https://{DATABRICKS_HOST}/api/2.0/jobs/list" + response = requests.get(url, headers=headers) + + if response.status_code == 200: + data = response.json() + if "jobs" in data: + print("✅ Successfully accessed Jobs API (list)") + + # Also test job details if jobs exist + if data.get("jobs"): + job_id = data["jobs"][0]["job_id"] + + # Test get_job_details + detail_url = f"https://{DATABRICKS_HOST}/api/2.0/jobs/get?job_id={job_id}" + detail_response = requests.get(detail_url, headers=headers) + + if detail_response.status_code == 200: + print("✅ Successfully accessed Jobs API (details)") + else: + print(f"❌ Failed to access Jobs API (details): {detail_response.status_code}") + + # Test get_job_status + status_url = f"https://{DATABRICKS_HOST}/api/2.0/jobs/runs/list" + status_response = requests.get(status_url, headers=headers, params={"job_id": job_id}) + + if status_response.status_code == 200 and "runs" in status_response.json(): + print("✅ Successfully accessed Jobs API (status)") + else: + print(f"❌ Failed to access Jobs API (status): {status_response.status_code}") + + # Test list_job_runs + runs_url = f"https://{DATABRICKS_HOST}/api/2.2/jobs/runs/list" + runs_response = requests.get(runs_url, headers=headers) + + if runs_response.status_code == 200 and "runs" in runs_response.json(): + print("✅ Successfully accessed Jobs API (runs list)") + else: + print(f"❌ Failed to access Jobs API (runs list): {runs_response.status_code}") + + # Test job creation API + # Find an available cluster to use + cluster_url = f"https://{DATABRICKS_HOST}/api/2.0/clusters/list" + cluster_response = requests.get(cluster_url, headers=headers) + + if cluster_response.status_code == 200 and cluster_response.json().get("clusters"): + # Find the first RUNNING cluster to use + clusters = cluster_response.json()["clusters"] + running_clusters = [c for c in clusters if c.get("state") == "RUNNING"] + + if running_clusters: + test_cluster_id = running_clusters[0]["cluster_id"] + + # Create a test job + create_url = f"https://{DATABRICKS_HOST}/api/2.0/jobs/create" + job_name = f"Test Job {datetime.datetime.now().strftime('%Y%m%d%H%M%S')}" + + job_data = { + "name": job_name, + "settings": { + "name": job_name, + "existing_cluster_id": test_cluster_id, + "notebook_task": { + "notebook_path": "/Shared/test-notebook" if os.path.exists("/Shared/test-notebook") else "/Users/test-notebook" + }, + "max_retries": 0, + "timeout_seconds": 300 + } + } + + create_response = requests.post(create_url, headers=headers, json=job_data) + + if create_response.status_code == 200 and "job_id" in create_response.json(): + print("✅ Successfully accessed Jobs API (create)") + else: + print(f"❌ Failed to access Jobs API (create): {create_response.status_code} - {create_response.text}") + else: + print("⚠️ Skipping job creation test: No running clusters found") + else: + print("⚠️ Skipping job creation test: Couldn't list clusters") + + return True + else: + print("❌ Response missing expected 'jobs' field") + return False + else: + print(f"❌ Failed to access Jobs API: {response.status_code} - {response.text}") + return False + except Exception as e: + print(f"❌ Error testing Jobs API: {str(e)}") + return False + +def test_dbfs_api(): + """Test DBFS API functionality""" + import requests + import json + import subprocess + + print("\nTesting DBFS API...") + + try: + # Method 1: Direct API call with query parameters (matches our implementation) + headers = { + "Authorization": f"Bearer {DATABRICKS_TOKEN}", + "Content-Type": "application/json" + } + + # The DBFS list endpoint requires a path parameter + url = f"https://{DATABRICKS_HOST}/api/2.0/dbfs/list?path=/" + response = requests.get(url, headers=headers) + + if response.status_code == 200: + data = response.json() + if "files" in data: + print("✅ Successfully accessed DBFS API (list)") + + # Also test read_dbfs_file if README.md exists + if any(file.get("path") == "/databricks-datasets/README.md" for file in data.get("files", [])): + read_url = f"https://{DATABRICKS_HOST}/api/2.0/dbfs/read?path=/databricks-datasets/README.md&length=1000" + read_response = requests.get(read_url, headers=headers) + + if read_response.status_code == 200 and "data" in read_response.json(): + print("✅ Successfully accessed DBFS API (read)") + else: + print(f"❌ Failed to access DBFS API (read): {read_response.status_code}") + + # Test file upload to DBFS + test_content = "This is a test file uploaded via the DBFS API." + test_path = "/FileStore/test-upload.txt" + + # Step 1: Create handle + create_url = f"https://{DATABRICKS_HOST}/api/2.0/dbfs/create" + create_data = { + "path": test_path, + "overwrite": True + } + + create_response = requests.post(create_url, headers=headers, json=create_data) + + if create_response.status_code == 200 and "handle" in create_response.json(): + handle = create_response.json()["handle"] + + # Step 2: Add data block + add_block_url = f"https://{DATABRICKS_HOST}/api/2.0/dbfs/add-block" + encoded_content = base64.b64encode(test_content.encode('utf-8')).decode('utf-8') + + add_block_data = { + "handle": handle, + "data": encoded_content + } + + add_block_response = requests.post(add_block_url, headers=headers, json=add_block_data) + + # Step 3: Close handle + if add_block_response.status_code == 200: + close_url = f"https://{DATABRICKS_HOST}/api/2.0/dbfs/close" + close_data = { + "handle": handle + } + + close_response = requests.post(close_url, headers=headers, json=close_data) + + if close_response.status_code == 200: + print("✅ Successfully accessed DBFS API (upload)") + + # Verify the file was uploaded by reading it back + verify_url = f"https://{DATABRICKS_HOST}/api/2.0/dbfs/read?path={test_path}&length=1000" + verify_response = requests.get(verify_url, headers=headers) + + if (verify_response.status_code == 200 and + "data" in verify_response.json() and + test_content == base64.b64decode(verify_response.json()["data"]).decode('utf-8')): + print("✅ Successfully verified uploaded content") + else: + print(f"❌ Failed to verify uploaded content: {verify_response.status_code}") + else: + print(f"❌ Failed to close handle: {close_response.status_code}") + else: + print(f"❌ Failed to add data block: {add_block_response.status_code}") + else: + print(f"❌ Failed to create upload handle: {create_response.status_code}") + + return True + else: + print("❌ Response missing expected 'files' field") + return False + else: + # Method 2: Try with curl as a fallback (this is our current implementation) + print(f"Direct requests failed, trying curl: {response.status_code}") + + cmd = [ + "curl", "-s", + "-H", f"Authorization: Bearer {DATABRICKS_TOKEN}", + "-H", "Content-Type: application/json", + f"https://{DATABRICKS_HOST}/api/2.0/dbfs/list?path=/" + ] + + result = subprocess.run(cmd, capture_output=True, text=True) + + if result.returncode == 0: + try: + data = json.loads(result.stdout) + if "files" in data: + print("✅ Successfully accessed DBFS API using curl") + return True + else: + print("❌ Response missing expected 'files' field (curl)") + return False + except json.JSONDecodeError: + print(f"❌ Failed to parse curl response as JSON") + return False + else: + print(f"❌ Curl command failed: {result.stderr}") + return False + except Exception as e: + print(f"❌ Error testing DBFS API: {str(e)}") + return False + +def test_workspace_api(): + """Test Workspace API functionality""" + import requests + import base64 + + print("\nTesting Workspace API...") + + try: + headers = { + "Authorization": f"Bearer {DATABRICKS_TOKEN}", + "Content-Type": "application/json" + } + + # The workspace list endpoint requires a path parameter + url = f"https://{DATABRICKS_HOST}/api/2.0/workspace/list" + data = {"path": "/"} + response = requests.get(url, headers=headers, params=data) + + if response.status_code == 200: + data = response.json() + if "objects" in data: + print("✅ Successfully accessed Workspace API (list)") + + # Test notebook import + test_notebook_path = "/Users/test-import-notebook" + test_notebook_content = "# Test Notebook\n\nprint('Hello, Databricks!')" + + # Encode content + encoded_content = base64.b64encode(test_notebook_content.encode('utf-8')).decode('utf-8') + + import_url = f"https://{DATABRICKS_HOST}/api/2.0/workspace/import" + import_data = { + "path": test_notebook_path, + "content": encoded_content, + "language": "PYTHON", + "format": "SOURCE", + "overwrite": True + } + + import_response = requests.post(import_url, headers=headers, json=import_data) + + if import_response.status_code == 200: + print("✅ Successfully accessed Workspace API (import)") + + # Test notebook export + export_url = f"https://{DATABRICKS_HOST}/api/2.0/workspace/export" + export_data = { + "path": test_notebook_path, + "format": "SOURCE" + } + + export_response = requests.get(export_url, headers=headers, params=export_data) + + if export_response.status_code == 200 and "content" in export_response.json(): + exported_content = base64.b64decode(export_response.json()["content"]).decode('utf-8') + + if test_notebook_content in exported_content: + print("✅ Successfully accessed Workspace API (export)") + else: + print("❌ Exported content does not match imported content") + else: + print(f"❌ Failed to access Workspace API (export): {export_response.status_code}") + else: + print(f"❌ Failed to access Workspace API (import): {import_response.status_code}") + + return True + else: + print("❌ Response missing expected 'objects' field") + return False + else: + print(f"❌ Failed to access Workspace API: {response.status_code} - {response.text}") + return False + except Exception as e: + print(f"❌ Error testing Workspace API: {str(e)}") + return False + def test_sql_connection(): """Test connection to Databricks SQL warehouse""" print("\nTesting Databricks SQL warehouse connection...") @@ -112,6 +620,12 @@ def test_sql_connection(): api_ok = test_databricks_api() sql_ok = test_sql_connection() + clusters_ok = test_clusters_api() + jobs_ok = test_jobs_api() + dbfs_ok = test_dbfs_api() + workspace_ok = test_workspace_api() + instance_pools_ok = test_instance_pools_api() + unity_catalog_ok = test_unity_catalog_api() # Summary print("\nTest Summary") @@ -119,8 +633,14 @@ def test_sql_connection(): print(f"Environment Variables: {'✅ OK' if env_ok else '❌ Failed'}") print(f"Databricks API: {'✅ OK' if api_ok else '❌ Failed'}") print(f"Databricks SQL: {'✅ OK' if sql_ok else '❌ Failed'}") + print(f"Clusters API: {'✅ OK' if clusters_ok else '❌ Failed'}") + print(f"Jobs API: {'✅ OK' if jobs_ok else '❌ Failed'}") + print(f"DBFS API: {'✅ OK' if dbfs_ok else '❌ Failed'}") + print(f"Workspace API: {'✅ OK' if workspace_ok else '❌ Failed'}") + print(f"Instance Pools API: {'✅ OK' if instance_pools_ok else '❌ Failed'}") + print(f"Unity Catalog API: {'✅ OK' if unity_catalog_ok else '❌ Failed'}") - if env_ok and api_ok and sql_ok: + if env_ok and api_ok and sql_ok and clusters_ok and jobs_ok and dbfs_ok and workspace_ok and instance_pools_ok and unity_catalog_ok: print("\n✅ All tests passed! Your Databricks MCP server should work correctly.") sys.exit(0) else: