diff --git a/internal/endtoend/case_test.go b/internal/endtoend/case_test.go new file mode 100644 index 0000000000..367b9dd158 --- /dev/null +++ b/internal/endtoend/case_test.go @@ -0,0 +1,88 @@ +package main + +import ( + "encoding/json" + "fmt" + "os" + "path/filepath" + "strings" + "testing" +) + +type Testcase struct { + Name string + Path string + ConfigName string + Stderr []byte + Exec *Exec +} + +type Exec struct { + Command string `json:"command"` + Contexts []string `json:"contexts"` + Process string `json:"process"` + Env map[string]string `json:"env"` +} + +func parseStderr(t *testing.T, dir, testctx string) []byte { + t.Helper() + paths := []string{ + filepath.Join(dir, "stderr", fmt.Sprintf("%s.txt", testctx)), + filepath.Join(dir, "stderr.txt"), + } + for _, path := range paths { + if _, err := os.Stat(path); !os.IsNotExist(err) { + blob, err := os.ReadFile(path) + if err != nil { + t.Fatal(err) + } + return blob + } + } + return nil +} + +func parseExec(t *testing.T, dir string) *Exec { + t.Helper() + path := filepath.Join(dir, "exec.json") + if _, err := os.Stat(path); os.IsNotExist(err) { + return nil + } + var e Exec + blob, err := os.ReadFile(path) + if err != nil { + t.Fatal(err) + } + if err := json.Unmarshal(blob, &e); err != nil { + t.Fatal(err) + } + if e.Command == "" { + e.Command = "generate" + } + return &e +} + +func FindTests(t *testing.T, root, testctx string) []*Testcase { + var tcs []*Testcase + err := filepath.Walk(root, func(path string, info os.FileInfo, err error) error { + if err != nil { + return err + } + if info.Name() == "sqlc.json" || info.Name() == "sqlc.yaml" || info.Name() == "sqlc.yml" { + dir := filepath.Dir(path) + tcs = append(tcs, &Testcase{ + Path: dir, + Name: strings.TrimPrefix(dir, root+string(filepath.Separator)), + ConfigName: info.Name(), + Stderr: parseStderr(t, dir, testctx), + Exec: parseExec(t, dir), + }) + return filepath.SkipDir + } + return nil + }) + if err != nil { + t.Fatal(err) + } + return tcs +} diff --git a/internal/endtoend/endtoend_test.go b/internal/endtoend/endtoend_test.go index e2ba13f277..39694e9437 100644 --- a/internal/endtoend/endtoend_test.go +++ b/internal/endtoend/endtoend_test.go @@ -3,8 +3,6 @@ package main import ( "bytes" "context" - "encoding/json" - "fmt" "os" osexec "os/exec" "path/filepath" @@ -97,20 +95,6 @@ func TestReplay(t *testing.T) { // t.Parallel() ctx := context.Background() - var dirs []string - err := filepath.Walk("testdata", func(path string, info os.FileInfo, err error) error { - if err != nil { - return err - } - if info.Name() == "sqlc.json" || info.Name() == "sqlc.yaml" || info.Name() == "sqlc.yml" { - dirs = append(dirs, filepath.Dir(path)) - return filepath.SkipDir - } - return nil - }) - if err != nil { - t.Fatal(err) - } contexts := map[string]textContext{ "base": { @@ -135,24 +119,29 @@ func TestReplay(t *testing.T) { }, } - for _, replay := range dirs { - tc := replay - for name, testctx := range contexts { - name := name - testctx := testctx + for name, testctx := range contexts { + name := name + testctx := testctx - if !testctx.Enabled() { - continue - } + if !testctx.Enabled() { + continue + } - t.Run(filepath.Join(name, tc), func(t *testing.T) { + for _, replay := range FindTests(t, "testdata", name) { + tc := replay + t.Run(filepath.Join(name, tc.Name), func(t *testing.T) { t.Parallel() + var stderr bytes.Buffer var output map[string]string var err error - path, _ := filepath.Abs(tc) - args := parseExec(t, path) + path, _ := filepath.Abs(tc.Path) + args := tc.Exec + if args == nil { + args = &Exec{Command: "generate"} + } + expected := string(tc.Stderr) if args.Process != "" { _, err := osexec.LookPath(args.Process) @@ -167,7 +156,6 @@ func TestReplay(t *testing.T) { } } - expected := expectedStderr(t, path, name) opts := cmd.Options{ Env: cmd.Env{ Debug: opts.DebugFromString(args.Env["SQLCDEBUG"]), @@ -263,50 +251,6 @@ func cmpDirectory(t *testing.T, dir string, actual map[string]string) { } } -func expectedStderr(t *testing.T, dir, testctx string) string { - t.Helper() - paths := []string{ - filepath.Join(dir, "stderr", fmt.Sprintf("%s.txt", testctx)), - filepath.Join(dir, "stderr.txt"), - } - for _, path := range paths { - if _, err := os.Stat(path); !os.IsNotExist(err) { - blob, err := os.ReadFile(path) - if err != nil { - t.Fatal(err) - } - return string(blob) - } - } - return "" -} - -type exec struct { - Command string `json:"command"` - Process string `json:"process"` - Contexts []string `json:"contexts"` - Env map[string]string `json:"env"` -} - -func parseExec(t *testing.T, dir string) exec { - t.Helper() - var e exec - path := filepath.Join(dir, "exec.json") - if _, err := os.Stat(path); !os.IsNotExist(err) { - blob, err := os.ReadFile(path) - if err != nil { - t.Fatal(err) - } - if err := json.Unmarshal(blob, &e); err != nil { - t.Fatal(err) - } - } - if e.Command == "" { - e.Command = "generate" - } - return e -} - func BenchmarkReplay(b *testing.B) { ctx := context.Background() var dirs []string diff --git a/internal/endtoend/fmt_test.go b/internal/endtoend/fmt_test.go new file mode 100644 index 0000000000..22b5f1392d --- /dev/null +++ b/internal/endtoend/fmt_test.go @@ -0,0 +1,75 @@ +package main + +import ( + "bytes" + "fmt" + "os" + "path/filepath" + "strings" + "testing" + + pg_query "github.com/pganalyze/pg_query_go/v4" + "github.com/sqlc-dev/sqlc/internal/debug" + "github.com/sqlc-dev/sqlc/internal/engine/postgresql" + "github.com/sqlc-dev/sqlc/internal/sql/ast" +) + +func TestFormat(t *testing.T) { + t.Parallel() + parse := postgresql.NewParser() + for _, tc := range FindTests(t, "testdata", "base") { + tc := tc + + if !strings.Contains(tc.Path, filepath.Join("pgx/v5")) { + continue + } + + q := filepath.Join(tc.Path, "query.sql") + if _, err := os.Stat(q); os.IsNotExist(err) { + continue + } + + t.Run(tc.Name, func(t *testing.T) { + contents, err := os.ReadFile(q) + if err != nil { + t.Fatal(err) + } + for i, query := range bytes.Split(bytes.TrimSpace(contents), []byte(";")) { + if len(query) <= 1 { + continue + } + query := query + t.Run(fmt.Sprintf("%d", i), func(t *testing.T) { + expected, err := pg_query.Fingerprint(string(query)) + if err != nil { + t.Fatal(err) + } + stmts, err := parse.Parse(bytes.NewReader(query)) + if err != nil { + t.Fatal(err) + } + if len(stmts) != 1 { + t.Fatal("expected one statement") + } + if false { + r, err := pg_query.Parse(string(query)) + debug.Dump(r, err) + } + + out := ast.Format(stmts[0].Raw) + actual, err := pg_query.Fingerprint(out) + if err != nil { + t.Error(err) + } + if expected != actual { + debug.Dump(stmts[0].Raw) + t.Errorf("- %s", expected) + t.Errorf("- %s", string(query)) + t.Errorf("+ %s", actual) + t.Errorf("+ %s", out) + } + }) + } + }) + } +} diff --git a/internal/endtoend/testdata/join_where_clause/postgresql/pgx/v5/query.sql b/internal/endtoend/testdata/join_where_clause/postgresql/pgx/v5/query.sql index 6887585544..776cd41ced 100644 --- a/internal/endtoend/testdata/join_where_clause/postgresql/pgx/v5/query.sql +++ b/internal/endtoend/testdata/join_where_clause/postgresql/pgx/v5/query.sql @@ -14,4 +14,4 @@ WHERE owner = $1; SELECT foo.* FROM foo CROSS JOIN bar -WHERE bar.id = $2 AND owner = $1; +WHERE bar.id = $2 AND owner = $1; \ No newline at end of file diff --git a/internal/endtoend/testdata/materialized_views/postgresql/pgx/v4/go/query.sql.go b/internal/endtoend/testdata/materialized_views/postgresql/pgx/v4/go/query.sql.go new file mode 100644 index 0000000000..a1c398553e --- /dev/null +++ b/internal/endtoend/testdata/materialized_views/postgresql/pgx/v4/go/query.sql.go @@ -0,0 +1,39 @@ +// Code generated by sqlc. DO NOT EDIT. +// versions: +// sqlc v1.22.0 +// source: query.sql + +package querytest + +import ( + "context" +) + +const listAuthors = `-- name: ListAuthors :many +SELECT id, name, bio, gender FROM authors +` + +func (q *Queries) ListAuthors(ctx context.Context) ([]Author, error) { + rows, err := q.db.Query(ctx, listAuthors) + if err != nil { + return nil, err + } + defer rows.Close() + var items []Author + for rows.Next() { + var i Author + if err := rows.Scan( + &i.ID, + &i.Name, + &i.Bio, + &i.Gender, + ); err != nil { + return nil, err + } + items = append(items, i) + } + if err := rows.Err(); err != nil { + return nil, err + } + return items, nil +} diff --git a/internal/endtoend/testdata/materialized_views/postgresql/pgx/v4/query.sql b/internal/endtoend/testdata/materialized_views/postgresql/pgx/v4/query.sql index f14191f894..86636e2655 100644 --- a/internal/endtoend/testdata/materialized_views/postgresql/pgx/v4/query.sql +++ b/internal/endtoend/testdata/materialized_views/postgresql/pgx/v4/query.sql @@ -1,11 +1,2 @@ -CREATE TABLE authors ( - id BIGSERIAL PRIMARY KEY, - name TEXT NOT NULL, - bio TEXT -); - -ALTER TABLE authors ADD COLUMN gender INTEGER NULL; - -CREATE MATERIALIZED VIEW authors_names as SELECT name from authors; - +-- name: ListAuthors :many SELECT * FROM authors; diff --git a/internal/endtoend/testdata/materialized_views/postgresql/pgx/v5/go/query.sql.go b/internal/endtoend/testdata/materialized_views/postgresql/pgx/v5/go/query.sql.go new file mode 100644 index 0000000000..a1c398553e --- /dev/null +++ b/internal/endtoend/testdata/materialized_views/postgresql/pgx/v5/go/query.sql.go @@ -0,0 +1,39 @@ +// Code generated by sqlc. DO NOT EDIT. +// versions: +// sqlc v1.22.0 +// source: query.sql + +package querytest + +import ( + "context" +) + +const listAuthors = `-- name: ListAuthors :many +SELECT id, name, bio, gender FROM authors +` + +func (q *Queries) ListAuthors(ctx context.Context) ([]Author, error) { + rows, err := q.db.Query(ctx, listAuthors) + if err != nil { + return nil, err + } + defer rows.Close() + var items []Author + for rows.Next() { + var i Author + if err := rows.Scan( + &i.ID, + &i.Name, + &i.Bio, + &i.Gender, + ); err != nil { + return nil, err + } + items = append(items, i) + } + if err := rows.Err(); err != nil { + return nil, err + } + return items, nil +} diff --git a/internal/endtoend/testdata/materialized_views/postgresql/pgx/v5/query.sql b/internal/endtoend/testdata/materialized_views/postgresql/pgx/v5/query.sql index f14191f894..86636e2655 100644 --- a/internal/endtoend/testdata/materialized_views/postgresql/pgx/v5/query.sql +++ b/internal/endtoend/testdata/materialized_views/postgresql/pgx/v5/query.sql @@ -1,11 +1,2 @@ -CREATE TABLE authors ( - id BIGSERIAL PRIMARY KEY, - name TEXT NOT NULL, - bio TEXT -); - -ALTER TABLE authors ADD COLUMN gender INTEGER NULL; - -CREATE MATERIALIZED VIEW authors_names as SELECT name from authors; - +-- name: ListAuthors :many SELECT * FROM authors; diff --git a/internal/endtoend/testdata/materialized_views/postgresql/stdlib/go/query.sql.go b/internal/endtoend/testdata/materialized_views/postgresql/stdlib/go/query.sql.go new file mode 100644 index 0000000000..bef31c8b9c --- /dev/null +++ b/internal/endtoend/testdata/materialized_views/postgresql/stdlib/go/query.sql.go @@ -0,0 +1,42 @@ +// Code generated by sqlc. DO NOT EDIT. +// versions: +// sqlc v1.22.0 +// source: query.sql + +package querytest + +import ( + "context" +) + +const listAuthors = `-- name: ListAuthors :many +SELECT id, name, bio, gender FROM authors +` + +func (q *Queries) ListAuthors(ctx context.Context) ([]Author, error) { + rows, err := q.db.QueryContext(ctx, listAuthors) + if err != nil { + return nil, err + } + defer rows.Close() + var items []Author + for rows.Next() { + var i Author + if err := rows.Scan( + &i.ID, + &i.Name, + &i.Bio, + &i.Gender, + ); err != nil { + return nil, err + } + items = append(items, i) + } + if err := rows.Close(); err != nil { + return nil, err + } + if err := rows.Err(); err != nil { + return nil, err + } + return items, nil +} diff --git a/internal/endtoend/testdata/materialized_views/postgresql/stdlib/query.sql b/internal/endtoend/testdata/materialized_views/postgresql/stdlib/query.sql index f14191f894..86636e2655 100644 --- a/internal/endtoend/testdata/materialized_views/postgresql/stdlib/query.sql +++ b/internal/endtoend/testdata/materialized_views/postgresql/stdlib/query.sql @@ -1,11 +1,2 @@ -CREATE TABLE authors ( - id BIGSERIAL PRIMARY KEY, - name TEXT NOT NULL, - bio TEXT -); - -ALTER TABLE authors ADD COLUMN gender INTEGER NULL; - -CREATE MATERIALIZED VIEW authors_names as SELECT name from authors; - +-- name: ListAuthors :many SELECT * FROM authors; diff --git a/internal/engine/postgresql/parse.go b/internal/engine/postgresql/parse.go index 75ae3ca344..c1ac83381c 100644 --- a/internal/engine/postgresql/parse.go +++ b/internal/engine/postgresql/parse.go @@ -438,12 +438,21 @@ func translate(node *nodes.Node) (ast.Node, error) { if err != nil { return nil, err } + + primary := false + for _, con := range item.ColumnDef.Constraints { + if constraint, ok := con.Node.(*nodes.Node_Constraint); ok { + primary = constraint.Constraint.Contype == nodes.ConstrType_CONSTR_PRIMARY + } + } + create.Cols = append(create.Cols, &ast.ColumnDef{ - Colname: item.ColumnDef.Colname, - TypeName: rel.TypeName(), - IsNotNull: isNotNull(item.ColumnDef) || primaryKey[item.ColumnDef.Colname], - IsArray: isArray(item.ColumnDef.TypeName), - ArrayDims: len(item.ColumnDef.TypeName.ArrayBounds), + Colname: item.ColumnDef.Colname, + TypeName: rel.TypeName(), + IsNotNull: isNotNull(item.ColumnDef) || primaryKey[item.ColumnDef.Colname], + IsArray: isArray(item.ColumnDef.TypeName), + ArrayDims: len(item.ColumnDef.TypeName.ArrayBounds), + PrimaryKey: primary, }) } } diff --git a/internal/sql/ast/a_const.go b/internal/sql/ast/a_const.go index 720dca4a11..ec1d780945 100644 --- a/internal/sql/ast/a_const.go +++ b/internal/sql/ast/a_const.go @@ -8,3 +8,16 @@ type A_Const struct { func (n *A_Const) Pos() int { return n.Location } + +func (n *A_Const) Format(buf *TrackedBuffer) { + if n == nil { + return + } + if _, ok := n.Val.(*String); ok { + buf.WriteString("'") + buf.astFormat(n.Val) + buf.WriteString("'") + } else { + buf.astFormat(n.Val) + } +} diff --git a/internal/sql/ast/a_expr.go b/internal/sql/ast/a_expr.go index 415dd1e23f..b0b7f75367 100644 --- a/internal/sql/ast/a_expr.go +++ b/internal/sql/ast/a_expr.go @@ -11,3 +11,24 @@ type A_Expr struct { func (n *A_Expr) Pos() int { return n.Location } + +func (n *A_Expr) Format(buf *TrackedBuffer) { + if n == nil { + return + } + buf.astFormat(n.Lexpr) + buf.WriteString(" ") + switch n.Kind { + case A_Expr_Kind_IN: + buf.WriteString(" IN (") + buf.astFormat(n.Rexpr) + buf.WriteString(")") + case A_Expr_Kind_LIKE: + buf.WriteString(" LIKE ") + buf.astFormat(n.Rexpr) + default: + buf.astFormat(n.Name) + buf.WriteString(" ") + buf.astFormat(n.Rexpr) + } +} diff --git a/internal/sql/ast/a_expr_kind.go b/internal/sql/ast/a_expr_kind.go index 50fc6bc6bb..53a237896b 100644 --- a/internal/sql/ast/a_expr_kind.go +++ b/internal/sql/ast/a_expr_kind.go @@ -2,6 +2,11 @@ package ast type A_Expr_Kind uint +const ( + A_Expr_Kind_IN A_Expr_Kind = 7 + A_Expr_Kind_LIKE A_Expr_Kind = 8 +) + func (n *A_Expr_Kind) Pos() int { return 0 } diff --git a/internal/sql/ast/a_star.go b/internal/sql/ast/a_star.go index accd0f7dd8..a43b2ab5b7 100644 --- a/internal/sql/ast/a_star.go +++ b/internal/sql/ast/a_star.go @@ -6,3 +6,10 @@ type A_Star struct { func (n *A_Star) Pos() int { return 0 } + +func (n *A_Star) Format(buf *TrackedBuffer) { + if n == nil { + return + } + buf.WriteRune('*') +} diff --git a/internal/sql/ast/alias.go b/internal/sql/ast/alias.go index 7c6302a5ef..55965b55c9 100644 --- a/internal/sql/ast/alias.go +++ b/internal/sql/ast/alias.go @@ -8,3 +8,17 @@ type Alias struct { func (n *Alias) Pos() int { return 0 } + +func (n *Alias) Format(buf *TrackedBuffer) { + if n == nil { + return + } + if n.Aliasname != nil { + buf.WriteString(*n.Aliasname) + } + if items(n.Colnames) { + buf.WriteString("(") + buf.astFormat((n.Colnames)) + buf.WriteString(")") + } +} diff --git a/internal/sql/ast/alter_table_cmd.go b/internal/sql/ast/alter_table_cmd.go index 3c6be340cd..80fad95eaf 100644 --- a/internal/sql/ast/alter_table_cmd.go +++ b/internal/sql/ast/alter_table_cmd.go @@ -39,3 +39,17 @@ type AlterTableCmd struct { func (n *AlterTableCmd) Pos() int { return 0 } + +func (n *AlterTableCmd) Format(buf *TrackedBuffer) { + if n == nil { + return + } + switch n.Subtype { + case AT_AddColumn: + buf.WriteString(" ADD COLUMN ") + case AT_DropColumn: + buf.WriteString(" DROP COLUMN ") + } + + buf.astFormat(n.Def) +} diff --git a/internal/sql/ast/alter_table_stmt.go b/internal/sql/ast/alter_table_stmt.go index 245d7c6821..5d4a22f50e 100644 --- a/internal/sql/ast/alter_table_stmt.go +++ b/internal/sql/ast/alter_table_stmt.go @@ -12,3 +12,13 @@ type AlterTableStmt struct { func (n *AlterTableStmt) Pos() int { return 0 } + +func (n *AlterTableStmt) Format(buf *TrackedBuffer) { + if n == nil { + return + } + buf.WriteString("ALTER TABLE ") + buf.astFormat(n.Relation) + buf.astFormat(n.Table) + buf.astFormat(n.Cmds) +} diff --git a/internal/sql/ast/bool_expr.go b/internal/sql/ast/bool_expr.go index 41ddba949b..6d15276a05 100644 --- a/internal/sql/ast/bool_expr.go +++ b/internal/sql/ast/bool_expr.go @@ -10,3 +10,22 @@ type BoolExpr struct { func (n *BoolExpr) Pos() int { return n.Location } + +func (n *BoolExpr) Format(buf *TrackedBuffer) { + if n == nil { + return + } + buf.WriteString("(") + if items(n.Args) { + switch n.Boolop { + case BoolExprTypeAnd: + buf.join(n.Args, " AND ") + case BoolExprTypeOr: + buf.join(n.Args, " OR ") + case BoolExprTypeNot: + buf.WriteString(" NOT ") + buf.astFormat(n.Args) + } + } + buf.WriteString(")") +} diff --git a/internal/sql/ast/boolean.go b/internal/sql/ast/boolean.go index cf193f2c12..522af84868 100644 --- a/internal/sql/ast/boolean.go +++ b/internal/sql/ast/boolean.go @@ -1,5 +1,7 @@ package ast +import "fmt" + type Boolean struct { Boolval bool } @@ -7,3 +9,14 @@ type Boolean struct { func (n *Boolean) Pos() int { return 0 } + +func (n *Boolean) Format(buf *TrackedBuffer) { + if n == nil { + return + } + if n.Boolval { + fmt.Fprintf(buf, "true") + } else { + fmt.Fprintf(buf, "false") + } +} diff --git a/internal/sql/ast/call_stmt.go b/internal/sql/ast/call_stmt.go index 252bfb3169..5267a1ff3f 100644 --- a/internal/sql/ast/call_stmt.go +++ b/internal/sql/ast/call_stmt.go @@ -10,3 +10,8 @@ func (n *CallStmt) Pos() int { } return n.FuncCall.Pos() } + +func (n *CallStmt) Format(buf *TrackedBuffer) { + buf.WriteString("CALL ") + buf.astFormat(n.FuncCall) +} diff --git a/internal/sql/ast/case_expr.go b/internal/sql/ast/case_expr.go index c23ffae2a4..1da54f0d78 100644 --- a/internal/sql/ast/case_expr.go +++ b/internal/sql/ast/case_expr.go @@ -13,3 +13,14 @@ type CaseExpr struct { func (n *CaseExpr) Pos() int { return n.Location } + +func (n *CaseExpr) Format(buf *TrackedBuffer) { + if n == nil { + return + } + buf.WriteString("CASE ") + buf.astFormat(n.Args) + buf.WriteString(" ELSE ") + buf.astFormat(n.Defresult) + buf.WriteString(" END ") +} diff --git a/internal/sql/ast/case_when.go b/internal/sql/ast/case_when.go index 9b8a488955..b036411d54 100644 --- a/internal/sql/ast/case_when.go +++ b/internal/sql/ast/case_when.go @@ -10,3 +10,13 @@ type CaseWhen struct { func (n *CaseWhen) Pos() int { return n.Location } + +func (n *CaseWhen) Format(buf *TrackedBuffer) { + if n == nil { + return + } + buf.WriteString("WHEN ") + buf.astFormat(n.Expr) + buf.WriteString(" THEN ") + buf.astFormat(n.Result) +} diff --git a/internal/sql/ast/coalesce_expr.go b/internal/sql/ast/coalesce_expr.go index 513b495445..cbf7025748 100644 --- a/internal/sql/ast/coalesce_expr.go +++ b/internal/sql/ast/coalesce_expr.go @@ -11,3 +11,12 @@ type CoalesceExpr struct { func (n *CoalesceExpr) Pos() int { return n.Location } + +func (n *CoalesceExpr) Format(buf *TrackedBuffer) { + if n == nil { + return + } + buf.WriteString("COALESCE(") + buf.astFormat(n.Args) + buf.WriteString(")") +} diff --git a/internal/sql/ast/column_def.go b/internal/sql/ast/column_def.go index c4cd372437..f9504eefc7 100644 --- a/internal/sql/ast/column_def.go +++ b/internal/sql/ast/column_def.go @@ -9,6 +9,7 @@ type ColumnDef struct { ArrayDims int Vals *List Length *int + PrimaryKey bool // From pg.ColumnDef Inhcount int @@ -30,3 +31,18 @@ type ColumnDef struct { func (n *ColumnDef) Pos() int { return n.Location } + +func (n *ColumnDef) Format(buf *TrackedBuffer) { + if n == nil { + return + } + buf.WriteString(n.Colname) + buf.WriteString(" ") + buf.astFormat(n.TypeName) + if n.PrimaryKey { + buf.WriteString(" PRIMARY KEY") + } else if n.IsNotNull { + buf.WriteString(" NOT NULL") + } + buf.astFormat(n.Constraints) +} diff --git a/internal/sql/ast/column_ref.go b/internal/sql/ast/column_ref.go index 891fa163f7..e95b844896 100644 --- a/internal/sql/ast/column_ref.go +++ b/internal/sql/ast/column_ref.go @@ -1,5 +1,7 @@ package ast +import "strings" + type ColumnRef struct { Name string @@ -11,3 +13,26 @@ type ColumnRef struct { func (n *ColumnRef) Pos() int { return n.Location } + +func (n *ColumnRef) Format(buf *TrackedBuffer) { + if n == nil { + return + } + + if n.Fields != nil { + var items []string + for _, item := range n.Fields.Items { + switch nn := item.(type) { + case *String: + if nn.Str == "user" { + items = append(items, `"user"`) + } else { + items = append(items, nn.Str) + } + case *A_Star: + items = append(items, "*") + } + } + buf.WriteString(strings.Join(items, ".")) + } +} diff --git a/internal/sql/ast/common_table_expr.go b/internal/sql/ast/common_table_expr.go index d5ae01f040..f2edddff79 100644 --- a/internal/sql/ast/common_table_expr.go +++ b/internal/sql/ast/common_table_expr.go @@ -1,5 +1,9 @@ package ast +import ( + "fmt" +) + type CommonTableExpr struct { Ctename *string Aliascolnames *List @@ -16,3 +20,14 @@ type CommonTableExpr struct { func (n *CommonTableExpr) Pos() int { return n.Location } + +func (n *CommonTableExpr) Format(buf *TrackedBuffer) { + if n == nil { + return + } + if n.Ctename != nil { + fmt.Fprintf(buf, " %s AS (", *n.Ctename) + } + buf.astFormat(n.Ctequery) + buf.WriteString(")") +} diff --git a/internal/sql/ast/create_table_stmt.go b/internal/sql/ast/create_table_stmt.go index 7273ffa852..ce88a1b244 100644 --- a/internal/sql/ast/create_table_stmt.go +++ b/internal/sql/ast/create_table_stmt.go @@ -12,3 +12,20 @@ type CreateTableStmt struct { func (n *CreateTableStmt) Pos() int { return 0 } + +func (n *CreateTableStmt) Format(buf *TrackedBuffer) { + if n == nil { + return + } + buf.WriteString("CREATE TABLE ") + buf.astFormat(n.Name) + + buf.WriteString("(") + for i, col := range n.Cols { + if i > 0 { + buf.WriteString(", ") + } + buf.astFormat(col) + } + buf.WriteString(")") +} diff --git a/internal/sql/ast/delete_stmt.go b/internal/sql/ast/delete_stmt.go index 45b6a35869..d77f043a12 100644 --- a/internal/sql/ast/delete_stmt.go +++ b/internal/sql/ast/delete_stmt.go @@ -12,3 +12,34 @@ type DeleteStmt struct { func (n *DeleteStmt) Pos() int { return 0 } + +func (n *DeleteStmt) Format(buf *TrackedBuffer) { + if n == nil { + return + } + + if n.WithClause != nil { + buf.astFormat(n.WithClause) + buf.WriteString(" ") + } + + buf.WriteString("DELETE FROM ") + if items(n.Relations) { + buf.astFormat(n.Relations) + } + + if set(n.WhereClause) { + buf.WriteString(" WHERE ") + buf.astFormat(n.WhereClause) + } + + if set(n.LimitCount) { + buf.WriteString(" LIMIT ") + buf.astFormat(n.LimitCount) + } + + if items(n.ReturningList) { + buf.WriteString(" RETURNING ") + buf.astFormat(n.ReturningList) + } +} diff --git a/internal/sql/ast/float.go b/internal/sql/ast/float.go index 8e5ef10f97..fee8655bbe 100644 --- a/internal/sql/ast/float.go +++ b/internal/sql/ast/float.go @@ -7,3 +7,10 @@ type Float struct { func (n *Float) Pos() int { return 0 } + +func (n *Float) Format(buf *TrackedBuffer) { + if n == nil { + return + } + buf.WriteString(n.Str) +} diff --git a/internal/sql/ast/func_call.go b/internal/sql/ast/func_call.go index f3feb82225..2bfe961b50 100644 --- a/internal/sql/ast/func_call.go +++ b/internal/sql/ast/func_call.go @@ -17,3 +17,17 @@ type FuncCall struct { func (n *FuncCall) Pos() int { return n.Location } + +func (n *FuncCall) Format(buf *TrackedBuffer) { + if n == nil { + return + } + buf.astFormat(n.Func) + buf.WriteString("(") + if n.AggStar { + buf.WriteString("*") + } else { + buf.astFormat(n.Args) + } + buf.WriteString(")") +} diff --git a/internal/sql/ast/func_name.go b/internal/sql/ast/func_name.go index e8b93a752c..29b8e0fa61 100644 --- a/internal/sql/ast/func_name.go +++ b/internal/sql/ast/func_name.go @@ -9,3 +9,16 @@ type FuncName struct { func (n *FuncName) Pos() int { return 0 } + +func (n *FuncName) Format(buf *TrackedBuffer) { + if n == nil { + return + } + if n.Schema != "" { + buf.WriteString(n.Schema) + buf.WriteString(".") + } + if n.Name != "" { + buf.WriteString(n.Name) + } +} diff --git a/internal/sql/ast/insert_stmt.go b/internal/sql/ast/insert_stmt.go index 12ee24846c..3cdf854091 100644 --- a/internal/sql/ast/insert_stmt.go +++ b/internal/sql/ast/insert_stmt.go @@ -13,3 +13,37 @@ type InsertStmt struct { func (n *InsertStmt) Pos() int { return 0 } + +func (n *InsertStmt) Format(buf *TrackedBuffer) { + if n == nil { + return + } + + if n.WithClause != nil { + buf.astFormat(n.WithClause) + buf.WriteString(" ") + } + + buf.WriteString("INSERT INTO ") + if n.Relation != nil { + buf.astFormat(n.Relation) + } + if items(n.Cols) { + buf.WriteString(" (") + buf.astFormat(n.Cols) + buf.WriteString(") ") + } + + if set(n.SelectStmt) { + buf.astFormat(n.SelectStmt) + } + + if n.OnConflictClause != nil { + buf.WriteString(" ON CONFLICT DO NOTHING ") + } + + if items(n.ReturningList) { + buf.WriteString(" RETURNING ") + buf.astFormat(n.ReturningList) + } +} diff --git a/internal/sql/ast/integer.go b/internal/sql/ast/integer.go index a00e906b22..e9f911add2 100644 --- a/internal/sql/ast/integer.go +++ b/internal/sql/ast/integer.go @@ -1,5 +1,7 @@ package ast +import "strconv" + type Integer struct { Ival int64 } @@ -7,3 +9,10 @@ type Integer struct { func (n *Integer) Pos() int { return 0 } + +func (n *Integer) Format(buf *TrackedBuffer) { + if n == nil { + return + } + buf.WriteString(strconv.FormatInt(n.Ival, 10)) +} diff --git a/internal/sql/ast/join_expr.go b/internal/sql/ast/join_expr.go index 86e38d2d3b..e316869560 100644 --- a/internal/sql/ast/join_expr.go +++ b/internal/sql/ast/join_expr.go @@ -14,3 +14,29 @@ type JoinExpr struct { func (n *JoinExpr) Pos() int { return 0 } + +func (n *JoinExpr) Format(buf *TrackedBuffer) { + if n == nil { + return + } + buf.astFormat(n.Larg) + switch n.Jointype { + case JoinTypeLeft: + buf.WriteString(" LEFT JOIN ") + case JoinTypeInner: + buf.WriteString(" INNER JOIN ") + default: + buf.WriteString(" JOIN ") + } + buf.astFormat(n.Rarg) + buf.WriteString(" ON ") + if n.Jointype == JoinTypeInner { + if set(n.Quals) { + buf.astFormat(n.Quals) + } else { + buf.WriteString("TRUE") + } + } else { + buf.astFormat(n.Quals) + } +} diff --git a/internal/sql/ast/list.go b/internal/sql/ast/list.go index ae49b0c429..1c89d55339 100644 --- a/internal/sql/ast/list.go +++ b/internal/sql/ast/list.go @@ -7,3 +7,10 @@ type List struct { func (n *List) Pos() int { return 0 } + +func (n *List) Format(buf *TrackedBuffer) { + if n == nil { + return + } + buf.join(n, ",") +} diff --git a/internal/sql/ast/listen_stmt.go b/internal/sql/ast/listen_stmt.go index cbd51dd90a..79c1b132c1 100644 --- a/internal/sql/ast/listen_stmt.go +++ b/internal/sql/ast/listen_stmt.go @@ -7,3 +7,13 @@ type ListenStmt struct { func (n *ListenStmt) Pos() int { return 0 } + +func (n *ListenStmt) Format(buf *TrackedBuffer) { + if n == nil { + return + } + buf.WriteString("LISTEN ") + if n.Conditionname != nil { + buf.WriteString(*n.Conditionname) + } +} diff --git a/internal/sql/ast/locking_clause.go b/internal/sql/ast/locking_clause.go index 5800a03806..11a9159de2 100644 --- a/internal/sql/ast/locking_clause.go +++ b/internal/sql/ast/locking_clause.go @@ -9,3 +9,16 @@ type LockingClause struct { func (n *LockingClause) Pos() int { return 0 } + +func (n *LockingClause) Format(buf *TrackedBuffer) { + if n == nil { + return + } + buf.WriteString("FOR ") + switch n.Strength { + case 3: + buf.WriteString("SHARE") + case 5: + buf.WriteString("UPDATE") + } +} diff --git a/internal/sql/ast/multi_assign_ref.go b/internal/sql/ast/multi_assign_ref.go index ef0d5554c3..16302b4e4c 100644 --- a/internal/sql/ast/multi_assign_ref.go +++ b/internal/sql/ast/multi_assign_ref.go @@ -9,3 +9,10 @@ type MultiAssignRef struct { func (n *MultiAssignRef) Pos() int { return 0 } + +func (n *MultiAssignRef) Format(buf *TrackedBuffer) { + if n == nil { + return + } + buf.astFormat(n.Source) +} diff --git a/internal/sql/ast/named_arg_expr.go b/internal/sql/ast/named_arg_expr.go index 1c802bdd26..e37427826e 100644 --- a/internal/sql/ast/named_arg_expr.go +++ b/internal/sql/ast/named_arg_expr.go @@ -11,3 +11,14 @@ type NamedArgExpr struct { func (n *NamedArgExpr) Pos() int { return n.Location } + +func (n *NamedArgExpr) Format(buf *TrackedBuffer) { + if n == nil { + return + } + if n.Name != nil { + buf.WriteString(*n.Name) + } + buf.WriteString(" => ") + buf.astFormat(n.Arg) +} diff --git a/internal/sql/ast/notify_stmt.go b/internal/sql/ast/notify_stmt.go index ef3058df56..0c50a11123 100644 --- a/internal/sql/ast/notify_stmt.go +++ b/internal/sql/ast/notify_stmt.go @@ -8,3 +8,18 @@ type NotifyStmt struct { func (n *NotifyStmt) Pos() int { return 0 } + +func (n *NotifyStmt) Format(buf *TrackedBuffer) { + if n == nil { + return + } + buf.WriteString("NOTIFY ") + if n.Conditionname != nil { + buf.WriteString(*n.Conditionname) + } + if n.Payload != nil { + buf.WriteString(", '") + buf.WriteString(*n.Payload) + buf.WriteString("'") + } +} diff --git a/internal/sql/ast/null.go b/internal/sql/ast/null.go index 92abc76c02..380c8e7372 100644 --- a/internal/sql/ast/null.go +++ b/internal/sql/ast/null.go @@ -6,3 +6,6 @@ type Null struct { func (n *Null) Pos() int { return 0 } +func (n *Null) Format(buf *TrackedBuffer) { + buf.WriteString("NULL") +} diff --git a/internal/sql/ast/param_ref.go b/internal/sql/ast/param_ref.go index d0f486cf85..8bd724993d 100644 --- a/internal/sql/ast/param_ref.go +++ b/internal/sql/ast/param_ref.go @@ -1,5 +1,7 @@ package ast +import "fmt" + type ParamRef struct { Number int Location int @@ -9,3 +11,10 @@ type ParamRef struct { func (n *ParamRef) Pos() int { return n.Location } + +func (n *ParamRef) Format(buf *TrackedBuffer) { + if n == nil { + return + } + fmt.Fprintf(buf, "$%d", n.Number) +} diff --git a/internal/sql/ast/print.go b/internal/sql/ast/print.go new file mode 100644 index 0000000000..867a53a177 --- /dev/null +++ b/internal/sql/ast/print.go @@ -0,0 +1,81 @@ +package ast + +import ( + "strings" + + "github.com/sqlc-dev/sqlc/internal/debug" +) + +type formatter interface { + Format(*TrackedBuffer) +} + +type TrackedBuffer struct { + *strings.Builder +} + +// NewTrackedBuffer creates a new TrackedBuffer. +func NewTrackedBuffer() *TrackedBuffer { + buf := &TrackedBuffer{ + Builder: new(strings.Builder), + } + return buf +} + +func (t *TrackedBuffer) astFormat(n Node) { + if ft, ok := n.(formatter); ok { + ft.Format(t) + } else { + debug.Dump(n) + } +} + +func (t *TrackedBuffer) join(n *List, sep string) { + if n == nil { + return + } + for i, item := range n.Items { + if _, ok := item.(*TODO); ok { + continue + } + if i > 0 { + t.WriteString(sep) + } + t.astFormat(item) + } +} + +func Format(n Node) string { + tb := NewTrackedBuffer() + if ft, ok := n.(formatter); ok { + ft.Format(tb) + } + return tb.String() +} + +func set(n Node) bool { + if n == nil { + return false + } + _, ok := n.(*TODO) + if ok { + return false + } + return true +} + +func items(n *List) bool { + if n == nil { + return false + } + return len(n.Items) > 0 +} + +func todo(n *List) bool { + for _, item := range n.Items { + if _, ok := item.(*TODO); !ok { + return false + } + } + return true +} diff --git a/internal/sql/ast/range_function.go b/internal/sql/ast/range_function.go index dd92870aa4..299078d481 100644 --- a/internal/sql/ast/range_function.go +++ b/internal/sql/ast/range_function.go @@ -12,3 +12,14 @@ type RangeFunction struct { func (n *RangeFunction) Pos() int { return 0 } + +func (n *RangeFunction) Format(buf *TrackedBuffer) { + if n == nil { + return + } + buf.astFormat(n.Functions) + if n.Ordinality { + buf.WriteString(" WITH ORDINALITY ") + } + buf.astFormat(n.Alias) +} diff --git a/internal/sql/ast/range_subselect.go b/internal/sql/ast/range_subselect.go index aaf4d5adaf..1506ee7994 100644 --- a/internal/sql/ast/range_subselect.go +++ b/internal/sql/ast/range_subselect.go @@ -9,3 +9,16 @@ type RangeSubselect struct { func (n *RangeSubselect) Pos() int { return 0 } + +func (n *RangeSubselect) Format(buf *TrackedBuffer) { + if n == nil { + return + } + buf.WriteString("(") + buf.astFormat(n.Subquery) + buf.WriteString(")") + if n.Alias != nil { + buf.WriteString(" ") + buf.astFormat(n.Alias) + } +} diff --git a/internal/sql/ast/range_var.go b/internal/sql/ast/range_var.go index 3b648ff7c3..1d1656f6c0 100644 --- a/internal/sql/ast/range_var.go +++ b/internal/sql/ast/range_var.go @@ -13,3 +13,27 @@ type RangeVar struct { func (n *RangeVar) Pos() int { return n.Location } + +func (n *RangeVar) Format(buf *TrackedBuffer) { + if n == nil { + return + } + if n.Schemaname != nil { + buf.WriteString(*n.Schemaname) + buf.WriteString(".") + } + if n.Relname != nil { + // TODO: What names need to be quoted + if *n.Relname == "user" { + buf.WriteString(`"`) + buf.WriteString(*n.Relname) + buf.WriteString(`"`) + } else { + buf.WriteString(*n.Relname) + } + } + if n.Alias != nil { + buf.WriteString(" ") + buf.astFormat(n.Alias) + } +} diff --git a/internal/sql/ast/raw_stmt.go b/internal/sql/ast/raw_stmt.go index 3e9e89329f..55192d2eec 100644 --- a/internal/sql/ast/raw_stmt.go +++ b/internal/sql/ast/raw_stmt.go @@ -9,3 +9,10 @@ type RawStmt struct { func (n *RawStmt) Pos() int { return n.StmtLocation } + +func (n *RawStmt) Format(buf *TrackedBuffer) { + if n.Stmt != nil { + buf.astFormat(n.Stmt) + } + buf.WriteString(";") +} diff --git a/internal/sql/ast/refresh_mat_view_stmt.go b/internal/sql/ast/refresh_mat_view_stmt.go index 9284c343de..e9b3e26bfa 100644 --- a/internal/sql/ast/refresh_mat_view_stmt.go +++ b/internal/sql/ast/refresh_mat_view_stmt.go @@ -9,3 +9,11 @@ type RefreshMatViewStmt struct { func (n *RefreshMatViewStmt) Pos() int { return 0 } + +func (n *RefreshMatViewStmt) Format(buf *TrackedBuffer) { + if n == nil { + return + } + buf.WriteString("REFRESH MATERIALIZED VIEW ") + buf.astFormat(n.Relation) +} diff --git a/internal/sql/ast/res_target.go b/internal/sql/ast/res_target.go index f9428e3885..4ee2e72112 100644 --- a/internal/sql/ast/res_target.go +++ b/internal/sql/ast/res_target.go @@ -10,3 +10,20 @@ type ResTarget struct { func (n *ResTarget) Pos() int { return n.Location } + +func (n *ResTarget) Format(buf *TrackedBuffer) { + if n == nil { + return + } + if set(n.Val) { + buf.astFormat(n.Val) + if n.Name != nil { + buf.WriteString(" AS ") + buf.WriteString(*n.Name) + } + } else { + if n.Name != nil { + buf.WriteString(*n.Name) + } + } +} diff --git a/internal/sql/ast/row_expr.go b/internal/sql/ast/row_expr.go index 7e996b0e93..14804f5821 100644 --- a/internal/sql/ast/row_expr.go +++ b/internal/sql/ast/row_expr.go @@ -12,3 +12,18 @@ type RowExpr struct { func (n *RowExpr) Pos() int { return n.Location } + +func (n *RowExpr) Format(buf *TrackedBuffer) { + if n == nil { + return + } + if items(n.Args) { + buf.WriteString("args") + buf.astFormat(n.Args) + } + buf.astFormat(n.Xpr) + if items(n.Colnames) { + buf.WriteString("cols") + buf.astFormat(n.Colnames) + } +} diff --git a/internal/sql/ast/select_stmt.go b/internal/sql/ast/select_stmt.go index 75a109c931..051dd5c8c5 100644 --- a/internal/sql/ast/select_stmt.go +++ b/internal/sql/ast/select_stmt.go @@ -1,5 +1,9 @@ package ast +import ( + "fmt" +) + type SelectStmt struct { DistinctClause *List IntoClause *IntoClause @@ -24,3 +28,85 @@ type SelectStmt struct { func (n *SelectStmt) Pos() int { return 0 } + +func (n *SelectStmt) Format(buf *TrackedBuffer) { + if n == nil { + return + } + + if items(n.ValuesLists) { + buf.WriteString("VALUES (") + buf.astFormat(n.ValuesLists) + buf.WriteString(")") + return + } + + if n.WithClause != nil { + buf.astFormat(n.WithClause) + buf.WriteString(" ") + } + + if n.Larg != nil && n.Rarg != nil { + buf.astFormat(n.Larg) + switch n.Op { + case Union: + buf.WriteString(" UNION ") + case Except: + buf.WriteString(" EXCEPT ") + case Intersect: + buf.WriteString(" INTERSECT ") + } + if n.All { + buf.WriteString("ALL ") + } + buf.astFormat(n.Rarg) + } else { + buf.WriteString("SELECT ") + } + + if items(n.DistinctClause) { + buf.WriteString("DISTINCT ") + if !todo(n.DistinctClause) { + fmt.Fprintf(buf, "ON (") + buf.astFormat(n.DistinctClause) + fmt.Fprintf(buf, ")") + } + } + buf.astFormat(n.TargetList) + + if items(n.FromClause) { + buf.WriteString(" FROM ") + buf.astFormat(n.FromClause) + } + + if set(n.WhereClause) { + buf.WriteString(" WHERE ") + buf.astFormat(n.WhereClause) + } + + if items(n.GroupClause) { + buf.WriteString(" GROUP BY ") + buf.astFormat(n.GroupClause) + } + + if items(n.SortClause) { + buf.WriteString(" ORDER BY ") + buf.astFormat(n.SortClause) + } + + if set(n.LimitCount) { + buf.WriteString(" LIMIT ") + buf.astFormat(n.LimitCount) + } + + if set(n.LimitOffset) { + buf.WriteString(" OFFSET ") + buf.astFormat(n.LimitOffset) + } + + if items(n.LockingClause) { + buf.WriteString(" ") + buf.astFormat(n.LockingClause) + } + +} diff --git a/internal/sql/ast/sort_by.go b/internal/sql/ast/sort_by.go index 49c4004f28..21a7a079aa 100644 --- a/internal/sql/ast/sort_by.go +++ b/internal/sql/ast/sort_by.go @@ -11,3 +11,16 @@ type SortBy struct { func (n *SortBy) Pos() int { return n.Location } + +func (n *SortBy) Format(buf *TrackedBuffer) { + if n == nil { + return + } + buf.astFormat(n.Node) + switch n.SortbyDir { + case SortByDirAsc: + buf.WriteString(" ASC") + case SortByDirDesc: + buf.WriteString(" DESC") + } +} diff --git a/internal/sql/ast/sql_value_function.go b/internal/sql/ast/sql_value_function.go index a2e5214cb0..0bd0777374 100644 --- a/internal/sql/ast/sql_value_function.go +++ b/internal/sql/ast/sql_value_function.go @@ -11,3 +11,27 @@ type SQLValueFunction struct { func (n *SQLValueFunction) Pos() int { return n.Location } + +func (n *SQLValueFunction) Format(buf *TrackedBuffer) { + if n == nil { + return + } + switch n.Op { + case SVFOpCurrentDate: + buf.WriteString("CURRENT_DATE") + case SVFOpCurrentTime: + case SVFOpCurrentTimeN: + case SVFOpCurrentTimestamp: + case SVFOpCurrentTimestampN: + case SVFOpLocaltime: + case SVFOpLocaltimeN: + case SVFOpLocaltimestamp: + case SVFOpLocaltimestampN: + case SVFOpCurrentRole: + case SVFOpCurrentUser: + case SVFOpUser: + case SVFOpSessionUser: + case SVFOpCurrentCatalog: + case SVFOpCurrentSchema: + } +} diff --git a/internal/sql/ast/sql_value_function_op.go b/internal/sql/ast/sql_value_function_op.go index e781109c8e..5d99afa0d3 100644 --- a/internal/sql/ast/sql_value_function_op.go +++ b/internal/sql/ast/sql_value_function_op.go @@ -2,6 +2,26 @@ package ast type SQLValueFunctionOp uint +const ( + // https://github.com/pganalyze/libpg_query/blob/15-latest/protobuf/pg_query.proto#L2984C1-L3003C1 + _ SQLValueFunctionOp = iota + SVFOpCurrentDate + SVFOpCurrentTime + SVFOpCurrentTimeN + SVFOpCurrentTimestamp + SVFOpCurrentTimestampN + SVFOpLocaltime + SVFOpLocaltimeN + SVFOpLocaltimestamp + SVFOpLocaltimestampN + SVFOpCurrentRole + SVFOpCurrentUser + SVFOpUser + SVFOpSessionUser + SVFOpCurrentCatalog + SVFOpCurrentSchema +) + func (n *SQLValueFunctionOp) Pos() int { return 0 } diff --git a/internal/sql/ast/string.go b/internal/sql/ast/string.go index 619c786db9..977fc19a2f 100644 --- a/internal/sql/ast/string.go +++ b/internal/sql/ast/string.go @@ -7,3 +7,10 @@ type String struct { func (n *String) Pos() int { return 0 } + +func (n *String) Format(buf *TrackedBuffer) { + if n == nil { + return + } + buf.WriteString(n.Str) +} diff --git a/internal/sql/ast/sub_link.go b/internal/sql/ast/sub_link.go index d61a629785..9463f98c54 100644 --- a/internal/sql/ast/sub_link.go +++ b/internal/sql/ast/sub_link.go @@ -26,3 +26,20 @@ type SubLink struct { func (n *SubLink) Pos() int { return n.Location } + +func (n *SubLink) Format(buf *TrackedBuffer) { + if n == nil { + return + } + buf.astFormat(n.Testexpr) + switch n.SubLinkType { + case EXISTS_SUBLINK: + buf.WriteString(" EXISTS (") + case ANY_SUBLINK: + buf.WriteString(" IN (") + default: + buf.WriteString(" (") + } + buf.astFormat(n.Subselect) + buf.WriteString(")") +} diff --git a/internal/sql/ast/table_name.go b/internal/sql/ast/table_name.go index ea77308b73..a95a510c83 100644 --- a/internal/sql/ast/table_name.go +++ b/internal/sql/ast/table_name.go @@ -9,3 +9,16 @@ type TableName struct { func (n *TableName) Pos() int { return 0 } + +func (n *TableName) Format(buf *TrackedBuffer) { + if n == nil { + return + } + if n.Schema != "" { + buf.WriteString(n.Schema) + buf.WriteString(".") + } + if n.Name != "" { + buf.WriteString(n.Name) + } +} diff --git a/internal/sql/ast/truncate_stmt.go b/internal/sql/ast/truncate_stmt.go index 6518fccdbc..f23a5bbcb3 100644 --- a/internal/sql/ast/truncate_stmt.go +++ b/internal/sql/ast/truncate_stmt.go @@ -9,3 +9,11 @@ type TruncateStmt struct { func (n *TruncateStmt) Pos() int { return 0 } + +func (n *TruncateStmt) Format(buf *TrackedBuffer) { + if n == nil { + return + } + buf.WriteString("TRUNCATE ") + buf.astFormat(n.Relations) +} diff --git a/internal/sql/ast/type_cast.go b/internal/sql/ast/type_cast.go index 8390f5b621..0b549eb4b1 100644 --- a/internal/sql/ast/type_cast.go +++ b/internal/sql/ast/type_cast.go @@ -9,3 +9,12 @@ type TypeCast struct { func (n *TypeCast) Pos() int { return n.Location } + +func (n *TypeCast) Format(buf *TrackedBuffer) { + if n == nil { + return + } + buf.astFormat(n.Arg) + buf.WriteString("::") + buf.astFormat(n.TypeName) +} diff --git a/internal/sql/ast/type_name.go b/internal/sql/ast/type_name.go index eb67dddbcc..e26404b3ba 100644 --- a/internal/sql/ast/type_name.go +++ b/internal/sql/ast/type_name.go @@ -19,3 +19,21 @@ type TypeName struct { func (n *TypeName) Pos() int { return n.Location } + +func (n *TypeName) Format(buf *TrackedBuffer) { + if n == nil { + return + } + if items(n.Names) { + buf.join(n.Names, ".") + } else { + if n.Name == "int4" { + buf.WriteString("INTEGER") + } else { + buf.WriteString(n.Name) + } + } + if items(n.ArrayBounds) { + buf.WriteString("[]") + } +} diff --git a/internal/sql/ast/update_stmt.go b/internal/sql/ast/update_stmt.go index 745b91b617..efd496ad75 100644 --- a/internal/sql/ast/update_stmt.go +++ b/internal/sql/ast/update_stmt.go @@ -1,5 +1,7 @@ package ast +import "strings" + type UpdateStmt struct { Relations *List TargetList *List @@ -13,3 +15,98 @@ type UpdateStmt struct { func (n *UpdateStmt) Pos() int { return 0 } + +func (n *UpdateStmt) Format(buf *TrackedBuffer) { + if n == nil { + return + } + if n.WithClause != nil { + buf.astFormat(n.WithClause) + buf.WriteString(" ") + } + + buf.WriteString("UPDATE ") + if items(n.Relations) { + buf.astFormat(n.Relations) + } + + if items(n.TargetList) { + buf.WriteString(" SET ") + + multi := false + for _, item := range n.TargetList.Items { + switch nn := item.(type) { + case *ResTarget: + if _, ok := nn.Val.(*MultiAssignRef); ok { + multi = true + } + } + } + if multi { + names := []string{} + vals := &List{} + for _, item := range n.TargetList.Items { + res, ok := item.(*ResTarget) + if !ok { + continue + } + if res.Name != nil { + names = append(names, *res.Name) + } + multi, ok := res.Val.(*MultiAssignRef) + if !ok { + vals.Items = append(vals.Items, res.Val) + continue + } + row, ok := multi.Source.(*RowExpr) + if !ok { + vals.Items = append(vals.Items, res.Val) + continue + } + vals.Items = append(vals.Items, row.Args.Items[multi.Colno-1]) + } + + buf.WriteString("(") + buf.WriteString(strings.Join(names, ",")) + buf.WriteString(") = (") + buf.join(vals, ",") + buf.WriteString(")") + } else { + for i, item := range n.TargetList.Items { + if i > 0 { + buf.WriteString(", ") + } + switch nn := item.(type) { + case *ResTarget: + if nn.Name != nil { + buf.WriteString(*nn.Name) + } + buf.WriteString(" = ") + buf.astFormat(nn.Val) + default: + buf.astFormat(item) + } + } + } + } + + if items(n.FromClause) { + buf.WriteString(" FROM ") + buf.astFormat(n.FromClause) + } + + if set(n.WhereClause) { + buf.WriteString(" WHERE ") + buf.astFormat(n.WhereClause) + } + + if set(n.LimitCount) { + buf.WriteString(" LIMIT ") + buf.astFormat(n.LimitCount) + } + + if items(n.ReturningList) { + buf.WriteString(" RETURNING ") + buf.astFormat(n.ReturningList) + } +} diff --git a/internal/sql/ast/with_clause.go b/internal/sql/ast/with_clause.go index 6334930439..634326fa7e 100644 --- a/internal/sql/ast/with_clause.go +++ b/internal/sql/ast/with_clause.go @@ -9,3 +9,14 @@ type WithClause struct { func (n *WithClause) Pos() int { return n.Location } + +func (n *WithClause) Format(buf *TrackedBuffer) { + if n == nil { + return + } + buf.WriteString("WITH") + if n.Recursive { + buf.WriteString(" RECURSIVE") + } + buf.astFormat(n.Ctes) +}