Skip to content

Commit

Permalink
Allow sqlc.arg('argname') form for named params (#351)
Browse files Browse the repository at this point in the history
* Allow sqlc.arg('argname') form for named params

This pacifies static SQL analizers in IDEs.

* Allow sqlc to be imported as a package for "go run github.com/kyleconroy/sqlc"

Co-authored-by: Kyle Conroy <kyle@conroy.org>
  • Loading branch information
Cyberax and kyleconroy committed Feb 24, 2020
1 parent b6636de commit 4b756bb
Show file tree
Hide file tree
Showing 3 changed files with 22 additions and 7 deletions.
22 changes: 16 additions & 6 deletions internal/dinosql/rewrite.go
Original file line number Diff line number Diff line change
Expand Up @@ -9,17 +9,21 @@ import (
)

// Given an AST node, return the string representation of names
func flatten(root nodes.Node) string {
func flatten(root nodes.Node) (string, bool) {
sw := &stringWalker{}
ast.Walk(sw, root)
return sw.String
return sw.String, sw.IsConst
}

type stringWalker struct {
String string
IsConst bool
}

func (s *stringWalker) Visit(node nodes.Node) ast.Visitor {
if _, ok := node.(nodes.A_Const); ok {
s.IsConst = true
}
if n, ok := node.(nodes.String); ok {
s.String += n.Str
}
Expand Down Expand Up @@ -61,7 +65,7 @@ func rewriteNamedParameters(raw nodes.RawStmt) (nodes.RawStmt, map[int]string, [

case isNamedParamFunc(node):
fun := node.(nodes.FuncCall)
param := flatten(fun.Args)
param, isConst := flatten(fun.Args)
if num, ok := args[param]; ok {
cr.Replace(nodes.ParamRef{
Number: num,
Expand All @@ -76,17 +80,23 @@ func rewriteNamedParameters(raw nodes.RawStmt) (nodes.RawStmt, map[int]string, [
})
}
// TODO: This code assumes that sqlc.arg(name) is on a single line
var old string
if isConst {
old = fmt.Sprintf("sqlc.arg('%s')", param)
} else {
old = fmt.Sprintf("sqlc.arg(%s)", param)
}
edits = append(edits, edit{
Location: fun.Location - raw.StmtLocation,
Old: fmt.Sprintf("sqlc.arg(%s)", param),
Old: old,
New: fmt.Sprintf("$%d", args[param]),
})
return false

case isNamedParamSignCast(node):
expr := node.(nodes.A_Expr)
cast := expr.Rexpr.(nodes.TypeCast)
param := flatten(cast.Arg)
param, _ := flatten(cast.Arg)
if num, ok := args[param]; ok {
cast.Arg = nodes.ParamRef{
Number: num,
Expand All @@ -112,7 +122,7 @@ func rewriteNamedParameters(raw nodes.RawStmt) (nodes.RawStmt, map[int]string, [

case isNamedParamSign(node):
expr := node.(nodes.A_Expr)
param := flatten(expr.Rexpr)
param, _ := flatten(expr.Rexpr)
if num, ok := args[param]; ok {
cr.Replace(nodes.ParamRef{
Number: num,
Expand Down
2 changes: 1 addition & 1 deletion internal/endtoend/testdata/named_param/query.sql
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
CREATE TABLE foo (name text not null, bio text not null);

-- name: FuncParams :many
SELECT name FROM foo WHERE name = sqlc.arg(slug) AND sqlc.arg(filter)::bool;
SELECT name FROM foo WHERE name = sqlc.arg('slug') AND sqlc.arg(filter)::bool;

-- name: AtParams :many
SELECT name FROM foo WHERE name = @slug AND @filter::bool;
Expand Down
5 changes: 5 additions & 0 deletions placeholder.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
package sqlc
// This is a dummy file that allows SQLC to be "installed" as a module and locked using
// go.mod and then run using "go run github.com/kyleconroy/sqlc"

type Placeholder struct{}

0 comments on commit 4b756bb

Please sign in to comment.