diff --git a/src/client/stdio.test.ts b/src/client/stdio.test.ts index 646f9ea5..c3e716b6 100644 --- a/src/client/stdio.test.ts +++ b/src/client/stdio.test.ts @@ -1,5 +1,6 @@ import { JSONRPCMessage } from "../types.js"; import { StdioClientTransport, StdioServerParameters } from "./stdio.js"; +import { ChildProcess } from "node:child_process"; const serverParameters: StdioServerParameters = { command: "/usr/bin/tee", @@ -22,6 +23,77 @@ test("should start then close cleanly", async () => { expect(didClose).toBeTruthy(); }); +test("should gracefully terminate the process", async () => { + const killSpy = jest.spyOn(ChildProcess.prototype, "kill"); + + jest.spyOn(global, "setTimeout").mockImplementationOnce((callback) => { + if (typeof callback === "function") { + callback(); + } + return 1 as unknown as NodeJS.Timeout; + }); + + const client = new StdioClientTransport(serverParameters); + + const mockProcess = { + kill: jest.fn(), + exitCode: null, + once: jest.fn().mockImplementation((event, handler) => { + if ( + mockProcess.kill.mock.calls.length === 2 && + (event === "exit" || event === "close") + ) { + setTimeout(() => handler(), 0); + } + return mockProcess; + }), + }; + + // @ts-expect-error accessing private property for testing + client._process = mockProcess; + + await client.close(); + + expect(mockProcess.kill).toHaveBeenNthCalledWith(1, "SIGTERM"); + expect(mockProcess.kill).toHaveBeenNthCalledWith(2, "SIGKILL"); + expect(mockProcess.kill).toHaveBeenCalledTimes(2); + + killSpy.mockRestore(); +}); + +test("should exit cleanly if SIGTERM works", async () => { + const client = new StdioClientTransport(serverParameters); + + const callbacks: Record = {}; + + const mockProcess = { + kill: jest.fn(), + exitCode: null, + once: jest.fn((event, callback) => { + callbacks[event] = callback; + return mockProcess; + }), + } as unknown as ChildProcess; + + // @ts-expect-error accessing private property for testing + client._process = mockProcess; + + // @ts-expect-error accessing private property for testing + client._abortController = { abort: jest.fn() }; + + const closePromise = client.close(); + + expect(mockProcess.kill).toHaveBeenCalledWith("SIGTERM"); + expect(mockProcess.once).toHaveBeenCalledWith("exit", expect.any(Function)); + + callbacks.exit && callbacks.exit(); + + await closePromise; + + expect(mockProcess.kill).toHaveBeenCalledWith("SIGTERM"); + expect(mockProcess.kill).toHaveBeenCalledTimes(1); +}); + test("should read messages", async () => { const client = new StdioClientTransport(serverParameters); client.onerror = (error) => { diff --git a/src/client/stdio.ts b/src/client/stdio.ts index 9e35293d..36ca001e 100644 --- a/src/client/stdio.ts +++ b/src/client/stdio.ts @@ -45,20 +45,20 @@ export type StdioServerParameters = { export const DEFAULT_INHERITED_ENV_VARS = process.platform === "win32" ? [ - "APPDATA", - "HOMEDRIVE", - "HOMEPATH", - "LOCALAPPDATA", - "PATH", - "PROCESSOR_ARCHITECTURE", - "SYSTEMDRIVE", - "SYSTEMROOT", - "TEMP", - "USERNAME", - "USERPROFILE", - ] + "APPDATA", + "HOMEDRIVE", + "HOMEPATH", + "LOCALAPPDATA", + "PATH", + "PROCESSOR_ARCHITECTURE", + "SYSTEMDRIVE", + "SYSTEMROOT", + "TEMP", + "USERNAME", + "USERPROFILE", + ] : /* list inspired by the default env inheritance of sudo */ - ["HOME", "LOGNAME", "PATH", "SHELL", "TERM", "USER"]; + ["HOME", "LOGNAME", "PATH", "SHELL", "TERM", "USER"]; /** * Returns a default environment object including only environment variables deemed safe to inherit. @@ -112,7 +112,7 @@ export class StdioClientTransport implements Transport { async start(): Promise { if (this._process) { throw new Error( - "StdioClientTransport already started! If using Client class, note that connect() calls start() automatically." + "StdioClientTransport already started! If using Client class, note that connect() calls start() automatically.", ); } @@ -127,7 +127,7 @@ export class StdioClientTransport implements Transport { signal: this._abortController.signal, windowsHide: process.platform === "win32" && isElectron(), cwd: this._serverParams.cwd, - } + }, ); this._process.on("error", (error) => { @@ -201,8 +201,35 @@ export class StdioClientTransport implements Transport { async close(): Promise { this._abortController.abort(); + + const taskProcess = this._process; this._process = undefined; this._readBuffer.clear(); + + if (!taskProcess || taskProcess.exitCode !== null) { + return; + } + + taskProcess.kill("SIGTERM"); + + const exited = await Promise.race([ + new Promise((resolve) => { + const onExit = () => resolve(true); + taskProcess.once("exit", onExit); + taskProcess.once("close", onExit); + }), + new Promise((resolve) => { + setTimeout(() => resolve(false), 1000); + }), + ]); + + if (!exited) { + taskProcess.kill("SIGKILL"); + await new Promise((resolve) => { + taskProcess.once("exit", resolve); + taskProcess.once("close", resolve); + }); + } } send(message: JSONRPCMessage): Promise {