diff --git a/Sources/OpenAPIRuntime/Interface/ErrorHandlingMiddleware.swift b/Sources/OpenAPIRuntime/Interface/ErrorHandlingMiddleware.swift new file mode 100644 index 00000000..55113ce5 --- /dev/null +++ b/Sources/OpenAPIRuntime/Interface/ErrorHandlingMiddleware.swift @@ -0,0 +1,98 @@ +//===----------------------------------------------------------------------===// +// +// This source file is part of the SwiftOpenAPIGenerator open source project +// +// Copyright (c) 2024 Apple Inc. and the SwiftOpenAPIGenerator project authors +// Licensed under Apache License v2.0 +// +// See LICENSE.txt for license information +// See CONTRIBUTORS.txt for the list of SwiftOpenAPIGenerator project authors +// +// SPDX-License-Identifier: Apache-2.0 +// +//===----------------------------------------------------------------------===// + +import HTTPTypes + +/// An opt-in error handling middleware that converts an error to an HTTP response. +/// +/// Inclusion of ``ErrorHandlingMiddleware`` should be accompanied by conforming errors to the ``HTTPResponseConvertible`` protocol. +/// Errors not conforming to ``HTTPResponseConvertible`` are converted to a response with the 500 status code. +/// +/// ## Example usage +/// +/// 1. Create an error type that conforms to the ``HTTPResponseConvertible`` protocol: +/// +/// ```swift +/// extension MyAppError: HTTPResponseConvertible { +/// var httpStatus: HTTPResponse.Status { +/// switch self { +/// case .invalidInputFormat: +/// .badRequest +/// case .authorizationError: +/// .forbidden +/// } +/// } +/// } +/// ``` +/// +/// 2. Opt into the ``ErrorHandlingMiddleware`` while registering the handler: +/// +/// ```swift +/// let handler = RequestHandler() +/// try handler.registerHandlers(on: transport, middlewares: [ErrorHandlingMiddleware()]) +/// ``` +/// - Note: The placement of ``ErrorHandlingMiddleware`` in the middleware chain is important. It should be determined based on the specific needs of each application. Consider the order of execution and dependencies between middlewares. +public struct ErrorHandlingMiddleware: ServerMiddleware { + /// Creates a new middleware. + public init() {} + // swift-format-ignore: AllPublicDeclarationsHaveDocumentation + public func intercept( + _ request: HTTPTypes.HTTPRequest, + body: OpenAPIRuntime.HTTPBody?, + metadata: OpenAPIRuntime.ServerRequestMetadata, + operationID: String, + next: @Sendable (HTTPTypes.HTTPRequest, OpenAPIRuntime.HTTPBody?, OpenAPIRuntime.ServerRequestMetadata) + async throws -> (HTTPTypes.HTTPResponse, OpenAPIRuntime.HTTPBody?) + ) async throws -> (HTTPTypes.HTTPResponse, OpenAPIRuntime.HTTPBody?) { + do { return try await next(request, body, metadata) } catch { + if let serverError = error as? ServerError, + let appError = serverError.underlyingError as? (any HTTPResponseConvertible) + { + return ( + HTTPResponse(status: appError.httpStatus, headerFields: appError.httpHeaderFields), + appError.httpBody + ) + } else { + return (HTTPResponse(status: .internalServerError), nil) + } + } + } +} + +/// A value that can be converted to an HTTP response and body. +/// +/// Conform your error type to this protocol to convert it to an `HTTPResponse` and ``HTTPBody``. +/// +/// Used by ``ErrorHandlingMiddleware``. +public protocol HTTPResponseConvertible { + + /// An HTTP status to return in the response. + var httpStatus: HTTPResponse.Status { get } + + /// The HTTP header fields of the response. + /// This is optional as default values are provided in the extension. + var httpHeaderFields: HTTPTypes.HTTPFields { get } + + /// The body of the HTTP response. + var httpBody: OpenAPIRuntime.HTTPBody? { get } +} + +extension HTTPResponseConvertible { + + // swift-format-ignore: AllPublicDeclarationsHaveDocumentation + public var httpHeaderFields: HTTPTypes.HTTPFields { [:] } + + // swift-format-ignore: AllPublicDeclarationsHaveDocumentation + public var httpBody: OpenAPIRuntime.HTTPBody? { nil } +} diff --git a/Tests/OpenAPIRuntimeTests/Interface/Test_ErrorHandlingMiddleware.swift b/Tests/OpenAPIRuntimeTests/Interface/Test_ErrorHandlingMiddleware.swift new file mode 100644 index 00000000..91977a31 --- /dev/null +++ b/Tests/OpenAPIRuntimeTests/Interface/Test_ErrorHandlingMiddleware.swift @@ -0,0 +1,144 @@ +//===----------------------------------------------------------------------===// +// +// This source file is part of the SwiftOpenAPIGenerator open source project +// +// Copyright (c) 2024 Apple Inc. and the SwiftOpenAPIGenerator project authors +// Licensed under Apache License v2.0 +// +// See LICENSE.txt for license information +// See CONTRIBUTORS.txt for the list of SwiftOpenAPIGenerator project authors +// +// SPDX-License-Identifier: Apache-2.0 +// +//===----------------------------------------------------------------------===// + +import HTTPTypes + +import XCTest +@_spi(Generated) @testable import OpenAPIRuntime + +final class Test_ErrorHandlingMiddlewareTests: XCTestCase { + static let mockRequest: HTTPRequest = .init(soar_path: "http://abc.com", method: .get) + static let mockBody: HTTPBody = HTTPBody("hello") + static let errorHandlingMiddleware = ErrorHandlingMiddleware() + + func testSuccessfulRequest() async throws { + let response = try await Test_ErrorHandlingMiddlewareTests.errorHandlingMiddleware.intercept( + Test_ErrorHandlingMiddlewareTests.mockRequest, + body: Test_ErrorHandlingMiddlewareTests.mockBody, + metadata: .init(), + operationID: "testop", + next: getNextMiddleware(failurePhase: .never) + ) + XCTAssertEqual(response.0.status, .ok) + } + + func testError_conformingToProtocol_convertedToResponse() async throws { + let (response, responseBody) = try await Test_ErrorHandlingMiddlewareTests.errorHandlingMiddleware.intercept( + Test_ErrorHandlingMiddlewareTests.mockRequest, + body: Test_ErrorHandlingMiddlewareTests.mockBody, + metadata: .init(), + operationID: "testop", + next: getNextMiddleware(failurePhase: .convertibleError) + ) + XCTAssertEqual(response.status, .badGateway) + XCTAssertEqual(response.headerFields, [.contentType: "application/json"]) + XCTAssertEqual(responseBody, testHTTPBody) + } + + func testError_conformingToProtocolWithoutAllValues_convertedToResponse() async throws { + let (response, responseBody) = try await Test_ErrorHandlingMiddlewareTests.errorHandlingMiddleware.intercept( + Test_ErrorHandlingMiddlewareTests.mockRequest, + body: Test_ErrorHandlingMiddlewareTests.mockBody, + metadata: .init(), + operationID: "testop", + next: getNextMiddleware(failurePhase: .partialConvertibleError) + ) + XCTAssertEqual(response.status, .badRequest) + XCTAssertEqual(response.headerFields, [:]) + XCTAssertEqual(responseBody, nil) + } + + func testError_notConformingToProtocol_returns500() async throws { + let (response, responseBody) = try await Test_ErrorHandlingMiddlewareTests.errorHandlingMiddleware.intercept( + Test_ErrorHandlingMiddlewareTests.mockRequest, + body: Test_ErrorHandlingMiddlewareTests.mockBody, + metadata: .init(), + operationID: "testop", + next: getNextMiddleware(failurePhase: .nonConvertibleError) + ) + XCTAssertEqual(response.status, .internalServerError) + XCTAssertEqual(response.headerFields, [:]) + XCTAssertEqual(responseBody, nil) + } + + private func getNextMiddleware(failurePhase: MockErrorMiddleware_Next.FailurePhase) -> @Sendable ( + HTTPTypes.HTTPRequest, OpenAPIRuntime.HTTPBody?, OpenAPIRuntime.ServerRequestMetadata + ) async throws -> (HTTPTypes.HTTPResponse, OpenAPIRuntime.HTTPBody?) { + let mockNext: + @Sendable (HTTPTypes.HTTPRequest, OpenAPIRuntime.HTTPBody?, OpenAPIRuntime.ServerRequestMetadata) + async throws -> (HTTPTypes.HTTPResponse, OpenAPIRuntime.HTTPBody?) = { request, body, metadata in + try await MockErrorMiddleware_Next(failurePhase: failurePhase) + .intercept( + request, + body: body, + metadata: metadata, + operationID: "testop", + next: { _, _, _ in (HTTPResponse.init(status: .ok), nil) } + ) + } + return mockNext + } +} + +struct MockErrorMiddleware_Next: ServerMiddleware { + enum FailurePhase { + case never + case convertibleError + case nonConvertibleError + case partialConvertibleError + } + var failurePhase: FailurePhase = .never + + @Sendable func intercept( + _ request: HTTPRequest, + body: HTTPBody?, + metadata: ServerRequestMetadata, + operationID: String, + next: (HTTPRequest, HTTPBody?, ServerRequestMetadata) async throws -> (HTTPResponse, HTTPBody?) + ) async throws -> (HTTPResponse, HTTPBody?) { + var error: (any Error)? + switch failurePhase { + case .never: break + case .convertibleError: error = ConvertibleError() + case .nonConvertibleError: error = NonConvertibleError() + case .partialConvertibleError: error = PartialConvertibleError() + } + if let underlyingError = error { + throw ServerError( + operationID: operationID, + request: request, + requestBody: body, + requestMetadata: metadata, + causeDescription: "", + underlyingError: underlyingError + ) + } + let (response, responseBody) = try await next(request, body, metadata) + return (response, responseBody) + } +} + +struct ConvertibleError: Error, HTTPResponseConvertible { + var httpStatus: HTTPTypes.HTTPResponse.Status = HTTPResponse.Status.badGateway + var httpHeaderFields: HTTPFields = [.contentType: "application/json"] + var httpBody: OpenAPIRuntime.HTTPBody? = testHTTPBody +} + +struct PartialConvertibleError: Error, HTTPResponseConvertible { + var httpStatus: HTTPTypes.HTTPResponse.Status = HTTPResponse.Status.badRequest +} + +struct NonConvertibleError: Error {} + +let testHTTPBody = HTTPBody(try! JSONEncoder().encode(["error", " test error"]))