Skip to content

Commit

Permalink
fix: check column references in ORDER BY (#1411) (#1915)
Browse files Browse the repository at this point in the history
* fix: check column references in ORDER BY (#1411)

* test: move test cases to endtoend tests

* feat: add validate_order_by config option #1411

* feat: expand error message #1411

Tell the uses how to switch off validation here.

* feat: add expanded error message to test #1411

* compiler: Add functions to the compiler struct

Don't pass configuration around as a parameter

---------

Co-authored-by: Kyle Conroy <kyle@conroy.org>
  • Loading branch information
akutschera and kyleconroy committed Jun 8, 2023
1 parent 9b9a2b6 commit c4e4b68
Show file tree
Hide file tree
Showing 16 changed files with 119 additions and 21 deletions.
2 changes: 1 addition & 1 deletion internal/compiler/expand.go
Original file line number Diff line number Diff line change
Expand Up @@ -55,7 +55,7 @@ func (c *Compiler) quoteIdent(ident string) string {
}

func (c *Compiler) expandStmt(qc *QueryCatalog, raw *ast.RawStmt, node ast.Node) ([]source.Edit, error) {
tables, err := sourceTables(qc, node)
tables, err := c.sourceTables(qc, node)
if err != nil {
return nil, err
}
Expand Down
67 changes: 52 additions & 15 deletions internal/compiler/output_columns.go
Original file line number Diff line number Diff line change
Expand Up @@ -14,11 +14,11 @@ import (

// OutputColumns determines which columns a statement will output
func (c *Compiler) OutputColumns(stmt ast.Node) ([]*catalog.Column, error) {
qc, err := buildQueryCatalog(c.catalog, stmt, nil)
qc, err := c.buildQueryCatalog(c.catalog, stmt, nil)
if err != nil {
return nil, err
}
cols, err := outputColumns(qc, stmt)
cols, err := c.outputColumns(qc, stmt)
if err != nil {
return nil, err
}
Expand Down Expand Up @@ -51,8 +51,8 @@ func hasStarRef(cf *ast.ColumnRef) bool {
//
// Return an error if column references are ambiguous
// Return an error if column references don't exist
func outputColumns(qc *QueryCatalog, node ast.Node) ([]*Column, error) {
tables, err := sourceTables(qc, node)
func (c *Compiler) outputColumns(qc *QueryCatalog, node ast.Node) ([]*Column, error) {
tables, err := c.sourceTables(qc, node)
if err != nil {
return nil, err
}
Expand All @@ -68,21 +68,50 @@ func outputColumns(qc *QueryCatalog, node ast.Node) ([]*Column, error) {

if n.GroupClause != nil {
for _, item := range n.GroupClause.Items {
ref, ok := item.(*ast.ColumnRef)
if !ok {
continue
}

if err := findColumnForRef(ref, tables, n); err != nil {
if err := findColumnForNode(item, tables, n); err != nil {
return nil, err
}
}
}
validateOrderBy := true
if c.conf.StrictOrderBy != nil {
validateOrderBy = *c.conf.StrictOrderBy
}
if validateOrderBy {
if n.SortClause != nil {
for _, item := range n.SortClause.Items {
sb, ok := item.(*ast.SortBy)
if !ok {
continue
}
if err := findColumnForNode(sb.Node, tables, n); err != nil {
return nil, fmt.Errorf("%v: if you want to skip this validation, set 'strict_order_by' to false", err)
}
}
}
if n.WindowClause != nil {
for _, item := range n.WindowClause.Items {
sb, ok := item.(*ast.List)
if !ok {
continue
}
for _, single := range sb.Items {
caseExpr, ok := single.(*ast.CaseExpr)
if !ok {
continue
}
if err := findColumnForNode(caseExpr.Xpr, tables, n); err != nil {
return nil, fmt.Errorf("%v: if you want to skip this validation, set 'strict_order_by' to false", err)
}
}
}
}
}

// For UNION queries, targets is empty and we need to look for the
// columns in Largs.
if len(targets.Items) == 0 && n.Larg != nil {
return outputColumns(qc, n.Larg)
return c.outputColumns(qc, n.Larg)
}
case *ast.CallStmt:
targets = &ast.List{}
Expand Down Expand Up @@ -303,7 +332,7 @@ func outputColumns(qc *QueryCatalog, node ast.Node) ([]*Column, error) {
case ast.EXISTS_SUBLINK:
cols = append(cols, &Column{Name: name, DataType: "bool", NotNull: true})
case ast.EXPR_SUBLINK:
subcols, err := outputColumns(qc, n.Subselect)
subcols, err := c.outputColumns(qc, n.Subselect)
if err != nil {
return nil, err
}
Expand Down Expand Up @@ -339,7 +368,7 @@ func outputColumns(qc *QueryCatalog, node ast.Node) ([]*Column, error) {
cols = append(cols, col)

case *ast.SelectStmt:
subcols, err := outputColumns(qc, n)
subcols, err := c.outputColumns(qc, n)
if err != nil {
return nil, err
}
Expand Down Expand Up @@ -428,7 +457,7 @@ func isTableRequired(n ast.Node, col *Column, prior int) int {
// Return an error if column references don't exist
// Return an error if a table is referenced twice
// Return an error if an unknown column is referenced
func sourceTables(qc *QueryCatalog, node ast.Node) ([]*Table, error) {
func (c *Compiler) sourceTables(qc *QueryCatalog, node ast.Node) ([]*Table, error) {
var list *ast.List
switch n := node.(type) {
case *ast.DeleteStmt:
Expand Down Expand Up @@ -483,7 +512,7 @@ func sourceTables(qc *QueryCatalog, node ast.Node) ([]*Table, error) {
tables = append(tables, table)

case *ast.RangeSubselect:
cols, err := outputColumns(qc, n.Subquery)
cols, err := c.outputColumns(qc, n.Subquery)
if err != nil {
return nil, err
}
Expand Down Expand Up @@ -581,6 +610,14 @@ func outputColumnRefs(res *ast.ResTarget, tables []*Table, node *ast.ColumnRef)
return cols, nil
}

func findColumnForNode(item ast.Node, tables []*Table, n *ast.SelectStmt) error {
ref, ok := item.(*ast.ColumnRef)
if !ok {
return nil
}
return findColumnForRef(ref, tables, n)
}

func findColumnForRef(ref *ast.ColumnRef, tables []*Table, selectStatement *ast.SelectStmt) error {
parts := stringSlice(ref.Fields)
var alias, name string
Expand Down
5 changes: 2 additions & 3 deletions internal/compiler/parse.go
Original file line number Diff line number Diff line change
Expand Up @@ -86,9 +86,8 @@ func (c *Compiler) parseQuery(stmt ast.Node, src string, o opts.Parser) (*Query,
} else {
sort.Slice(refs, func(i, j int) bool { return refs[i].ref.Number < refs[j].ref.Number })
}

raw, embeds := rewrite.Embeds(raw)
qc, err := buildQueryCatalog(c.catalog, raw.Stmt, embeds)
qc, err := c.buildQueryCatalog(c.catalog, raw.Stmt, embeds)
if err != nil {
return nil, err
}
Expand All @@ -97,7 +96,7 @@ func (c *Compiler) parseQuery(stmt ast.Node, src string, o opts.Parser) (*Query,
if err != nil {
return nil, err
}
cols, err := outputColumns(qc, raw.Stmt)
cols, err := c.outputColumns(qc, raw.Stmt)
if err != nil {
return nil, err
}
Expand Down
4 changes: 2 additions & 2 deletions internal/compiler/query_catalog.go
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@ type QueryCatalog struct {
embeds rewrite.EmbedSet
}

func buildQueryCatalog(c *catalog.Catalog, node ast.Node, embeds rewrite.EmbedSet) (*QueryCatalog, error) {
func (comp *Compiler) buildQueryCatalog(c *catalog.Catalog, node ast.Node, embeds rewrite.EmbedSet) (*QueryCatalog, error) {
var with *ast.WithClause
switch n := node.(type) {
case *ast.DeleteStmt:
Expand All @@ -32,7 +32,7 @@ func buildQueryCatalog(c *catalog.Catalog, node ast.Node, embeds rewrite.EmbedSe
if with != nil {
for _, item := range with.Ctes.Items {
if cte, ok := item.(*ast.CommonTableExpr); ok {
cols, err := outputColumns(qc, cte.Ctequery)
cols, err := comp.outputColumns(qc, cte.Ctequery)
if err != nil {
return nil, err
}
Expand Down
1 change: 1 addition & 0 deletions internal/config/config.go
Original file line number Diff line number Diff line change
Expand Up @@ -99,6 +99,7 @@ type SQL struct {
Schema Paths `json:"schema" yaml:"schema"`
Queries Paths `json:"queries" yaml:"queries"`
StrictFunctionChecks bool `json:"strict_function_checks" yaml:"strict_function_checks"`
StrictOrderBy *bool `json:"strict_order_by" yaml:"strict_order_by"`
Gen SQLGen `json:"gen" yaml:"gen"`
Codegen []Codegen `json:"codegen" yaml:"codegen"`
}
Expand Down
6 changes: 6 additions & 0 deletions internal/config/v_one.go
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,7 @@ type v1PackageSettings struct {
OutputQuerierFileName string `json:"output_querier_file_name,omitempty" yaml:"output_querier_file_name"`
OutputFilesSuffix string `json:"output_files_suffix,omitempty" yaml:"output_files_suffix"`
StrictFunctionChecks bool `json:"strict_function_checks" yaml:"strict_function_checks"`
StrictOrderBy *bool `json:"strict_order_by" yaml:"strict_order_by"`
QueryParameterLimit *int32 `json:"query_parameter_limit,omitempty" yaml:"query_parameter_limit"`
}

Expand Down Expand Up @@ -130,6 +131,10 @@ func (c *V1GenerateSettings) Translate() Config {
}

for _, pkg := range c.Packages {
if pkg.StrictOrderBy == nil {
defaultValue := true
pkg.StrictOrderBy = &defaultValue
}
conf.SQL = append(conf.SQL, SQL{
Engine: pkg.Engine,
Schema: pkg.Schema,
Expand Down Expand Up @@ -164,6 +169,7 @@ func (c *V1GenerateSettings) Translate() Config {
},
},
StrictFunctionChecks: pkg.StrictFunctionChecks,
StrictOrderBy: pkg.StrictOrderBy,
})
}

Expand Down
4 changes: 4 additions & 0 deletions internal/config/v_two.go
Original file line number Diff line number Diff line change
Expand Up @@ -110,6 +110,10 @@ func v2ParseConfig(rd io.Reader) (Config, error) {
return conf, ErrPluginNotFound
}
}
if conf.SQL[j].StrictOrderBy == nil {
defaultValidate := true
conf.SQL[j].StrictOrderBy = &defaultValidate
}
}
return conf, nil
}
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,8 @@
-- Example queries for sqlc
CREATE TABLE authors (
id INT
);

-- name: ListAuthors :many
SELECT id FROM authors
ORDER BY adfadsf;
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
version: 1
packages:
- path: "go"
name: "querytest"
engine: "postgresql"
schema: "query.sql"
queries: "query.sql"
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
# package querytest
query.sql:7:1: column reference "adfadsf" not found: if you want to skip this validation, set 'strict_order_by' to false
Original file line number Diff line number Diff line change
@@ -0,0 +1,8 @@
-- Example queries for sqlc
CREATE TABLE authors (
id INT
);

-- name: ListAuthors :many
SELECT id FROM authors
ORDER BY adfadsf;
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
version: 1
packages:
- path: "go"
name: "querytest"
engine: "postgresql"
schema: "query.sql"
queries: "query.sql"
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
# package querytest
query.sql:7:1: column reference "adfadsf" not found: if you want to skip this validation, set 'strict_order_by' to false
Original file line number Diff line number Diff line change
@@ -0,0 +1,8 @@
-- Example queries for sqlc
CREATE TABLE authors (
id INT
);

-- name: ListAuthors :many
SELECT id FROM authors
ORDER BY adfadsf;
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
version: 1
packages:
- path: "go"
name: "querytest"
engine: "postgresql"
schema: "query.sql"
queries: "query.sql"
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
# package querytest
query.sql:7:1: column reference "adfadsf" not found: if you want to skip this validation, set 'strict_order_by' to false

0 comments on commit c4e4b68

Please sign in to comment.