Skip to content

Commit

Permalink
Quote identifiers (#383)
Browse files Browse the repository at this point in the history
* WIP quote identifiers

* Handle update strings

* Fix SQLite driver
  • Loading branch information
stanislas-m committed Aug 15, 2019
1 parent 77a6d94 commit 79000c1
Show file tree
Hide file tree
Showing 9 changed files with 80 additions and 31 deletions.
5 changes: 5 additions & 0 deletions columns/column.go
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,11 @@ func (c Column) UpdateString() string {
return fmt.Sprintf("%s = :%s", c.Name, c.Name)
}

// QuotedUpdateString returns quoted the SQL statement to UPDATE the column.
func (c Column) QuotedUpdateString(quoter quoter) string {
return fmt.Sprintf("%s = :%s", quoter.Quote(c.Name), c.Name)
}

// SetSelectSQL sets a custom SELECT statement for the column.
func (c *Column) SetSelectSQL(s string) {
c.SelectSQL = s
Expand Down
14 changes: 14 additions & 0 deletions columns/columns.go
Original file line number Diff line number Diff line change
Expand Up @@ -118,6 +118,20 @@ func (c Columns) Readable() *ReadableColumns {
return w
}

type quoter interface {
Quote(key string) string
}

// QuotedString gives the columns list quoted with the given quoter function.
func (c Columns) QuotedString(quoter quoter) string {
var xs []string
for _, t := range c.Cols {
xs = append(xs, quoter.Quote(t.Name))
}
sort.Strings(xs)
return strings.Join(xs, ", ")
}

func (c Columns) String() string {
var xs []string
for _, t := range c.Cols {
Expand Down
10 changes: 10 additions & 0 deletions columns/writeable_columns.go
Original file line number Diff line number Diff line change
Expand Up @@ -19,3 +19,13 @@ func (c WriteableColumns) UpdateString() string {
sort.Strings(xs)
return strings.Join(xs, ", ")
}

// QuotedUpdateString returns the quoted SQL column list part of the UPDATE query.
func (c Columns) QuotedUpdateString(quoter quoter) string {
var xs []string
for _, t := range c.Cols {
xs = append(xs, t.QuotedUpdateString(quoter))
}
sort.Strings(xs)
return strings.Join(xs, ", ")
}
16 changes: 16 additions & 0 deletions columns/writeable_columns_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,22 @@ func Test_Columns_UpdateString(t *testing.T) {
}
}

type testQuoter struct{}

func (testQuoter) Quote(col string) string {
return `"` + col + `"`
}

func Test_Columns_QuotedUpdateString(t *testing.T) {
r := require.New(t)
q := testQuoter{}
for _, f := range []interface{}{foo{}, &foo{}} {
c := columns.ForStruct(f, "foo")
u := c.Writeable().QuotedUpdateString(q)
r.Equal(u, "\"LastName\" = :LastName, \"write\" = :write")
}
}

func Test_Columns_WriteableString(t *testing.T) {
r := require.New(t)
for _, f := range []interface{}{foo{}, &foo{}} {
Expand Down
16 changes: 8 additions & 8 deletions dialect_cockroach.go
Original file line number Diff line number Diff line change
Expand Up @@ -72,9 +72,9 @@ func (p *cockroach) Create(s store, model *Model, cols columns.Columns) error {
w := cols.Writeable()
var query string
if len(w.Cols) > 0 {
query = fmt.Sprintf("INSERT INTO %s (%s) VALUES (%s) returning id", model.TableName(), w.String(), w.SymbolizedString())
query = fmt.Sprintf("INSERT INTO %s (%s) VALUES (%s) returning id", p.Quote(model.TableName()), w.QuotedString(p), w.SymbolizedString())
} else {
query = fmt.Sprintf("INSERT INTO %s DEFAULT VALUES returning id", model.TableName())
query = fmt.Sprintf("INSERT INTO %s DEFAULT VALUES returning id", p.Quote(model.TableName()))
}
log(logging.SQL, query)
stmt, err := s.PrepareNamed(query)
Expand All @@ -91,15 +91,15 @@ func (p *cockroach) Create(s store, model *Model, cols columns.Columns) error {
model.setID(id.ID)
return errors.WithMessage(stmt.Close(), "failed to close statement")
}
return genericCreate(s, model, cols)
return genericCreate(s, model, cols, p)
}

func (p *cockroach) Update(s store, model *Model, cols columns.Columns) error {
return genericUpdate(s, model, cols)
return genericUpdate(s, model, cols, p)
}

func (p *cockroach) Destroy(s store, model *Model) error {
stmt := p.TranslateSQL(fmt.Sprintf("DELETE FROM %s WHERE %s", model.TableName(), model.whereID()))
stmt := p.TranslateSQL(fmt.Sprintf("DELETE FROM %s WHERE %s", p.Quote(model.TableName()), model.whereID()))
_, err := genericExec(s, stmt, model.ID())
return err
}
Expand All @@ -120,7 +120,7 @@ func (p *cockroach) CreateDB() error {
return errors.Wrapf(err, "error creating Cockroach database %s", deets.Database)
}
defer db.Close()
query := fmt.Sprintf("CREATE DATABASE \"%s\"", deets.Database)
query := fmt.Sprintf("CREATE DATABASE %s", p.Quote(deets.Database))
log(logging.SQL, query)

_, err = db.Exec(query)
Expand All @@ -139,7 +139,7 @@ func (p *cockroach) DropDB() error {
return errors.Wrapf(err, "error dropping Cockroach database %s", deets.Database)
}
defer db.Close()
query := fmt.Sprintf("DROP DATABASE \"%s\" CASCADE;", deets.Database)
query := fmt.Sprintf("DROP DATABASE %s CASCADE;", p.Quote(deets.Database))
log(logging.SQL, query)

_, err = db.Exec(query)
Expand Down Expand Up @@ -223,7 +223,7 @@ func (p *cockroach) TruncateAll(tx *Connection) error {
//! work around for current limitation of DDL and DML at the same transaction.
// it should be fixed when cockroach support it or with other approach.
// https://www.cockroachlabs.com/docs/stable/known-limitations.html#schema-changes-within-transactions
if err := tx.RawQuery(fmt.Sprintf("delete from %s", t.TableName)).Exec(); err != nil {
if err := tx.RawQuery(fmt.Sprintf("delete from %s", p.Quote(t.TableName))).Exec(); err != nil {
return err
}
}
Expand Down
18 changes: 11 additions & 7 deletions dialect_common.go
Original file line number Diff line number Diff line change
Expand Up @@ -34,13 +34,17 @@ func (commonDialect) Quote(key string) string {
return fmt.Sprintf(`"%s"`, key)
}

func genericCreate(s store, model *Model, cols columns.Columns) error {
type quoter interface {
Quote(key string) string
}

func genericCreate(s store, model *Model, cols columns.Columns, quoter quoter) error {
keyType := model.PrimaryKeyType()
switch keyType {
case "int", "int64":
var id int64
w := cols.Writeable()
query := fmt.Sprintf("INSERT INTO %s (%s) VALUES (%s)", model.TableName(), w.String(), w.SymbolizedString())
query := fmt.Sprintf("INSERT INTO %s (%s) VALUES (%s)", quoter.Quote(model.TableName()), w.QuotedString(quoter), w.SymbolizedString())
log(logging.SQL, query)
res, err := s.NamedExec(query, model.Value)
if err != nil {
Expand Down Expand Up @@ -68,7 +72,7 @@ func genericCreate(s store, model *Model, cols columns.Columns) error {
}
w := cols.Writeable()
w.Add("id")
query := fmt.Sprintf("INSERT INTO %s (%s) VALUES (%s)", model.TableName(), w.String(), w.SymbolizedString())
query := fmt.Sprintf("INSERT INTO %s (%s) VALUES (%s)", quoter.Quote(model.TableName()), w.QuotedString(quoter), w.SymbolizedString())
log(logging.SQL, query)
stmt, err := s.PrepareNamed(query)
if err != nil {
Expand All @@ -86,8 +90,8 @@ func genericCreate(s store, model *Model, cols columns.Columns) error {
return errors.Errorf("can not use %s as a primary key type!", keyType)
}

func genericUpdate(s store, model *Model, cols columns.Columns) error {
stmt := fmt.Sprintf("UPDATE %s SET %s WHERE %s", model.TableName(), cols.Writeable().UpdateString(), model.whereNamedID())
func genericUpdate(s store, model *Model, cols columns.Columns, quoter quoter) error {
stmt := fmt.Sprintf("UPDATE %s SET %s WHERE %s", quoter.Quote(model.TableName()), cols.Writeable().QuotedUpdateString(quoter), model.whereNamedID())
log(logging.SQL, stmt, model.ID())
_, err := s.NamedExec(stmt, model.Value)
if err != nil {
Expand All @@ -96,8 +100,8 @@ func genericUpdate(s store, model *Model, cols columns.Columns) error {
return nil
}

func genericDestroy(s store, model *Model) error {
stmt := fmt.Sprintf("DELETE FROM %s WHERE %s", model.TableName(), model.whereID())
func genericDestroy(s store, model *Model, quoter quoter) error {
stmt := fmt.Sprintf("DELETE FROM %s WHERE %s", quoter.Quote(model.TableName()), model.whereID())
_, err := genericExec(s, stmt, model.ID())
if err != nil {
return err
Expand Down
6 changes: 3 additions & 3 deletions dialect_mysql.go
Original file line number Diff line number Diff line change
Expand Up @@ -81,15 +81,15 @@ func (m *mysql) MigrationURL() string {
}

func (m *mysql) Create(s store, model *Model, cols columns.Columns) error {
return errors.Wrap(genericCreate(s, model, cols), "mysql create")
return errors.Wrap(genericCreate(s, model, cols, m), "mysql create")
}

func (m *mysql) Update(s store, model *Model, cols columns.Columns) error {
return errors.Wrap(genericUpdate(s, model, cols), "mysql update")
return errors.Wrap(genericUpdate(s, model, cols, m), "mysql update")
}

func (m *mysql) Destroy(s store, model *Model) error {
return errors.Wrap(genericDestroy(s, model), "mysql destroy")
return errors.Wrap(genericDestroy(s, model, m), "mysql destroy")
}

func (m *mysql) SelectOne(s store, model *Model, query Query) error {
Expand Down
14 changes: 7 additions & 7 deletions dialect_postgresql.go
Original file line number Diff line number Diff line change
Expand Up @@ -58,9 +58,9 @@ func (p *postgresql) Create(s store, model *Model, cols columns.Columns) error {
w := cols.Writeable()
var query string
if len(w.Cols) > 0 {
query = fmt.Sprintf("INSERT INTO %s (%s) VALUES (%s) returning id", model.TableName(), w.String(), w.SymbolizedString())
query = fmt.Sprintf("INSERT INTO %s (%s) VALUES (%s) returning id", p.Quote(model.TableName()), w.QuotedString(p), w.SymbolizedString())
} else {
query = fmt.Sprintf("INSERT INTO %s DEFAULT VALUES returning id", model.TableName())
query = fmt.Sprintf("INSERT INTO %s DEFAULT VALUES returning id", p.Quote(model.TableName()))
}
log(logging.SQL, query)
stmt, err := s.PrepareNamed(query)
Expand All @@ -77,15 +77,15 @@ func (p *postgresql) Create(s store, model *Model, cols columns.Columns) error {
model.setID(id.ID)
return errors.WithMessage(stmt.Close(), "failed to close statement")
}
return genericCreate(s, model, cols)
return genericCreate(s, model, cols, p)
}

func (p *postgresql) Update(s store, model *Model, cols columns.Columns) error {
return genericUpdate(s, model, cols)
return genericUpdate(s, model, cols, p)
}

func (p *postgresql) Destroy(s store, model *Model) error {
stmt := p.TranslateSQL(fmt.Sprintf("DELETE FROM %s WHERE %s", model.TableName(), model.whereID()))
stmt := p.TranslateSQL(fmt.Sprintf("DELETE FROM %s WHERE %s", p.Quote(model.TableName()), model.whereID()))
_, err := genericExec(s, stmt, model.ID())
if err != nil {
return err
Expand All @@ -109,7 +109,7 @@ func (p *postgresql) CreateDB() error {
return errors.Wrapf(err, "error creating PostgreSQL database %s", deets.Database)
}
defer db.Close()
query := fmt.Sprintf("CREATE DATABASE \"%s\"", deets.Database)
query := fmt.Sprintf("CREATE DATABASE %s", p.Quote(deets.Database))
log(logging.SQL, query)

_, err = db.Exec(query)
Expand All @@ -128,7 +128,7 @@ func (p *postgresql) DropDB() error {
return errors.Wrapf(err, "error dropping PostgreSQL database %s", deets.Database)
}
defer db.Close()
query := fmt.Sprintf("DROP DATABASE \"%s\"", deets.Database)
query := fmt.Sprintf("DROP DATABASE %s", p.Quote(deets.Database))
log(logging.SQL, query)

_, err = db.Exec(query)
Expand Down
12 changes: 6 additions & 6 deletions dialect_sqlite.go
Original file line number Diff line number Diff line change
Expand Up @@ -63,9 +63,9 @@ func (m *sqlite) Create(s store, model *Model, cols columns.Columns) error {
w := cols.Writeable()
var query string
if len(w.Cols) > 0 {
query = fmt.Sprintf("INSERT INTO %s (%s) VALUES (%s)", model.TableName(), w.String(), w.SymbolizedString())
query = fmt.Sprintf("INSERT INTO %s (%s) VALUES (%s)", m.Quote(model.TableName()), w.QuotedString(m), w.SymbolizedString())
} else {
query = fmt.Sprintf("INSERT INTO %s DEFAULT VALUES", model.TableName())
query = fmt.Sprintf("INSERT INTO %s DEFAULT VALUES", m.Quote(model.TableName()))
}
log(logging.SQL, query)
res, err := s.NamedExec(query, model.Value)
Expand All @@ -81,19 +81,19 @@ func (m *sqlite) Create(s store, model *Model, cols columns.Columns) error {
}
return nil
}
return errors.Wrap(genericCreate(s, model, cols), "sqlite create")
return errors.Wrap(genericCreate(s, model, cols, m), "sqlite create")
})
}

func (m *sqlite) Update(s store, model *Model, cols columns.Columns) error {
return m.locker(m.smGil, func() error {
return errors.Wrap(genericUpdate(s, model, cols), "sqlite update")
return errors.Wrap(genericUpdate(s, model, cols, m), "sqlite update")
})
}

func (m *sqlite) Destroy(s store, model *Model) error {
return m.locker(m.smGil, func() error {
return errors.Wrap(genericDestroy(s, model), "sqlite destroy")
return errors.Wrap(genericDestroy(s, model, m), "sqlite destroy")
})
}

Expand Down Expand Up @@ -205,7 +205,7 @@ func (m *sqlite) TruncateAll(tx *Connection) error {
}
stmts := []string{}
for _, n := range names {
stmts = append(stmts, fmt.Sprintf("DELETE FROM %s", n.Name))
stmts = append(stmts, fmt.Sprintf("DELETE FROM %s", m.Quote(n.Name)))
}
return tx.RawQuery(strings.Join(stmts, "; ")).Exec()
}
Expand Down

0 comments on commit 79000c1

Please sign in to comment.