|
14 | 14 |
|
15 | 15 | @testable import AsyncHTTPClient
|
16 | 16 | import NIO
|
| 17 | +import NIOConcurrencyHelpers |
17 | 18 | import NIOFoundationCompat
|
18 | 19 | import NIOHTTP1
|
19 | 20 | import NIOHTTPCompression
|
@@ -1546,4 +1547,88 @@ class HTTPClientTests: XCTestCase {
|
1546 | 1547 | }
|
1547 | 1548 | }
|
1548 | 1549 | }
|
| 1550 | + |
| 1551 | + func testWeRecoverFromServerThatClosesTheConnectionOnUs() { |
| 1552 | + final class ServerThatAcceptsThenRejects: ChannelInboundHandler { |
| 1553 | + typealias InboundIn = HTTPServerRequestPart |
| 1554 | + typealias OutboundOut = HTTPServerResponsePart |
| 1555 | + |
| 1556 | + let requestNumber: NIOAtomic<Int> |
| 1557 | + let connectionNumber: NIOAtomic<Int> |
| 1558 | + |
| 1559 | + init(requestNumber: NIOAtomic<Int>, connectionNumber: NIOAtomic<Int>) { |
| 1560 | + self.requestNumber = requestNumber |
| 1561 | + self.connectionNumber = connectionNumber |
| 1562 | + } |
| 1563 | + |
| 1564 | + func channelActive(context: ChannelHandlerContext) { |
| 1565 | + _ = self.connectionNumber.add(1) |
| 1566 | + } |
| 1567 | + |
| 1568 | + func channelRead(context: ChannelHandlerContext, data: NIOAny) { |
| 1569 | + let req = self.unwrapInboundIn(data) |
| 1570 | + |
| 1571 | + switch req { |
| 1572 | + case .head, .body: |
| 1573 | + () |
| 1574 | + case .end: |
| 1575 | + let last = self.requestNumber.add(1) |
| 1576 | + switch last { |
| 1577 | + case 0, 2: |
| 1578 | + context.write(self.wrapOutboundOut(.head(.init(version: .init(major: 1, minor: 1), status: .ok))), |
| 1579 | + promise: nil) |
| 1580 | + context.writeAndFlush(self.wrapOutboundOut(.end(nil)), promise: nil) |
| 1581 | + case 1: |
| 1582 | + context.close(promise: nil) |
| 1583 | + default: |
| 1584 | + XCTFail("did not expect request \(last + 1)") |
| 1585 | + } |
| 1586 | + } |
| 1587 | + } |
| 1588 | + } |
| 1589 | + |
| 1590 | + let requestNumber = NIOAtomic<Int>.makeAtomic(value: 0) |
| 1591 | + let connectionNumber = NIOAtomic<Int>.makeAtomic(value: 0) |
| 1592 | + let sharedStateServerHandler = ServerThatAcceptsThenRejects(requestNumber: requestNumber, |
| 1593 | + connectionNumber: connectionNumber) |
| 1594 | + var maybeServer: Channel? |
| 1595 | + XCTAssertNoThrow(maybeServer = try ServerBootstrap(group: self.group) |
| 1596 | + .serverChannelOption(ChannelOptions.socket(.init(SOL_SOCKET), .init(SO_REUSEADDR)), value: 1) |
| 1597 | + .childChannelInitializer { channel in |
| 1598 | + channel.pipeline.configureHTTPServerPipeline().flatMap { |
| 1599 | + // We're deliberately adding a handler which is shared between multiple channels. This is normally |
| 1600 | + // very verboten but this handler is specially crafted to tolerate this. |
| 1601 | + channel.pipeline.addHandler(sharedStateServerHandler) |
| 1602 | + } |
| 1603 | + } |
| 1604 | + .bind(host: "127.0.0.1", port: 0) |
| 1605 | + .wait()) |
| 1606 | + guard let server = maybeServer else { |
| 1607 | + XCTFail("couldn't create server") |
| 1608 | + return |
| 1609 | + } |
| 1610 | + defer { |
| 1611 | + XCTAssertNoThrow(try server.close().wait()) |
| 1612 | + } |
| 1613 | + |
| 1614 | + let url = "http://127.0.0.1:\(server.localAddress!.port!)" |
| 1615 | + let client = HTTPClient(eventLoopGroupProvider: .shared(self.group)) |
| 1616 | + defer { |
| 1617 | + XCTAssertNoThrow(try client.syncShutdown()) |
| 1618 | + } |
| 1619 | + |
| 1620 | + XCTAssertEqual(0, sharedStateServerHandler.connectionNumber.load()) |
| 1621 | + XCTAssertEqual(0, sharedStateServerHandler.requestNumber.load()) |
| 1622 | + XCTAssertNoThrow(XCTAssertEqual(.ok, try client.get(url: url).wait().status)) |
| 1623 | + XCTAssertEqual(1, sharedStateServerHandler.connectionNumber.load()) |
| 1624 | + XCTAssertEqual(1, sharedStateServerHandler.requestNumber.load()) |
| 1625 | + XCTAssertThrowsError(try client.get(url: url).wait().status) { error in |
| 1626 | + XCTAssertEqual(.remoteConnectionClosed, error as? HTTPClientError) |
| 1627 | + } |
| 1628 | + XCTAssertEqual(1, sharedStateServerHandler.connectionNumber.load()) |
| 1629 | + XCTAssertEqual(2, sharedStateServerHandler.requestNumber.load()) |
| 1630 | + XCTAssertNoThrow(XCTAssertEqual(.ok, try client.get(url: url).wait().status)) |
| 1631 | + XCTAssertEqual(2, sharedStateServerHandler.connectionNumber.load()) |
| 1632 | + XCTAssertEqual(3, sharedStateServerHandler.requestNumber.load()) |
| 1633 | + } |
1549 | 1634 | }
|
0 commit comments