Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: add prompt capabilities #4

Merged
merged 1 commit into from
Dec 9, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
70 changes: 62 additions & 8 deletions src/core/MCPServer.ts
Original file line number Diff line number Diff line change
Expand Up @@ -3,9 +3,13 @@ import { StdioServerTransport } from "@modelcontextprotocol/sdk/server/stdio.js"
import {
CallToolRequestSchema,
ListToolsRequestSchema,
ListPromptsRequestSchema,
GetPromptRequestSchema,
} from "@modelcontextprotocol/sdk/types.js";
import { ToolLoader } from "./toolLoader.js";
import { PromptLoader } from "./promptLoader.js";
import { ToolProtocol } from "../tools/BaseTool.js";
import { PromptProtocol } from "../prompts/BasePrompt.js";
import { readFileSync } from "fs";
import { join, dirname } from "path";
import { logger } from "./Logger.js";
Expand All @@ -31,7 +35,9 @@ export type ServerCapabilities = {
export class MCPServer {
private server: Server;
private toolsMap: Map<string, ToolProtocol> = new Map();
private promptsMap: Map<string, PromptProtocol> = new Map();
private toolLoader: ToolLoader;
private promptLoader: PromptLoader;
private serverName: string;
private serverVersion: string;
private basePath: string;
Expand All @@ -46,6 +52,7 @@ export class MCPServer {
);

this.toolLoader = new ToolLoader(this.basePath);
this.promptLoader = new PromptLoader(this.basePath);

this.server = new Server(
{
Expand All @@ -55,6 +62,7 @@ export class MCPServer {
{
capabilities: {
tools: { enabled: true },
prompts: { enabled: false },
},
}
);
Expand Down Expand Up @@ -129,20 +137,50 @@ export class MCPServer {

return tool.toolCall(toolRequest);
});

this.server.setRequestHandler(ListPromptsRequestSchema, async () => {
return {
prompts: Array.from(this.promptsMap.values()).map(
(prompt) => prompt.promptDefinition
),
};
});

this.server.setRequestHandler(GetPromptRequestSchema, async (request) => {
const prompt = this.promptsMap.get(request.params.name);
if (!prompt) {
throw new Error(
`Unknown prompt: ${
request.params.name
}. Available prompts: ${Array.from(this.promptsMap.keys()).join(
", "
)}`
);
}

return {
messages: await prompt.getMessages(request.params.arguments),
};
});
}

private async detectCapabilities(): Promise<ServerCapabilities> {
const capabilities: ServerCapabilities = {};

//IK this is unecessary but it'll guide future schema and prompt capability autodiscovery

if (await this.toolLoader.hasTools()) {
capabilities.tools = { enabled: true };
logger.debug("Tools capability enabled");
} else {
logger.debug("No tools found, tools capability disabled");
}

if (await this.promptLoader.hasPrompts()) {
capabilities.prompts = { enabled: true };
logger.debug("Prompts capability enabled");
} else {
logger.debug("No prompts found, prompts capability disabled");
}

return capabilities;
}

Expand All @@ -153,19 +191,35 @@ export class MCPServer {
tools.map((tool: ToolProtocol) => [tool.name, tool])
);

const prompts = await this.promptLoader.loadPrompts();
this.promptsMap = new Map(
prompts.map((prompt: PromptProtocol) => [prompt.name, prompt])
);

this.detectCapabilities();

const transport = new StdioServerTransport();
await this.server.connect(transport);

if (tools.length > 0) {
logger.info(
`Started ${this.serverName}@${this.serverVersion} with ${tools.length} tools`
);
if (tools.length > 0 || prompts.length > 0) {
logger.info(
`Available tools: ${Array.from(this.toolsMap.keys()).join(", ")}`
`Started ${this.serverName}@${this.serverVersion} with ${tools.length} tools and ${prompts.length} prompts`
);
if (tools.length > 0) {
logger.info(
`Available tools: ${Array.from(this.toolsMap.keys()).join(", ")}`
);
}
if (prompts.length > 0) {
logger.info(
`Available prompts: ${Array.from(this.promptsMap.keys()).join(
", "
)}`
);
}
} else {
logger.info(
`Started ${this.serverName}@${this.serverVersion} with no tools`
`Started ${this.serverName}@${this.serverVersion} with no tools or prompts`
);
}
} catch (error) {
Expand Down
128 changes: 128 additions & 0 deletions src/core/promptLoader.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,128 @@
import { PromptProtocol } from "../prompts/BasePrompt.js";
import { join, dirname } from "path";
import { promises as fs } from "fs";
import { logger } from "./Logger.js";

export class PromptLoader {
private readonly PROMPTS_DIR: string;
private readonly EXCLUDED_FILES = ["BasePrompt.js", "*.test.js", "*.spec.js"];

constructor(basePath?: string) {
const mainModulePath = basePath || process.argv[1];
this.PROMPTS_DIR = join(dirname(mainModulePath), "prompts");
logger.debug(
`Initialized PromptLoader with directory: ${this.PROMPTS_DIR}`
);
}

async hasPrompts(): Promise<boolean> {
try {
const stats = await fs.stat(this.PROMPTS_DIR);
if (!stats.isDirectory()) {
logger.debug("Prompts path exists but is not a directory");
return false;
}

const files = await fs.readdir(this.PROMPTS_DIR);
const hasValidFiles = files.some((file) => this.isPromptFile(file));
logger.debug(`Prompts directory has valid files: ${hasValidFiles}`);
return hasValidFiles;
} catch (error) {
logger.debug("No prompts directory found");
return false;
}
}

private isPromptFile(file: string): boolean {
if (!file.endsWith(".js")) return false;
const isExcluded = this.EXCLUDED_FILES.some((pattern) => {
if (pattern.includes("*")) {
const regex = new RegExp(pattern.replace("*", ".*"));
return regex.test(file);
}
return file === pattern;
});

logger.debug(
`Checking file ${file}: ${isExcluded ? "excluded" : "included"}`
);
return !isExcluded;
}

private validatePrompt(prompt: any): prompt is PromptProtocol {
const isValid = Boolean(
prompt &&
typeof prompt.name === "string" &&
prompt.promptDefinition &&
typeof prompt.getMessages === "function"
);

if (isValid) {
logger.debug(`Validated prompt: ${prompt.name}`);
} else {
logger.warn(`Invalid prompt found: missing required properties`);
}

return isValid;
}

async loadPrompts(): Promise<PromptProtocol[]> {
try {
logger.debug(`Attempting to load prompts from: ${this.PROMPTS_DIR}`);

let stats;
try {
stats = await fs.stat(this.PROMPTS_DIR);
} catch (error) {
logger.debug("No prompts directory found");
return [];
}

if (!stats.isDirectory()) {
logger.error(`Path is not a directory: ${this.PROMPTS_DIR}`);
return [];
}

const files = await fs.readdir(this.PROMPTS_DIR);
logger.debug(`Found files in directory: ${files.join(", ")}`);

const prompts: PromptProtocol[] = [];

for (const file of files) {
if (!this.isPromptFile(file)) {
continue;
}

try {
const fullPath = join(this.PROMPTS_DIR, file);
logger.debug(`Attempting to load prompt from: ${fullPath}`);

const importPath = `file://${fullPath}`;
const { default: PromptClass } = await import(importPath);

if (!PromptClass) {
logger.warn(`No default export found in ${file}`);
continue;
}

const prompt = new PromptClass();
if (this.validatePrompt(prompt)) {
prompts.push(prompt);
}
} catch (error) {
logger.error(`Error loading prompt ${file}: ${error}`);
}
}

logger.debug(
`Successfully loaded ${prompts.length} prompts: ${prompts
.map((p) => p.name)
.join(", ")}`
);
return prompts;
} catch (error) {
logger.error(`Failed to load prompts: ${error}`);
return [];
}
}
}
7 changes: 7 additions & 0 deletions src/index.ts
Original file line number Diff line number Diff line change
Expand Up @@ -5,4 +5,11 @@ export {
type ToolInputSchema,
type ToolInput,
} from "./tools/BaseTool.js";
export {
MCPPrompt,
type PromptProtocol,
type PromptArgumentSchema,
type PromptArguments,
} from "./prompts/BasePrompt.js";
export { ToolLoader } from "./core/toolLoader.js";
export { PromptLoader } from "./core/promptLoader.js";
95 changes: 95 additions & 0 deletions src/prompts/BasePrompt.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,95 @@
import { z } from "zod";

export type PromptArgumentSchema<T> = {
[K in keyof T]: {
type: z.ZodType<T[K]>;
description: string;
required?: boolean;
};
};

export type PromptArguments<T extends PromptArgumentSchema<any>> = {
[K in keyof T]: z.infer<T[K]["type"]>;
};

export interface PromptProtocol {
name: string;
description: string;
promptDefinition: {
name: string;
description: string;
arguments?: Array<{
name: string;
description: string;
required?: boolean;
}>;
};
getMessages(args?: Record<string, unknown>): Promise<
Array<{
role: string;
content: {
type: string;
text: string;
resource?: {
uri: string;
text: string;
mimeType: string;
};
};
}>
>;
}

export abstract class MCPPrompt<TArgs extends Record<string, any> = {}>
implements PromptProtocol
{
abstract name: string;
abstract description: string;
protected abstract schema: PromptArgumentSchema<TArgs>;

get promptDefinition() {
return {
name: this.name,
description: this.description,
arguments: Object.entries(this.schema).map(([name, schema]) => ({
name,
description: schema.description,
required: schema.required ?? false,
})),
};
}

protected abstract generateMessages(args: TArgs): Promise<
Array<{
role: string;
content: {
type: string;
text: string;
resource?: {
uri: string;
text: string;
mimeType: string;
};
};
}>
>;

async getMessages(args: Record<string, unknown> = {}) {
const zodSchema = z.object(
Object.fromEntries(
Object.entries(this.schema).map(([key, schema]) => [key, schema.type])
)
);

const validatedArgs = (await zodSchema.parse(args)) as TArgs;
return this.generateMessages(validatedArgs);
}

protected async fetch<T>(url: string, init?: RequestInit): Promise<T> {
const response = await fetch(url, init);
if (!response.ok) {
throw new Error(`HTTP error! status: ${response.status}`);
}
return response.json();
}
}