Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[test] add a unit test for the LambdaHTTPServer Pool #500

Open
wants to merge 6 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions [email protected]
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,11 @@ let package = Package(
.byName(name: "AWSLambdaRuntime"),
.product(name: "NIOTestUtils", package: "swift-nio"),
.product(name: "NIOFoundationCompat", package: "swift-nio"),
],
swiftSettings: [
.define("FoundationJSONSupport"),
.define("ServiceLifecycleSupport"),
.define("LocalServerSupport"),
]
),
// for perf testing
Expand Down
6 changes: 3 additions & 3 deletions Sources/AWSLambdaRuntime/Lambda+LocalServer.swift
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
//
// This source file is part of the SwiftAWSLambdaRuntime open source project
//
// Copyright (c) 2020 Apple Inc. and the SwiftAWSLambdaRuntime project authors
// Copyright (c) 2025 Apple Inc. and the SwiftAWSLambdaRuntime project authors
// Licensed under Apache License v2.0
//
// See LICENSE.txt for license information
Expand Down Expand Up @@ -75,7 +75,7 @@ extension Lambda {
/// 1. POST /invoke - the client posts the event to the lambda function
///
/// This server passes the data received from /invoke POST request to the lambda function (GET /next) and then forwards the response back to the client.
private struct LambdaHTTPServer {
internal struct LambdaHTTPServer {
private let invocationEndpoint: String

private let invocationPool = Pool<LocalServerInvocation>()
Expand Down Expand Up @@ -425,7 +425,7 @@ private struct LambdaHTTPServer {
/// A shared data structure to store the current invocation or response requests and the continuation objects.
/// This data structure is shared between instances of the HTTPHandler
/// (one instance to serve requests from the Lambda function and one instance to serve requests from the client invoking the lambda function).
private final class Pool<T>: AsyncSequence, AsyncIteratorProtocol, Sendable where T: Sendable {
internal final class Pool<T>: AsyncSequence, AsyncIteratorProtocol, Sendable where T: Sendable {
typealias Element = T

enum State: ~Copyable {
Expand Down
150 changes: 150 additions & 0 deletions Tests/AWSLambdaRuntimeTests/PoolTests.swift
Original file line number Diff line number Diff line change
@@ -0,0 +1,150 @@
//===----------------------------------------------------------------------===//
//
// This source file is part of the SwiftAWSLambdaRuntime open source project
//
// Copyright (c) 2025 Apple Inc. and the SwiftAWSLambdaRuntime project authors
// Licensed under Apache License v2.0
//
// See LICENSE.txt for license information
// See CONTRIBUTORS.txt for the list of SwiftAWSLambdaRuntime project authors
//
// SPDX-License-Identifier: Apache-2.0
//
//===----------------------------------------------------------------------===//

import Testing

@testable import AWSLambdaRuntime

struct PoolTests {

@Test
func testBasicPushAndIteration() async throws {
let pool = LambdaHTTPServer.Pool<String>()

// Push values
await pool.push("first")
await pool.push("second")

// Iterate and verify order
var values = [String]()
for try await value in pool {
values.append(value)
if values.count == 2 { break }
}

#expect(values == ["first", "second"])
}

@Test
func testCancellation() async throws {
let pool = LambdaHTTPServer.Pool<String>()

// Create a task that will be cancelled
let task = Task {
for try await _ in pool {
Issue.record("Should not receive any values after cancellation")
}
}

// Cancel the task immediately
task.cancel()

// This should complete without receiving any values
try await task.value
}

@Test
func testConcurrentPushAndIteration() async throws {
let pool = LambdaHTTPServer.Pool<Int>()
let iterations = 1000
var receivedValues = Set<Int>()

// Start consumer task first
let consumer = Task {
var count = 0
for try await value in pool {
receivedValues.insert(value)
count += 1
if count >= iterations { break }
}
}

// Create multiple producer tasks
try await withThrowingTaskGroup(of: Void.self) { group in
for i in 0..<iterations {
group.addTask {
await pool.push(i)
}
}
try await group.waitForAll()
}

// Wait for consumer to complete
try await consumer.value

// Verify all values were received exactly once
#expect(receivedValues.count == iterations)
#expect(Set(0..<iterations) == receivedValues)
}

@Test
func testPushToWaitingConsumer() async throws {
let pool = LambdaHTTPServer.Pool<String>()
let expectedValue = "test value"

// Start a consumer that will wait for a value
let consumer = Task {
for try await value in pool {
#expect(value == expectedValue)
break
}
}

// Give consumer time to start waiting
try await Task.sleep(nanoseconds: 100_000_000) // 0.1 seconds

// Push a value
await pool.push(expectedValue)

// Wait for consumer to complete
try await consumer.value
}

@Test
func testStressTest() async throws {
let pool = LambdaHTTPServer.Pool<Int>()
let producerCount = 10
let messagesPerProducer = 1000
var receivedValues = [Int]()

// Start consumer
let consumer = Task {
var count = 0
for try await value in pool {
receivedValues.append(value)
count += 1
if count >= producerCount * messagesPerProducer { break }
}
}

// Create multiple producers
try await withThrowingTaskGroup(of: Void.self) { group in
for p in 0..<producerCount {
group.addTask {
for i in 0..<messagesPerProducer {
await pool.push(p * messagesPerProducer + i)
}
}
}
try await group.waitForAll()
}

// Wait for consumer to complete
try await consumer.value

// Verify we received all values
#expect(receivedValues.count == producerCount * messagesPerProducer)
#expect(Set(receivedValues).count == producerCount * messagesPerProducer)
}
}
Loading