diff --git a/Sources/GRPC/GRPCServerPipelineConfigurator.swift b/Sources/GRPC/GRPCServerPipelineConfigurator.swift index 5854faea9..61155426b 100644 --- a/Sources/GRPC/GRPCServerPipelineConfigurator.swift +++ b/Sources/GRPC/GRPCServerPipelineConfigurator.swift @@ -33,8 +33,8 @@ final class GRPCServerPipelineConfigurator: ChannelInboundHandler, RemovableChan /// The server configuration. private let configuration: Server.Configuration - /// Reads which we're holding on to before the pipeline is configured. - private var bufferedReads = CircularBuffer() + /// A buffer containing the buffered bytes. + private var buffer: ByteBuffer? /// The current state. private var state: State @@ -212,13 +212,17 @@ final class GRPCServerPipelineConfigurator: ChannelInboundHandler, RemovableChan buffer: ByteBuffer, context: ChannelHandlerContext ) { - if HTTPVersionParser.prefixedWithHTTP2ConnectionPreface(buffer) { + switch HTTPVersionParser.determineHTTPVersion(buffer) { + case .http2: self.configureHTTP2(context: context) - } else if HTTPVersionParser.prefixedWithHTTP1RequestLine(buffer) { + case .http1: self.configureHTTP1(context: context) - } else { + case .unknown: + // Neither H2 nor H1 or the length limit has been exceeded. self.configuration.logger.error("Unable to determine http version, closing") context.close(mode: .all, promise: nil) + case .notEnoughBytes: + () // Try again with more bytes. } } @@ -268,13 +272,9 @@ final class GRPCServerPipelineConfigurator: ChannelInboundHandler, RemovableChan /// Try to parse the buffered data to determine whether or not HTTP/2 or HTTP/1 should be used. private func tryParsingBufferedData(context: ChannelHandlerContext) { - guard let first = self.bufferedReads.first else { - // No data buffered yet. We'll try when we read. - return + if let buffer = self.buffer { + self.determineHTTPVersionAndConfigurePipeline(buffer: buffer, context: context) } - - let buffer = self.unwrapInboundIn(first) - self.determineHTTPVersionAndConfigurePipeline(buffer: buffer, context: context) } // MARK: - Channel Handler @@ -312,7 +312,8 @@ final class GRPCServerPipelineConfigurator: ChannelInboundHandler, RemovableChan } internal func channelRead(context: ChannelHandlerContext, data: NIOAny) { - self.bufferedReads.append(data) + var buffer = self.unwrapInboundIn(data) + self.buffer.setOrWriteBuffer(&buffer) switch self.state { case .notConfigured(alpn: .notExpected), @@ -335,8 +336,9 @@ final class GRPCServerPipelineConfigurator: ChannelInboundHandler, RemovableChan removalToken: ChannelHandlerContext.RemovalToken ) { // Forward any buffered reads. - while let read = self.bufferedReads.popFirst() { - context.fireChannelRead(read) + if let buffer = self.buffer { + self.buffer = nil + context.fireChannelRead(self.wrapInboundOut(buffer)) } context.leavePipeline(removalToken: removalToken) } @@ -375,16 +377,64 @@ struct HTTPVersionParser { /// Determines whether the bytes in the `ByteBuffer` are prefixed with the HTTP/2 client /// connection preface. - static func prefixedWithHTTP2ConnectionPreface(_ buffer: ByteBuffer) -> Bool { + static func prefixedWithHTTP2ConnectionPreface(_ buffer: ByteBuffer) -> SubParseResult { let view = buffer.readableBytesView guard view.count >= HTTPVersionParser.http2ClientMagic.count else { // Not enough bytes. - return false + return .notEnoughBytes } let slice = view[view.startIndex ..< view.startIndex.advanced(by: self.http2ClientMagic.count)] - return slice.elementsEqual(HTTPVersionParser.http2ClientMagic) + return slice.elementsEqual(HTTPVersionParser.http2ClientMagic) ? .accepted : .rejected + } + + enum ParseResult: Hashable { + case http1 + case http2 + case unknown + case notEnoughBytes + } + + enum SubParseResult: Hashable { + case accepted + case rejected + case notEnoughBytes + } + + private static let maxLengthToCheck = 1024 + + static func determineHTTPVersion(_ buffer: ByteBuffer) -> ParseResult { + switch Self.prefixedWithHTTP2ConnectionPreface(buffer) { + case .accepted: + return .http2 + + case .notEnoughBytes: + switch Self.prefixedWithHTTP1RequestLine(buffer) { + case .accepted: + // Not enough bytes to check H2, but enough to confirm H1. + return .http1 + case .notEnoughBytes: + // Not enough bytes to check H2 or H1. + return .notEnoughBytes + case .rejected: + // Not enough bytes to check H2 and definitely not H1. + return .notEnoughBytes + } + + case .rejected: + switch Self.prefixedWithHTTP1RequestLine(buffer) { + case .accepted: + // Not H2, but H1 is confirmed. + return .http1 + case .notEnoughBytes: + // Not H2, but not enough bytes to reject H1 yet. + return .notEnoughBytes + case .rejected: + // Not H2 or H1. + return .unknown + } + } } private static let http1_1 = [ @@ -399,29 +449,59 @@ struct HTTPVersionParser { ] /// Determines whether the bytes in the `ByteBuffer` are prefixed with an HTTP/1.1 request line. - static func prefixedWithHTTP1RequestLine(_ buffer: ByteBuffer) -> Bool { + static func prefixedWithHTTP1RequestLine(_ buffer: ByteBuffer) -> SubParseResult { var readableBytesView = buffer.readableBytesView + // We don't need to validate the request line, only determine whether we think it's an HTTP1 + // request line. Another handler will parse it properly. + // From RFC 2616 ยง 5.1: // Request-Line = Method SP Request-URI SP HTTP-Version CRLF - // Read off the Method and Request-URI (and spaces). - guard readableBytesView.trimPrefix(to: UInt8(ascii: " ")) != nil, - readableBytesView.trimPrefix(to: UInt8(ascii: " ")) != nil else { - return false + // Get through the first space. + guard readableBytesView.dropPrefix(through: UInt8(ascii: " ")) != nil else { + let tooLong = buffer.readableBytes > Self.maxLengthToCheck + return tooLong ? .rejected : .notEnoughBytes + } + + // Get through the second space. + guard readableBytesView.dropPrefix(through: UInt8(ascii: " ")) != nil else { + let tooLong = buffer.readableBytes > Self.maxLengthToCheck + return tooLong ? .rejected : .notEnoughBytes + } + + // +2 for \r\n + guard readableBytesView.count >= (Self.http1_1.count + 2) else { + return .notEnoughBytes } - // Read off the HTTP-Version and CR. - guard let versionView = readableBytesView.trimPrefix(to: UInt8(ascii: "\r")) else { - return false + guard let version = readableBytesView.dropPrefix(through: UInt8(ascii: "\r")), + readableBytesView.first == UInt8(ascii: "\n") else { + // If we didn't drop the prefix OR we did and the next byte wasn't '\n', then we had enough + // bytes but the '\r\n' wasn't present: reject this as being HTTP1. + return .rejected + } + + return version.elementsEqual(Self.http1_1) ? .accepted : .rejected + } +} + +extension Collection where Self == Self.SubSequence, Self.Element: Equatable { + /// Drops the prefix off the collection up to and including the first `separator` + /// only if that separator appears in the collection. + /// + /// Returns the prefix up to but not including the separator if it was found, nil otherwise. + mutating func dropPrefix(through separator: Element) -> SubSequence? { + if self.isEmpty { + return nil } - // Check that the LF followed the CR. - guard readableBytesView.first == UInt8(ascii: "\n") else { - return false + guard let separatorIndex = self.firstIndex(of: separator) else { + return nil } - // Now check the HTTP version. - return versionView.elementsEqual(HTTPVersionParser.http1_1) + let prefix = self[..