diff --git a/.idea/misc.xml b/.idea/misc.xml index 6f735e5..195deed 100644 --- a/.idea/misc.xml +++ b/.idea/misc.xml @@ -14,4 +14,4 @@ - + \ No newline at end of file diff --git a/src/commonMain/kotlin/io/modelcontextprotocol/kotlin/sdk/client/Client.kt b/src/commonMain/kotlin/io/modelcontextprotocol/kotlin/sdk/client/Client.kt index 7b38bb4..1bd7372 100644 --- a/src/commonMain/kotlin/io/modelcontextprotocol/kotlin/sdk/client/Client.kt +++ b/src/commonMain/kotlin/io/modelcontextprotocol/kotlin/sdk/client/Client.kt @@ -10,6 +10,7 @@ import kotlinx.serialization.json.JsonElement import kotlinx.serialization.json.JsonNull import kotlinx.serialization.json.JsonObject import kotlinx.serialization.json.JsonPrimitive +import kotlin.coroutines.cancellation.CancellationException /** * Options for configuring the MCP client. @@ -100,7 +101,12 @@ public open class Client( notification(InitializedNotification()) } catch (error: Throwable) { close() + if (error !is CancellationException) { + throw IllegalStateException("Error connecting to transport: ${error.message}") + } + throw error + } } diff --git a/src/jvmTest/kotlin/client/ClientTest.kt b/src/jvmTest/kotlin/client/ClientTest.kt index ef60fb9..4630a34 100644 --- a/src/jvmTest/kotlin/client/ClientTest.kt +++ b/src/jvmTest/kotlin/client/ClientTest.kt @@ -6,6 +6,8 @@ import io.modelcontextprotocol.kotlin.sdk.CreateMessageResult import io.modelcontextprotocol.kotlin.sdk.EmptyJsonObject import io.modelcontextprotocol.kotlin.sdk.Implementation import InMemoryTransport +import io.mockk.coEvery +import io.mockk.spyk import io.modelcontextprotocol.kotlin.sdk.InitializeRequest import io.modelcontextprotocol.kotlin.sdk.InitializeResult import io.modelcontextprotocol.kotlin.sdk.JSONRPCMessage @@ -49,13 +51,13 @@ import kotlin.test.fail class ClientTest { @Test fun `should initialize with matching protocol version`() = runTest { - var initialied = false + var initialised = false val clientTransport = object : AbstractTransport() { override suspend fun start() {} override suspend fun send(message: JSONRPCMessage) { if (message !is JSONRPCRequest) return - initialied = true + initialised = true val result = InitializeResult( protocolVersion = LATEST_PROTOCOL_VERSION, capabilities = ServerCapabilities(), @@ -90,7 +92,7 @@ class ClientTest { ) client.connect(clientTransport) - assertTrue(initialied) + assertTrue(initialised) } @Test @@ -189,6 +191,61 @@ class ClientTest { assertTrue(closed) } + @Test + fun `should reject due to non cancellation exception`() = runTest { + var closed = false + val clientTransport = object : AbstractTransport() { + override suspend fun start() {} + + override suspend fun send(message: JSONRPCMessage) { + if (message !is JSONRPCRequest) return + check(message.method == Method.Defined.Initialize.value) + + val result = InitializeResult( + protocolVersion = LATEST_PROTOCOL_VERSION, + capabilities = ServerCapabilities(), + serverInfo = Implementation( + name = "test", + version = "1.0" + ) + ) + + val response = JSONRPCResponse( + id = message.id, + result = result + ) + + _onMessage.invoke(response) + } + + override suspend fun close() { + closed = true + } + } + + val mockClient = spyk( + Client( + clientInfo = Implementation( + name = "test client", + version = "1.0" + ), + options = ClientOptions() + ) + ) + + coEvery{ + mockClient.request(any()) + } throws IllegalStateException("Test error") + + val exception = assertFailsWith { + mockClient.connect(clientTransport) + } + + assertEquals("Error connecting to transport: Test error", exception.message) + + assertTrue(closed) + } + @Test fun `should respect server capabilities`() = runTest { val serverOptions = ServerOptions(