Skip to content

Commit 9ca6a0f

Browse files
authored
Merge pull request #4 from QuantGeekDev/feature/prompts
feat: add prompt capabilities
2 parents 6fe825d + 019f409 commit 9ca6a0f

File tree

4 files changed

+292
-8
lines changed

4 files changed

+292
-8
lines changed

src/core/MCPServer.ts

+62-8
Original file line numberDiff line numberDiff line change
@@ -3,9 +3,13 @@ import { StdioServerTransport } from "@modelcontextprotocol/sdk/server/stdio.js"
33
import {
44
CallToolRequestSchema,
55
ListToolsRequestSchema,
6+
ListPromptsRequestSchema,
7+
GetPromptRequestSchema,
68
} from "@modelcontextprotocol/sdk/types.js";
79
import { ToolLoader } from "./toolLoader.js";
10+
import { PromptLoader } from "./promptLoader.js";
811
import { ToolProtocol } from "../tools/BaseTool.js";
12+
import { PromptProtocol } from "../prompts/BasePrompt.js";
913
import { readFileSync } from "fs";
1014
import { join, dirname } from "path";
1115
import { logger } from "./Logger.js";
@@ -31,7 +35,9 @@ export type ServerCapabilities = {
3135
export class MCPServer {
3236
private server: Server;
3337
private toolsMap: Map<string, ToolProtocol> = new Map();
38+
private promptsMap: Map<string, PromptProtocol> = new Map();
3439
private toolLoader: ToolLoader;
40+
private promptLoader: PromptLoader;
3541
private serverName: string;
3642
private serverVersion: string;
3743
private basePath: string;
@@ -46,6 +52,7 @@ export class MCPServer {
4652
);
4753

4854
this.toolLoader = new ToolLoader(this.basePath);
55+
this.promptLoader = new PromptLoader(this.basePath);
4956

5057
this.server = new Server(
5158
{
@@ -55,6 +62,7 @@ export class MCPServer {
5562
{
5663
capabilities: {
5764
tools: { enabled: true },
65+
prompts: { enabled: false },
5866
},
5967
}
6068
);
@@ -129,20 +137,50 @@ export class MCPServer {
129137

130138
return tool.toolCall(toolRequest);
131139
});
140+
141+
this.server.setRequestHandler(ListPromptsRequestSchema, async () => {
142+
return {
143+
prompts: Array.from(this.promptsMap.values()).map(
144+
(prompt) => prompt.promptDefinition
145+
),
146+
};
147+
});
148+
149+
this.server.setRequestHandler(GetPromptRequestSchema, async (request) => {
150+
const prompt = this.promptsMap.get(request.params.name);
151+
if (!prompt) {
152+
throw new Error(
153+
`Unknown prompt: ${
154+
request.params.name
155+
}. Available prompts: ${Array.from(this.promptsMap.keys()).join(
156+
", "
157+
)}`
158+
);
159+
}
160+
161+
return {
162+
messages: await prompt.getMessages(request.params.arguments),
163+
};
164+
});
132165
}
133166

134167
private async detectCapabilities(): Promise<ServerCapabilities> {
135168
const capabilities: ServerCapabilities = {};
136169

137-
//IK this is unecessary but it'll guide future schema and prompt capability autodiscovery
138-
139170
if (await this.toolLoader.hasTools()) {
140171
capabilities.tools = { enabled: true };
141172
logger.debug("Tools capability enabled");
142173
} else {
143174
logger.debug("No tools found, tools capability disabled");
144175
}
145176

177+
if (await this.promptLoader.hasPrompts()) {
178+
capabilities.prompts = { enabled: true };
179+
logger.debug("Prompts capability enabled");
180+
} else {
181+
logger.debug("No prompts found, prompts capability disabled");
182+
}
183+
146184
return capabilities;
147185
}
148186

@@ -153,19 +191,35 @@ export class MCPServer {
153191
tools.map((tool: ToolProtocol) => [tool.name, tool])
154192
);
155193

194+
const prompts = await this.promptLoader.loadPrompts();
195+
this.promptsMap = new Map(
196+
prompts.map((prompt: PromptProtocol) => [prompt.name, prompt])
197+
);
198+
199+
this.detectCapabilities();
200+
156201
const transport = new StdioServerTransport();
157202
await this.server.connect(transport);
158203

159-
if (tools.length > 0) {
160-
logger.info(
161-
`Started ${this.serverName}@${this.serverVersion} with ${tools.length} tools`
162-
);
204+
if (tools.length > 0 || prompts.length > 0) {
163205
logger.info(
164-
`Available tools: ${Array.from(this.toolsMap.keys()).join(", ")}`
206+
`Started ${this.serverName}@${this.serverVersion} with ${tools.length} tools and ${prompts.length} prompts`
165207
);
208+
if (tools.length > 0) {
209+
logger.info(
210+
`Available tools: ${Array.from(this.toolsMap.keys()).join(", ")}`
211+
);
212+
}
213+
if (prompts.length > 0) {
214+
logger.info(
215+
`Available prompts: ${Array.from(this.promptsMap.keys()).join(
216+
", "
217+
)}`
218+
);
219+
}
166220
} else {
167221
logger.info(
168-
`Started ${this.serverName}@${this.serverVersion} with no tools`
222+
`Started ${this.serverName}@${this.serverVersion} with no tools or prompts`
169223
);
170224
}
171225
} catch (error) {

src/core/promptLoader.ts

+128
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,128 @@
1+
import { PromptProtocol } from "../prompts/BasePrompt.js";
2+
import { join, dirname } from "path";
3+
import { promises as fs } from "fs";
4+
import { logger } from "./Logger.js";
5+
6+
export class PromptLoader {
7+
private readonly PROMPTS_DIR: string;
8+
private readonly EXCLUDED_FILES = ["BasePrompt.js", "*.test.js", "*.spec.js"];
9+
10+
constructor(basePath?: string) {
11+
const mainModulePath = basePath || process.argv[1];
12+
this.PROMPTS_DIR = join(dirname(mainModulePath), "prompts");
13+
logger.debug(
14+
`Initialized PromptLoader with directory: ${this.PROMPTS_DIR}`
15+
);
16+
}
17+
18+
async hasPrompts(): Promise<boolean> {
19+
try {
20+
const stats = await fs.stat(this.PROMPTS_DIR);
21+
if (!stats.isDirectory()) {
22+
logger.debug("Prompts path exists but is not a directory");
23+
return false;
24+
}
25+
26+
const files = await fs.readdir(this.PROMPTS_DIR);
27+
const hasValidFiles = files.some((file) => this.isPromptFile(file));
28+
logger.debug(`Prompts directory has valid files: ${hasValidFiles}`);
29+
return hasValidFiles;
30+
} catch (error) {
31+
logger.debug("No prompts directory found");
32+
return false;
33+
}
34+
}
35+
36+
private isPromptFile(file: string): boolean {
37+
if (!file.endsWith(".js")) return false;
38+
const isExcluded = this.EXCLUDED_FILES.some((pattern) => {
39+
if (pattern.includes("*")) {
40+
const regex = new RegExp(pattern.replace("*", ".*"));
41+
return regex.test(file);
42+
}
43+
return file === pattern;
44+
});
45+
46+
logger.debug(
47+
`Checking file ${file}: ${isExcluded ? "excluded" : "included"}`
48+
);
49+
return !isExcluded;
50+
}
51+
52+
private validatePrompt(prompt: any): prompt is PromptProtocol {
53+
const isValid = Boolean(
54+
prompt &&
55+
typeof prompt.name === "string" &&
56+
prompt.promptDefinition &&
57+
typeof prompt.getMessages === "function"
58+
);
59+
60+
if (isValid) {
61+
logger.debug(`Validated prompt: ${prompt.name}`);
62+
} else {
63+
logger.warn(`Invalid prompt found: missing required properties`);
64+
}
65+
66+
return isValid;
67+
}
68+
69+
async loadPrompts(): Promise<PromptProtocol[]> {
70+
try {
71+
logger.debug(`Attempting to load prompts from: ${this.PROMPTS_DIR}`);
72+
73+
let stats;
74+
try {
75+
stats = await fs.stat(this.PROMPTS_DIR);
76+
} catch (error) {
77+
logger.debug("No prompts directory found");
78+
return [];
79+
}
80+
81+
if (!stats.isDirectory()) {
82+
logger.error(`Path is not a directory: ${this.PROMPTS_DIR}`);
83+
return [];
84+
}
85+
86+
const files = await fs.readdir(this.PROMPTS_DIR);
87+
logger.debug(`Found files in directory: ${files.join(", ")}`);
88+
89+
const prompts: PromptProtocol[] = [];
90+
91+
for (const file of files) {
92+
if (!this.isPromptFile(file)) {
93+
continue;
94+
}
95+
96+
try {
97+
const fullPath = join(this.PROMPTS_DIR, file);
98+
logger.debug(`Attempting to load prompt from: ${fullPath}`);
99+
100+
const importPath = `file://${fullPath}`;
101+
const { default: PromptClass } = await import(importPath);
102+
103+
if (!PromptClass) {
104+
logger.warn(`No default export found in ${file}`);
105+
continue;
106+
}
107+
108+
const prompt = new PromptClass();
109+
if (this.validatePrompt(prompt)) {
110+
prompts.push(prompt);
111+
}
112+
} catch (error) {
113+
logger.error(`Error loading prompt ${file}: ${error}`);
114+
}
115+
}
116+
117+
logger.debug(
118+
`Successfully loaded ${prompts.length} prompts: ${prompts
119+
.map((p) => p.name)
120+
.join(", ")}`
121+
);
122+
return prompts;
123+
} catch (error) {
124+
logger.error(`Failed to load prompts: ${error}`);
125+
return [];
126+
}
127+
}
128+
}

src/index.ts

+7
Original file line numberDiff line numberDiff line change
@@ -5,4 +5,11 @@ export {
55
type ToolInputSchema,
66
type ToolInput,
77
} from "./tools/BaseTool.js";
8+
export {
9+
MCPPrompt,
10+
type PromptProtocol,
11+
type PromptArgumentSchema,
12+
type PromptArguments,
13+
} from "./prompts/BasePrompt.js";
814
export { ToolLoader } from "./core/toolLoader.js";
15+
export { PromptLoader } from "./core/promptLoader.js";

src/prompts/BasePrompt.ts

+95
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,95 @@
1+
import { z } from "zod";
2+
3+
export type PromptArgumentSchema<T> = {
4+
[K in keyof T]: {
5+
type: z.ZodType<T[K]>;
6+
description: string;
7+
required?: boolean;
8+
};
9+
};
10+
11+
export type PromptArguments<T extends PromptArgumentSchema<any>> = {
12+
[K in keyof T]: z.infer<T[K]["type"]>;
13+
};
14+
15+
export interface PromptProtocol {
16+
name: string;
17+
description: string;
18+
promptDefinition: {
19+
name: string;
20+
description: string;
21+
arguments?: Array<{
22+
name: string;
23+
description: string;
24+
required?: boolean;
25+
}>;
26+
};
27+
getMessages(args?: Record<string, unknown>): Promise<
28+
Array<{
29+
role: string;
30+
content: {
31+
type: string;
32+
text: string;
33+
resource?: {
34+
uri: string;
35+
text: string;
36+
mimeType: string;
37+
};
38+
};
39+
}>
40+
>;
41+
}
42+
43+
export abstract class MCPPrompt<TArgs extends Record<string, any> = {}>
44+
implements PromptProtocol
45+
{
46+
abstract name: string;
47+
abstract description: string;
48+
protected abstract schema: PromptArgumentSchema<TArgs>;
49+
50+
get promptDefinition() {
51+
return {
52+
name: this.name,
53+
description: this.description,
54+
arguments: Object.entries(this.schema).map(([name, schema]) => ({
55+
name,
56+
description: schema.description,
57+
required: schema.required ?? false,
58+
})),
59+
};
60+
}
61+
62+
protected abstract generateMessages(args: TArgs): Promise<
63+
Array<{
64+
role: string;
65+
content: {
66+
type: string;
67+
text: string;
68+
resource?: {
69+
uri: string;
70+
text: string;
71+
mimeType: string;
72+
};
73+
};
74+
}>
75+
>;
76+
77+
async getMessages(args: Record<string, unknown> = {}) {
78+
const zodSchema = z.object(
79+
Object.fromEntries(
80+
Object.entries(this.schema).map(([key, schema]) => [key, schema.type])
81+
)
82+
);
83+
84+
const validatedArgs = (await zodSchema.parse(args)) as TArgs;
85+
return this.generateMessages(validatedArgs);
86+
}
87+
88+
protected async fetch<T>(url: string, init?: RequestInit): Promise<T> {
89+
const response = await fetch(url, init);
90+
if (!response.ok) {
91+
throw new Error(`HTTP error! status: ${response.status}`);
92+
}
93+
return response.json();
94+
}
95+
}

0 commit comments

Comments
 (0)