Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Support additional connection parameters #361

Merged
merged 5 commits into from
Dec 12, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -85,14 +85,19 @@ extension PostgresConnection {
/// This property is provided for compatibility with Amazon RDS Proxy, which requires it to be `false`.
/// If you are not using Amazon RDS Proxy, you should leave this set to `true` (the default).
public var requireBackendKeyData: Bool


/// Additional parameters to send to the server on startup. The name value pairs are added to the initial
/// startup message that the client sends to the server.
public var additionalStartupParameters: [(String, String)]

/// Create an options structure with default values.
///
/// Most users should not need to adjust the defaults.
public init() {
self.connectTimeout = .seconds(10)
self.tlsServerName = nil
self.requireBackendKeyData = true
self.additionalStartupParameters = []
}
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1113,11 +1113,19 @@ struct SendPrepareStatement {
let query: String
}

struct AuthContext: Equatable, CustomDebugStringConvertible {
let username: String
let password: String?
let database: String?

struct AuthContext: CustomDebugStringConvertible {
var username: String
var password: String?
var database: String?
var additionalParameters: [(String, String)]

init(username: String, password: String? = nil, database: String? = nil, additionalParameters: [(String, String)] = []) {
self.username = username
self.password = password
self.database = database
self.additionalParameters = additionalParameters
}

var debugDescription: String {
"""
AuthContext(username: \(String(reflecting: self.username)), \
Expand All @@ -1127,6 +1135,22 @@ struct AuthContext: Equatable, CustomDebugStringConvertible {
}
}

extension AuthContext: Equatable {
static func ==(lhs: Self, rhs: Self) -> Bool {
guard lhs.username == rhs.username
&& lhs.password == rhs.password
&& lhs.database == rhs.database
&& lhs.additionalParameters.count == rhs.additionalParameters.count
else {
return false
}

return lhs.additionalParameters.elementsEqual(rhs.additionalParameters) { lhs, rhs in
lhs.0 == rhs.0 && lhs.1 == rhs.1
}
}
}

enum PasswordAuthencationMode: Equatable {
case cleartext
case md5(salt: UInt32)
Expand Down
2 changes: 1 addition & 1 deletion Sources/PostgresNIO/New/PostgresChannelHandler.swift
Original file line number Diff line number Diff line change
Expand Up @@ -328,7 +328,7 @@ final class PostgresChannelHandler: ChannelDuplexHandler {
case .wait:
break
case .sendStartupMessage(let authContext):
self.encoder.startup(user: authContext.username, database: authContext.database)
self.encoder.startup(user: authContext.username, database: authContext.database, options: authContext.additionalParameters)
context.writeAndFlush(self.wrapOutboundOut(self.encoder.flushBuffer()), promise: nil)
case .sendSSLRequest:
self.encoder.ssl()
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@ struct PostgresFrontendMessageEncoder {
self.buffer = buffer
}

mutating func startup(user: String, database: String?) {
mutating func startup(user: String, database: String?, options: [(String, String)]) {
self.clearIfNeeded()
self.buffer.psqlLengthPrefixed { buffer in
buffer.writeInteger(Self.startupVersionThree)
Expand All @@ -37,6 +37,13 @@ struct PostgresFrontendMessageEncoder {
buffer.writeNullTerminatedString(database)
}

// we don't send replication parameters, as the default is false and this is what we
// need for a client
for (key, value) in options {
buffer.writeNullTerminatedString(key)
buffer.writeNullTerminatedString(value)
}

buffer.writeInteger(UInt8(0))
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -39,8 +39,8 @@ struct PSQLFrontendMessageDecoder: NIOSingleStepByteToMessageDecoder {
case 196608:
var user: String?
var database: String?
var options: String?
var options = [(String, String)]()

while let name = messageSlice.readNullTerminatedString(), messageSlice.readerIndex < finalIndex {
let value = messageSlice.readNullTerminatedString()

Expand All @@ -51,11 +51,10 @@ struct PSQLFrontendMessageDecoder: NIOSingleStepByteToMessageDecoder {
case "database":
database = value

case "options":
options = value

default:
break
if let value = value {
options.append((name, value))
}
}
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -103,7 +103,7 @@ enum PostgresFrontendMessage: Equatable {
static let requestCode: Int32 = 80877103
}

struct Startup: Hashable {
struct Startup: Equatable {
static let versionThree: Int32 = 0x00_03_00_00

/// Creates a `Startup` with "3.0" as the protocol version.
Expand All @@ -119,7 +119,7 @@ enum PostgresFrontendMessage: Equatable {
/// The protocol version number is followed by one or more pairs of parameter
/// name and value strings. A zero byte is required as a terminator after
/// the last name/value pair. `user` is required, others are optional.
struct Parameters: Hashable {
struct Parameters: Equatable {
enum Replication {
case `true`
case `false`
Expand All @@ -136,12 +136,33 @@ enum PostgresFrontendMessage: Equatable {
/// of setting individual run-time parameters.) Spaces within this string are
/// considered to separate arguments, unless escaped with a
/// backslash (\); write \\ to represent a literal backslash.
var options: String?
var options: [(String, String)]

/// Used to connect in streaming replication mode, where a small set of
/// replication commands can be issued instead of SQL statements. Value
/// can be true, false, or database, and the default is false.
var replication: Replication

static func ==(lhs: Self, rhs: Self) -> Bool {
guard lhs.user == rhs.user
&& lhs.database == rhs.database
&& lhs.replication == rhs.replication
&& lhs.options.count == rhs.options.count
else {
return false
}

var lhsIterator = lhs.options.makeIterator()
var rhsIterator = rhs.options.makeIterator()

while let lhsNext = lhsIterator.next(), let rhsNext = rhsIterator.next() {
guard lhsNext.0 == rhsNext.0 && lhsNext.1 == rhsNext.1 else {
return false
}
}
return true
}

}

var parameters: Parameters
Expand Down
41 changes: 39 additions & 2 deletions Tests/PostgresNIOTests/New/Messages/StartupTests.swift
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@ class StartupTests: XCTestCase {
let user = "test"
let database = "abc123"

encoder.startup(user: user, database: database)
encoder.startup(user: user, database: database, options: [])
byteBuffer = encoder.flushBuffer()

let byteBufferLength = Int32(byteBuffer.readableBytes)
Expand All @@ -32,7 +32,7 @@ class StartupTests: XCTestCase {

let user = "test"

encoder.startup(user: user, database: nil)
encoder.startup(user: user, database: nil, options: [])
byteBuffer = encoder.flushBuffer()

let byteBufferLength = Int32(byteBuffer.readableBytes)
Expand All @@ -44,4 +44,41 @@ class StartupTests: XCTestCase {

XCTAssertEqual(byteBuffer.readableBytes, 0)
}

func testStartupMessageWithAdditionalOptions() {
var encoder = PostgresFrontendMessageEncoder(buffer: .init())
var byteBuffer = ByteBuffer()

let user = "test"
let database = "abc123"

encoder.startup(user: user, database: database, options: [("some", "options")])
byteBuffer = encoder.flushBuffer()

let byteBufferLength = Int32(byteBuffer.readableBytes)
XCTAssertEqual(byteBufferLength, byteBuffer.readInteger())
XCTAssertEqual(PostgresFrontendMessage.Startup.versionThree, byteBuffer.readInteger())
XCTAssertEqual(byteBuffer.readNullTerminatedString(), "user")
XCTAssertEqual(byteBuffer.readNullTerminatedString(), "test")
XCTAssertEqual(byteBuffer.readNullTerminatedString(), "database")
XCTAssertEqual(byteBuffer.readNullTerminatedString(), "abc123")
XCTAssertEqual(byteBuffer.readNullTerminatedString(), "some")
XCTAssertEqual(byteBuffer.readNullTerminatedString(), "options")
XCTAssertEqual(byteBuffer.readInteger(), UInt8(0))

XCTAssertEqual(byteBuffer.readableBytes, 0)
}
}

extension PostgresFrontendMessage.Startup.Parameters.Replication {
var stringValue: String {
switch self {
case .true:
return "true"
case .false:
return "false"
case .database:
return "replication"
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -37,9 +37,8 @@ class PostgresChannelHandlerTests: XCTestCase {

XCTAssertEqual(startup.parameters.user, config.username)
XCTAssertEqual(startup.parameters.database, config.database)
XCTAssertEqual(startup.parameters.options, nil)
XCTAssertEqual(startup.parameters.replication, .false)

XCTAssert(startup.parameters.options.isEmpty)

XCTAssertNoThrow(try embedded.writeInbound(PostgresBackendMessage.authentication(.ok)))
XCTAssertNoThrow(try embedded.writeInbound(PostgresBackendMessage.backendKeyData(.init(processID: 1234, secretKey: 5678))))
XCTAssertNoThrow(try embedded.writeInbound(PostgresBackendMessage.readyForQuery(.idle)))
Expand Down Expand Up @@ -209,7 +208,7 @@ class PostgresChannelHandlerTests: XCTestCase {

XCTAssertEqual(startup.parameters.user, config.username)
XCTAssertEqual(startup.parameters.database, config.database)
XCTAssertEqual(startup.parameters.options, nil)
XCTAssert(startup.parameters.options.isEmpty)
XCTAssertEqual(startup.parameters.replication, .false)

var buffer = ByteBuffer()
Expand Down Expand Up @@ -282,7 +281,7 @@ extension AuthContext {
PostgresFrontendMessage.Startup.Parameters(
user: self.username,
database: self.database,
options: nil,
options: self.additionalParameters,
replication: .false
)
}
Expand Down
2 changes: 1 addition & 1 deletion Tests/PostgresNIOTests/New/PostgresConnectionTests.swift
Original file line number Diff line number Diff line change
Expand Up @@ -602,7 +602,7 @@ class PostgresConnectionTests: XCTestCase {

async let connectionPromise = PostgresConnection.connect(on: eventLoop, configuration: configuration, id: 1, logger: self.logger)
let message = try await channel.waitForOutboundWrite(as: PostgresFrontendMessage.self)
XCTAssertEqual(message, .startup(.versionThree(parameters: .init(user: "username", database: "database", replication: .false))))
XCTAssertEqual(message, .startup(.versionThree(parameters: .init(user: "username", database: "database", options: [], replication: .false))))
try await channel.writeInbound(PostgresBackendMessage.authentication(.ok))
try await channel.writeInbound(PostgresBackendMessage.backendKeyData(.init(processID: 1234, secretKey: 5678)))
try await channel.writeInbound(PostgresBackendMessage.readyForQuery(.idle))
Expand Down