Skip to content

Commit 3e51f70

Browse files
authored
Merge pull request #1154 from ahoppen/ahoppen/dont-crash-if-message-has-missing-field
Don’t crash sourcekit-lsp if a known message is missing a field
2 parents d5e3dbd + be42621 commit 3e51f70

File tree

5 files changed

+189
-68
lines changed

5 files changed

+189
-68
lines changed

Sources/LSPTestSupport/TestJSONRPCConnection.swift

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -197,7 +197,7 @@ public final class TestServer: MessageHandler {
197197

198198
private let testMessageRegistry = MessageRegistry(
199199
requests: [EchoRequest.self, EchoError.self],
200-
notifications: [EchoNotification.self]
200+
notifications: [EchoNotification.self, ShowMessageNotification.self]
201201
)
202202

203203
#if compiler(<5.11)

Sources/LanguageServerProtocolJSONRPC/JSONRPCConnection.swift

Lines changed: 140 additions & 54 deletions
Original file line numberDiff line numberDiff line change
@@ -56,7 +56,7 @@ public final class JSONRPCConnection: Connection {
5656
/// - `init`: Reference to `JSONRPCConnection` trivially can't have escaped to other isolation domains yet.
5757
/// - `start`: Is required to be called in the same serial region as the initializer, so
5858
/// `JSONRPCConnection` can't have escaped to other isolation domains yet.
59-
/// - `_close`: Synchronized on `queue`.
59+
/// - `closeAssumingOnQueue`: Synchronized on `queue`.
6060
/// - `readyToSend`: Synchronized on `queue`.
6161
/// - `deinit`: Can also only trivially be called once.
6262
private nonisolated(unsafe) var state: State
@@ -230,6 +230,131 @@ public final class JSONRPCConnection: Connection {
230230
}
231231
}
232232

233+
/// Send a notification to the client that informs the user about a message decoding error and tells them to file an
234+
/// issue.
235+
///
236+
/// `message` describes what has gone wrong to the user.
237+
///
238+
/// - Important: Must be called on `queue`
239+
private func sendMessageDecodingErrorNotificationToClient(message: String) {
240+
dispatchPrecondition(condition: .onQueue(queue))
241+
let showMessage = ShowMessageNotification(
242+
type: .error,
243+
message: """
244+
\(message). Please run 'sourcekit-lsp diagnose' to file an issue.
245+
"""
246+
)
247+
self.send(.notification(showMessage))
248+
}
249+
250+
/// Decode a single JSONRPC message from the given `messageBytes`.
251+
///
252+
/// `messageBytes` should be valid JSON, ie. this is the message sent from the client without the `Content-Length`
253+
/// header.
254+
///
255+
/// If an error occurs during message parsing, this tries to recover as gracefully as possible and returns `nil`.
256+
/// Callers should consider the message handled and ignore it when this function returns `nil`.
257+
///
258+
/// - Important: Must be called on `queue`
259+
private func decodeJSONRPCMessage(messageBytes: Slice<UnsafeBufferPointer<UInt8>>) -> JSONRPCMessage? {
260+
dispatchPrecondition(condition: .onQueue(queue))
261+
let decoder = JSONDecoder()
262+
263+
// Set message registry to use for model decoding.
264+
decoder.userInfo[.messageRegistryKey] = messageRegistry
265+
266+
// Setup callback for response type.
267+
decoder.userInfo[.responseTypeCallbackKey] = { (id: RequestID) -> ResponseType.Type? in
268+
guard let outstanding = self.outstandingRequests[id] else {
269+
logger.error("Unknown request for \(id, privacy: .public)")
270+
return nil
271+
}
272+
return outstanding.responseType
273+
}
274+
275+
do {
276+
let pointer = UnsafeMutableRawPointer(mutating: UnsafeBufferPointer(rebasing: messageBytes).baseAddress!)
277+
return try decoder.decode(
278+
JSONRPCMessage.self,
279+
from: Data(bytesNoCopy: pointer, count: messageBytes.count, deallocator: .none)
280+
)
281+
} catch let error as MessageDecodingError {
282+
logger.fault("Failed to decode message: \(error.forLogging)")
283+
logger.fault("Malformed message: \(String(bytes: messageBytes, encoding: .utf8) ?? "<invalid UTF-8>")")
284+
285+
// We failed to decode the message. Under those circumstances try to behave as LSP-conforming as possible.
286+
// Always log at the fault level so that we know something is going wrong from the logs.
287+
//
288+
// The pattern below is to handle the message in the best possible way and then `return nil` to acknowledge the
289+
// handling. That way the compiler enforces that we handle all code paths.
290+
switch error.messageKind {
291+
case .request:
292+
if let id = error.id {
293+
// If we know it was a request and we have the request ID, simply reply to the request and tell the client
294+
// that we couldn't parse it. That complies with LSP that all requests should eventually get a response.
295+
logger.fault(
296+
"Replying to request \(id, privacy: .public) with error response because we failed to decode the request"
297+
)
298+
self.send(.errorResponse(ResponseError(error), id: id))
299+
return nil
300+
}
301+
// If we don't know the ID of the request, ignore it and show a notification to the user.
302+
// That way the user at least knows that something is going wrong even if the client never gets a response
303+
// for the request.
304+
logger.fault("Ignoring request because we failed to decode the request and don't have a request ID")
305+
sendMessageDecodingErrorNotificationToClient(message: "sourcekit-lsp failed to decode a request")
306+
return nil
307+
case .response:
308+
if let id = error.id {
309+
if let outstanding = self.outstandingRequests.removeValue(forKey: id) {
310+
// If we received a response to a request we sent to the client, assume that the client responded with an
311+
// error. That complies with LSP that all requests should eventually get a response.
312+
logger.fault(
313+
"Assuming an error response to request \(id, privacy: .public) because response from client could not be decoded"
314+
)
315+
outstanding.replyHandler(.failure(ResponseError(error)))
316+
return nil
317+
}
318+
// If there's an error in the response but we don't even know about the request, we can ignore it.
319+
logger.fault(
320+
"Ignoring response to request \(id, privacy: .public) because it could not be decoded and given request ID is unknown"
321+
)
322+
return nil
323+
}
324+
// And if we can't even recover the ID the response is for, we drop it. This means that whichever code in
325+
// sourcekit-lsp sent the request will probably never get a reply but there's nothing we can do about that.
326+
// Ideally requests sent from sourcekit-lsp to the client would have some kind of timeout anyway.
327+
logger.fault("Ignoring response because its request ID could not be recovered")
328+
return nil
329+
case .notification:
330+
if error.code == .methodNotFound {
331+
// If we receive a notification we don't know about, this might be a client sending a new LSP notification
332+
// that we don't know about. It can't be very critical so we ignore it without bothering the user with an
333+
// error notification.
334+
logger.fault("Ignoring notification because we don't know about it's method")
335+
return nil
336+
}
337+
// Ignoring any other notification might result in corrupted behavior. For example, ignoring a
338+
// `textDocument/didChange` will result in an out-of-sync state between the editor and sourcekit-lsp.
339+
// Warn the user about the error.
340+
logger.fault("Ignoring notification that may cause corrupted behavior")
341+
sendMessageDecodingErrorNotificationToClient(message: "sourcekit-lsp failed to decode a notification")
342+
return nil
343+
case .unknown:
344+
// We don't know what has gone wrong. This could be any level of badness. Inform the user about it.
345+
logger.fault("Ignoring unknown message")
346+
sendMessageDecodingErrorNotificationToClient(message: "sourcekit-lsp failed to decode a message")
347+
return nil
348+
}
349+
} catch {
350+
// We don't know what has gone wrong. This could be any level of badness. Inform the user about it and ignore the
351+
// message.
352+
logger.fault("Ignoring unknown message")
353+
sendMessageDecodingErrorNotificationToClient(message: "sourcekit-lsp failed to decode an unknown message")
354+
return nil
355+
}
356+
}
357+
233358
/// Whether we can send messages in the current state.
234359
///
235360
/// - parameter shouldLog: Whether to log an info message if not ready.
@@ -250,69 +375,30 @@ public final class JSONRPCConnection: Connection {
250375
/// - Important: Must be called on `queue`
251376
func parseAndHandleMessages(from bytes: UnsafeBufferPointer<UInt8>) -> UnsafeBufferPointer<UInt8>.SubSequence {
252377
dispatchPrecondition(condition: .onQueue(queue))
253-
let decoder = JSONDecoder()
254-
255-
// Set message registry to use for model decoding.
256-
decoder.userInfo[.messageRegistryKey] = messageRegistry
257-
258-
// Setup callback for response type.
259-
decoder.userInfo[.responseTypeCallbackKey] =
260-
{ id in
261-
guard let outstanding = self.outstandingRequests[id] else {
262-
logger.error("Unknown request for \(id, privacy: .public)")
263-
return nil
264-
}
265-
return outstanding.responseType
266-
} as JSONRPCMessage.ResponseTypeCallback
267378

268379
var bytes = bytes[...]
269380

270381
MESSAGE_LOOP: while true {
382+
// Split the messages based on the Content-Length header.
383+
let messageBytes: Slice<UnsafeBufferPointer<UInt8>>
271384
do {
272-
guard let ((messageBytes, _), rest) = try bytes.jsonrpcSplitMessage() else {
385+
guard let (header: _, message: message, rest: rest) = try bytes.jsonrpcSplitMessage() else {
273386
return bytes
274387
}
388+
messageBytes = message
275389
bytes = rest
276-
277-
let pointer = UnsafeMutableRawPointer(mutating: UnsafeBufferPointer(rebasing: messageBytes).baseAddress!)
278-
let message = try decoder.decode(
279-
JSONRPCMessage.self,
280-
from: Data(bytesNoCopy: pointer, count: messageBytes.count, deallocator: .none)
281-
)
282-
283-
handle(message)
284-
} catch let error as MessageDecodingError {
285-
switch error.messageKind {
286-
case .request:
287-
if let id = error.id {
288-
queue.async {
289-
self.send(.errorResponse(ResponseError(error), id: id))
290-
}
291-
continue MESSAGE_LOOP
292-
}
293-
case .response:
294-
if let id = error.id {
295-
if let outstanding = self.outstandingRequests.removeValue(forKey: id) {
296-
outstanding.replyHandler(.failure(ResponseError(error)))
297-
} else {
298-
logger.error("error in response to unknown request \(id, privacy: .public) \(error.forLogging)")
299-
}
300-
continue MESSAGE_LOOP
301-
}
302-
case .notification:
303-
if error.code == .methodNotFound {
304-
logger.error("ignoring unknown notification \(error.forLogging)")
305-
continue MESSAGE_LOOP
306-
}
307-
case .unknown:
308-
break
309-
}
310-
// FIXME: graceful shutdown?
311-
fatalError("fatal error encountered decoding message \(error)")
312390
} catch {
313-
// FIXME: graceful shutdown?
314-
fatalError("fatal error encountered decoding message \(error)")
391+
// We failed to parse the message header. There isn't really much we can do to recover because we lost our
392+
// anchor in the stream where new messages start. Crashing and letting ourselves be restarted by the client is
393+
// probably the best option.
394+
sendMessageDecodingErrorNotificationToClient(message: "Failed to find next message in connection to editor")
395+
fatalError("fatal error encountered while splitting JSON RPC messages \(error)")
396+
}
397+
398+
guard let message = decodeJSONRPCMessage(messageBytes: messageBytes) else {
399+
continue
315400
}
401+
handle(message)
316402
}
317403
}
318404

Sources/LanguageServerProtocolJSONRPC/MessageSplitting.swift

Lines changed: 18 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,7 @@ import LanguageServerProtocol
1515
public struct JSONRPCMessageHeader: Hashable {
1616
static let contentLengthKey: [UInt8] = [UInt8]("Content-Length".utf8)
1717
static let separator: [UInt8] = [UInt8]("\r\n".utf8)
18-
static let colon: UInt8 = ":".utf8.first!
18+
static let colon: UInt8 = UInt8(ascii: ":")
1919
static let invalidKeyBytes: [UInt8] = [colon] + separator
2020

2121
public var contentLength: Int? = nil
@@ -25,21 +25,29 @@ public struct JSONRPCMessageHeader: Hashable {
2525
}
2626
}
2727

28-
extension RandomAccessCollection where Element == UInt8 {
29-
30-
/// Returns the first message range and header in `self`, or nil.
31-
public func jsonrpcSplitMessage()
32-
throws -> ((SubSequence, header: JSONRPCMessageHeader), SubSequence)?
33-
{
28+
extension RandomAccessCollection<UInt8> {
29+
/// Tries to parse a single message from this collection of bytes.
30+
///
31+
/// If an entire message could be found, returns
32+
/// - header (representing `Content-Length:<length>\r\n\r\n`)
33+
/// - message: The data that represents the actual message as JSON
34+
/// - rest: The remaining bytes that haven't weren't part of the first message in this collection
35+
///
36+
/// If a `Content-Length` header could be found but the collection doesn't have enough bytes for the entire message
37+
/// (eg. because the `Content-Length` header has been transmitted yet but not the entire message), returns `nil`.
38+
/// Callers should call this method again once more data is available.
39+
@_spi(Testing)
40+
public func jsonrpcSplitMessage() throws -> (header: JSONRPCMessageHeader, message: SubSequence, rest: SubSequence)? {
3441
guard let (header, rest) = try jsonrcpParseHeader() else { return nil }
3542
guard let contentLength = header.contentLength else {
3643
throw MessageDecodingError.parseError("missing Content-Length header")
3744
}
3845
if contentLength > rest.count { return nil }
39-
return ((rest.prefix(contentLength), header: header), rest.dropFirst(contentLength))
46+
return (header: header, message: rest.prefix(contentLength), rest: rest.dropFirst(contentLength))
4047
}
4148

42-
public func jsonrcpParseHeader() throws -> (JSONRPCMessageHeader, SubSequence)? {
49+
@_spi(Testing)
50+
public func jsonrcpParseHeader() throws -> (header: JSONRPCMessageHeader, rest: SubSequence)? {
4351
var header = JSONRPCMessageHeader()
4452
var slice = self[...]
4553
while let (kv, rest) = try slice.jsonrpcParseHeaderField() {
@@ -62,6 +70,7 @@ extension RandomAccessCollection where Element == UInt8 {
6270
return nil
6371
}
6472

73+
@_spi(Testing)
6574
public func jsonrpcParseHeaderField() throws -> ((key: SubSequence, value: SubSequence)?, SubSequence)? {
6675
if starts(with: JSONRPCMessageHeader.separator) {
6776
return (nil, dropFirst(JSONRPCMessageHeader.separator.count))
@@ -85,11 +94,9 @@ extension RandomAccessCollection where Element == UInt8 {
8594
}
8695

8796
extension RandomAccessCollection where Element: Equatable {
88-
8997
/// Returns the first index where the specified subsequence appears or nil.
9098
@inlinable
9199
public func firstIndex(of pattern: some RandomAccessCollection<Element>) -> Index? {
92-
93100
if pattern.isEmpty {
94101
return startIndex
95102
}

Tests/LanguageServerProtocolJSONRPCTests/ConnectionTests.swift

Lines changed: 28 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -279,4 +279,32 @@ class ConnectionTests: XCTestCase {
279279
}
280280
}
281281
}
282+
283+
func testMessageWithMissingParameter() async throws {
284+
let expectation = self.expectation(description: "Received ShowMessageNotification")
285+
await connection.client.appendOneShotNotificationHandler { (note: ShowMessageNotification) in
286+
XCTAssertEqual(note.type, .error)
287+
expectation.fulfill()
288+
}
289+
290+
let messageContents = """
291+
{
292+
"method": "test_server/echo_note",
293+
"jsonrpc": "2.0",
294+
"params": {}
295+
}
296+
"""
297+
connection.clientToServerConnection.send(message: messageContents)
298+
299+
try await self.fulfillmentOfOrThrow([expectation])
300+
}
301+
}
302+
303+
fileprivate extension JSONRPCConnection {
304+
func send(message: String) {
305+
let messageWithHeader = "Content-Length: \(message.utf8.count)\r\n\r\n\(message)".data(using: .utf8)!
306+
messageWithHeader.withUnsafeBytes { bytes in
307+
send(_rawData: DispatchData(bytes: bytes))
308+
}
309+
}
282310
}

Tests/LanguageServerProtocolJSONRPCTests/MessageParsingTests.swift

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@
1111
//===----------------------------------------------------------------------===//
1212

1313
import LanguageServerProtocol
14-
import LanguageServerProtocolJSONRPC
14+
@_spi(Testing) import LanguageServerProtocolJSONRPC
1515
import XCTest
1616

1717
final class MessageParsingTests: XCTestCase {
@@ -25,7 +25,7 @@ final class MessageParsingTests: XCTestCase {
2525
line: UInt = #line
2626
) throws {
2727
let bytes: [UInt8] = [UInt8](string.utf8)
28-
guard let ((content, header), rest) = try bytes.jsonrpcSplitMessage() else {
28+
guard let (header, content, rest) = try bytes.jsonrpcSplitMessage() else {
2929
XCTAssert(restLen == nil, "expected non-empty field", file: file, line: line)
3030
return
3131
}

0 commit comments

Comments
 (0)