Skip to content

Commit 4d5fb39

Browse files
authored
Merge pull request #2 from QuantGeekDev/feature/mcp-tools
feat: add MCPTool abstraction
2 parents 1d1cfef + 3458e78 commit 4d5fb39

File tree

4 files changed

+123
-25
lines changed

4 files changed

+123
-25
lines changed

Diff for: src/core/MCPServer.ts

+12-4
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@ import {
55
ListToolsRequestSchema,
66
} from "@modelcontextprotocol/sdk/types.js";
77
import { ToolLoader } from "./toolLoader.js";
8-
import { BaseTool } from "../tools/BaseTool.js";
8+
import { ToolProtocol } from "../tools/BaseTool.js";
99
import { readFileSync } from "fs";
1010
import { join, dirname } from "path";
1111
import { logger } from "./Logger.js";
@@ -17,7 +17,7 @@ export interface MCPServerConfig {
1717

1818
export class MCPServer {
1919
private server: Server;
20-
private toolsMap: Map<string, BaseTool> = new Map();
20+
private toolsMap: Map<string, ToolProtocol> = new Map();
2121
private toolLoader: ToolLoader;
2222
private serverName: string;
2323
private serverVersion: string;
@@ -106,14 +106,22 @@ export class MCPServer {
106106
).join(", ")}`
107107
);
108108
}
109-
return tool.toolCall(request);
109+
110+
const toolRequest = {
111+
params: request.params,
112+
method: "tools/call" as const,
113+
};
114+
115+
return tool.toolCall(toolRequest);
110116
});
111117
}
112118

113119
async start() {
114120
try {
115121
const tools = await this.toolLoader.loadTools();
116-
this.toolsMap = new Map(tools.map((tool: BaseTool) => [tool.name, tool]));
122+
this.toolsMap = new Map(
123+
tools.map((tool: ToolProtocol) => [tool.name, tool])
124+
);
117125

118126
const transport = new StdioServerTransport();
119127
await this.server.connect(transport);

Diff for: src/core/toolLoader.ts

+4-4
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
import { BaseTool } from "../tools/BaseTool.js";
1+
import { ToolProtocol } from "../tools/BaseTool.js";
22
import { join, dirname } from "path";
33
import { promises as fs } from "fs";
44
import { logger } from "./Logger.js";
@@ -29,7 +29,7 @@ export class ToolLoader {
2929
return !isExcluded;
3030
}
3131

32-
private validateTool(tool: any): tool is BaseTool {
32+
private validateTool(tool: any): tool is ToolProtocol {
3333
const isValid = Boolean(
3434
tool &&
3535
typeof tool.name === "string" &&
@@ -46,7 +46,7 @@ export class ToolLoader {
4646
return isValid;
4747
}
4848

49-
async loadTools(): Promise<BaseTool[]> {
49+
async loadTools(): Promise<ToolProtocol[]> {
5050
try {
5151
logger.debug(`Attempting to load tools from: ${this.TOOLS_DIR}`);
5252

@@ -66,7 +66,7 @@ export class ToolLoader {
6666
const files = await fs.readdir(this.TOOLS_DIR);
6767
logger.debug(`Found files in directory: ${files.join(", ")}`);
6868

69-
const tools: BaseTool[] = [];
69+
const tools: ToolProtocol[] = [];
7070

7171
for (const file of files) {
7272
if (!this.isToolFile(file)) {

Diff for: src/index.ts

+6-1
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,8 @@
11
export { MCPServer, type MCPServerConfig } from "./core/MCPServer.js";
2-
export { BaseTool, BaseToolImplementation } from "./tools/BaseTool.js";
2+
export {
3+
MCPTool,
4+
type ToolProtocol,
5+
type ToolInputSchema,
6+
type ToolInput,
7+
} from "./tools/BaseTool.js";
38
export { ToolLoader } from "./core/toolLoader.js";

Diff for: src/tools/BaseTool.ts

+101-16
Original file line numberDiff line numberDiff line change
@@ -1,32 +1,117 @@
1-
import {
2-
CallToolRequestSchema,
3-
Tool,
4-
} from "@modelcontextprotocol/sdk/types.js";
51
import { z } from "zod";
2+
import { Tool as SDKTool } from "@modelcontextprotocol/sdk/types.js";
63

7-
export interface BaseTool {
4+
export type ToolInputSchema<T> = {
5+
[K in keyof T]: {
6+
type: z.ZodType<T[K]>;
7+
description: string;
8+
};
9+
};
10+
11+
export type ToolInput<T extends ToolInputSchema<any>> = {
12+
[K in keyof T]: z.infer<T[K]["type"]>;
13+
};
14+
15+
export interface ToolProtocol extends SDKTool {
816
name: string;
9-
toolDefinition: Tool;
10-
toolCall(request: z.infer<typeof CallToolRequestSchema>): Promise<any>;
17+
description: string;
18+
toolDefinition: {
19+
name: string;
20+
description: string;
21+
inputSchema: {
22+
type: "object";
23+
properties?: Record<string, unknown>;
24+
};
25+
};
26+
toolCall(request: {
27+
params: { name: string; arguments?: Record<string, unknown> };
28+
}): Promise<{
29+
content: Array<{ type: string; text: string }>;
30+
}>;
1131
}
1232

13-
export abstract class BaseToolImplementation implements BaseTool {
33+
export abstract class MCPTool<TInput extends Record<string, any> = {}>
34+
implements ToolProtocol
35+
{
1436
abstract name: string;
15-
abstract toolDefinition: Tool;
16-
abstract toolCall(
17-
request: z.infer<typeof CallToolRequestSchema>
18-
): Promise<any>;
37+
abstract description: string;
38+
protected abstract schema: ToolInputSchema<TInput>;
39+
[key: string]: unknown;
1940

20-
protected createSuccessResponse(data: any) {
41+
get inputSchema(): { type: "object"; properties?: Record<string, unknown> } {
42+
return {
43+
type: "object" as const,
44+
properties: Object.fromEntries(
45+
Object.entries(this.schema).map(([key, schema]) => [
46+
key,
47+
{
48+
type: this.getJsonSchemaType(schema.type),
49+
description: schema.description,
50+
},
51+
])
52+
),
53+
};
54+
}
55+
56+
get toolDefinition() {
57+
return {
58+
name: this.name,
59+
description: this.description,
60+
inputSchema: this.inputSchema,
61+
};
62+
}
63+
64+
protected abstract execute(input: TInput): Promise<unknown>;
65+
66+
async toolCall(request: {
67+
params: { name: string; arguments?: Record<string, unknown> };
68+
}) {
69+
try {
70+
const args = request.params.arguments || {};
71+
const validatedInput = await this.validateInput(args);
72+
const result = await this.execute(validatedInput);
73+
return this.createSuccessResponse(result);
74+
} catch (error) {
75+
return this.createErrorResponse(error as Error);
76+
}
77+
}
78+
79+
private async validateInput(args: Record<string, unknown>): Promise<TInput> {
80+
const zodSchema = z.object(
81+
Object.fromEntries(
82+
Object.entries(this.schema).map(([key, schema]) => [key, schema.type])
83+
)
84+
);
85+
86+
return zodSchema.parse(args) as TInput;
87+
}
88+
89+
private getJsonSchemaType(zodType: z.ZodType<any>): string {
90+
if (zodType instanceof z.ZodString) return "string";
91+
if (zodType instanceof z.ZodNumber) return "number";
92+
if (zodType instanceof z.ZodBoolean) return "boolean";
93+
if (zodType instanceof z.ZodArray) return "array";
94+
if (zodType instanceof z.ZodObject) return "object";
95+
return "string";
96+
}
97+
98+
protected createSuccessResponse(data: unknown) {
2199
return {
22100
content: [{ type: "text", text: JSON.stringify(data) }],
23101
};
24102
}
25103

26-
protected createErrorResponse(error: Error | string) {
27-
const message = error instanceof Error ? error.message : error;
104+
protected createErrorResponse(error: Error) {
28105
return {
29-
content: [{ type: "error", text: message }],
106+
content: [{ type: "error", text: error.message }],
30107
};
31108
}
109+
110+
protected async fetch<T>(url: string, init?: RequestInit): Promise<T> {
111+
const response = await fetch(url, init);
112+
if (!response.ok) {
113+
throw new Error(`HTTP error! status: ${response.status}`);
114+
}
115+
return response.json();
116+
}
32117
}

0 commit comments

Comments
 (0)