Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat(mysql): :copyfrom support via LOAD DATA INFILE #2545

Merged
merged 4 commits into from
Jul 30, 2023
Merged
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
15 changes: 9 additions & 6 deletions internal/codegen/golang/driver.go
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
package golang

type SQLDriver int
type SQLDriver string

const (
SQLPackagePGXV4 string = "pgx/v4"
Expand All @@ -9,13 +9,12 @@ const (
)

const (
SQLDriverPGXV4 SQLDriver = iota
SQLDriverPGXV5
SQLDriverLibPQ
SQLDriverPGXV4 SQLDriver = "github.com/jackc/pgx/v4"
SQLDriverPGXV5 = "github.com/jackc/pgx/v5"
SQLDriverLibPQ = "github.com/lib/pq"
SQLDriverGoSQLDriverMySQL = "github.com/go-sql-driver/mysql"
)

const SQLDriverGoSQLDriverMySQL = "github.com/go-sql-driver/mysql"

func parseDriver(sqlPackage string) SQLDriver {
switch sqlPackage {
case SQLPackagePGXV4:
Expand All @@ -31,6 +30,10 @@ func (d SQLDriver) IsPGX() bool {
return d == SQLDriverPGXV4 || d == SQLDriverPGXV5
}

func (d SQLDriver) IsGoSQLDriverMySQL() bool {
return d == SQLDriverGoSQLDriverMySQL
}

func (d SQLDriver) Package() string {
switch d {
case SQLDriverPGXV4:
Expand Down
22 changes: 20 additions & 2 deletions internal/codegen/golang/gen.go
Original file line number Diff line number Diff line change
Expand Up @@ -145,8 +145,15 @@ func generate(req *plugin.CodeGenRequest, enums []Enum, structs []Struct, querie
SqlcVersion: req.SqlcVersion,
}

if tctx.UsesCopyFrom && !tctx.SQLDriver.IsPGX() {
return nil, errors.New(":copyfrom is only supported by pgx")
if tctx.UsesCopyFrom && !tctx.SQLDriver.IsPGX() && golang.SqlDriver != SQLDriverGoSQLDriverMySQL {
return nil, errors.New(":copyfrom is only supported by pgx and github.com/go-sql-driver/mysql")
}

if tctx.UsesCopyFrom && golang.SqlDriver == SQLDriverGoSQLDriverMySQL {
if err := checkNoTimesForMySQLCopyFrom(queries); err != nil {
return nil, err
}
tctx.SQLDriver = SQLDriverGoSQLDriverMySQL
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is a gross hack, but I don't want to change how parseDriver works right now.

}

if tctx.UsesBatch && !tctx.SQLDriver.IsPGX() {
Expand Down Expand Up @@ -294,6 +301,17 @@ func usesBatch(queries []Query) bool {
return false
}

func checkNoTimesForMySQLCopyFrom(queries []Query) error {
for _, q := range queries {
for _, f := range q.Arg.Fields() {
if f.Type == "time.Time" {
return fmt.Errorf("values with a timezone are not yet supported")
}
}
}
return nil
}

func filterUnusedStructs(enums []Enum, structs []Struct, queries []Query) ([]Enum, []Struct) {
keepTypes := make(map[string]struct{})

Expand Down
7 changes: 7 additions & 0 deletions internal/codegen/golang/imports.go
Original file line number Diff line number Diff line change
Expand Up @@ -407,6 +407,13 @@ func (i *importer) copyfromImports() fileImports {
})

std["context"] = struct{}{}
if i.Settings.Go.SqlDriver == SQLDriverGoSQLDriverMySQL {
std["io"] = struct{}{}
std["fmt"] = struct{}{}
std["sync/atomic"] = struct{}{}
pkg[ImportSpec{Path: "github.com/go-sql-driver/mysql"}] = struct{}{}
pkg[ImportSpec{Path: "github.com/hexon/mysqltsv"}] = struct{}{}
}

return sortedImports(std, pkg)
}
Expand Down
38 changes: 36 additions & 2 deletions internal/codegen/golang/query.go
Original file line number Diff line number Diff line change
Expand Up @@ -129,7 +129,18 @@ func (v QueryValue) Params() string {
return "\n" + strings.Join(out, ",\n")
}

func (v QueryValue) ColumnNames() string {
func (v QueryValue) ColumnNames() []string {
if v.Struct == nil {
return []string{v.DBName}
}
names := make([]string, len(v.Struct.Fields))
for i, f := range v.Struct.Fields {
names[i] = f.DBName
}
return names
}

func (v QueryValue) ColumnNamesAsGoSlice() string {
if v.Struct == nil {
return fmt.Sprintf("[]string{%q}", v.DBName)
}
Expand Down Expand Up @@ -187,6 +198,19 @@ func (v QueryValue) Scan() string {
return "\n" + strings.Join(out, ",\n")
}

func (v QueryValue) Fields() []Field {
if v.Struct != nil {
return v.Struct.Fields
}
return []Field{
{
Name: v.Name,
DBName: v.DBName,
Type: v.Typ,
},
}
}

func (v QueryValue) VariableForField(f Field) string {
if !v.IsStruct() {
return v.Name
Expand Down Expand Up @@ -218,7 +242,7 @@ func (q Query) hasRetType() bool {
return scanned && !q.Ret.isEmpty()
}

func (q Query) TableIdentifier() string {
func (q Query) TableIdentifierAsGoSlice() string {
escapedNames := make([]string, 0, 3)
for _, p := range []string{q.Table.Catalog, q.Table.Schema, q.Table.Name} {
if p != "" {
Expand All @@ -227,3 +251,13 @@ func (q Query) TableIdentifier() string {
}
return "[]string{" + strings.Join(escapedNames, ", ") + "}"
}

func (q Query) TableIdentifierForMySQL() string {
escapedNames := make([]string, 0, 3)
for _, p := range []string{q.Table.Catalog, q.Table.Schema, q.Table.Name} {
if p != "" {
escapedNames = append(escapedNames, fmt.Sprintf("`%s`", p))
}
}
return strings.Join(escapedNames, ".")
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,52 @@
{{define "copyfromCodeGoSqlDriver"}}
{{range .GoQueries}}
{{if eq .Cmd ":copyfrom" }}
var readerHandlerSequenceFor{{.MethodName}} uint32 = 1

func convertRowsFor{{.MethodName}}(w *io.PipeWriter, {{.Arg.SlicePair}}) {
e := mysqltsv.NewEncoder(w, {{ len .Arg.Fields }}, nil)
for _, row := range {{.Arg.Name}} {
{{- with $arg := .Arg }}
{{- range $arg.Fields}}
{{- if eq .Type "string"}}
e.AppendString({{if eq (len $arg.Fields) 1}}row{{else}}row.{{.Name}}{{end}})
{{- else if eq .Type "[]byte"}}
e.AppendBytes({{if eq (len $arg.Fields) 1}}row{{else}}row.{{.Name}}{{end}})
{{- else}}
e.AppendValue({{if eq (len $arg.Fields) 1}}row{{else}}row.{{.Name}}{{end}})
{{- end}}
{{- end}}
{{- end}}
}
w.CloseWithError(e.Close())
}

{{range .Comments}}//{{.}}
{{end -}}
// {{.MethodName}} uses MySQL's LOAD DATA LOCAL INFILE and is not atomic.
//
// Errors and duplicate keys are treated as warnings and insertion will
// continue, even without an error for some cases. Use this in a transaction
// and use SHOW WARNINGS to check for any problems and roll back if you want to.
//
// Check the documentation for more information:
// https://dev.mysql.com/doc/refman/8.0/en/load-data.html#load-data-error-handling
func (q *Queries) {{.MethodName}}(ctx context.Context{{if $.EmitMethodsWithDBArgument}}, db DBTX{{end}}, {{.Arg.SlicePair}}) (int64, error) {
pr, pw := io.Pipe()
defer pr.Close()
rh := fmt.Sprintf("{{.MethodName}}_%d", atomic.AddUint32(&readerHandlerSequenceFor{{.MethodName}}, 1))
mysql.RegisterReaderHandler(rh, func() io.Reader { return pr })
defer mysql.DeregisterReaderHandler(rh)
go convertRowsFor{{.MethodName}}(pw, {{.Arg.Name}})
// The string interpolation is necessary because LOAD DATA INFILE requires
// the file name to be given as a literal string.
result, err := {{if (not $.EmitMethodsWithDBArgument)}}q.{{end}}db.ExecContext(ctx, fmt.Sprintf("LOAD DATA LOCAL INFILE '%s' INTO TABLE {{.TableIdentifierForMySQL}} %s ({{range $index, $name := .Arg.ColumnNames}}{{if gt $index 0}}, {{end}}{{$name}}{{end}})", "Reader::" + rh, mysqltsv.Escaping))
if err != nil {
return 0, err
}
return result.RowsAffected()
}

{{end}}
{{end}}
{{end}}
4 changes: 2 additions & 2 deletions internal/codegen/golang/templates/pgx/copyfromCopy.tmpl
Original file line number Diff line number Diff line change
Expand Up @@ -39,10 +39,10 @@ func (r iteratorFor{{.MethodName}}) Err() error {
{{end -}}
{{- if $.EmitMethodsWithDBArgument -}}
func (q *Queries) {{.MethodName}}(ctx context.Context, db DBTX, {{.Arg.SlicePair}}) (int64, error) {
return db.CopyFrom(ctx, {{.TableIdentifier}}, {{.Arg.ColumnNames}}, &iteratorFor{{.MethodName}}{rows: {{.Arg.Name}}})
return db.CopyFrom(ctx, {{.TableIdentifierAsGoSlice}}, {{.Arg.ColumnNamesAsGoSlice}}, &iteratorFor{{.MethodName}}{rows: {{.Arg.Name}}})
{{- else -}}
func (q *Queries) {{.MethodName}}(ctx context.Context, {{.Arg.SlicePair}}) (int64, error) {
return q.db.CopyFrom(ctx, {{.TableIdentifier}}, {{.Arg.ColumnNames}}, &iteratorFor{{.MethodName}}{rows: {{.Arg.Name}}})
return q.db.CopyFrom(ctx, {{.TableIdentifierAsGoSlice}}, {{.Arg.ColumnNamesAsGoSlice}}, &iteratorFor{{.MethodName}}{rows: {{.Arg.Name}}})
{{- end}}
}

Expand Down
2 changes: 2 additions & 0 deletions internal/codegen/golang/templates/template.tmpl
Original file line number Diff line number Diff line change
Expand Up @@ -186,6 +186,8 @@ import (
{{define "copyfromCode"}}
{{if .SQLDriver.IsPGX }}
{{- template "copyfromCodePgx" .}}
{{else if .SQLDriver.IsGoSQLDriverMySQL }}
{{- template "copyfromCodeGoSqlDriver" .}}
{{end}}
{{end}}

Expand Down
88 changes: 88 additions & 0 deletions internal/endtoend/testdata/copyfrom/mysql/go/copyfrom.go

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

31 changes: 31 additions & 0 deletions internal/endtoend/testdata/copyfrom/mysql/go/db.go

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

16 changes: 16 additions & 0 deletions internal/endtoend/testdata/copyfrom/mysql/go/models.go

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

25 changes: 25 additions & 0 deletions internal/endtoend/testdata/copyfrom/mysql/go/query.sql.go

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

7 changes: 7 additions & 0 deletions internal/endtoend/testdata/copyfrom/mysql/query.sql
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
CREATE TABLE foo (a text, b integer, c DATETIME, d DATE);

-- name: InsertValues :copyfrom
INSERT INTO foo (a, b, c, d) VALUES (?, ?, ?, ?);

-- name: InsertSingleValue :copyfrom
INSERT INTO foo (a) VALUES (?);
Loading
Loading