Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Added auto-refreshing tool list notification handler to client #239

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
130 changes: 90 additions & 40 deletions src/client/index.ts
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@ import {
ListResourceTemplatesRequest,
ListResourceTemplatesResultSchema,
ListToolsRequest,
ListToolsResult,
ListToolsResultSchema,
LoggingLevel,
Notification,
Expand Down Expand Up @@ -76,7 +77,7 @@ export type ClientOptions = ProtocolOptions & {
export class Client<
RequestT extends Request = Request,
NotificationT extends Notification = Notification,
ResultT extends Result = Result,
ResultT extends Result = Result
> extends Protocol<
ClientRequest | RequestT,
ClientNotification | NotificationT,
Expand All @@ -87,15 +88,38 @@ export class Client<
private _capabilities: ClientCapabilities;
private _instructions?: string;

/**
* Callback for when the server indicates that the tools list has changed.
* Client should typically refresh its list of tools in response.
*/
onToolListChanged?: (tools?: ListToolsResult["tools"]) => void;

/**
* Initializes this client with the given name and version information.
*/
constructor(
private _clientInfo: Implementation,
options?: ClientOptions,
) {
constructor(private _clientInfo: Implementation, options?: ClientOptions) {
super(options);
this._capabilities = options?.capabilities ?? {};

// Set up notification handlers
this.setNotificationHandler(
"notifications/tools/list_changed",
async () => {
// Automatically refresh the tools list when the server indicates a change
try {
// Only refresh if the server supports tools
if (this._serverCapabilities?.tools) {
const result = await this.listTools();
// Call the user's callback with the updated tools list
this.onToolListChanged?.(result.tools);
}
} catch (error) {
console.error("Failed to refresh tools list:", error);
// Still call the callback even if refresh failed
this.onToolListChanged?.(undefined);
}
}
);
}

/**
Expand All @@ -106,7 +130,7 @@ export class Client<
public registerCapabilities(capabilities: ClientCapabilities): void {
if (this.transport) {
throw new Error(
"Cannot register capabilities after connecting to transport",
"Cannot register capabilities after connecting to transport"
);
}

Expand All @@ -115,11 +139,11 @@ export class Client<

protected assertCapability(
capability: keyof ServerCapabilities,
method: string,
method: string
): void {
if (!this._serverCapabilities?.[capability]) {
throw new Error(
`Server does not support ${capability} (required for ${method})`,
`Server does not support ${String(capability)} (required for ${method})`
);
}
}
Expand All @@ -137,7 +161,7 @@ export class Client<
clientInfo: this._clientInfo,
},
},
InitializeResultSchema,
InitializeResultSchema
);

if (result === undefined) {
Expand All @@ -146,7 +170,7 @@ export class Client<

if (!SUPPORTED_PROTOCOL_VERSIONS.includes(result.protocolVersion)) {
throw new Error(
`Server's protocol version is not supported: ${result.protocolVersion}`,
`Server's protocol version is not supported: ${result.protocolVersion}`
);
}

Expand Down Expand Up @@ -191,7 +215,7 @@ export class Client<
case "logging/setLevel":
if (!this._serverCapabilities?.logging) {
throw new Error(
`Server does not support logging (required for ${method})`,
`Server does not support logging (required for ${method})`
);
}
break;
Expand All @@ -200,7 +224,7 @@ export class Client<
case "prompts/list":
if (!this._serverCapabilities?.prompts) {
throw new Error(
`Server does not support prompts (required for ${method})`,
`Server does not support prompts (required for ${method})`
);
}
break;
Expand All @@ -212,7 +236,7 @@ export class Client<
case "resources/unsubscribe":
if (!this._serverCapabilities?.resources) {
throw new Error(
`Server does not support resources (required for ${method})`,
`Server does not support resources (required for ${method})`
);
}

Expand All @@ -221,7 +245,7 @@ export class Client<
!this._serverCapabilities.resources.subscribe
) {
throw new Error(
`Server does not support resource subscriptions (required for ${method})`,
`Server does not support resource subscriptions (required for ${method})`
);
}

Expand All @@ -231,15 +255,15 @@ export class Client<
case "tools/list":
if (!this._serverCapabilities?.tools) {
throw new Error(
`Server does not support tools (required for ${method})`,
`Server does not support tools (required for ${method})`
);
}
break;

case "completion/complete":
if (!this._serverCapabilities?.prompts) {
throw new Error(
`Server does not support prompts (required for ${method})`,
`Server does not support prompts (required for ${method})`
);
}
break;
Expand All @@ -255,13 +279,23 @@ export class Client<
}

protected assertNotificationCapability(
method: NotificationT["method"],
method: NotificationT["method"]
): void {
switch (method as ClientNotification["method"]) {
case "notifications/roots/list_changed":
if (!this._capabilities.roots?.listChanged) {
throw new Error(
`Client does not support roots list changed notifications (required for ${method})`,
`Client does not support roots list changed notifications (required for ${method})`
);
}
break;

case "notifications/tools/list_changed":
if (!this._capabilities.tools?.listChanged) {
throw new Error(
`Client does not support tools capability (required for ${String(
method
)})`
);
}
break;
Expand All @@ -285,15 +319,15 @@ export class Client<
case "sampling/createMessage":
if (!this._capabilities.sampling) {
throw new Error(
`Client does not support sampling capability (required for ${method})`,
`Client does not support sampling capability (required for ${method})`
);
}
break;

case "roots/list":
if (!this._capabilities.roots) {
throw new Error(
`Client does not support roots capability (required for ${method})`,
`Client does not support roots capability (required for ${method})`
);
}
break;
Expand All @@ -312,92 +346,92 @@ export class Client<
return this.request(
{ method: "completion/complete", params },
CompleteResultSchema,
options,
options
);
}

async setLoggingLevel(level: LoggingLevel, options?: RequestOptions) {
return this.request(
{ method: "logging/setLevel", params: { level } },
EmptyResultSchema,
options,
options
);
}

async getPrompt(
params: GetPromptRequest["params"],
options?: RequestOptions,
options?: RequestOptions
) {
return this.request(
{ method: "prompts/get", params },
GetPromptResultSchema,
options,
options
);
}

async listPrompts(
params?: ListPromptsRequest["params"],
options?: RequestOptions,
options?: RequestOptions
) {
return this.request(
{ method: "prompts/list", params },
ListPromptsResultSchema,
options,
options
);
}

async listResources(
params?: ListResourcesRequest["params"],
options?: RequestOptions,
options?: RequestOptions
) {
return this.request(
{ method: "resources/list", params },
ListResourcesResultSchema,
options,
options
);
}

async listResourceTemplates(
params?: ListResourceTemplatesRequest["params"],
options?: RequestOptions,
options?: RequestOptions
) {
return this.request(
{ method: "resources/templates/list", params },
ListResourceTemplatesResultSchema,
options,
options
);
}

async readResource(
params: ReadResourceRequest["params"],
options?: RequestOptions,
options?: RequestOptions
) {
return this.request(
{ method: "resources/read", params },
ReadResourceResultSchema,
options,
options
);
}

async subscribeResource(
params: SubscribeRequest["params"],
options?: RequestOptions,
options?: RequestOptions
) {
return this.request(
{ method: "resources/subscribe", params },
EmptyResultSchema,
options,
options
);
}

async unsubscribeResource(
params: UnsubscribeRequest["params"],
options?: RequestOptions,
options?: RequestOptions
) {
return this.request(
{ method: "resources/unsubscribe", params },
EmptyResultSchema,
options,
options
);
}

Expand All @@ -406,27 +440,43 @@ export class Client<
resultSchema:
| typeof CallToolResultSchema
| typeof CompatibilityCallToolResultSchema = CallToolResultSchema,
options?: RequestOptions,
options?: RequestOptions
) {
return this.request(
{ method: "tools/call", params },
resultSchema,
options,
options
);
}

async listTools(
params?: ListToolsRequest["params"],
options?: RequestOptions,
options?: RequestOptions
) {
return this.request(
{ method: "tools/list", params },
ListToolsResultSchema,
options,
options
);
}

/**
* Registers a callback to be called when the server indicates that
* the tools list has changed. The callback should typically refresh the tools list.
*
* @param callback Function to call when tools list changes
*/
setToolListChangedCallback(
callback: (tools?: ListToolsResult["tools"]) => void
): void {
this.onToolListChanged = callback;
}

async sendRootsListChanged() {
return this.notification({ method: "notifications/roots/list_changed" });
}

async sendToolListChanged() {
return this.notification({ method: "notifications/tools/list_changed" });
}
}