Skip to content

add support for redirect limits #113

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 12 commits into from
Oct 23, 2019
57 changes: 50 additions & 7 deletions Sources/AsyncHTTPClient/HTTPClient.swift
Original file line number Diff line number Diff line change
Expand Up @@ -227,15 +227,20 @@ public class HTTPClient {
channelEL: EventLoop? = nil,
deadline: NIODeadline? = nil) -> Task<Delegate.Response> {
let redirectHandler: RedirectHandler<Delegate.Response>?
if self.configuration.followRedirects {
switch self.configuration.redirectConfiguration.configuration {
case .follow(let max, let allowCycles):
var request = request
if request.redirectState == nil {
request.redirectState = .init(count: max, visited: allowCycles ? nil : Set())
}
redirectHandler = RedirectHandler<Delegate.Response>(request: request) { newRequest in
self.execute(request: newRequest,
delegate: delegate,
eventLoop: delegateEL,
channelEL: channelEL,
deadline: deadline)
}
} else {
case .disallow:
redirectHandler = nil
}

Expand Down Expand Up @@ -325,7 +330,7 @@ public class HTTPClient {
/// - `305: Use Proxy`
/// - `307: Temporary Redirect`
/// - `308: Permanent Redirect`
public var followRedirects: Bool
public var redirectConfiguration: RedirectConfiguration
/// Default client timeout, defaults to no timeouts.
public var timeout: Timeout
/// Upstream proxy, defaults to no proxy.
Expand All @@ -336,27 +341,27 @@ public class HTTPClient {
public var ignoreUncleanSSLShutdown: Bool

public init(tlsConfiguration: TLSConfiguration? = nil,
followRedirects: Bool = false,
redirectConfiguration: RedirectConfiguration? = nil,
timeout: Timeout = Timeout(),
proxy: Proxy? = nil,
ignoreUncleanSSLShutdown: Bool = false,
decompression: Decompression = .disabled) {
self.tlsConfiguration = tlsConfiguration
self.followRedirects = followRedirects
self.redirectConfiguration = redirectConfiguration ?? RedirectConfiguration()
self.timeout = timeout
self.proxy = proxy
self.ignoreUncleanSSLShutdown = ignoreUncleanSSLShutdown
self.decompression = decompression
}

public init(certificateVerification: CertificateVerification,
followRedirects: Bool = false,
redirectConfiguration: RedirectConfiguration? = nil,
timeout: Timeout = Timeout(),
proxy: Proxy? = nil,
ignoreUncleanSSLShutdown: Bool = false,
decompression: Decompression = .disabled) {
self.tlsConfiguration = TLSConfiguration.forClient(certificateVerification: certificateVerification)
self.followRedirects = followRedirects
self.redirectConfiguration = redirectConfiguration ?? RedirectConfiguration()
self.timeout = timeout
self.proxy = proxy
self.ignoreUncleanSSLShutdown = ignoreUncleanSSLShutdown
Expand Down Expand Up @@ -439,6 +444,38 @@ extension HTTPClient.Configuration {
self.read = read
}
}

/// Specifies redirect processing settings.
public struct RedirectConfiguration {
enum Configuration {
/// Redirects are not followed.
case disallow
/// Redirects are followed with a specified limit.
case follow(max: Int, allowCycles: Bool)
}

var configuration: Configuration

init() {
self.configuration = .follow(max: 5, allowCycles: false)
}

init(configuration: Configuration) {
self.configuration = configuration
}

/// Redirects are not followed.
public static let disallow = RedirectConfiguration(configuration: .disallow)

/// Redirects are followed with a specified limit.
///
/// - parameters:
/// - max: The maximum number of allowed redirects.
/// - allowCycles: Whether cycles are allowed.
///
/// - warning: Cycle detection will keep all visited URLs in memory which means a malicious server could use this as a denial-of-service vector.
public static func follow(max: Int, allowCycles: Bool) -> RedirectConfiguration { return .init(configuration: .follow(max: max, allowCycles: allowCycles)) }
}
}

private extension ChannelPipeline {
Expand Down Expand Up @@ -488,6 +525,8 @@ public struct HTTPClientError: Error, Equatable, CustomStringConvertible {
case invalidProxyResponse
case contentLengthMissing
case proxyAuthenticationRequired
case redirectLimitReached
case redirectCycleDetected
}

private var code: Code
Expand Down Expand Up @@ -526,4 +565,8 @@ public struct HTTPClientError: Error, Equatable, CustomStringConvertible {
public static let contentLengthMissing = HTTPClientError(code: .contentLengthMissing)
/// Proxy Authentication Required.
public static let proxyAuthenticationRequired = HTTPClientError(code: .proxyAuthenticationRequired)
/// Redirect Limit reached.
public static let redirectLimitReached = HTTPClientError(code: .redirectLimitReached)
/// Redirect Cycle detected.
public static let redirectCycleDetected = HTTPClientError(code: .redirectCycleDetected)
}
32 changes: 31 additions & 1 deletion Sources/AsyncHTTPClient/HTTPHandler.swift
Original file line number Diff line number Diff line change
Expand Up @@ -100,6 +100,13 @@ extension HTTPClient {
/// Request body, defaults to no body.
public var body: Body?

struct RedirectState {
var count: Int
var visited: Set<URL>?
}

var redirectState: RedirectState?

/// Create HTTP request.
///
/// - parameters:
Expand Down Expand Up @@ -152,6 +159,8 @@ extension HTTPClient {
self.host = host
self.headers = headers
self.body = body

self.redirectState = nil
}

/// Whether request will be executed using secure socket.
Expand Down Expand Up @@ -813,6 +822,26 @@ internal struct RedirectHandler<ResponseType> {
}

func redirect(status: HTTPResponseStatus, to redirectURL: URL, promise: EventLoopPromise<ResponseType>) {
var nextState: HTTPClient.Request.RedirectState?
if var state = request.redirectState {
guard state.count > 0 else {
return promise.fail(HTTPClientError.redirectLimitReached)
}

state.count -= 1

if var visited = state.visited {
guard !visited.contains(redirectURL) else {
return promise.fail(HTTPClientError.redirectCycleDetected)
}

visited.insert(redirectURL)
state.visited = visited
}

nextState = state
}

let originalRequest = self.request

var convertToGet = false
Expand Down Expand Up @@ -841,7 +870,8 @@ internal struct RedirectHandler<ResponseType> {
}

do {
let newRequest = try HTTPClient.Request(url: redirectURL, method: method, headers: headers, body: body)
var newRequest = try HTTPClient.Request(url: redirectURL, method: method, headers: headers, body: body)
newRequest.redirectState = nextState
return self.execute(newRequest).futureResult.cascade(to: promise)
} catch {
return promise.fail(error)
Expand Down
10 changes: 10 additions & 0 deletions Tests/AsyncHTTPClientTests/HTTPClientTestUtils.swift
Original file line number Diff line number Diff line change
Expand Up @@ -284,6 +284,16 @@ internal final class HttpBinHandler: ChannelInboundHandler {
headers.add(name: "Location", value: "http://127.0.0.1:\(port)/echohostheader")
self.resps.append(HTTPResponseBuilder(status: .found, headers: headers))
return
case "/redirect/infinite1":
var headers = HTTPHeaders()
headers.add(name: "Location", value: "/redirect/infinite2")
self.resps.append(HTTPResponseBuilder(status: .found, headers: headers))
return
case "/redirect/infinite2":
var headers = HTTPHeaders()
headers.add(name: "Location", value: "/redirect/infinite1")
self.resps.append(HTTPResponseBuilder(status: .found, headers: headers))
return
// Since this String is taken from URL.path, the percent encoding has been removed
case "/percent encoded":
if req.method != .GET {
Expand Down
2 changes: 2 additions & 0 deletions Tests/AsyncHTTPClientTests/HTTPClientTests+XCTest.swift
Original file line number Diff line number Diff line change
Expand Up @@ -59,6 +59,8 @@ extension HTTPClientTests {
("testEventLoopArgument", testEventLoopArgument),
("testDecompression", testDecompression),
("testDecompressionLimit", testDecompressionLimit),
("testLoopDetectionRedirectLimit", testLoopDetectionRedirectLimit),
("testCountRedirectLimit", testCountRedirectLimit),
]
}
}
38 changes: 35 additions & 3 deletions Tests/AsyncHTTPClientTests/HTTPClientTests.swift
Original file line number Diff line number Diff line change
Expand Up @@ -131,7 +131,7 @@ class HTTPClientTests: XCTestCase {
let httpBin = HTTPBin(ssl: false)
let httpsBin = HTTPBin(ssl: true)
let httpClient = HTTPClient(eventLoopGroupProvider: .createNew,
configuration: HTTPClient.Configuration(certificateVerification: .none, followRedirects: true))
configuration: HTTPClient.Configuration(certificateVerification: .none, redirectConfiguration: .follow(max: 10, allowCycles: true)))

defer {
XCTAssertNoThrow(try httpClient.syncShutdown())
Expand All @@ -149,7 +149,7 @@ class HTTPClientTests: XCTestCase {
func testHttpHostRedirect() throws {
let httpBin = HTTPBin(ssl: false)
let httpClient = HTTPClient(eventLoopGroupProvider: .createNew,
configuration: HTTPClient.Configuration(certificateVerification: .none, followRedirects: true))
configuration: HTTPClient.Configuration(certificateVerification: .none, redirectConfiguration: .follow(max: 10, allowCycles: true)))

defer {
XCTAssertNoThrow(try httpClient.syncShutdown())
Expand Down Expand Up @@ -526,7 +526,7 @@ class HTTPClientTests: XCTestCase {
let httpBin = HTTPBin()
let eventLoopGroup = MultiThreadedEventLoopGroup(numberOfThreads: 5)
let httpClient = HTTPClient(eventLoopGroupProvider: .shared(eventLoopGroup),
configuration: HTTPClient.Configuration(followRedirects: true))
configuration: HTTPClient.Configuration(redirectConfiguration: .follow(max: 10, allowCycles: true)))
defer {
XCTAssertNoThrow(try httpClient.syncShutdown())
XCTAssertNoThrow(try eventLoopGroup.syncShutdownGracefully())
Expand Down Expand Up @@ -568,6 +568,7 @@ class HTTPClientTests: XCTestCase {
func testDecompression() throws {
let httpBin = HTTPBin(compress: true)
let httpClient = HTTPClient(eventLoopGroupProvider: .createNew, configuration: .init(decompression: .enabled(limit: .none)))

defer {
XCTAssertNoThrow(try httpClient.syncShutdown())
XCTAssertNoThrow(try httpBin.shutdown())
Expand Down Expand Up @@ -603,6 +604,7 @@ class HTTPClientTests: XCTestCase {
func testDecompressionLimit() throws {
let httpBin = HTTPBin(compress: true)
let httpClient = HTTPClient(eventLoopGroupProvider: .createNew, configuration: .init(decompression: .enabled(limit: .ratio(10))))

defer {
XCTAssertNoThrow(try httpClient.syncShutdown())
XCTAssertNoThrow(try httpBin.shutdown())
Expand All @@ -626,4 +628,34 @@ class HTTPClientTests: XCTestCase {
XCTFail("Unexptected error: \(error)")
}
}

func testLoopDetectionRedirectLimit() throws {
let httpBin = HTTPBin(ssl: true)
let httpClient = HTTPClient(eventLoopGroupProvider: .createNew,
configuration: HTTPClient.Configuration(certificateVerification: .none, redirectConfiguration: .follow(max: 5, allowCycles: false)))

defer {
XCTAssertNoThrow(try httpClient.syncShutdown())
XCTAssertNoThrow(try httpBin.shutdown())
}

XCTAssertThrowsError(try httpClient.get(url: "https://localhost:\(httpBin.port)/redirect/infinite1").wait(), "Should fail with redirect limit") { error in
XCTAssertEqual(error as! HTTPClientError, HTTPClientError.redirectCycleDetected)
}
}

func testCountRedirectLimit() throws {
let httpBin = HTTPBin(ssl: true)
let httpClient = HTTPClient(eventLoopGroupProvider: .createNew,
configuration: HTTPClient.Configuration(certificateVerification: .none, redirectConfiguration: .follow(max: 5, allowCycles: true)))

defer {
XCTAssertNoThrow(try httpClient.syncShutdown())
XCTAssertNoThrow(try httpBin.shutdown())
}

XCTAssertThrowsError(try httpClient.get(url: "https://localhost:\(httpBin.port)/redirect/infinite1").wait(), "Should fail with redirect limit") { error in
XCTAssertEqual(error as! HTTPClientError, HTTPClientError.redirectLimitReached)
}
}
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

do we want to test case insensitive loops?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

do you mean when schema and path are capitalised differently?

}