Skip to content

Commit

Permalink
Ensure pool runs until all connections are closed
Browse files Browse the repository at this point in the history
  • Loading branch information
fabianfett committed Oct 28, 2023
1 parent 468ae25 commit 0c11d37
Show file tree
Hide file tree
Showing 5 changed files with 82 additions and 20 deletions.
15 changes: 9 additions & 6 deletions Sources/ConnectionPoolModule/ConnectionPool.swift
Original file line number Diff line number Diff line change
Expand Up @@ -306,7 +306,6 @@ public final class ConnectionPool<
@usableFromInline
enum NewPoolActions: Sendable {
case makeConnection(StateMachine.ConnectionRequest)
case closeConnection(Connection)
case runKeepAlive(Connection)

case scheduleTimer(StateMachine.Timer)
Expand Down Expand Up @@ -342,9 +341,6 @@ public final class ConnectionPool<
case .runKeepAlive(let connection):
self.runKeepAlive(connection, in: &taskGroup)

case .closeConnection(let connection):
self.closeConnection(connection)

case .scheduleTimer(let timer):
self.runTimer(timer, in: &taskGroup)
}
Expand Down Expand Up @@ -427,8 +423,15 @@ public final class ConnectionPool<
do {
let bundle = try await self.factory(request.connectionID, self)
self.connectionEstablished(bundle)
bundle.connection.onClose {
self.connectionDidClose(bundle.connection, error: $0)

// after the connection has been established, we keep the task open. This ensures
// that the pools run method can not be exited before all connections have been
// closed.
await withCheckedContinuation { (continuation: CheckedContinuation<Void, Never>) in
bundle.connection.onClose {
self.connectionDidClose(bundle.connection, error: $0)
continuation.resume()
}
}
} catch {
self.connectionEstablishFailed(error, for: request)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,7 @@ extension PoolStateMachine {
var result = TinyFastSequence<Request>()
result.reserveCapacity(Int(max))
var popped = 0
while let requestID = self.queue.popFirst(), popped < max {
while popped < max, let requestID = self.queue.popFirst() {
if let requestIndex = self.requests.index(forKey: requestID) {
popped += 1
result.append(self.requests.remove(at: requestIndex).value)
Expand Down
24 changes: 15 additions & 9 deletions Sources/ConnectionPoolModule/PoolStateMachine.swift
Original file line number Diff line number Diff line change
Expand Up @@ -355,18 +355,24 @@ struct PoolStateMachine<

@inlinable
mutating func connectionClosed(_ connection: Connection) -> Action {
self.cacheNoMoreConnectionsAllowed = false
switch self.poolState {
case .running, .shuttingDown(graceful: true):
self.cacheNoMoreConnectionsAllowed = false

let closedConnectionAction = self.connections.connectionClosed(connection.id)
let closedConnectionAction = self.connections.connectionClosed(connection.id)

let connectionAction: ConnectionAction
if let newRequest = closedConnectionAction.newConnectionRequest {
connectionAction = .makeConnection(newRequest, closedConnectionAction.timersToCancel)
} else {
connectionAction = .cancelTimers(closedConnectionAction.timersToCancel)
}
let connectionAction: ConnectionAction
if let newRequest = closedConnectionAction.newConnectionRequest {
connectionAction = .makeConnection(newRequest, closedConnectionAction.timersToCancel)
} else {
connectionAction = .cancelTimers(closedConnectionAction.timersToCancel)
}

return .init(request: .none, connection: connectionAction)

return .init(request: .none, connection: connectionAction)
case .shuttingDown(graceful: false), .shutDown:
return .none()
}
}

struct CleanupAction {
Expand Down
44 changes: 40 additions & 4 deletions Tests/ConnectionPoolModuleTests/ConnectionPoolTests.swift
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
@testable import _ConnectionPoolModule
import Atomics
import XCTest
import NIOEmbedded

Expand Down Expand Up @@ -52,7 +53,14 @@ final class ConnectionPoolTests: XCTestCase {
}

taskGroup.cancelAll()

XCTAssertEqual(factory.pendingConnectionAttemptsCount, 0)
for connection in factory.runningConnections {
connection.closeIfClosing()
}
}

XCTAssertEqual(factory.runningConnections.count, 0)
}

func testShutdownPoolWhileConnectionIsBeingCreated() async {
Expand Down Expand Up @@ -155,34 +163,62 @@ final class ConnectionPoolTests: XCTestCase {
try await factory.makeConnection(id: $0, for: $1)
}

let hasFinished = ManagedAtomic(false)
let createdConnections = ManagedAtomic(0)
let iterations = 10_000

// the same connection is reused 1000 times

await withThrowingTaskGroup(of: Void.self) { taskGroup in
await withTaskGroup(of: Void.self) { taskGroup in
taskGroup.addTask {
await pool.run()
XCTAssertFalse(hasFinished.compareExchange(expected: false, desired: true, ordering: .relaxed).original)
}

taskGroup.addTask {
var usedConnectionIDs = Set<Int>()
for _ in 0..<config.maximumConnectionHardLimit {
await factory.nextConnectAttempt { connectionID in
XCTAssertTrue(usedConnectionIDs.insert(connectionID).inserted)
createdConnections.wrappingIncrement(ordering: .relaxed)
return 1
}
}


XCTAssertEqual(factory.pendingConnectionAttemptsCount, 0)
}

for _ in 0..<10_000 {
let (stream, continuation) = AsyncStream.makeStream(of: Void.self)

for _ in 0..<iterations {
taskGroup.addTask {
let leasedConnection = try await pool.leaseConnection()
pool.releaseConnection(leasedConnection)
do {
let leasedConnection = try await pool.leaseConnection()
pool.releaseConnection(leasedConnection)
} catch {
XCTFail("Unexpected error: \(error)")
}
continuation.yield()
}
}

var leaseReleaseIterator = stream.makeAsyncIterator()
for _ in 0..<iterations {
_ = await leaseReleaseIterator.next()
}

taskGroup.cancelAll()

XCTAssertFalse(hasFinished.load(ordering: .relaxed))
for connection in factory.runningConnections {
connection.closeIfClosing()
}
}

XCTAssertEqual(createdConnections.load(ordering: .relaxed), config.maximumConnectionHardLimit)
XCTAssert(hasFinished.load(ordering: .relaxed))
XCTAssertEqual(factory.runningConnections.count, 0)
}
}

Expand Down
17 changes: 17 additions & 0 deletions Tests/ConnectionPoolModuleTests/Mocks/MockConnection.swift
Original file line number Diff line number Diff line change
Expand Up @@ -88,12 +88,18 @@ final class MockConnectionFactory<Clock: _Concurrency.Clock> where Clock.Duratio
var attempts = Deque<(ConnectionID, CheckedContinuation<(MockConnection, UInt16), any Error>)>()

var waiter = Deque<CheckedContinuation<(ConnectionID, CheckedContinuation<(MockConnection, UInt16), any Error>), Never>>()

var runningConnections = [ConnectionID: Connection]()
}

var pendingConnectionAttemptsCount: Int {
self.stateBox.withLockedValue { $0.attempts.count }
}

var runningConnections: [Connection] {
self.stateBox.withLockedValue { Array($0.runningConnections.values) }
}

func makeConnection(
id: Int,
for pool: ConnectionPool<MockConnection, Int, ConnectionIDGenerator, ConnectionRequest<MockConnection>, Int, MockPingPongBehavior, NoOpConnectionPoolMetrics<Int>, Clock>
Expand Down Expand Up @@ -137,6 +143,17 @@ final class MockConnectionFactory<Clock: _Concurrency.Clock> where Clock.Duratio
do {
let streamCount = try await closure(connectionID)
let connection = MockConnection(id: connectionID)

connection.onClose { _ in
self.stateBox.withLockedValue { state in
_ = state.runningConnections.removeValue(forKey: connectionID)
}
}

self.stateBox.withLockedValue { state in
_ = state.runningConnections[connectionID] = connection
}

continuation.resume(returning: (connection, streamCount))
return connection
} catch {
Expand Down

0 comments on commit 0c11d37

Please sign in to comment.