Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Quote identifiers #383

Merged
merged 4 commits into from
Aug 15, 2019
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions columns/column.go
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,11 @@ func (c Column) UpdateString() string {
return fmt.Sprintf("%s = :%s", c.Name, c.Name)
}

// QuotedUpdateString returns quoted the SQL statement to UPDATE the column.
func (c Column) QuotedUpdateString(quoter quoter) string {
return fmt.Sprintf("%s = :%s", quoter.Quote(c.Name), c.Name)
}

// SetSelectSQL sets a custom SELECT statement for the column.
func (c *Column) SetSelectSQL(s string) {
c.SelectSQL = s
Expand Down
14 changes: 14 additions & 0 deletions columns/columns.go
Original file line number Diff line number Diff line change
Expand Up @@ -118,6 +118,20 @@ func (c Columns) Readable() *ReadableColumns {
return w
}

type quoter interface {
Quote(key string) string
}

// QuotedString gives the columns list quoted with the given quoter function.
func (c Columns) QuotedString(quoter quoter) string {
var xs []string
for _, t := range c.Cols {
xs = append(xs, quoter.Quote(t.Name))
}
sort.Strings(xs)
return strings.Join(xs, ", ")
}

func (c Columns) String() string {
var xs []string
for _, t := range c.Cols {
Expand Down
10 changes: 10 additions & 0 deletions columns/writeable_columns.go
Original file line number Diff line number Diff line change
Expand Up @@ -19,3 +19,13 @@ func (c WriteableColumns) UpdateString() string {
sort.Strings(xs)
return strings.Join(xs, ", ")
}

// QuotedUpdateString returns the quoted SQL column list part of the UPDATE query.
func (c Columns) QuotedUpdateString(quoter quoter) string {
var xs []string
for _, t := range c.Cols {
xs = append(xs, t.QuotedUpdateString(quoter))
}
sort.Strings(xs)
return strings.Join(xs, ", ")
}
16 changes: 16 additions & 0 deletions columns/writeable_columns_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,22 @@ func Test_Columns_UpdateString(t *testing.T) {
}
}

type testQuoter struct{}

func (testQuoter) Quote(col string) string {
return `"` + col + `"`
}

func Test_Columns_QuotedUpdateString(t *testing.T) {
r := require.New(t)
q := testQuoter{}
for _, f := range []interface{}{foo{}, &foo{}} {
c := columns.ForStruct(f, "foo")
u := c.Writeable().QuotedUpdateString(q)
r.Equal(u, "\"LastName\" = :LastName, \"write\" = :write")
}
}

func Test_Columns_WriteableString(t *testing.T) {
r := require.New(t)
for _, f := range []interface{}{foo{}, &foo{}} {
Expand Down
16 changes: 8 additions & 8 deletions dialect_cockroach.go
Original file line number Diff line number Diff line change
Expand Up @@ -72,9 +72,9 @@ func (p *cockroach) Create(s store, model *Model, cols columns.Columns) error {
w := cols.Writeable()
var query string
if len(w.Cols) > 0 {
query = fmt.Sprintf("INSERT INTO %s (%s) VALUES (%s) returning id", model.TableName(), w.String(), w.SymbolizedString())
query = fmt.Sprintf("INSERT INTO %s (%s) VALUES (%s) returning id", p.Quote(model.TableName()), w.QuotedString(p), w.SymbolizedString())
} else {
query = fmt.Sprintf("INSERT INTO %s DEFAULT VALUES returning id", model.TableName())
query = fmt.Sprintf("INSERT INTO %s DEFAULT VALUES returning id", p.Quote(model.TableName()))
}
log(logging.SQL, query)
stmt, err := s.PrepareNamed(query)
Expand All @@ -91,15 +91,15 @@ func (p *cockroach) Create(s store, model *Model, cols columns.Columns) error {
model.setID(id.ID)
return errors.WithMessage(stmt.Close(), "failed to close statement")
}
return genericCreate(s, model, cols)
return genericCreate(s, model, cols, p)
}

func (p *cockroach) Update(s store, model *Model, cols columns.Columns) error {
return genericUpdate(s, model, cols)
return genericUpdate(s, model, cols, p)
}

func (p *cockroach) Destroy(s store, model *Model) error {
stmt := p.TranslateSQL(fmt.Sprintf("DELETE FROM %s WHERE %s", model.TableName(), model.whereID()))
stmt := p.TranslateSQL(fmt.Sprintf("DELETE FROM %s WHERE %s", p.Quote(model.TableName()), model.whereID()))
_, err := genericExec(s, stmt, model.ID())
return err
}
Expand All @@ -120,7 +120,7 @@ func (p *cockroach) CreateDB() error {
return errors.Wrapf(err, "error creating Cockroach database %s", deets.Database)
}
defer db.Close()
query := fmt.Sprintf("CREATE DATABASE \"%s\"", deets.Database)
query := fmt.Sprintf("CREATE DATABASE %s", p.Quote(deets.Database))
log(logging.SQL, query)

_, err = db.Exec(query)
Expand All @@ -139,7 +139,7 @@ func (p *cockroach) DropDB() error {
return errors.Wrapf(err, "error dropping Cockroach database %s", deets.Database)
}
defer db.Close()
query := fmt.Sprintf("DROP DATABASE \"%s\" CASCADE;", deets.Database)
query := fmt.Sprintf("DROP DATABASE %s CASCADE;", p.Quote(deets.Database))
log(logging.SQL, query)

_, err = db.Exec(query)
Expand Down Expand Up @@ -223,7 +223,7 @@ func (p *cockroach) TruncateAll(tx *Connection) error {
//! work around for current limitation of DDL and DML at the same transaction.
// it should be fixed when cockroach support it or with other approach.
// https://www.cockroachlabs.com/docs/stable/known-limitations.html#schema-changes-within-transactions
if err := tx.RawQuery(fmt.Sprintf("delete from %s", t.TableName)).Exec(); err != nil {
if err := tx.RawQuery(fmt.Sprintf("delete from %s", p.Quote(t.TableName))).Exec(); err != nil {
return err
}
}
Expand Down
18 changes: 11 additions & 7 deletions dialect_common.go
Original file line number Diff line number Diff line change
Expand Up @@ -34,13 +34,17 @@ func (commonDialect) Quote(key string) string {
return fmt.Sprintf(`"%s"`, key)
}

func genericCreate(s store, model *Model, cols columns.Columns) error {
type quoter interface {
Quote(key string) string
}

func genericCreate(s store, model *Model, cols columns.Columns, quoter quoter) error {
keyType := model.PrimaryKeyType()
switch keyType {
case "int", "int64":
var id int64
w := cols.Writeable()
query := fmt.Sprintf("INSERT INTO %s (%s) VALUES (%s)", model.TableName(), w.String(), w.SymbolizedString())
query := fmt.Sprintf("INSERT INTO %s (%s) VALUES (%s)", quoter.Quote(model.TableName()), w.QuotedString(quoter), w.SymbolizedString())
log(logging.SQL, query)
res, err := s.NamedExec(query, model.Value)
if err != nil {
Expand Down Expand Up @@ -68,7 +72,7 @@ func genericCreate(s store, model *Model, cols columns.Columns) error {
}
w := cols.Writeable()
w.Add("id")
query := fmt.Sprintf("INSERT INTO %s (%s) VALUES (%s)", model.TableName(), w.String(), w.SymbolizedString())
query := fmt.Sprintf("INSERT INTO %s (%s) VALUES (%s)", quoter.Quote(model.TableName()), w.QuotedString(quoter), w.SymbolizedString())
log(logging.SQL, query)
stmt, err := s.PrepareNamed(query)
if err != nil {
Expand All @@ -86,8 +90,8 @@ func genericCreate(s store, model *Model, cols columns.Columns) error {
return errors.Errorf("can not use %s as a primary key type!", keyType)
}

func genericUpdate(s store, model *Model, cols columns.Columns) error {
stmt := fmt.Sprintf("UPDATE %s SET %s WHERE %s", model.TableName(), cols.Writeable().UpdateString(), model.whereNamedID())
func genericUpdate(s store, model *Model, cols columns.Columns, quoter quoter) error {
stmt := fmt.Sprintf("UPDATE %s SET %s WHERE %s", quoter.Quote(model.TableName()), cols.Writeable().QuotedUpdateString(quoter), model.whereNamedID())
log(logging.SQL, stmt, model.ID())
_, err := s.NamedExec(stmt, model.Value)
if err != nil {
Expand All @@ -96,8 +100,8 @@ func genericUpdate(s store, model *Model, cols columns.Columns) error {
return nil
}

func genericDestroy(s store, model *Model) error {
stmt := fmt.Sprintf("DELETE FROM %s WHERE %s", model.TableName(), model.whereID())
func genericDestroy(s store, model *Model, quoter quoter) error {
stmt := fmt.Sprintf("DELETE FROM %s WHERE %s", quoter.Quote(model.TableName()), model.whereID())
_, err := genericExec(s, stmt, model.ID())
if err != nil {
return err
Expand Down
6 changes: 3 additions & 3 deletions dialect_mysql.go
Original file line number Diff line number Diff line change
Expand Up @@ -81,15 +81,15 @@ func (m *mysql) MigrationURL() string {
}

func (m *mysql) Create(s store, model *Model, cols columns.Columns) error {
return errors.Wrap(genericCreate(s, model, cols), "mysql create")
return errors.Wrap(genericCreate(s, model, cols, m), "mysql create")
}

func (m *mysql) Update(s store, model *Model, cols columns.Columns) error {
return errors.Wrap(genericUpdate(s, model, cols), "mysql update")
return errors.Wrap(genericUpdate(s, model, cols, m), "mysql update")
}

func (m *mysql) Destroy(s store, model *Model) error {
return errors.Wrap(genericDestroy(s, model), "mysql destroy")
return errors.Wrap(genericDestroy(s, model, m), "mysql destroy")
}

func (m *mysql) SelectOne(s store, model *Model, query Query) error {
Expand Down
14 changes: 7 additions & 7 deletions dialect_postgresql.go
Original file line number Diff line number Diff line change
Expand Up @@ -58,9 +58,9 @@ func (p *postgresql) Create(s store, model *Model, cols columns.Columns) error {
w := cols.Writeable()
var query string
if len(w.Cols) > 0 {
query = fmt.Sprintf("INSERT INTO %s (%s) VALUES (%s) returning id", model.TableName(), w.String(), w.SymbolizedString())
query = fmt.Sprintf("INSERT INTO %s (%s) VALUES (%s) returning id", p.Quote(model.TableName()), w.QuotedString(p), w.SymbolizedString())
} else {
query = fmt.Sprintf("INSERT INTO %s DEFAULT VALUES returning id", model.TableName())
query = fmt.Sprintf("INSERT INTO %s DEFAULT VALUES returning id", p.Quote(model.TableName()))
}
log(logging.SQL, query)
stmt, err := s.PrepareNamed(query)
Expand All @@ -77,15 +77,15 @@ func (p *postgresql) Create(s store, model *Model, cols columns.Columns) error {
model.setID(id.ID)
return errors.WithMessage(stmt.Close(), "failed to close statement")
}
return genericCreate(s, model, cols)
return genericCreate(s, model, cols, p)
}

func (p *postgresql) Update(s store, model *Model, cols columns.Columns) error {
return genericUpdate(s, model, cols)
return genericUpdate(s, model, cols, p)
}

func (p *postgresql) Destroy(s store, model *Model) error {
stmt := p.TranslateSQL(fmt.Sprintf("DELETE FROM %s WHERE %s", model.TableName(), model.whereID()))
stmt := p.TranslateSQL(fmt.Sprintf("DELETE FROM %s WHERE %s", p.Quote(model.TableName()), model.whereID()))
_, err := genericExec(s, stmt, model.ID())
if err != nil {
return err
Expand All @@ -109,7 +109,7 @@ func (p *postgresql) CreateDB() error {
return errors.Wrapf(err, "error creating PostgreSQL database %s", deets.Database)
}
defer db.Close()
query := fmt.Sprintf("CREATE DATABASE \"%s\"", deets.Database)
query := fmt.Sprintf("CREATE DATABASE %s", p.Quote(deets.Database))
log(logging.SQL, query)

_, err = db.Exec(query)
Expand All @@ -128,7 +128,7 @@ func (p *postgresql) DropDB() error {
return errors.Wrapf(err, "error dropping PostgreSQL database %s", deets.Database)
}
defer db.Close()
query := fmt.Sprintf("DROP DATABASE \"%s\"", deets.Database)
query := fmt.Sprintf("DROP DATABASE %s", p.Quote(deets.Database))
log(logging.SQL, query)

_, err = db.Exec(query)
Expand Down
12 changes: 6 additions & 6 deletions dialect_sqlite.go
Original file line number Diff line number Diff line change
Expand Up @@ -63,9 +63,9 @@ func (m *sqlite) Create(s store, model *Model, cols columns.Columns) error {
w := cols.Writeable()
var query string
if len(w.Cols) > 0 {
query = fmt.Sprintf("INSERT INTO %s (%s) VALUES (%s)", model.TableName(), w.String(), w.SymbolizedString())
query = fmt.Sprintf("INSERT INTO %s (%s) VALUES (%s)", m.Quote(model.TableName()), w.QuotedString(m), w.SymbolizedString())
} else {
query = fmt.Sprintf("INSERT INTO %s DEFAULT VALUES", model.TableName())
query = fmt.Sprintf("INSERT INTO %s DEFAULT VALUES", m.Quote(model.TableName()))
}
log(logging.SQL, query)
res, err := s.NamedExec(query, model.Value)
Expand All @@ -81,19 +81,19 @@ func (m *sqlite) Create(s store, model *Model, cols columns.Columns) error {
}
return nil
}
return errors.Wrap(genericCreate(s, model, cols), "sqlite create")
return errors.Wrap(genericCreate(s, model, cols, m), "sqlite create")
})
}

func (m *sqlite) Update(s store, model *Model, cols columns.Columns) error {
return m.locker(m.smGil, func() error {
return errors.Wrap(genericUpdate(s, model, cols), "sqlite update")
return errors.Wrap(genericUpdate(s, model, cols, m), "sqlite update")
})
}

func (m *sqlite) Destroy(s store, model *Model) error {
return m.locker(m.smGil, func() error {
return errors.Wrap(genericDestroy(s, model), "sqlite destroy")
return errors.Wrap(genericDestroy(s, model, m), "sqlite destroy")
})
}

Expand Down Expand Up @@ -205,7 +205,7 @@ func (m *sqlite) TruncateAll(tx *Connection) error {
}
stmts := []string{}
for _, n := range names {
stmts = append(stmts, fmt.Sprintf("DELETE FROM %s", n.Name))
stmts = append(stmts, fmt.Sprintf("DELETE FROM %s", m.Quote(n.Name)))
}
return tx.RawQuery(strings.Join(stmts, "; ")).Exec()
}
Expand Down