Skip to content

Commit

Permalink
Fix default driver logic (#585)
Browse files Browse the repository at this point in the history
Change Dialect usage to Driver / DefaultDriver.
  • Loading branch information
stanislas-m committed Aug 16, 2020
1 parent 81b8a32 commit 0e25cc8
Show file tree
Hide file tree
Showing 4 changed files with 25 additions and 22 deletions.
18 changes: 9 additions & 9 deletions dialect_cockroach.go
Original file line number Diff line number Diff line change
Expand Up @@ -126,12 +126,12 @@ func (p *cockroach) CreateDB() error {
deets := p.ConnectionDetails

// Overwrite dialect to match pgx driver for sql.Open
dialect := deets.Dialect
if dialect == "postgres" {
dialect = "pgx"
driver := p.DefaultDriver()
if p.ConnectionDetails.Driver != "" {
driver = p.ConnectionDetails.Driver
}

db, err := sql.Open(dialect, p.urlWithoutDb())
db, err := sql.Open(driver, p.urlWithoutDb())
if err != nil {
return errors.Wrapf(err, "error creating Cockroach database %s", deets.Database)
}
Expand All @@ -152,12 +152,12 @@ func (p *cockroach) DropDB() error {
deets := p.ConnectionDetails

// Overwrite dialect to match pgx driver for sql.Open
dialect := deets.Dialect
if dialect == "postgres" {
dialect = "pgx"
driver := p.DefaultDriver()
if p.ConnectionDetails.Driver != "" {
driver = p.ConnectionDetails.Driver
}

db, err := sql.Open(dialect, p.urlWithoutDb())
db, err := sql.Open(driver, p.urlWithoutDb())
if err != nil {
return errors.Wrapf(err, "error dropping Cockroach database %s", deets.Database)
}
Expand Down Expand Up @@ -221,7 +221,7 @@ func (p *cockroach) DumpSchema(w io.Writer) error {
}

func (p *cockroach) LoadSchema(r io.Reader) error {
return genericLoadSchema(p.ConnectionDetails, p.MigrationURL(), r)
return genericLoadSchema(p.ConnectionDetails, p.DefaultDriver(), p.MigrationURL(), r)
}

func (p *cockroach) TruncateAll(tx *Connection) error {
Expand Down
8 changes: 6 additions & 2 deletions dialect_common.go
Original file line number Diff line number Diff line change
Expand Up @@ -143,9 +143,13 @@ func genericSelectMany(s store, models *Model, query Query) error {
return nil
}

func genericLoadSchema(deets *ConnectionDetails, migrationURL string, r io.Reader) error {
func genericLoadSchema(deets *ConnectionDetails, defaultDriver, migrationURL string, r io.Reader) error {
// Open DB connection on the target DB
db, err := sqlx.Open(deets.Dialect, migrationURL)
driver := defaultDriver
if deets.Driver != "" {
driver = deets.Driver
}
db, err := sqlx.Open(driver, migrationURL)
if err != nil {
return errors.WithMessage(err, fmt.Sprintf("unable to load schema for %s", deets.Database))
}
Expand Down
2 changes: 1 addition & 1 deletion dialect_mysql.go
Original file line number Diff line number Diff line change
Expand Up @@ -165,7 +165,7 @@ func (m *mysql) DumpSchema(w io.Writer) error {

// LoadSchema executes a schema sql file against the configured database.
func (m *mysql) LoadSchema(r io.Reader) error {
return genericLoadSchema(m.ConnectionDetails, m.MigrationURL(), r)
return genericLoadSchema(m.ConnectionDetails, m.DefaultDriver(), m.MigrationURL(), r)
}

// TruncateAll truncates all tables for the given connection.
Expand Down
19 changes: 9 additions & 10 deletions dialect_postgresql.go
Original file line number Diff line number Diff line change
Expand Up @@ -115,13 +115,12 @@ func (p *postgresql) CreateDB() error {
// createdb -h db -p 5432 -U postgres enterprise_development
deets := p.ConnectionDetails

// Overwrite dialect to match pgx driver for sql.Open
dialect := deets.Dialect
if dialect == "postgres" {
dialect = "pgx"
driver := p.DefaultDriver()
if p.ConnectionDetails.Driver != "" {
driver = p.ConnectionDetails.Driver
}

db, err := sql.Open(dialect, p.urlWithoutDb())
db, err := sql.Open(driver, p.urlWithoutDb())
if err != nil {
return errors.Wrapf(err, "error creating PostgreSQL database %s", deets.Database)
}
Expand All @@ -142,12 +141,12 @@ func (p *postgresql) DropDB() error {
deets := p.ConnectionDetails

// Overwrite dialect to match pgx driver for sql.Open
dialect := deets.Dialect
if dialect == "postgres" {
dialect = "pgx"
driver := p.DefaultDriver()
if p.ConnectionDetails.Driver != "" {
driver = p.ConnectionDetails.Driver
}

db, err := sql.Open(dialect, p.urlWithoutDb())
db, err := sql.Open(driver, p.urlWithoutDb())
if err != nil {
return errors.Wrapf(err, "error dropping PostgreSQL database %s", deets.Database)
}
Expand Down Expand Up @@ -211,7 +210,7 @@ func (p *postgresql) DumpSchema(w io.Writer) error {

// LoadSchema executes a schema sql file against the configured database.
func (p *postgresql) LoadSchema(r io.Reader) error {
return genericLoadSchema(p.ConnectionDetails, p.MigrationURL(), r)
return genericLoadSchema(p.ConnectionDetails, p.DefaultDriver(), p.MigrationURL(), r)
}

// TruncateAll truncates all tables for the given connection.
Expand Down

0 comments on commit 0e25cc8

Please sign in to comment.