Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[DRAFT] add raw explain #3581

Open
wants to merge 2 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions internal/cmd/cmd.go
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,7 @@ func Do(args []string, stdin io.Reader, stdout io.Writer, stderr io.Writer) int
rootCmd.AddCommand(verifyCmd)
rootCmd.AddCommand(pushCmd)
rootCmd.AddCommand(NewCmdVet())
rootCmd.AddCommand(NewCmdExplain())

rootCmd.SetArgs(args)
rootCmd.SetIn(stdin)
Expand Down
280 changes: 280 additions & 0 deletions internal/cmd/explain.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,280 @@
package cmd

import (
"context"
"database/sql"
"errors"
"fmt"
"io"
"os"
"path/filepath"
"runtime/trace"

"github.com/jackc/pgx/v5"
"github.com/spf13/cobra"
"github.com/sqlc-dev/sqlc/internal/config"
"github.com/sqlc-dev/sqlc/internal/dbmanager"
"github.com/sqlc-dev/sqlc/internal/debug"
"github.com/sqlc-dev/sqlc/internal/opts"
"github.com/sqlc-dev/sqlc/internal/plugin"
"github.com/sqlc-dev/sqlc/internal/shfmt"
"gopkg.in/yaml.v3"
)

func NewCmdExplain() *cobra.Command {
return &cobra.Command{
Use: "explain",
Short: "Explain queries",
RunE: func(cmd *cobra.Command, args []string) error {
defer trace.StartRegion(cmd.Context(), "vet").End()
stderr := cmd.ErrOrStderr()
opts := &Options{
Env: ParseEnv(cmd),
Stderr: stderr,
}
dir, name := getConfigPath(stderr, cmd.Flag("file"))
if err := Explain(cmd.Context(), dir, name, opts); err != nil {
if !errors.Is(err, ErrFailedChecks) {
fmt.Fprintf(stderr, "%s\n", err)
}
os.Exit(1)
}
return nil
},
}
}

func Explain(ctx context.Context, dir, filename string, opts *Options) error {
e := opts.Env
stderr := opts.Stderr
configPath, conf, err := readConfig(stderr, dir, filename)
if err != nil {
return err
}

base := filepath.Base(configPath)
if err := config.Validate(conf); err != nil {
fmt.Fprintf(stderr, "error validating %s: %s\n", base, err)
return err
}

if err := e.Validate(conf); err != nil {
fmt.Fprintf(stderr, "error validating %s: %s\n", base, err)
return err
}

c := rawExplainer{
Conf: conf,
Dir: dir,
Stderr: stderr,
OnlyManagedDB: e.Debug.OnlyManagedDatabases,
Replacer: shfmt.NewReplacer(nil),
}
var errs error
for _, sql := range conf.SQL {
if err := c.explainSQL(ctx, sql); err != nil {
fmt.Fprintf(stderr, "%s\n", err)
errs = errors.Join(errs, err)
}
}

return errs
}

type rawExplainer struct {
Conf *config.Config
Dir string
Stderr io.Writer
OnlyManagedDB bool
Client dbmanager.Client
Replacer *shfmt.Replacer
}

func (c *rawExplainer) fetchDatabaseUri(ctx context.Context, s config.SQL) (string, func() error, error) {
return (&checker{
Conf: c.Conf,
Dir: c.Dir,
Stderr: c.Stderr,
OnlyManagedDB: c.OnlyManagedDB,
Client: c.Client,
Replacer: c.Replacer,
}).fetchDatabaseUri(ctx, s)
}

func (c *rawExplainer) explainSQL(ctx context.Context, s config.SQL) error {
// TODO: Create a separate function for this logic so we can
combo := config.Combine(*c.Conf, s)

// TODO: This feels like a hack that will bite us later
joined := make([]string, 0, len(s.Schema))
for _, s := range s.Schema {
joined = append(joined, filepath.Join(c.Dir, s))
}
s.Schema = joined

joined = make([]string, 0, len(s.Queries))
for _, q := range s.Queries {
joined = append(joined, filepath.Join(c.Dir, q))
}
s.Queries = joined

var name string
parseOpts := opts.Parser{
Debug: debug.Debug,
}

result, failed := parse(ctx, name, c.Dir, s, combo, parseOpts, c.Stderr)
if failed {
return ErrFailedChecks
}

var expl rawDBExplainer
if s.Database != nil { // TODO only set up a database connection if a rule evaluation requires it
if s.Database.URI != "" && c.OnlyManagedDB {
return fmt.Errorf("database: connections disabled via SQLCDEBUG=databases=managed")
}
dburl, cleanup, err := c.fetchDatabaseUri(ctx, s)
if err != nil {
return err
}
defer func() {
if err := cleanup(); err != nil {
fmt.Fprintf(c.Stderr, "error cleaning up: %s\n", err)
}
}()

switch s.Engine {
case config.EnginePostgreSQL:
conn, err := pgx.Connect(ctx, dburl)
if err != nil {
return fmt.Errorf("database: connection error: %s", err)
}
if err := conn.Ping(ctx); err != nil {
return fmt.Errorf("database: connection error: %s", err)
}
defer conn.Close(ctx)

expl = &rawPostgresExplainer{c: conn}
case config.EngineMySQL:
db, err := sql.Open("mysql", dburl)
if err != nil {
return fmt.Errorf("database: connection error: %s", err)
}
if err := db.PingContext(ctx); err != nil {
return fmt.Errorf("database: connection error: %s", err)
}
defer db.Close()
expl = &rawMySQLExplainer{db}
case config.EngineSQLite:
db, err := sql.Open("sqlite", dburl)
if err != nil {
return fmt.Errorf("database: connection error: %s", err)
}
if err := db.PingContext(ctx); err != nil {
return fmt.Errorf("database: connection error: %s", err)
}
defer db.Close()
// SQLite really doesn't want us to depend on the output of EXPLAIN
// QUERY PLAN: https://www.sqlite.org/eqp.html
expl = nil
default:
return fmt.Errorf("unsupported database uri: %s", s.Engine)
}

req := codeGenRequest(result, combo)
for _, query := range req.Queries {
if expl == nil {
fmt.Fprintf(c.Stderr, "%s: %s: %s: error explaining query: database connection required\n", query.Filename, query.Name, name)
continue
}
results, err := expl.Explain(ctx, query.Text, query.Params...)
if err != nil {
fmt.Fprintf(c.Stderr, "%s: %s: %s: error explaining query: %s\n", query.Filename, query.Name, name, err)
}

err = yaml.NewEncoder(os.Stdout).Encode([]struct {
Name string
Query string `yaml:",flow"` // https://github.com/go-yaml/yaml/issues/789
Arguments []any
Output interface{}
}{{
Name: query.Name,
Query: query.Text,
Arguments: expl.DefaultValues(query.Params),
Output: string(results),
}})
if err != nil {
fmt.Fprintf(c.Stderr, "%s: %s: %s: fail marshal results: %s\n", query.Filename, query.Name, name, err)
}

}
}
return nil
}

type rawDBExplainer interface {
DefaultValues([]*plugin.Parameter) []any
Explain(context.Context, string, ...*plugin.Parameter) ([]byte, error)
}

type rawPostgresExplainer struct {
c *pgx.Conn
}

func (p *rawPostgresExplainer) DefaultValues(args []*plugin.Parameter) []any {
eArgs := make([]any, len(args))
for i, a := range args {
eArgs[i] = pgDefaultValue(a.Column)
}
return eArgs
}

func (p *rawPostgresExplainer) Explain(ctx context.Context, query string, args ...*plugin.Parameter) ([]byte, error) {
eQuery := "EXPLAIN " + query
eArgs := p.DefaultValues(args)
var results []byte

rows, err := p.c.Query(ctx, eQuery, eArgs...)
if err != nil {
return nil, err
}
for rows.Next() {
var result []byte
err := rows.Scan(&result)
if err != nil {
return nil, err
}
results = append(results, append(result, '\n')...)
}
return results, nil
}

type rawMySQLExplainer struct {
*sql.DB
}

func (me *rawMySQLExplainer) DefaultValues(args []*plugin.Parameter) []any {
eArgs := make([]any, len(args))
for i, a := range args {
eArgs[i] = mysqlDefaultValue(a.Column)
}
return eArgs
}
func (me *rawMySQLExplainer) Explain(ctx context.Context, query string, args ...*plugin.Parameter) ([]byte, error) {
eQuery := "EXPLAIN " + query
eArgs := me.DefaultValues(args)
var results []byte
rows, err := me.QueryContext(ctx, eQuery, eArgs...)
if err != nil {
return nil, err
}
for rows.Next() {
var result []byte
err := rows.Scan(&result)
if err != nil {
return nil, err
}
results = append(results, append(result, '\n')...)
}
return results, nil
}