Skip to content

Commit 02821fe

Browse files
authored
[test] Update mock server for Swift 6 compliance (#463)
Rewrite to MockServer (used only for performance testing and part of a separate target) to comply with Swift 6 language mode and concurrency. The new MockServer now uses `NIOAsyncChannel` and structured concurrency. Instead of adding support to [MAX_REQUEST](https://github.com/swift-server/swift-aws-lambda-runtime/blob/11756b4e00ca75894826b41666bdae506b6eb496/Sources/AWSLambdaRuntimeCore/LambdaConfiguration.swift#L53) environment variable like v1 did, we implemented support for `MAX_REQUEST` environment variable in the MockServer itself. It closes the connection and shutdown the server after servicing MAX_INVOCATIONS Lambda requests). This allow to add the MAX_REQUEST penalty on the MockServer and not on the LambdaRuntimeClient. However, currently, the LambdaRuntimeClient does not shutdown when the MockServer ends. I created #465 to track this issue. See #377
1 parent ed84609 commit 02821fe

File tree

3 files changed

+301
-180
lines changed

3 files changed

+301
-180
lines changed

Diff for: Package.swift

+3-3
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,7 @@ let package = Package(
1717
.library(name: "AWSLambdaTesting", targets: ["AWSLambdaTesting"]),
1818
],
1919
dependencies: [
20-
.package(url: "https://github.com/apple/swift-nio.git", from: "2.76.0"),
20+
.package(url: "https://github.com/apple/swift-nio.git", from: "2.77.0"),
2121
.package(url: "https://github.com/apple/swift-log.git", from: "1.5.4"),
2222
],
2323
targets: [
@@ -89,11 +89,11 @@ let package = Package(
8989
.executableTarget(
9090
name: "MockServer",
9191
dependencies: [
92+
.product(name: "Logging", package: "swift-log"),
9293
.product(name: "NIOHTTP1", package: "swift-nio"),
9394
.product(name: "NIOCore", package: "swift-nio"),
9495
.product(name: "NIOPosix", package: "swift-nio"),
95-
],
96-
swiftSettings: [.swiftLanguageMode(.v5)]
96+
]
9797
),
9898
]
9999
)

Diff for: Sources/MockServer/MockHTTPServer.swift

+298
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,298 @@
1+
//===----------------------------------------------------------------------===//
2+
//
3+
// This source file is part of the SwiftAWSLambdaRuntime open source project
4+
//
5+
// Copyright (c) 2017-2025 Apple Inc. and the SwiftAWSLambdaRuntime project authors
6+
// Licensed under Apache License v2.0
7+
//
8+
// See LICENSE.txt for license information
9+
// See CONTRIBUTORS.txt for the list of SwiftAWSLambdaRuntime project authors
10+
//
11+
// SPDX-License-Identifier: Apache-2.0
12+
//
13+
//===----------------------------------------------------------------------===//
14+
15+
import Logging
16+
import NIOCore
17+
import NIOHTTP1
18+
import NIOPosix
19+
import Synchronization
20+
21+
// for UUID and Date
22+
#if canImport(FoundationEssentials)
23+
import FoundationEssentials
24+
#else
25+
import Foundation
26+
#endif
27+
28+
@main
29+
struct HttpServer {
30+
/// The server's host. (default: 127.0.0.1)
31+
private let host: String
32+
/// The server's port. (default: 7000)
33+
private let port: Int
34+
/// The server's event loop group. (default: MultiThreadedEventLoopGroup.singleton)
35+
private let eventLoopGroup: MultiThreadedEventLoopGroup
36+
/// the mode. Are we mocking a server for a Lambda function that expects a String or a JSON document? (default: string)
37+
private let mode: Mode
38+
/// the number of connections this server must accept before shutting down (default: 1)
39+
private let maxInvocations: Int
40+
/// the logger (control verbosity with LOG_LEVEL environment variable)
41+
private let logger: Logger
42+
43+
static func main() async throws {
44+
var log = Logger(label: "MockServer")
45+
log.logLevel = env("LOG_LEVEL").flatMap(Logger.Level.init) ?? .info
46+
47+
let server = HttpServer(
48+
host: env("HOST") ?? "127.0.0.1",
49+
port: env("PORT").flatMap(Int.init) ?? 7000,
50+
eventLoopGroup: .singleton,
51+
mode: env("MODE").flatMap(Mode.init) ?? .string,
52+
maxInvocations: env("MAX_INVOCATIONS").flatMap(Int.init) ?? 1,
53+
logger: log
54+
)
55+
try await server.run()
56+
}
57+
58+
/// This method starts the server and handles one unique incoming connections
59+
/// The Lambda function will send two HTTP requests over this connection: one for the next invocation and one for the response.
60+
private func run() async throws {
61+
let channel = try await ServerBootstrap(group: self.eventLoopGroup)
62+
.serverChannelOption(.backlog, value: 256)
63+
.serverChannelOption(.socketOption(.so_reuseaddr), value: 1)
64+
.childChannelOption(.maxMessagesPerRead, value: 1)
65+
.bind(
66+
host: self.host,
67+
port: self.port
68+
) { channel in
69+
channel.eventLoop.makeCompletedFuture {
70+
71+
try channel.pipeline.syncOperations.configureHTTPServerPipeline(
72+
withErrorHandling: true
73+
)
74+
75+
return try NIOAsyncChannel(
76+
wrappingChannelSynchronously: channel,
77+
configuration: NIOAsyncChannel.Configuration(
78+
inboundType: HTTPServerRequestPart.self,
79+
outboundType: HTTPServerResponsePart.self
80+
)
81+
)
82+
}
83+
}
84+
85+
logger.info(
86+
"Server started and listening",
87+
metadata: [
88+
"host": "\(channel.channel.localAddress?.ipAddress?.debugDescription ?? "")",
89+
"port": "\(channel.channel.localAddress?.port ?? 0)",
90+
"maxInvocations": "\(self.maxInvocations)",
91+
]
92+
)
93+
94+
// This counter is used to track the number of incoming connections.
95+
// This mock servers accepts n TCP connection then shutdowns
96+
let connectionCounter = SharedCounter(maxValue: self.maxInvocations)
97+
98+
// We are handling each incoming connection in a separate child task. It is important
99+
// to use a discarding task group here which automatically discards finished child tasks.
100+
// A normal task group retains all child tasks and their outputs in memory until they are
101+
// consumed by iterating the group or by exiting the group. Since, we are never consuming
102+
// the results of the group we need the group to automatically discard them; otherwise, this
103+
// would result in a memory leak over time.
104+
try await withThrowingDiscardingTaskGroup { group in
105+
try await channel.executeThenClose { inbound in
106+
for try await connectionChannel in inbound {
107+
108+
let counter = connectionCounter.current()
109+
logger.trace("Handling new connection", metadata: ["connectionNumber": "\(counter)"])
110+
111+
group.addTask {
112+
await self.handleConnection(channel: connectionChannel)
113+
logger.trace("Done handling connection", metadata: ["connectionNumber": "\(counter)"])
114+
}
115+
116+
if connectionCounter.increment() {
117+
logger.info(
118+
"Maximum number of connections reached, shutting down after current connection",
119+
metadata: ["maxConnections": "\(self.maxInvocations)"]
120+
)
121+
break // this causes the server to shutdown after handling the connection
122+
}
123+
}
124+
}
125+
}
126+
logger.info("Server shutting down")
127+
}
128+
129+
/// This method handles a single connection by responsing hard coded value to a Lambda function request.
130+
/// It handles two requests: one for the next invocation and one for the response.
131+
/// when the maximum number of requests is reached, it closes the connection.
132+
private func handleConnection(
133+
channel: NIOAsyncChannel<HTTPServerRequestPart, HTTPServerResponsePart>
134+
) async {
135+
136+
var requestHead: HTTPRequestHead!
137+
var requestBody: ByteBuffer?
138+
139+
// each Lambda invocation results in TWO HTTP requests (next and response)
140+
let requestCount = SharedCounter(maxValue: 2)
141+
142+
// Note that this method is non-throwing and we are catching any error.
143+
// We do this since we don't want to tear down the whole server when a single connection
144+
// encounters an error.
145+
do {
146+
try await channel.executeThenClose { inbound, outbound in
147+
for try await inboundData in inbound {
148+
let requestNumber = requestCount.current()
149+
logger.trace("Handling request", metadata: ["requestNumber": "\(requestNumber)"])
150+
151+
if case .head(let head) = inboundData {
152+
logger.trace("Received request head", metadata: ["head": "\(head)"])
153+
requestHead = head
154+
}
155+
if case .body(let body) = inboundData {
156+
logger.trace("Received request body", metadata: ["body": "\(body)"])
157+
requestBody = body
158+
}
159+
if case .end(let end) = inboundData {
160+
logger.trace("Received request end", metadata: ["end": "\(String(describing: end))"])
161+
162+
precondition(requestHead != nil, "Received .end without .head")
163+
let (responseStatus, responseHeaders, responseBody) = self.processRequest(
164+
requestHead: requestHead,
165+
requestBody: requestBody
166+
)
167+
168+
try await self.sendResponse(
169+
responseStatus: responseStatus,
170+
responseHeaders: responseHeaders,
171+
responseBody: responseBody,
172+
outbound: outbound
173+
)
174+
175+
requestHead = nil
176+
177+
if requestCount.increment() {
178+
logger.info(
179+
"Maximum number of requests reached, closing this connection",
180+
metadata: ["maxRequest": "2"]
181+
)
182+
break // this finishes handiling request on this connection
183+
}
184+
}
185+
}
186+
}
187+
} catch {
188+
logger.error("Hit error: \(error)")
189+
}
190+
}
191+
/// This function process the requests and return an hard-coded response (string or JSON depending on the mode).
192+
/// We ignore the requestBody.
193+
private func processRequest(
194+
requestHead: HTTPRequestHead,
195+
requestBody: ByteBuffer?
196+
) -> (HTTPResponseStatus, [(String, String)], String) {
197+
var responseStatus: HTTPResponseStatus = .ok
198+
var responseBody: String = ""
199+
var responseHeaders: [(String, String)] = []
200+
201+
logger.trace(
202+
"Processing request",
203+
metadata: ["VERB": "\(requestHead.method)", "URI": "\(requestHead.uri)"]
204+
)
205+
206+
if requestHead.uri.hasSuffix("/next") {
207+
responseStatus = .accepted
208+
209+
let requestId = UUID().uuidString
210+
switch self.mode {
211+
case .string:
212+
responseBody = "\"Seb\"" // must be a valid JSON document
213+
case .json:
214+
responseBody = "{ \"name\": \"Seb\", \"age\" : 52 }"
215+
}
216+
let deadline = Int64(Date(timeIntervalSinceNow: 60).timeIntervalSince1970 * 1000)
217+
responseHeaders = [
218+
(AmazonHeaders.requestID, requestId),
219+
(AmazonHeaders.invokedFunctionARN, "arn:aws:lambda:us-east-1:123456789012:function:custom-runtime"),
220+
(AmazonHeaders.traceID, "Root=1-5bef4de7-ad49b0e87f6ef6c87fc2e700;Parent=9a9197af755a6419;Sampled=1"),
221+
(AmazonHeaders.deadline, String(deadline)),
222+
]
223+
} else if requestHead.uri.hasSuffix("/response") {
224+
responseStatus = .accepted
225+
} else if requestHead.uri.hasSuffix("/error") {
226+
responseStatus = .ok
227+
} else {
228+
responseStatus = .notFound
229+
}
230+
logger.trace("Returning response: \(responseStatus), \(responseHeaders), \(responseBody)")
231+
return (responseStatus, responseHeaders, responseBody)
232+
}
233+
234+
private func sendResponse(
235+
responseStatus: HTTPResponseStatus,
236+
responseHeaders: [(String, String)],
237+
responseBody: String,
238+
outbound: NIOAsyncChannelOutboundWriter<HTTPServerResponsePart>
239+
) async throws {
240+
var headers = HTTPHeaders(responseHeaders)
241+
headers.add(name: "Content-Length", value: "\(responseBody.utf8.count)")
242+
headers.add(name: "KeepAlive", value: "timeout=1, max=2")
243+
244+
logger.trace("Writing response head")
245+
try await outbound.write(
246+
HTTPServerResponsePart.head(
247+
HTTPResponseHead(
248+
version: .init(major: 1, minor: 1), // use HTTP 1.1 it keeps connection alive between requests
249+
status: responseStatus,
250+
headers: headers
251+
)
252+
)
253+
)
254+
logger.trace("Writing response body")
255+
try await outbound.write(HTTPServerResponsePart.body(.byteBuffer(ByteBuffer(string: responseBody))))
256+
logger.trace("Writing response end")
257+
try await outbound.write(HTTPServerResponsePart.end(nil))
258+
}
259+
260+
private enum Mode: String {
261+
case string
262+
case json
263+
}
264+
265+
private static func env(_ name: String) -> String? {
266+
guard let value = getenv(name) else {
267+
return nil
268+
}
269+
return String(cString: value)
270+
}
271+
272+
private enum AmazonHeaders {
273+
static let requestID = "Lambda-Runtime-Aws-Request-Id"
274+
static let traceID = "Lambda-Runtime-Trace-Id"
275+
static let clientContext = "X-Amz-Client-Context"
276+
static let cognitoIdentity = "X-Amz-Cognito-Identity"
277+
static let deadline = "Lambda-Runtime-Deadline-Ms"
278+
static let invokedFunctionARN = "Lambda-Runtime-Invoked-Function-Arn"
279+
}
280+
281+
private final class SharedCounter: Sendable {
282+
private let counterMutex = Mutex<Int>(0)
283+
private let maxValue: Int
284+
285+
init(maxValue: Int) {
286+
self.maxValue = maxValue
287+
}
288+
func current() -> Int {
289+
counterMutex.withLock { $0 }
290+
}
291+
func increment() -> Bool {
292+
counterMutex.withLock {
293+
$0 += 1
294+
return $0 >= maxValue
295+
}
296+
}
297+
}
298+
}

0 commit comments

Comments
 (0)