Skip to content

Implement cancellation notifications and handling #54

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 9 commits into from
Nov 15, 2024
55 changes: 55 additions & 0 deletions src/client/index.test.ts
Original file line number Diff line number Diff line change
Expand Up @@ -436,3 +436,58 @@ test("should typecheck", () => {
},
});
});

test("should handle client cancelling a request", async () => {
const server = new Server(
{
name: "test server",
version: "1.0",
},
{
capabilities: {
resources: {},
},
},
);

// Set up server to delay responding to listResources
server.setRequestHandler(
ListResourcesRequestSchema,
async (request, extra) => {
await new Promise((resolve) => setTimeout(resolve, 1000));
return {
resources: [],
};
},
);

const [clientTransport, serverTransport] =
InMemoryTransport.createLinkedPair();

const client = new Client(
{
name: "test client",
version: "1.0",
},
{
capabilities: {},
},
);

await Promise.all([
client.connect(clientTransport),
server.connect(serverTransport),
]);

// Set up abort controller
const controller = new AbortController();

// Issue request but cancel it immediately
const listResourcesPromise = client.listResources(undefined, {
signal: controller.signal,
});
controller.abort("Cancelled by test");

// Request should be rejected
await expect(listResourcesPromise).rejects.toBe("Cancelled by test");
});
41 changes: 21 additions & 20 deletions src/client/index.ts
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import {
ProgressCallback,
Protocol,
ProtocolOptions,
RequestOptions,
} from "../shared/protocol.js";
import { Transport } from "../shared/transport.js";
import {
Expand Down Expand Up @@ -244,6 +244,10 @@ export class Client<
// No specific capability required for initialized
break;

case "notifications/cancelled":
// Cancellation notifications are always allowed
break;

case "notifications/progress":
// Progress notifications are always allowed
break;
Expand Down Expand Up @@ -278,14 +282,11 @@ export class Client<
return this.request({ method: "ping" }, EmptyResultSchema);
}

async complete(
params: CompleteRequest["params"],
onprogress?: ProgressCallback,
) {
async complete(params: CompleteRequest["params"], options?: RequestOptions) {
return this.request(
{ method: "completion/complete", params },
CompleteResultSchema,
onprogress,
options,
);
}

Expand All @@ -298,56 +299,56 @@ export class Client<

async getPrompt(
params: GetPromptRequest["params"],
onprogress?: ProgressCallback,
options?: RequestOptions,
) {
return this.request(
{ method: "prompts/get", params },
GetPromptResultSchema,
onprogress,
options,
);
}

async listPrompts(
params?: ListPromptsRequest["params"],
onprogress?: ProgressCallback,
options?: RequestOptions,
) {
return this.request(
{ method: "prompts/list", params },
ListPromptsResultSchema,
onprogress,
options,
);
}

async listResources(
params?: ListResourcesRequest["params"],
onprogress?: ProgressCallback,
options?: RequestOptions,
) {
return this.request(
{ method: "resources/list", params },
ListResourcesResultSchema,
onprogress,
options,
);
}

async listResourceTemplates(
params?: ListResourceTemplatesRequest["params"],
onprogress?: ProgressCallback,
options?: RequestOptions,
) {
return this.request(
{ method: "resources/templates/list", params },
ListResourceTemplatesResultSchema,
onprogress,
options,
);
}

async readResource(
params: ReadResourceRequest["params"],
onprogress?: ProgressCallback,
options?: RequestOptions,
) {
return this.request(
{ method: "resources/read", params },
ReadResourceResultSchema,
onprogress,
options,
);
}

Expand All @@ -370,23 +371,23 @@ export class Client<
resultSchema:
| typeof CallToolResultSchema
| typeof CompatibilityCallToolResultSchema = CallToolResultSchema,
onprogress?: ProgressCallback,
options?: RequestOptions,
) {
return this.request(
{ method: "tools/call", params },
resultSchema,
onprogress,
options,
);
}

async listTools(
params?: ListToolsRequest["params"],
onprogress?: ProgressCallback,
options?: RequestOptions,
) {
return this.request(
{ method: "tools/list", params },
ListToolsResultSchema,
onprogress,
options,
);
}

Expand Down
68 changes: 68 additions & 0 deletions src/server/index.test.ts
Original file line number Diff line number Diff line change
Expand Up @@ -407,3 +407,71 @@ test("should typecheck", () => {
},
);
});

test("should handle server cancelling a request", async () => {
const server = new Server(
{
name: "test server",
version: "1.0",
},
{
capabilities: {
sampling: {},
},
},
);

const client = new Client(
{
name: "test client",
version: "1.0",
},
{
capabilities: {
sampling: {},
},
},
);

// Set up client to delay responding to createMessage
client.setRequestHandler(
CreateMessageRequestSchema,
async (_request, extra) => {
await new Promise((resolve) => setTimeout(resolve, 1000));
return {
model: "test",
role: "assistant",
content: {
type: "text",
text: "Test response",
},
};
},
);

const [clientTransport, serverTransport] =
InMemoryTransport.createLinkedPair();

await Promise.all([
client.connect(clientTransport),
server.connect(serverTransport),
]);

// Set up abort controller
const controller = new AbortController();

// Issue request but cancel it immediately
const createMessagePromise = server.createMessage(
{
messages: [],
maxTokens: 10,
},
{
signal: controller.signal,
},
);
controller.abort("Cancelled by test");

// Request should be rejected
await expect(createMessagePromise).rejects.toBe("Cancelled by test");
});
14 changes: 9 additions & 5 deletions src/server/index.ts
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import {
ProgressCallback,
Protocol,
ProtocolOptions,
RequestOptions,
} from "../shared/protocol.js";
import {
ClientCapabilities,
Expand Down Expand Up @@ -157,6 +157,10 @@ export class Server<
}
break;

case "notifications/cancelled":
// Cancellation notifications are always allowed
break;

case "notifications/progress":
// Progress notifications are always allowed
break;
Expand Down Expand Up @@ -257,23 +261,23 @@ export class Server<

async createMessage(
params: CreateMessageRequest["params"],
onprogress?: ProgressCallback,
options?: RequestOptions,
) {
return this.request(
{ method: "sampling/createMessage", params },
CreateMessageResultSchema,
onprogress,
options,
);
}

async listRoots(
params?: ListRootsRequest["params"],
onprogress?: ProgressCallback,
options?: RequestOptions,
) {
return this.request(
{ method: "roots/list", params },
ListRootsResultSchema,
onprogress,
options,
);
}

Expand Down
Loading