diff --git a/.idea/misc.xml b/.idea/misc.xml
index 6f735e5..195deed 100644
--- a/.idea/misc.xml
+++ b/.idea/misc.xml
@@ -14,4 +14,4 @@
-
+
\ No newline at end of file
diff --git a/src/commonMain/kotlin/io/modelcontextprotocol/kotlin/sdk/client/Client.kt b/src/commonMain/kotlin/io/modelcontextprotocol/kotlin/sdk/client/Client.kt
index 7b38bb4..1bd7372 100644
--- a/src/commonMain/kotlin/io/modelcontextprotocol/kotlin/sdk/client/Client.kt
+++ b/src/commonMain/kotlin/io/modelcontextprotocol/kotlin/sdk/client/Client.kt
@@ -10,6 +10,7 @@ import kotlinx.serialization.json.JsonElement
import kotlinx.serialization.json.JsonNull
import kotlinx.serialization.json.JsonObject
import kotlinx.serialization.json.JsonPrimitive
+import kotlin.coroutines.cancellation.CancellationException
/**
* Options for configuring the MCP client.
@@ -100,7 +101,12 @@ public open class Client(
notification(InitializedNotification())
} catch (error: Throwable) {
close()
+ if (error !is CancellationException) {
+ throw IllegalStateException("Error connecting to transport: ${error.message}")
+ }
+
throw error
+
}
}
diff --git a/src/jvmTest/kotlin/client/ClientTest.kt b/src/jvmTest/kotlin/client/ClientTest.kt
index ef60fb9..4630a34 100644
--- a/src/jvmTest/kotlin/client/ClientTest.kt
+++ b/src/jvmTest/kotlin/client/ClientTest.kt
@@ -6,6 +6,8 @@ import io.modelcontextprotocol.kotlin.sdk.CreateMessageResult
import io.modelcontextprotocol.kotlin.sdk.EmptyJsonObject
import io.modelcontextprotocol.kotlin.sdk.Implementation
import InMemoryTransport
+import io.mockk.coEvery
+import io.mockk.spyk
import io.modelcontextprotocol.kotlin.sdk.InitializeRequest
import io.modelcontextprotocol.kotlin.sdk.InitializeResult
import io.modelcontextprotocol.kotlin.sdk.JSONRPCMessage
@@ -49,13 +51,13 @@ import kotlin.test.fail
class ClientTest {
@Test
fun `should initialize with matching protocol version`() = runTest {
- var initialied = false
+ var initialised = false
val clientTransport = object : AbstractTransport() {
override suspend fun start() {}
override suspend fun send(message: JSONRPCMessage) {
if (message !is JSONRPCRequest) return
- initialied = true
+ initialised = true
val result = InitializeResult(
protocolVersion = LATEST_PROTOCOL_VERSION,
capabilities = ServerCapabilities(),
@@ -90,7 +92,7 @@ class ClientTest {
)
client.connect(clientTransport)
- assertTrue(initialied)
+ assertTrue(initialised)
}
@Test
@@ -189,6 +191,61 @@ class ClientTest {
assertTrue(closed)
}
+ @Test
+ fun `should reject due to non cancellation exception`() = runTest {
+ var closed = false
+ val clientTransport = object : AbstractTransport() {
+ override suspend fun start() {}
+
+ override suspend fun send(message: JSONRPCMessage) {
+ if (message !is JSONRPCRequest) return
+ check(message.method == Method.Defined.Initialize.value)
+
+ val result = InitializeResult(
+ protocolVersion = LATEST_PROTOCOL_VERSION,
+ capabilities = ServerCapabilities(),
+ serverInfo = Implementation(
+ name = "test",
+ version = "1.0"
+ )
+ )
+
+ val response = JSONRPCResponse(
+ id = message.id,
+ result = result
+ )
+
+ _onMessage.invoke(response)
+ }
+
+ override suspend fun close() {
+ closed = true
+ }
+ }
+
+ val mockClient = spyk(
+ Client(
+ clientInfo = Implementation(
+ name = "test client",
+ version = "1.0"
+ ),
+ options = ClientOptions()
+ )
+ )
+
+ coEvery{
+ mockClient.request(any())
+ } throws IllegalStateException("Test error")
+
+ val exception = assertFailsWith {
+ mockClient.connect(clientTransport)
+ }
+
+ assertEquals("Error connecting to transport: Test error", exception.message)
+
+ assertTrue(closed)
+ }
+
@Test
fun `should respect server capabilities`() = runTest {
val serverOptions = ServerOptions(