From 019f40949ea66fc4d1ca1969aaa57c8943e81036 Mon Sep 17 00:00:00 2001 From: Alex Andru Date: Mon, 9 Dec 2024 23:47:46 +0100 Subject: [PATCH] feat: add prompt capabilities --- src/core/MCPServer.ts | 70 ++++++++++++++++++--- src/core/promptLoader.ts | 128 ++++++++++++++++++++++++++++++++++++++ src/index.ts | 7 +++ src/prompts/BasePrompt.ts | 95 ++++++++++++++++++++++++++++ 4 files changed, 292 insertions(+), 8 deletions(-) create mode 100644 src/core/promptLoader.ts create mode 100644 src/prompts/BasePrompt.ts diff --git a/src/core/MCPServer.ts b/src/core/MCPServer.ts index 569cee4..3759ae2 100644 --- a/src/core/MCPServer.ts +++ b/src/core/MCPServer.ts @@ -3,9 +3,13 @@ import { StdioServerTransport } from "@modelcontextprotocol/sdk/server/stdio.js" import { CallToolRequestSchema, ListToolsRequestSchema, + ListPromptsRequestSchema, + GetPromptRequestSchema, } from "@modelcontextprotocol/sdk/types.js"; import { ToolLoader } from "./toolLoader.js"; +import { PromptLoader } from "./promptLoader.js"; import { ToolProtocol } from "../tools/BaseTool.js"; +import { PromptProtocol } from "../prompts/BasePrompt.js"; import { readFileSync } from "fs"; import { join, dirname } from "path"; import { logger } from "./Logger.js"; @@ -31,7 +35,9 @@ export type ServerCapabilities = { export class MCPServer { private server: Server; private toolsMap: Map = new Map(); + private promptsMap: Map = new Map(); private toolLoader: ToolLoader; + private promptLoader: PromptLoader; private serverName: string; private serverVersion: string; private basePath: string; @@ -46,6 +52,7 @@ export class MCPServer { ); this.toolLoader = new ToolLoader(this.basePath); + this.promptLoader = new PromptLoader(this.basePath); this.server = new Server( { @@ -55,6 +62,7 @@ export class MCPServer { { capabilities: { tools: { enabled: true }, + prompts: { enabled: false }, }, } ); @@ -129,13 +137,36 @@ export class MCPServer { return tool.toolCall(toolRequest); }); + + this.server.setRequestHandler(ListPromptsRequestSchema, async () => { + return { + prompts: Array.from(this.promptsMap.values()).map( + (prompt) => prompt.promptDefinition + ), + }; + }); + + this.server.setRequestHandler(GetPromptRequestSchema, async (request) => { + const prompt = this.promptsMap.get(request.params.name); + if (!prompt) { + throw new Error( + `Unknown prompt: ${ + request.params.name + }. Available prompts: ${Array.from(this.promptsMap.keys()).join( + ", " + )}` + ); + } + + return { + messages: await prompt.getMessages(request.params.arguments), + }; + }); } private async detectCapabilities(): Promise { const capabilities: ServerCapabilities = {}; - //IK this is unecessary but it'll guide future schema and prompt capability autodiscovery - if (await this.toolLoader.hasTools()) { capabilities.tools = { enabled: true }; logger.debug("Tools capability enabled"); @@ -143,6 +174,13 @@ export class MCPServer { logger.debug("No tools found, tools capability disabled"); } + if (await this.promptLoader.hasPrompts()) { + capabilities.prompts = { enabled: true }; + logger.debug("Prompts capability enabled"); + } else { + logger.debug("No prompts found, prompts capability disabled"); + } + return capabilities; } @@ -153,19 +191,35 @@ export class MCPServer { tools.map((tool: ToolProtocol) => [tool.name, tool]) ); + const prompts = await this.promptLoader.loadPrompts(); + this.promptsMap = new Map( + prompts.map((prompt: PromptProtocol) => [prompt.name, prompt]) + ); + + this.detectCapabilities(); + const transport = new StdioServerTransport(); await this.server.connect(transport); - if (tools.length > 0) { - logger.info( - `Started ${this.serverName}@${this.serverVersion} with ${tools.length} tools` - ); + if (tools.length > 0 || prompts.length > 0) { logger.info( - `Available tools: ${Array.from(this.toolsMap.keys()).join(", ")}` + `Started ${this.serverName}@${this.serverVersion} with ${tools.length} tools and ${prompts.length} prompts` ); + if (tools.length > 0) { + logger.info( + `Available tools: ${Array.from(this.toolsMap.keys()).join(", ")}` + ); + } + if (prompts.length > 0) { + logger.info( + `Available prompts: ${Array.from(this.promptsMap.keys()).join( + ", " + )}` + ); + } } else { logger.info( - `Started ${this.serverName}@${this.serverVersion} with no tools` + `Started ${this.serverName}@${this.serverVersion} with no tools or prompts` ); } } catch (error) { diff --git a/src/core/promptLoader.ts b/src/core/promptLoader.ts new file mode 100644 index 0000000..7047026 --- /dev/null +++ b/src/core/promptLoader.ts @@ -0,0 +1,128 @@ +import { PromptProtocol } from "../prompts/BasePrompt.js"; +import { join, dirname } from "path"; +import { promises as fs } from "fs"; +import { logger } from "./Logger.js"; + +export class PromptLoader { + private readonly PROMPTS_DIR: string; + private readonly EXCLUDED_FILES = ["BasePrompt.js", "*.test.js", "*.spec.js"]; + + constructor(basePath?: string) { + const mainModulePath = basePath || process.argv[1]; + this.PROMPTS_DIR = join(dirname(mainModulePath), "prompts"); + logger.debug( + `Initialized PromptLoader with directory: ${this.PROMPTS_DIR}` + ); + } + + async hasPrompts(): Promise { + try { + const stats = await fs.stat(this.PROMPTS_DIR); + if (!stats.isDirectory()) { + logger.debug("Prompts path exists but is not a directory"); + return false; + } + + const files = await fs.readdir(this.PROMPTS_DIR); + const hasValidFiles = files.some((file) => this.isPromptFile(file)); + logger.debug(`Prompts directory has valid files: ${hasValidFiles}`); + return hasValidFiles; + } catch (error) { + logger.debug("No prompts directory found"); + return false; + } + } + + private isPromptFile(file: string): boolean { + if (!file.endsWith(".js")) return false; + const isExcluded = this.EXCLUDED_FILES.some((pattern) => { + if (pattern.includes("*")) { + const regex = new RegExp(pattern.replace("*", ".*")); + return regex.test(file); + } + return file === pattern; + }); + + logger.debug( + `Checking file ${file}: ${isExcluded ? "excluded" : "included"}` + ); + return !isExcluded; + } + + private validatePrompt(prompt: any): prompt is PromptProtocol { + const isValid = Boolean( + prompt && + typeof prompt.name === "string" && + prompt.promptDefinition && + typeof prompt.getMessages === "function" + ); + + if (isValid) { + logger.debug(`Validated prompt: ${prompt.name}`); + } else { + logger.warn(`Invalid prompt found: missing required properties`); + } + + return isValid; + } + + async loadPrompts(): Promise { + try { + logger.debug(`Attempting to load prompts from: ${this.PROMPTS_DIR}`); + + let stats; + try { + stats = await fs.stat(this.PROMPTS_DIR); + } catch (error) { + logger.debug("No prompts directory found"); + return []; + } + + if (!stats.isDirectory()) { + logger.error(`Path is not a directory: ${this.PROMPTS_DIR}`); + return []; + } + + const files = await fs.readdir(this.PROMPTS_DIR); + logger.debug(`Found files in directory: ${files.join(", ")}`); + + const prompts: PromptProtocol[] = []; + + for (const file of files) { + if (!this.isPromptFile(file)) { + continue; + } + + try { + const fullPath = join(this.PROMPTS_DIR, file); + logger.debug(`Attempting to load prompt from: ${fullPath}`); + + const importPath = `file://${fullPath}`; + const { default: PromptClass } = await import(importPath); + + if (!PromptClass) { + logger.warn(`No default export found in ${file}`); + continue; + } + + const prompt = new PromptClass(); + if (this.validatePrompt(prompt)) { + prompts.push(prompt); + } + } catch (error) { + logger.error(`Error loading prompt ${file}: ${error}`); + } + } + + logger.debug( + `Successfully loaded ${prompts.length} prompts: ${prompts + .map((p) => p.name) + .join(", ")}` + ); + return prompts; + } catch (error) { + logger.error(`Failed to load prompts: ${error}`); + return []; + } + } +} diff --git a/src/index.ts b/src/index.ts index 7f99980..5a44f73 100644 --- a/src/index.ts +++ b/src/index.ts @@ -5,4 +5,11 @@ export { type ToolInputSchema, type ToolInput, } from "./tools/BaseTool.js"; +export { + MCPPrompt, + type PromptProtocol, + type PromptArgumentSchema, + type PromptArguments, +} from "./prompts/BasePrompt.js"; export { ToolLoader } from "./core/toolLoader.js"; +export { PromptLoader } from "./core/promptLoader.js"; diff --git a/src/prompts/BasePrompt.ts b/src/prompts/BasePrompt.ts new file mode 100644 index 0000000..cf10e6a --- /dev/null +++ b/src/prompts/BasePrompt.ts @@ -0,0 +1,95 @@ +import { z } from "zod"; + +export type PromptArgumentSchema = { + [K in keyof T]: { + type: z.ZodType; + description: string; + required?: boolean; + }; +}; + +export type PromptArguments> = { + [K in keyof T]: z.infer; +}; + +export interface PromptProtocol { + name: string; + description: string; + promptDefinition: { + name: string; + description: string; + arguments?: Array<{ + name: string; + description: string; + required?: boolean; + }>; + }; + getMessages(args?: Record): Promise< + Array<{ + role: string; + content: { + type: string; + text: string; + resource?: { + uri: string; + text: string; + mimeType: string; + }; + }; + }> + >; +} + +export abstract class MCPPrompt = {}> + implements PromptProtocol +{ + abstract name: string; + abstract description: string; + protected abstract schema: PromptArgumentSchema; + + get promptDefinition() { + return { + name: this.name, + description: this.description, + arguments: Object.entries(this.schema).map(([name, schema]) => ({ + name, + description: schema.description, + required: schema.required ?? false, + })), + }; + } + + protected abstract generateMessages(args: TArgs): Promise< + Array<{ + role: string; + content: { + type: string; + text: string; + resource?: { + uri: string; + text: string; + mimeType: string; + }; + }; + }> + >; + + async getMessages(args: Record = {}) { + const zodSchema = z.object( + Object.fromEntries( + Object.entries(this.schema).map(([key, schema]) => [key, schema.type]) + ) + ); + + const validatedArgs = (await zodSchema.parse(args)) as TArgs; + return this.generateMessages(validatedArgs); + } + + protected async fetch(url: string, init?: RequestInit): Promise { + const response = await fetch(url, init); + if (!response.ok) { + throw new Error(`HTTP error! status: ${response.status}`); + } + return response.json(); + } +}