Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Enable StrictConcurrency checking #483

Merged
merged 3 commits into from
Jun 14, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
19 changes: 14 additions & 5 deletions Package.swift
Original file line number Diff line number Diff line change
@@ -1,6 +1,10 @@
// swift-tools-version:5.8
import PackageDescription

let swiftSettings: [SwiftSetting] = [
.enableUpcomingFeature("StrictConcurrency")
]

let package = Package(
name: "postgres-nio",
platforms: [
Expand Down Expand Up @@ -41,23 +45,26 @@ let package = Package(
.product(name: "NIOSSL", package: "swift-nio-ssl"),
.product(name: "NIOFoundationCompat", package: "swift-nio"),
.product(name: "ServiceLifecycle", package: "swift-service-lifecycle"),
]
],
swiftSettings: swiftSettings
),
.target(
name: "_ConnectionPoolModule",
dependencies: [
.product(name: "Atomics", package: "swift-atomics"),
.product(name: "DequeModule", package: "swift-collections"),
],
path: "Sources/ConnectionPoolModule"
path: "Sources/ConnectionPoolModule",
swiftSettings: swiftSettings
),
.testTarget(
name: "PostgresNIOTests",
dependencies: [
.target(name: "PostgresNIO"),
.product(name: "NIOEmbedded", package: "swift-nio"),
.product(name: "NIOTestUtils", package: "swift-nio"),
]
],
swiftSettings: swiftSettings
),
.testTarget(
name: "ConnectionPoolModuleTests",
Expand All @@ -67,14 +74,16 @@ let package = Package(
.product(name: "NIOCore", package: "swift-nio"),
.product(name: "NIOConcurrencyHelpers", package: "swift-nio"),
.product(name: "NIOEmbedded", package: "swift-nio"),
]
],
swiftSettings: swiftSettings
),
.testTarget(
name: "IntegrationTests",
dependencies: [
.target(name: "PostgresNIO"),
.product(name: "NIOTestUtils", package: "swift-nio"),
]
],
swiftSettings: swiftSettings
),
]
)
4 changes: 2 additions & 2 deletions Sources/ConnectionPoolModule/ConnectionPool.swift
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@

@available(macOS 13.0, iOS 16.0, tvOS 16.0, watchOS 9.0, *)
public struct ConnectionAndMetadata<Connection: PooledConnection> {
public struct ConnectionAndMetadata<Connection: PooledConnection>: Sendable {

public var connection: Connection

Expand Down Expand Up @@ -495,7 +495,7 @@ public final class ConnectionPool<
}

@usableFromInline
enum TimerRunResult {
enum TimerRunResult: Sendable {
case timerTriggered
case timerCancelled
case cancellationContinuationFinished
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@ public protocol ConnectionPoolObservabilityDelegate: Sendable {
func requestQueueDepthChanged(_ newDepth: Int)
}

public struct NoOpConnectionPoolMetrics<ConnectionID: Hashable>: ConnectionPoolObservabilityDelegate {
public struct NoOpConnectionPoolMetrics<ConnectionID: Hashable & Sendable>: ConnectionPoolObservabilityDelegate {
public init(connectionIDType: ConnectionID.Type) {}

public func startedConnecting(id: ConnectionID) {}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@ extension PostgresMessage {
/// Identifies an incoming or outgoing postgres message. Sent as the first byte, before the message size.
/// Values are not unique across all identifiers, meaning some messages will require keeping state to identify.
@available(*, deprecated, message: "Will be removed from public API.")
public struct Identifier: ExpressibleByIntegerLiteral, Equatable, CustomStringConvertible {
public struct Identifier: Sendable, ExpressibleByIntegerLiteral, Equatable, CustomStringConvertible {
// special
public static let none: Identifier = 0x00
// special
Expand Down
2 changes: 1 addition & 1 deletion Sources/PostgresNIO/Pool/PostgresClient.swift
Original file line number Diff line number Diff line change
Expand Up @@ -478,7 +478,7 @@ extension PostgresConnection: PooledConnection {
self.channel.close(mode: .all, promise: nil)
}

public func onClose(_ closure: @escaping ((any Error)?) -> ()) {
public func onClose(_ closure: @escaping @Sendable ((any Error)?) -> ()) {
self.closeFuture.whenComplete { _ in closure(nil) }
}
}
Expand Down
2 changes: 1 addition & 1 deletion Sources/PostgresNIO/Utilities/PostgresError+Code.swift
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
extension PostgresError {
public struct Code: ExpressibleByStringLiteral, Equatable {
public struct Code: Sendable, ExpressibleByStringLiteral, Equatable {
// Class 00 — Successful Completion
public static let successfulCompletion: Code = "00000"

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
import DequeModule

@available(macOS 13.0, iOS 16.0, tvOS 16.0, watchOS 9.0, *)
final class MockConnectionFactory<Clock: _Concurrency.Clock> where Clock.Duration == Duration {
final class MockConnectionFactory<Clock: _Concurrency.Clock>: Sendable where Clock.Duration == Duration {
typealias ConnectionIDGenerator = _ConnectionPoolModule.ConnectionIDGenerator
typealias Request = ConnectionRequest<MockConnection>
typealias KeepAliveBehavior = MockPingPongBehavior
Expand Down
61 changes: 31 additions & 30 deletions Tests/IntegrationTests/PostgresNIOTests.swift
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import Logging
@testable import PostgresNIO
import Atomics
import XCTest
import NIOCore
import NIOPosix
Expand Down Expand Up @@ -112,59 +113,59 @@ final class PostgresNIOTests: XCTestCase {
XCTAssertNoThrow(conn = try PostgresConnection.test(on: eventLoop).wait())
defer { XCTAssertNoThrow( try conn?.close().wait() ) }

var receivedNotifications: [PostgresMessage.NotificationResponse] = []
let receivedNotifications = ManagedAtomic<Int>(0)
conn?.addListener(channel: "example") { context, notification in
receivedNotifications.append(notification)
receivedNotifications.wrappingIncrement(ordering: .relaxed)
XCTAssertEqual(notification.channel, "example")
XCTAssertEqual(notification.payload, "")
}
XCTAssertNoThrow(_ = try conn?.simpleQuery("LISTEN example").wait())
XCTAssertNoThrow(_ = try conn?.simpleQuery("NOTIFY example").wait())
// Notifications are asynchronous, so we should run at least one more query to make sure we'll have received the notification response by then
XCTAssertNoThrow(_ = try conn?.simpleQuery("SELECT 1").wait())
XCTAssertEqual(receivedNotifications.count, 1)
XCTAssertEqual(receivedNotifications.first?.channel, "example")
XCTAssertEqual(receivedNotifications.first?.payload, "")
XCTAssertEqual(receivedNotifications.load(ordering: .relaxed), 1)
}

func testNotificationsNonEmptyPayload() {
var conn: PostgresConnection?
XCTAssertNoThrow(conn = try PostgresConnection.test(on: eventLoop).wait())
defer { XCTAssertNoThrow( try conn?.close().wait() ) }
var receivedNotifications: [PostgresMessage.NotificationResponse] = []
let receivedNotifications = ManagedAtomic<Int>(0)
conn?.addListener(channel: "example") { context, notification in
receivedNotifications.append(notification)
receivedNotifications.wrappingIncrement(ordering: .relaxed)
XCTAssertEqual(notification.channel, "example")
XCTAssertEqual(notification.payload, "Notification payload example")
}
XCTAssertNoThrow(_ = try conn?.simpleQuery("LISTEN example").wait())
XCTAssertNoThrow(_ = try conn?.simpleQuery("NOTIFY example, 'Notification payload example'").wait())
// Notifications are asynchronous, so we should run at least one more query to make sure we'll have received the notification response by then
XCTAssertNoThrow(_ = try conn?.simpleQuery("SELECT 1").wait())
XCTAssertEqual(receivedNotifications.count, 1)
XCTAssertEqual(receivedNotifications.first?.channel, "example")
XCTAssertEqual(receivedNotifications.first?.payload, "Notification payload example")
XCTAssertEqual(receivedNotifications.load(ordering: .relaxed), 1)
}

func testNotificationsRemoveHandlerWithinHandler() {
var conn: PostgresConnection?
XCTAssertNoThrow(conn = try PostgresConnection.test(on: eventLoop).wait())
defer { XCTAssertNoThrow( try conn?.close().wait() ) }
var receivedNotifications = 0
let receivedNotifications = ManagedAtomic<Int>(0)
conn?.addListener(channel: "example") { context, notification in
receivedNotifications += 1
receivedNotifications.wrappingIncrement(ordering: .relaxed)
context.stop()
}
XCTAssertNoThrow(_ = try conn?.simpleQuery("LISTEN example").wait())
XCTAssertNoThrow(_ = try conn?.simpleQuery("NOTIFY example").wait())
XCTAssertNoThrow(_ = try conn?.simpleQuery("NOTIFY example").wait())
XCTAssertNoThrow(_ = try conn?.simpleQuery("SELECT 1").wait())
XCTAssertEqual(receivedNotifications, 1)
XCTAssertEqual(receivedNotifications.load(ordering: .relaxed), 1)
}

func testNotificationsRemoveHandlerOutsideHandler() {
var conn: PostgresConnection?
XCTAssertNoThrow(conn = try PostgresConnection.test(on: eventLoop).wait())
defer { XCTAssertNoThrow( try conn?.close().wait() ) }
var receivedNotifications = 0
let receivedNotifications = ManagedAtomic<Int>(0)
let context = conn?.addListener(channel: "example") { context, notification in
receivedNotifications += 1
receivedNotifications.wrappingIncrement(ordering: .relaxed)
}
XCTAssertNotNil(context)
XCTAssertNoThrow(_ = try conn?.simpleQuery("LISTEN example").wait())
Expand All @@ -173,47 +174,47 @@ final class PostgresNIOTests: XCTestCase {
context?.stop()
XCTAssertNoThrow(_ = try conn?.simpleQuery("NOTIFY example").wait())
XCTAssertNoThrow(_ = try conn?.simpleQuery("SELECT 1").wait())
XCTAssertEqual(receivedNotifications, 1)
XCTAssertEqual(receivedNotifications.load(ordering: .relaxed), 1)
}

func testNotificationsMultipleRegisteredHandlers() {
var conn: PostgresConnection?
XCTAssertNoThrow(conn = try PostgresConnection.test(on: eventLoop).wait())
defer { XCTAssertNoThrow( try conn?.close().wait() ) }
var receivedNotifications1 = 0
let receivedNotifications1 = ManagedAtomic<Int>(0)
conn?.addListener(channel: "example") { context, notification in
receivedNotifications1 += 1
receivedNotifications1.wrappingIncrement(ordering: .relaxed)
}
var receivedNotifications2 = 0
let receivedNotifications2 = ManagedAtomic<Int>(0)
conn?.addListener(channel: "example") { context, notification in
receivedNotifications2 += 1
receivedNotifications2.wrappingIncrement(ordering: .relaxed)
}
XCTAssertNoThrow(_ = try conn?.simpleQuery("LISTEN example").wait())
XCTAssertNoThrow(_ = try conn?.simpleQuery("NOTIFY example").wait())
XCTAssertNoThrow(_ = try conn?.simpleQuery("SELECT 1").wait())
XCTAssertEqual(receivedNotifications1, 1)
XCTAssertEqual(receivedNotifications2, 1)
XCTAssertEqual(receivedNotifications1.load(ordering: .relaxed), 1)
XCTAssertEqual(receivedNotifications2.load(ordering: .relaxed), 1)
}

func testNotificationsMultipleRegisteredHandlersRemoval() throws {
var conn: PostgresConnection?
XCTAssertNoThrow(conn = try PostgresConnection.test(on: eventLoop).wait())
defer { XCTAssertNoThrow( try conn?.close().wait() ) }
var receivedNotifications1 = 0
let receivedNotifications1 = ManagedAtomic<Int>(0)
XCTAssertNotNil(conn?.addListener(channel: "example") { context, notification in
receivedNotifications1 += 1
receivedNotifications1.wrappingIncrement(ordering: .relaxed)
context.stop()
})
var receivedNotifications2 = 0
let receivedNotifications2 = ManagedAtomic<Int>(0)
XCTAssertNotNil(conn?.addListener(channel: "example") { context, notification in
receivedNotifications2 += 1
receivedNotifications2.wrappingIncrement(ordering: .relaxed)
})
XCTAssertNoThrow(_ = try conn?.simpleQuery("LISTEN example").wait())
XCTAssertNoThrow(_ = try conn?.simpleQuery("NOTIFY example").wait())
XCTAssertNoThrow(_ = try conn?.simpleQuery("NOTIFY example").wait())
XCTAssertNoThrow(_ = try conn?.simpleQuery("SELECT 1").wait())
XCTAssertEqual(receivedNotifications1, 1)
XCTAssertEqual(receivedNotifications2, 2)
XCTAssertEqual(receivedNotifications1.load(ordering: .relaxed), 1)
XCTAssertEqual(receivedNotifications2.load(ordering: .relaxed), 2)
}

func testNotificationHandlerFiltersOnChannel() {
Expand Down Expand Up @@ -1283,11 +1284,11 @@ final class PostgresNIOTests: XCTestCase {
XCTAssertNoThrow(conn = try PostgresConnection.test(on: eventLoop).wait())
defer { XCTAssertNoThrow( try conn?.close().wait() ) }
var queries: [[PostgresRow]]?
XCTAssertNoThrow(queries = try conn?.prepare(query: "SELECT $1::text as foo;", handler: { query in
XCTAssertNoThrow(queries = try conn?.prepare(query: "SELECT $1::text as foo;", handler: { [eventLoop] query in
let a = query.execute(["a"])
let b = query.execute(["b"])
let c = query.execute(["c"])
return EventLoopFuture.whenAllSucceed([a, b, c], on: self.eventLoop)
return EventLoopFuture.whenAllSucceed([a, b, c], on: eventLoop)
}).wait())
XCTAssertEqual(queries?.count, 3)
var resultIterator = queries?.makeIterator()
Expand Down
19 changes: 11 additions & 8 deletions Tests/PostgresNIOTests/New/PostgresConnectionTests.swift
Original file line number Diff line number Diff line change
Expand Up @@ -187,7 +187,7 @@ class PostgresConnectionTests: XCTestCase {
func testSimpleListenConnectionDrops() async throws {
let (connection, channel) = try await self.makeTestConnectionWithAsyncTestingChannel()

try await withThrowingTaskGroup(of: Void.self) { taskGroup in
try await withThrowingTaskGroup(of: Void.self) { [logger] taskGroup in
taskGroup.addTask {
let events = try await connection.listen("foo")
var iterator = events.makeAsyncIterator()
Expand All @@ -197,7 +197,7 @@ class PostgresConnectionTests: XCTestCase {
_ = try await iterator.next()
XCTFail("Did not expect to not throw")
} catch {
self.logger.error("error", metadata: ["error": "\(error)"])
logger.error("error", metadata: ["error": "\(error)"])
}
}

Expand Down Expand Up @@ -226,10 +226,10 @@ class PostgresConnectionTests: XCTestCase {

func testCloseGracefullyClosesWhenInternalQueueIsEmpty() async throws {
let (connection, channel) = try await self.makeTestConnectionWithAsyncTestingChannel()
try await withThrowingTaskGroup(of: Void.self) { taskGroup async throws -> () in
try await withThrowingTaskGroup(of: Void.self) { [logger] taskGroup async throws -> () in
for _ in 1...2 {
taskGroup.addTask {
let rows = try await connection.query("SELECT 1;", logger: self.logger)
let rows = try await connection.query("SELECT 1;", logger: logger)
var iterator = rows.decode(Int.self).makeAsyncIterator()
let first = try await iterator.next()
XCTAssertEqual(first, 1)
Expand Down Expand Up @@ -286,10 +286,10 @@ class PostgresConnectionTests: XCTestCase {
func testCloseClosesImmediatly() async throws {
let (connection, channel) = try await self.makeTestConnectionWithAsyncTestingChannel()

try await withThrowingTaskGroup(of: Void.self) { taskGroup async throws -> () in
try await withThrowingTaskGroup(of: Void.self) { [logger] taskGroup async throws -> () in
for _ in 1...2 {
taskGroup.addTask {
try await connection.query("SELECT 1;", logger: self.logger)
try await connection.query("SELECT 1;", logger: logger)
}
}

Expand Down Expand Up @@ -319,8 +319,9 @@ class PostgresConnectionTests: XCTestCase {

func testIfServerJustClosesTheErrorReflectsThat() async throws {
let (connection, channel) = try await self.makeTestConnectionWithAsyncTestingChannel()
let logger = self.logger

async let response = try await connection.query("SELECT 1;", logger: self.logger)
async let response = try await connection.query("SELECT 1;", logger: logger)

let listenMessage = try await channel.waitForUnpreparedRequest()
XCTAssertEqual(listenMessage.parse.query, "SELECT 1;")
Expand Down Expand Up @@ -423,6 +424,7 @@ class PostgresConnectionTests: XCTestCase {
case pleaseDontCrash
}
channel.pipeline.fireUserInboundEventTriggered(MyEvent.pleaseDontCrash)
try await connection.close()
}

func testSerialExecutionOfSamePreparedStatement() async throws {
Expand Down Expand Up @@ -651,7 +653,8 @@ class PostgresConnectionTests: XCTestCase {
database: "database"
)

async let connectionPromise = PostgresConnection.connect(on: eventLoop, configuration: configuration, id: 1, logger: self.logger)
let logger = self.logger
async let connectionPromise = PostgresConnection.connect(on: eventLoop, configuration: configuration, id: 1, logger: logger)
let message = try await channel.waitForOutboundWrite(as: PostgresFrontendMessage.self)
XCTAssertEqual(message, .startup(.versionThree(parameters: .init(user: "username", database: "database", options: [], replication: .false))))
try await channel.writeInbound(PostgresBackendMessage.authentication(.ok))
Expand Down
Loading