Skip to content

Commit

Permalink
Use pointer when passing the connection manager around (#3152)
Browse files Browse the repository at this point in the history
As otherwise existing connections aren't reused.
  • Loading branch information
S7evinK committed Jul 19, 2023
1 parent a01faee commit 297479e
Show file tree
Hide file tree
Showing 31 changed files with 148 additions and 84 deletions.
2 changes: 1 addition & 1 deletion cmd/dendrite-demo-pinecone/monolith/monolith.go
Original file line number Diff line number Diff line change
Expand Up @@ -126,7 +126,7 @@ func (p *P2PMonolith) SetupPinecone(sk ed25519.PrivateKey) {
}

func (p *P2PMonolith) SetupDendrite(
processCtx *process.ProcessContext, cfg *config.Dendrite, cm sqlutil.Connections, routers httputil.Routers,
processCtx *process.ProcessContext, cfg *config.Dendrite, cm *sqlutil.Connections, routers httputil.Routers,
port int, enableRelaying bool, enableMetrics bool, enableWebsockets bool) {

p.port = port
Expand Down
2 changes: 1 addition & 1 deletion federationapi/federationapi.go
Original file line number Diff line number Diff line change
Expand Up @@ -95,7 +95,7 @@ func AddPublicRoutes(
func NewInternalAPI(
processContext *process.ProcessContext,
dendriteCfg *config.Dendrite,
cm sqlutil.Connections,
cm *sqlutil.Connections,
natsInstance *jetstream.NATSInstance,
federation fclient.FederationClient,
rsAPI roomserverAPI.FederationRoomserverAPI,
Expand Down
2 changes: 1 addition & 1 deletion federationapi/storage/postgres/storage.go
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@ type Database struct {
}

// NewDatabase opens a new database
func NewDatabase(ctx context.Context, conMan sqlutil.Connections, dbProperties *config.DatabaseOptions, cache caching.FederationCache, isLocalServerName func(spec.ServerName) bool) (*Database, error) {
func NewDatabase(ctx context.Context, conMan *sqlutil.Connections, dbProperties *config.DatabaseOptions, cache caching.FederationCache, isLocalServerName func(spec.ServerName) bool) (*Database, error) {
var d Database
var err error
if d.db, d.writer, err = conMan.Connection(dbProperties); err != nil {
Expand Down
2 changes: 1 addition & 1 deletion federationapi/storage/sqlite3/storage.go
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@ type Database struct {
}

// NewDatabase opens a new database
func NewDatabase(ctx context.Context, conMan sqlutil.Connections, dbProperties *config.DatabaseOptions, cache caching.FederationCache, isLocalServerName func(spec.ServerName) bool) (*Database, error) {
func NewDatabase(ctx context.Context, conMan *sqlutil.Connections, dbProperties *config.DatabaseOptions, cache caching.FederationCache, isLocalServerName func(spec.ServerName) bool) (*Database, error) {
var d Database
var err error
if d.db, d.writer, err = conMan.Connection(dbProperties); err != nil {
Expand Down
2 changes: 1 addition & 1 deletion federationapi/storage/storage.go
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@ import (
)

// NewDatabase opens a new database
func NewDatabase(ctx context.Context, conMan sqlutil.Connections, dbProperties *config.DatabaseOptions, cache caching.FederationCache, isLocalServerName func(spec.ServerName) bool) (Database, error) {
func NewDatabase(ctx context.Context, conMan *sqlutil.Connections, dbProperties *config.DatabaseOptions, cache caching.FederationCache, isLocalServerName func(spec.ServerName) bool) (Database, error) {
switch {
case dbProperties.ConnectionString.IsSQLite():
return sqlite3.NewDatabase(ctx, conMan, dbProperties, cache, isLocalServerName)
Expand Down
16 changes: 9 additions & 7 deletions internal/sqlutil/connection_manager.go
Original file line number Diff line number Diff line change
Expand Up @@ -29,24 +29,26 @@ type Connections struct {
processContext *process.ProcessContext
}

func NewConnectionManager(processCtx *process.ProcessContext, globalConfig config.DatabaseOptions) Connections {
return Connections{
func NewConnectionManager(processCtx *process.ProcessContext, globalConfig config.DatabaseOptions) *Connections {
return &Connections{
globalConfig: globalConfig,
processContext: processCtx,
}
}

func (c *Connections) Connection(dbProperties *config.DatabaseOptions) (*sql.DB, Writer, error) {
writer := NewDummyWriter()
if dbProperties.ConnectionString.IsSQLite() {
writer = NewExclusiveWriter()
}
var err error
if dbProperties.ConnectionString == "" {
// if no connectionString was provided, try the global one
dbProperties = &c.globalConfig
}
if dbProperties.ConnectionString != "" || c.db == nil {

writer := NewDummyWriter()
if dbProperties.ConnectionString.IsSQLite() {
writer = NewExclusiveWriter()
}

if dbProperties.ConnectionString != "" && c.db == nil {
// Open a new database connection using the supplied config.
c.db, err = Open(dbProperties, writer)
if err != nil {
Expand Down
146 changes: 104 additions & 42 deletions internal/sqlutil/connection_manager_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -6,51 +6,113 @@ import (

"github.com/matrix-org/dendrite/internal/sqlutil"
"github.com/matrix-org/dendrite/setup/config"
"github.com/matrix-org/dendrite/setup/process"
"github.com/matrix-org/dendrite/test"
)

func TestConnectionManager(t *testing.T) {
test.WithAllDatabases(t, func(t *testing.T, dbType test.DBType) {
conStr, close := test.PrepareDBConnectionString(t, dbType)
t.Cleanup(close)
cm := sqlutil.NewConnectionManager(nil, config.DatabaseOptions{})

dbProps := &config.DatabaseOptions{ConnectionString: config.DataSource(conStr)}
db, writer, err := cm.Connection(dbProps)
if err != nil {
t.Fatal(err)
}

switch dbType {
case test.DBTypeSQLite:
_, ok := writer.(*sqlutil.ExclusiveWriter)
if !ok {
t.Fatalf("expected exclusive writer")
}
case test.DBTypePostgres:
_, ok := writer.(*sqlutil.DummyWriter)
if !ok {
t.Fatalf("expected dummy writer")
}
}

// test global db pool
dbGlobal, writerGlobal, err := cm.Connection(&config.DatabaseOptions{})
if err != nil {
t.Fatal(err)
}
if !reflect.DeepEqual(db, dbGlobal) {
t.Fatalf("expected database connection to be reused")
}
if !reflect.DeepEqual(writer, writerGlobal) {
t.Fatalf("expected database writer to be reused")
}

// test invalid connection string configured
cm2 := sqlutil.NewConnectionManager(nil, config.DatabaseOptions{})
_, _, err = cm2.Connection(&config.DatabaseOptions{ConnectionString: "http://"})
if err == nil {
t.Fatal("expected an error but got none")
}

t.Run("component defined connection string", func(t *testing.T) {
test.WithAllDatabases(t, func(t *testing.T, dbType test.DBType) {
conStr, close := test.PrepareDBConnectionString(t, dbType)
t.Cleanup(close)
cm := sqlutil.NewConnectionManager(nil, config.DatabaseOptions{})

dbProps := &config.DatabaseOptions{ConnectionString: config.DataSource(conStr)}
db, writer, err := cm.Connection(dbProps)
if err != nil {
t.Fatal(err)
}

switch dbType {
case test.DBTypeSQLite:
_, ok := writer.(*sqlutil.ExclusiveWriter)
if !ok {
t.Fatalf("expected exclusive writer")
}
case test.DBTypePostgres:
_, ok := writer.(*sqlutil.DummyWriter)
if !ok {
t.Fatalf("expected dummy writer")
}
}

// reuse existing connection
db2, writer2, err := cm.Connection(dbProps)
if err != nil {
t.Fatal(err)
}
if !reflect.DeepEqual(db, db2) {
t.Fatalf("expected database connection to be reused")
}
if !reflect.DeepEqual(writer, writer2) {
t.Fatalf("expected database writer to be reused")
}
})
})

t.Run("global connection pool", func(t *testing.T) {
test.WithAllDatabases(t, func(t *testing.T, dbType test.DBType) {
conStr, close := test.PrepareDBConnectionString(t, dbType)
t.Cleanup(close)
cm := sqlutil.NewConnectionManager(nil, config.DatabaseOptions{ConnectionString: config.DataSource(conStr)})

dbProps := &config.DatabaseOptions{}
db, writer, err := cm.Connection(dbProps)
if err != nil {
t.Fatal(err)
}

switch dbType {
case test.DBTypeSQLite:
_, ok := writer.(*sqlutil.ExclusiveWriter)
if !ok {
t.Fatalf("expected exclusive writer")
}
case test.DBTypePostgres:
_, ok := writer.(*sqlutil.DummyWriter)
if !ok {
t.Fatalf("expected dummy writer")
}
}

// reuse existing connection
db2, writer2, err := cm.Connection(dbProps)
if err != nil {
t.Fatal(err)
}
if !reflect.DeepEqual(db, db2) {
t.Fatalf("expected database connection to be reused")
}
if !reflect.DeepEqual(writer, writer2) {
t.Fatalf("expected database writer to be reused")
}
})
})

t.Run("shutdown", func(t *testing.T) {
test.WithAllDatabases(t, func(t *testing.T, dbType test.DBType) {
conStr, close := test.PrepareDBConnectionString(t, dbType)
t.Cleanup(close)

processCtx := process.NewProcessContext()
cm := sqlutil.NewConnectionManager(processCtx, config.DatabaseOptions{ConnectionString: config.DataSource(conStr)})

dbProps := &config.DatabaseOptions{}
_, _, err := cm.Connection(dbProps)
if err != nil {
t.Fatal(err)
}

processCtx.ShutdownDendrite()
processCtx.WaitForComponentsToFinish()
})
})

// test invalid connection string configured
cm2 := sqlutil.NewConnectionManager(nil, config.DatabaseOptions{})
_, _, err := cm2.Connection(&config.DatabaseOptions{ConnectionString: "http://"})
if err == nil {
t.Fatal("expected an error but got none")
}
}
2 changes: 1 addition & 1 deletion mediaapi/mediaapi.go
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@ import (
// AddPublicRoutes sets up and registers HTTP handlers for the MediaAPI component.
func AddPublicRoutes(
mediaRouter *mux.Router,
cm sqlutil.Connections,
cm *sqlutil.Connections,
cfg *config.Dendrite,
userAPI userapi.MediaUserAPI,
client *fclient.Client,
Expand Down
2 changes: 1 addition & 1 deletion mediaapi/storage/postgres/mediaapi.go
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ import (
)

// NewDatabase opens a postgres database.
func NewDatabase(conMan sqlutil.Connections, dbProperties *config.DatabaseOptions) (*shared.Database, error) {
func NewDatabase(conMan *sqlutil.Connections, dbProperties *config.DatabaseOptions) (*shared.Database, error) {
db, writer, err := conMan.Connection(dbProperties)
if err != nil {
return nil, err
Expand Down
2 changes: 1 addition & 1 deletion mediaapi/storage/sqlite3/mediaapi.go
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ import (
)

// NewDatabase opens a SQLIte database.
func NewDatabase(conMan sqlutil.Connections, dbProperties *config.DatabaseOptions) (*shared.Database, error) {
func NewDatabase(conMan *sqlutil.Connections, dbProperties *config.DatabaseOptions) (*shared.Database, error) {
db, writer, err := conMan.Connection(dbProperties)
if err != nil {
return nil, err
Expand Down
2 changes: 1 addition & 1 deletion mediaapi/storage/storage.go
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@ import (
)

// NewMediaAPIDatasource opens a database connection.
func NewMediaAPIDatasource(conMan sqlutil.Connections, dbProperties *config.DatabaseOptions) (Database, error) {
func NewMediaAPIDatasource(conMan *sqlutil.Connections, dbProperties *config.DatabaseOptions) (Database, error) {
switch {
case dbProperties.ConnectionString.IsSQLite():
return sqlite3.NewDatabase(conMan, dbProperties)
Expand Down
2 changes: 1 addition & 1 deletion relayapi/relayapi.go
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,7 @@ func AddPublicRoutes(

func NewRelayInternalAPI(
dendriteCfg *config.Dendrite,
cm sqlutil.Connections,
cm *sqlutil.Connections,
fedClient fclient.FederationClient,
rsAPI rsAPI.RoomserverInternalAPI,
keyRing *gomatrixserverlib.KeyRing,
Expand Down
2 changes: 1 addition & 1 deletion relayapi/storage/postgres/storage.go
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@ type Database struct {

// NewDatabase opens a new database
func NewDatabase(
conMan sqlutil.Connections,
conMan *sqlutil.Connections,
dbProperties *config.DatabaseOptions,
cache caching.FederationCache,
isLocalServerName func(spec.ServerName) bool,
Expand Down
2 changes: 1 addition & 1 deletion relayapi/storage/sqlite3/storage.go
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@ type Database struct {

// NewDatabase opens a new database
func NewDatabase(
conMan sqlutil.Connections,
conMan *sqlutil.Connections,
dbProperties *config.DatabaseOptions,
cache caching.FederationCache,
isLocalServerName func(spec.ServerName) bool,
Expand Down
2 changes: 1 addition & 1 deletion relayapi/storage/storage.go
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@ import (

// NewDatabase opens a new database
func NewDatabase(
conMan sqlutil.Connections,
conMan *sqlutil.Connections,
dbProperties *config.DatabaseOptions,
cache caching.FederationCache,
isLocalServerName func(spec.ServerName) bool,
Expand Down
2 changes: 1 addition & 1 deletion roomserver/roomserver.go
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@ import (
func NewInternalAPI(
processContext *process.ProcessContext,
cfg *config.Dendrite,
cm sqlutil.Connections,
cm *sqlutil.Connections,
natsInstance *jetstream.NATSInstance,
caches caching.RoomServerCaches,
enableMetrics bool,
Expand Down
2 changes: 1 addition & 1 deletion roomserver/storage/postgres/storage.go
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@ type Database struct {
}

// Open a postgres database.
func Open(ctx context.Context, conMan sqlutil.Connections, dbProperties *config.DatabaseOptions, cache caching.RoomServerCaches) (*Database, error) {
func Open(ctx context.Context, conMan *sqlutil.Connections, dbProperties *config.DatabaseOptions, cache caching.RoomServerCaches) (*Database, error) {
var d Database
var err error
db, writer, err := conMan.Connection(dbProperties)
Expand Down
2 changes: 1 addition & 1 deletion roomserver/storage/sqlite3/storage.go
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@ type Database struct {
}

// Open a sqlite database.
func Open(ctx context.Context, conMan sqlutil.Connections, dbProperties *config.DatabaseOptions, cache caching.RoomServerCaches) (*Database, error) {
func Open(ctx context.Context, conMan *sqlutil.Connections, dbProperties *config.DatabaseOptions, cache caching.RoomServerCaches) (*Database, error) {
var d Database
var err error
db, writer, err := conMan.Connection(dbProperties)
Expand Down
2 changes: 1 addition & 1 deletion roomserver/storage/storage.go
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@ import (
)

// Open opens a database connection.
func Open(ctx context.Context, conMan sqlutil.Connections, dbProperties *config.DatabaseOptions, cache caching.RoomServerCaches) (Database, error) {
func Open(ctx context.Context, conMan *sqlutil.Connections, dbProperties *config.DatabaseOptions, cache caching.RoomServerCaches) (Database, error) {
switch {
case dbProperties.ConnectionString.IsSQLite():
return sqlite3.Open(ctx, conMan, dbProperties, cache)
Expand Down
2 changes: 1 addition & 1 deletion setup/monolith.go
Original file line number Diff line number Diff line change
Expand Up @@ -61,7 +61,7 @@ func (m *Monolith) AddAllPublicRoutes(
processCtx *process.ProcessContext,
cfg *config.Dendrite,
routers httputil.Routers,
cm sqlutil.Connections,
cm *sqlutil.Connections,
natsInstance *jetstream.NATSInstance,
caches *caching.Caches,
enableMetrics bool,
Expand Down
2 changes: 1 addition & 1 deletion setup/mscs/msc2836/msc2836.go
Original file line number Diff line number Diff line change
Expand Up @@ -105,7 +105,7 @@ func toClientResponse(ctx context.Context, res *MSC2836EventRelationshipsRespons

// Enable this MSC
func Enable(
cfg *config.Dendrite, cm sqlutil.Connections, routers httputil.Routers, rsAPI roomserver.RoomserverInternalAPI, fsAPI fs.FederationInternalAPI,
cfg *config.Dendrite, cm *sqlutil.Connections, routers httputil.Routers, rsAPI roomserver.RoomserverInternalAPI, fsAPI fs.FederationInternalAPI,
userAPI userapi.UserInternalAPI, keyRing gomatrixserverlib.JSONVerifier,
) error {
db, err := NewDatabase(cm, &cfg.MSCs.Database)
Expand Down
6 changes: 3 additions & 3 deletions setup/mscs/msc2836/storage.go
Original file line number Diff line number Diff line change
Expand Up @@ -59,14 +59,14 @@ type DB struct {
}

// NewDatabase loads the database for msc2836
func NewDatabase(conMan sqlutil.Connections, dbOpts *config.DatabaseOptions) (Database, error) {
func NewDatabase(conMan *sqlutil.Connections, dbOpts *config.DatabaseOptions) (Database, error) {
if dbOpts.ConnectionString.IsPostgres() {
return newPostgresDatabase(conMan, dbOpts)
}
return newSQLiteDatabase(conMan, dbOpts)
}

func newPostgresDatabase(conMan sqlutil.Connections, dbOpts *config.DatabaseOptions) (Database, error) {
func newPostgresDatabase(conMan *sqlutil.Connections, dbOpts *config.DatabaseOptions) (Database, error) {
d := DB{}
var err error
if d.db, d.writer, err = conMan.Connection(dbOpts); err != nil {
Expand Down Expand Up @@ -144,7 +144,7 @@ func newPostgresDatabase(conMan sqlutil.Connections, dbOpts *config.DatabaseOpti
return &d, err
}

func newSQLiteDatabase(conMan sqlutil.Connections, dbOpts *config.DatabaseOptions) (Database, error) {
func newSQLiteDatabase(conMan *sqlutil.Connections, dbOpts *config.DatabaseOptions) (Database, error) {
d := DB{}
var err error
if d.db, d.writer, err = conMan.Connection(dbOpts); err != nil {
Expand Down
Loading

0 comments on commit 297479e

Please sign in to comment.