From d2559e8e992a810ee85efd479c213a5cb55fe6c9 Mon Sep 17 00:00:00 2001 From: Jille Timmermans Date: Tue, 27 Jun 2023 22:16:56 +0200 Subject: [PATCH] feat: Support "LIMIT ?" in UPDATE and DELETE for MySQL (#2365) PostgreSQL doesn't support UPDATE-LIMIT issue #2131 --- internal/compiler/find_params.go | 8 +++++ .../endtoend/testdata/limit/mysql/go/db.go | 31 +++++++++++++++++++ .../testdata/limit/mysql/go/models.go | 11 +++++++ .../testdata/limit/mysql/go/query.sql.go | 28 +++++++++++++++++ .../endtoend/testdata/limit/mysql/query.sql | 7 +++++ .../endtoend/testdata/limit/mysql/sqlc.json | 12 +++++++ internal/engine/dolphin/convert.go | 12 +++++-- internal/sql/ast/delete_stmt.go | 1 + internal/sql/ast/update_stmt.go | 1 + internal/sql/astutils/walk.go | 6 ++++ 10 files changed, 115 insertions(+), 2 deletions(-) create mode 100644 internal/endtoend/testdata/limit/mysql/go/db.go create mode 100644 internal/endtoend/testdata/limit/mysql/go/models.go create mode 100644 internal/endtoend/testdata/limit/mysql/go/query.sql.go create mode 100644 internal/endtoend/testdata/limit/mysql/query.sql create mode 100644 internal/endtoend/testdata/limit/mysql/sqlc.json diff --git a/internal/compiler/find_params.go b/internal/compiler/find_params.go index 62eb2fb02f..40217d15d0 100644 --- a/internal/compiler/find_params.go +++ b/internal/compiler/find_params.go @@ -69,6 +69,11 @@ func (p paramSearch) Visit(node ast.Node) astutils.Visitor { case *ast.CallStmt: p.parent = n.FuncCall + case *ast.DeleteStmt: + if n.LimitCount != nil { + p.limitCount = n.LimitCount + } + case *ast.FuncCall: p.parent = node @@ -129,6 +134,9 @@ func (p paramSearch) Visit(node ast.Node) astutils.Visitor { } p.seen[ref.Location] = struct{}{} } + if n.LimitCount != nil { + p.limitCount = n.LimitCount + } case *ast.RangeVar: p.rangeVar = n diff --git a/internal/endtoend/testdata/limit/mysql/go/db.go b/internal/endtoend/testdata/limit/mysql/go/db.go new file mode 100644 index 0000000000..8c5b31f933 --- /dev/null +++ b/internal/endtoend/testdata/limit/mysql/go/db.go @@ -0,0 +1,31 @@ +// Code generated by sqlc. DO NOT EDIT. +// versions: +// sqlc v1.18.0 + +package querytest + +import ( + "context" + "database/sql" +) + +type DBTX interface { + ExecContext(context.Context, string, ...interface{}) (sql.Result, error) + PrepareContext(context.Context, string) (*sql.Stmt, error) + QueryContext(context.Context, string, ...interface{}) (*sql.Rows, error) + QueryRowContext(context.Context, string, ...interface{}) *sql.Row +} + +func New(db DBTX) *Queries { + return &Queries{db: db} +} + +type Queries struct { + db DBTX +} + +func (q *Queries) WithTx(tx *sql.Tx) *Queries { + return &Queries{ + db: tx, + } +} diff --git a/internal/endtoend/testdata/limit/mysql/go/models.go b/internal/endtoend/testdata/limit/mysql/go/models.go new file mode 100644 index 0000000000..4c877ffd43 --- /dev/null +++ b/internal/endtoend/testdata/limit/mysql/go/models.go @@ -0,0 +1,11 @@ +// Code generated by sqlc. DO NOT EDIT. +// versions: +// sqlc v1.18.0 + +package querytest + +import () + +type Foo struct { + Bar bool +} diff --git a/internal/endtoend/testdata/limit/mysql/go/query.sql.go b/internal/endtoend/testdata/limit/mysql/go/query.sql.go new file mode 100644 index 0000000000..0c8978fb17 --- /dev/null +++ b/internal/endtoend/testdata/limit/mysql/go/query.sql.go @@ -0,0 +1,28 @@ +// Code generated by sqlc. DO NOT EDIT. +// versions: +// sqlc v1.18.0 +// source: query.sql + +package querytest + +import ( + "context" +) + +const limitMe = `-- name: LimitMe :exec +UPDATE foo SET bar='baz' LIMIT ? +` + +func (q *Queries) LimitMe(ctx context.Context, limit int32) error { + _, err := q.db.ExecContext(ctx, limitMe, limit) + return err +} + +const limitMeToo = `-- name: LimitMeToo :exec +DELETE FROM foo LIMIT ? +` + +func (q *Queries) LimitMeToo(ctx context.Context, limit int32) error { + _, err := q.db.ExecContext(ctx, limitMeToo, limit) + return err +} diff --git a/internal/endtoend/testdata/limit/mysql/query.sql b/internal/endtoend/testdata/limit/mysql/query.sql new file mode 100644 index 0000000000..4723273c5c --- /dev/null +++ b/internal/endtoend/testdata/limit/mysql/query.sql @@ -0,0 +1,7 @@ +CREATE TABLE foo (bar bool not null); + +-- name: LimitMe :exec +UPDATE foo SET bar='baz' LIMIT ?; + +-- name: LimitMeToo :exec +DELETE FROM foo LIMIT ?; diff --git a/internal/endtoend/testdata/limit/mysql/sqlc.json b/internal/endtoend/testdata/limit/mysql/sqlc.json new file mode 100644 index 0000000000..7676c3bc51 --- /dev/null +++ b/internal/endtoend/testdata/limit/mysql/sqlc.json @@ -0,0 +1,12 @@ +{ + "version": "1", + "packages": [ + { + "path": "go", + "engine": "mysql", + "name": "querytest", + "schema": "query.sql", + "queries": "query.sql" + } + ] +} diff --git a/internal/engine/dolphin/convert.go b/internal/engine/dolphin/convert.go index fc5c4211c3..bd642c55ed 100644 --- a/internal/engine/dolphin/convert.go +++ b/internal/engine/dolphin/convert.go @@ -327,12 +327,16 @@ func (c *cc) convertDeleteStmt(n *pcast.DeleteStmt) *ast.DeleteStmt { relations := &ast.List{} convertToRangeVarList(rels, relations) - return &ast.DeleteStmt{ + stmt := &ast.DeleteStmt{ Relations: relations, WhereClause: c.convert(n.Where), ReturningList: &ast.List{}, WithClause: c.convertWithClause(n.With), } + if n.Limit != nil { + stmt.LimitCount = c.convert(n.Limit.Count) + } + return stmt } func (c *cc) convertDropTableStmt(n *pcast.DropTableStmt) ast.Node { @@ -574,7 +578,7 @@ func (c *cc) convertUpdateStmt(n *pcast.UpdateStmt) *ast.UpdateStmt { for _, a := range n.List { list.Items = append(list.Items, c.convertAssignment(a)) } - return &ast.UpdateStmt{ + stmt := &ast.UpdateStmt{ Relations: relations, TargetList: list, WhereClause: c.convert(n.Where), @@ -582,6 +586,10 @@ func (c *cc) convertUpdateStmt(n *pcast.UpdateStmt) *ast.UpdateStmt { ReturningList: &ast.List{}, WithClause: c.convertWithClause(n.With), } + if n.Limit != nil { + stmt.LimitCount = c.convert(n.Limit.Count) + } + return stmt } func (c *cc) convertValueExpr(n *driver.ValueExpr) *ast.A_Const { diff --git a/internal/sql/ast/delete_stmt.go b/internal/sql/ast/delete_stmt.go index 36403f134b..45b6a35869 100644 --- a/internal/sql/ast/delete_stmt.go +++ b/internal/sql/ast/delete_stmt.go @@ -4,6 +4,7 @@ type DeleteStmt struct { Relations *List UsingClause *List WhereClause Node + LimitCount Node ReturningList *List WithClause *WithClause } diff --git a/internal/sql/ast/update_stmt.go b/internal/sql/ast/update_stmt.go index 517d0b420b..745b91b617 100644 --- a/internal/sql/ast/update_stmt.go +++ b/internal/sql/ast/update_stmt.go @@ -5,6 +5,7 @@ type UpdateStmt struct { TargetList *List WhereClause Node FromClause *List + LimitCount Node ReturningList *List WithClause *WithClause } diff --git a/internal/sql/astutils/walk.go b/internal/sql/astutils/walk.go index 60110c46c3..4bae5629ed 100644 --- a/internal/sql/astutils/walk.go +++ b/internal/sql/astutils/walk.go @@ -1068,6 +1068,9 @@ func Walk(f Visitor, node ast.Node) { if n.WhereClause != nil { Walk(f, n.WhereClause) } + if n.LimitCount != nil { + Walk(f, n.LimitCount) + } if n.ReturningList != nil { Walk(f, n.ReturningList) } @@ -2038,6 +2041,9 @@ func Walk(f Visitor, node ast.Node) { if n.FromClause != nil { Walk(f, n.FromClause) } + if n.LimitCount != nil { + Walk(f, n.LimitCount) + } if n.ReturningList != nil { Walk(f, n.ReturningList) }