Skip to content

add support for redirect limits #113

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 12 commits into from
Oct 23, 2019
56 changes: 48 additions & 8 deletions Sources/AsyncHTTPClient/HTTPClient.swift
Original file line number Diff line number Diff line change
Expand Up @@ -221,14 +221,16 @@ public class HTTPClient {
}

private func execute<Delegate: HTTPClientResponseDelegate>(request: Request,
redirectLimit: Configuration.RedirectLimit.Next? = nil,
delegate: Delegate,
eventLoop delegateEL: EventLoop,
channelEL: EventLoop? = nil,
deadline: NIODeadline? = nil) -> Task<Delegate.Response> {
let redirectHandler: RedirectHandler<Delegate.Response>?
if self.configuration.followRedirects {
redirectHandler = RedirectHandler<Delegate.Response>(request: request) { newRequest in
if let next = redirectLimit ?? self.configuration.followRedirects.next {
redirectHandler = RedirectHandler<Delegate.Response>(limit: next, request: request) { newRequest, limit in
self.execute(request: newRequest,
redirectLimit: limit,
delegate: delegate,
eventLoop: delegateEL,
channelEL: channelEL,
Expand Down Expand Up @@ -317,31 +319,31 @@ public class HTTPClient {
/// - `305: Use Proxy`
/// - `307: Temporary Redirect`
/// - `308: Permanent Redirect`
public var followRedirects: Bool
public var followRedirects: FollowRedirects
/// Default client timeout, defaults to no timeouts.
public var timeout: Timeout
/// Upstream proxy, defaults to no proxy.
public var proxy: Proxy?
/// Ignore TLS unclean shutdown error, defaults to `false`.
public var ignoreUncleanSSLShutdown: Bool

public init(tlsConfiguration: TLSConfiguration? = nil, followRedirects: Bool = false, timeout: Timeout = Timeout(), proxy: Proxy? = nil) {
public init(tlsConfiguration: TLSConfiguration? = nil, followRedirects: FollowRedirects = .disabled, timeout: Timeout = Timeout(), proxy: Proxy? = nil) {
self.init(tlsConfiguration: tlsConfiguration, followRedirects: followRedirects, timeout: timeout, proxy: proxy, ignoreUncleanSSLShutdown: false)
}

public init(tlsConfiguration: TLSConfiguration? = nil, followRedirects: Bool = false, timeout: Timeout = Timeout(), proxy: Proxy? = nil, ignoreUncleanSSLShutdown: Bool = false) {
public init(tlsConfiguration: TLSConfiguration? = nil, followRedirects: FollowRedirects = .disabled, timeout: Timeout = Timeout(), proxy: Proxy? = nil, ignoreUncleanSSLShutdown: Bool = false) {
self.tlsConfiguration = tlsConfiguration
self.followRedirects = followRedirects
self.timeout = timeout
self.proxy = proxy
self.ignoreUncleanSSLShutdown = ignoreUncleanSSLShutdown
}

public init(certificateVerification: CertificateVerification, followRedirects: Bool = false, timeout: Timeout = Timeout(), proxy: Proxy? = nil) {
public init(certificateVerification: CertificateVerification, followRedirects: FollowRedirects = .disabled, timeout: Timeout = Timeout(), proxy: Proxy? = nil) {
self.init(certificateVerification: certificateVerification, followRedirects: followRedirects, timeout: timeout, proxy: proxy, ignoreUncleanSSLShutdown: false)
}

public init(certificateVerification: CertificateVerification, followRedirects: Bool = false, timeout: Timeout = Timeout(), proxy: Proxy? = nil, ignoreUncleanSSLShutdown: Bool = false) {
public init(certificateVerification: CertificateVerification, followRedirects: FollowRedirects = .disabled, timeout: Timeout = Timeout(), proxy: Proxy? = nil, ignoreUncleanSSLShutdown: Bool = false) {
self.tlsConfiguration = TLSConfiguration.forClient(certificateVerification: certificateVerification)
self.followRedirects = followRedirects
self.timeout = timeout
Expand Down Expand Up @@ -423,6 +425,41 @@ extension HTTPClient.Configuration {
self.read = read
}
}

/// Specifies redirect limit.
public struct RedirectLimit {
enum Next {
case none
case loop(visited: Set<URL>)
case count(left: Int)
}

var next: Next

/// No redirect limit.
public static let none = RedirectLimit(next: .none)
/// Request execution will be stopped if redirect loop is detected.
public static let detectLoop = RedirectLimit(next: .loop(visited: Set()))
/// Specifies maximum number of redirects for a single request.
public static func count(_ value: Int) -> RedirectLimit { return RedirectLimit(next: .count(left: value)) }
}

/// Specifies redirect processing settings.
public enum FollowRedirects {
/// Redirects are not followed.
case disabled
/// Redirecets are followed with a specified limit.
case enabled(limit: RedirectLimit)

var next: RedirectLimit.Next? {
switch self {
case .disabled:
return nil
case .enabled(let limit):
return limit.next
}
}
}
}

private extension ChannelPipeline {
Expand Down Expand Up @@ -472,6 +509,7 @@ public struct HTTPClientError: Error, Equatable, CustomStringConvertible {
case invalidProxyResponse
case contentLengthMissing
case proxyAuthenticationRequired
case redirectLimitReached
}

private var code: Code
Expand Down Expand Up @@ -508,6 +546,8 @@ public struct HTTPClientError: Error, Equatable, CustomStringConvertible {
public static let invalidProxyResponse = HTTPClientError(code: .invalidProxyResponse)
/// Request does not contain `Content-Length` header.
public static let contentLengthMissing = HTTPClientError(code: .contentLengthMissing)
/// Proxy Authentication Required
/// Proxy Authentication Required.
public static let proxyAuthenticationRequired = HTTPClientError(code: .proxyAuthenticationRequired)
/// Redirect Limit reached.
public static let redirectLimitReached = HTTPClientError(code: .redirectLimitReached)
}
23 changes: 21 additions & 2 deletions Sources/AsyncHTTPClient/HTTPHandler.swift
Original file line number Diff line number Diff line change
Expand Up @@ -787,8 +787,9 @@ extension TaskHandler: ChannelDuplexHandler {
// MARK: - RedirectHandler

internal struct RedirectHandler<ResponseType> {
let limit: HTTPClient.Configuration.RedirectLimit.Next
let request: HTTPClient.Request
let execute: (HTTPClient.Request) -> HTTPClient.Task<ResponseType>
let execute: (HTTPClient.Request, HTTPClient.Configuration.RedirectLimit.Next) -> HTTPClient.Task<ResponseType>

func redirectTarget(status: HTTPResponseStatus, headers: HTTPHeaders) -> URL? {
switch status {
Expand Down Expand Up @@ -818,6 +819,24 @@ internal struct RedirectHandler<ResponseType> {
}

func redirect(status: HTTPResponseStatus, to redirectURL: URL, promise: EventLoopPromise<ResponseType>) {
let nextLimit: HTTPClient.Configuration.RedirectLimit.Next
switch self.limit {
case .none:
nextLimit = .none
case .count(let left):
if left == 0 {
return promise.fail(HTTPClientError.redirectLimitReached)
}
nextLimit = .count(left: left - 1)
case .loop(let visited):
if visited.contains(redirectURL) {
return promise.fail(HTTPClientError.redirectLimitReached)
}
var visited = visited
visited.insert(redirectURL)
nextLimit = .loop(visited: visited)
}

let originalRequest = self.request

var convertToGet = false
Expand Down Expand Up @@ -847,7 +866,7 @@ internal struct RedirectHandler<ResponseType> {

do {
let newRequest = try HTTPClient.Request(url: redirectURL, method: method, headers: headers, body: body)
return self.execute(newRequest).futureResult.cascade(to: promise)
return self.execute(newRequest, nextLimit).futureResult.cascade(to: promise)
} catch {
return promise.fail(error)
}
Expand Down
10 changes: 10 additions & 0 deletions Tests/AsyncHTTPClientTests/HTTPClientTestUtils.swift
Original file line number Diff line number Diff line change
Expand Up @@ -273,6 +273,16 @@ internal final class HttpBinHandler: ChannelInboundHandler {
headers.add(name: "Location", value: "http://127.0.0.1:\(port)/echohostheader")
self.resps.append(HTTPResponseBuilder(status: .found, headers: headers))
return
case "/redirect/infinite1":
var headers = HTTPHeaders()
headers.add(name: "Location", value: "/redirect/infinite2")
self.resps.append(HTTPResponseBuilder(status: .found, headers: headers))
return
case "/redirect/infinite2":
var headers = HTTPHeaders()
headers.add(name: "Location", value: "/redirect/infinite1")
self.resps.append(HTTPResponseBuilder(status: .found, headers: headers))
return
// Since this String is taken from URL.path, the percent encoding has been removed
case "/percent encoded":
if req.method != .GET {
Expand Down
2 changes: 2 additions & 0 deletions Tests/AsyncHTTPClientTests/HTTPClientTests+XCTest.swift
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,8 @@ extension HTTPClientTests {
("testWrongContentLengthForSSLUncleanShutdown", testWrongContentLengthForSSLUncleanShutdown),
("testWrongContentLengthWithIgnoreErrorForSSLUncleanShutdown", testWrongContentLengthWithIgnoreErrorForSSLUncleanShutdown),
("testEventLoopArgument", testEventLoopArgument),
("testLoopDetectionRedirectLimit", testLoopDetectionRedirectLimit),
("testCountRedirectLimit", testCountRedirectLimit),
]
}
}
34 changes: 31 additions & 3 deletions Tests/AsyncHTTPClientTests/HTTPClientTests.swift
Original file line number Diff line number Diff line change
Expand Up @@ -130,7 +130,7 @@ class HTTPClientTests: XCTestCase {
let httpBin = HTTPBin(ssl: false)
let httpsBin = HTTPBin(ssl: true)
let httpClient = HTTPClient(eventLoopGroupProvider: .createNew,
configuration: HTTPClient.Configuration(certificateVerification: .none, followRedirects: true))
configuration: HTTPClient.Configuration(certificateVerification: .none, followRedirects: .enabled(limit: .none)))

defer {
XCTAssertNoThrow(try httpClient.syncShutdown())
Expand All @@ -148,7 +148,7 @@ class HTTPClientTests: XCTestCase {
func testHttpHostRedirect() throws {
let httpBin = HTTPBin(ssl: false)
let httpClient = HTTPClient(eventLoopGroupProvider: .createNew,
configuration: HTTPClient.Configuration(certificateVerification: .none, followRedirects: true))
configuration: HTTPClient.Configuration(certificateVerification: .none, followRedirects: .enabled(limit: .none)))

defer {
XCTAssertNoThrow(try httpClient.syncShutdown())
Expand Down Expand Up @@ -525,7 +525,7 @@ class HTTPClientTests: XCTestCase {
let httpBin = HTTPBin()
let eventLoopGroup = MultiThreadedEventLoopGroup(numberOfThreads: 5)
let httpClient = HTTPClient(eventLoopGroupProvider: .shared(eventLoopGroup),
configuration: HTTPClient.Configuration(followRedirects: true))
configuration: HTTPClient.Configuration(followRedirects: .enabled(limit: .none)))
defer {
XCTAssertNoThrow(try httpClient.syncShutdown())
XCTAssertNoThrow(try eventLoopGroup.syncShutdownGracefully())
Expand Down Expand Up @@ -563,4 +563,32 @@ class HTTPClientTests: XCTestCase {
response = try httpClient.execute(request: request, delegate: delegate, eventLoop: .delegate(on: eventLoop)).wait()
XCTAssertEqual(true, response)
}

func testLoopDetectionRedirectLimit() throws {
let httpBin = HTTPBin(ssl: true)
let httpClient = HTTPClient(eventLoopGroupProvider: .createNew,
configuration: HTTPClient.Configuration(certificateVerification: .none, followRedirects: .enabled(limit: .detectLoop)))
defer {
XCTAssertNoThrow(try httpClient.syncShutdown())
XCTAssertNoThrow(try httpBin.shutdown())
}

XCTAssertThrowsError(try httpClient.get(url: "https://localhost:\(httpBin.port)/redirect/infinite1").wait(), "Should fail with redirect limit") { error in
XCTAssertEqual(error as! HTTPClientError, HTTPClientError.redirectLimitReached)
}
}

func testCountRedirectLimit() throws {
let httpBin = HTTPBin(ssl: true)
let httpClient = HTTPClient(eventLoopGroupProvider: .createNew,
configuration: HTTPClient.Configuration(certificateVerification: .none, followRedirects: .enabled(limit: .count(5))))
defer {
XCTAssertNoThrow(try httpClient.syncShutdown())
XCTAssertNoThrow(try httpBin.shutdown())
}

XCTAssertThrowsError(try httpClient.get(url: "https://localhost:\(httpBin.port)/redirect/infinite1").wait(), "Should fail with redirect limit") { error in
XCTAssertEqual(error as! HTTPClientError, HTTPClientError.redirectLimitReached)
}
}
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

do we want to test case insensitive loops?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

do you mean when schema and path are capitalised differently?

}