Skip to content

Fix test for misbehaving server #379

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 5 additions & 2 deletions Tests/AsyncHTTPClientTests/HTTPClient+SOCKSTests.swift
Original file line number Diff line number Diff line change
Expand Up @@ -131,7 +131,10 @@ class HTTPClientSOCKSTests: XCTestCase {
XCTAssertNoThrow(try socksBin.shutdown())
}

// the server will send a bogus message in response to the clients request
XCTAssertThrowsError(try localClient.get(url: "http://localhost/socks/test").wait())
// the server will send a bogus message in response to the clients greeting
// this will be first picked up as an invalid protocol
XCTAssertThrowsError(try localClient.get(url: "http://localhost/socks/test").wait()) { e in
XCTAssertTrue(e is SOCKSError.InvalidProtocolVersion)
}
}
}
49 changes: 30 additions & 19 deletions Tests/AsyncHTTPClientTests/SOCKSTestUtils.swift
Original file line number Diff line number Diff line change
Expand Up @@ -22,21 +22,40 @@ struct MockSOCKSError: Error, Hashable {
var description: String
}

class TestSOCKSBadServerHandler: ChannelInboundHandler {
typealias InboundIn = ByteBuffer

func channelRead(context: ChannelHandlerContext, data: NIOAny) {
// just write some nonsense bytes
let buffer = context.channel.allocator.buffer(bytes: [0xAA, 0xBB, 0xCC, 0xDD, 0xEE])
context.writeAndFlush(.init(buffer), promise: nil)
}
}

class MockSOCKSServer {
let channel: Channel

init(expectedURL: String, expectedResponse: String, misbehave: Bool = false, file: String = #file, line: UInt = #line) throws {
let elg = MultiThreadedEventLoopGroup(numberOfThreads: 1)
let bootstrap = ServerBootstrap(group: elg)
.serverChannelOption(ChannelOptions.socket(SocketOptionLevel(SOL_SOCKET), SO_REUSEADDR), value: 1)
.childChannelInitializer { channel in
let handshakeHandler = SOCKSServerHandshakeHandler()
return channel.pipeline.addHandlers([
handshakeHandler,
SOCKSTestHandler(handshakeHandler: handshakeHandler, misbehave: misbehave),
TestHTTPServer(expectedURL: expectedURL, expectedResponse: expectedResponse, file: file, line: line),
])
}
let bootstrap: ServerBootstrap
if misbehave {
bootstrap = ServerBootstrap(group: elg)
.serverChannelOption(ChannelOptions.socket(SocketOptionLevel(SOL_SOCKET), SO_REUSEADDR), value: 1)
.childChannelInitializer { channel in
channel.pipeline.addHandler(TestSOCKSBadServerHandler())
}
} else {
bootstrap = ServerBootstrap(group: elg)
.serverChannelOption(ChannelOptions.socket(SocketOptionLevel(SOL_SOCKET), SO_REUSEADDR), value: 1)
.childChannelInitializer { channel in
let handshakeHandler = SOCKSServerHandshakeHandler()
return channel.pipeline.addHandlers([
handshakeHandler,
SOCKSTestHandler(handshakeHandler: handshakeHandler),
TestHTTPServer(expectedURL: expectedURL, expectedResponse: expectedResponse, file: file, line: line),
])
}
}
self.channel = try bootstrap.bind(host: "localhost", port: 1080).wait()
}

Expand All @@ -49,11 +68,9 @@ class SOCKSTestHandler: ChannelInboundHandler, RemovableChannelHandler {
typealias InboundIn = ClientMessage

let handshakeHandler: SOCKSServerHandshakeHandler
let misbehave: Bool

init(handshakeHandler: SOCKSServerHandshakeHandler, misbehave: Bool) {
init(handshakeHandler: SOCKSServerHandshakeHandler) {
self.handshakeHandler = handshakeHandler
self.misbehave = misbehave
}

func channelRead(context: ChannelHandlerContext, data: NIOAny) {
Expand All @@ -69,12 +86,6 @@ class SOCKSTestHandler: ChannelInboundHandler, RemovableChannelHandler {
case .authenticationData:
context.fireErrorCaught(MockSOCKSError(description: "Received authentication data but didn't receive any."))
case .request(let request):
guard !self.misbehave else {
context.writeAndFlush(
.init(ServerMessage.authenticationData(context.channel.allocator.buffer(string: "bad server!"), complete: true)), promise: nil
)
return
}
context.writeAndFlush(.init(
ServerMessage.response(.init(reply: .succeeded, boundAddress: request.addressType))), promise: nil)
context.channel.pipeline.addHandlers([
Expand Down