diff --git a/fix/anko.go b/fix/anko.go index d379bad3..deeb1426 100644 --- a/fix/anko.go +++ b/fix/anko.go @@ -6,16 +6,12 @@ import ( ) func Anko(content string) (string, error) { - if !strings.Contains(content, "create_table") { - return content, nil - } - if !strings.Contains(content, "func(t) {") { - return content, nil - } - bb := &bytes.Buffer{} - for _, line := range strings.Split(content, "\n") { + lines := strings.Split(content, "\n") + + // fix create_table + for i, line := range lines { tl := strings.TrimSpace(line) if strings.HasPrefix(tl, "create_table") { line = strings.Replace(line, ", func(t) {", ") {", -1) @@ -23,11 +19,25 @@ func Anko(content string) (string, error) { if strings.HasPrefix(tl, "})") { line = "}" } - if tl == "" { - continue + lines[i] = line + } + + // fix (` && `) + for i, line := range lines { + lines[i] = strings.Replace(line, "(`", `("`, -1) + lines[i] = strings.Replace(lines[i], "`)", `")`, -1) + } + + // fix raw + for i, line := range lines { + tl := strings.TrimSpace(line) + if strings.HasPrefix(tl, "raw(") { + line = strings.Replace(line, "raw(", "sql(", -1) } - bb.WriteString(line + "\n") + lines[i] = line } + bb.WriteString(strings.Join(lines, "\n")) + return bb.String(), nil } diff --git a/fix/anko_test.go b/fix/anko_test.go new file mode 100644 index 00000000..af7ef067 --- /dev/null +++ b/fix/anko_test.go @@ -0,0 +1,53 @@ +package fix + +import ( + "io/ioutil" + "strings" + "testing" + + "github.com/gobuffalo/packr" + "github.com/stretchr/testify/require" +) + +func Test_Anko(t *testing.T) { + r := require.New(t) + box := packr.NewBox("./fixtures") + err := box.Walk(func(path string, info packr.File) error { + if strings.HasPrefix(path, "pass") { + t.Run(path, testPass(path, info)) + return nil + } + t.Run(path, testFail(path, info)) + return nil + }) + r.NoError(err) +} + +func testPass(path string, info packr.File) func(*testing.T) { + return func(t *testing.T) { + r := require.New(t) + b, err := ioutil.ReadAll(info) + r.NoError(err) + + body := string(b) + fixed, err := Anko(body) + r.NoError(err) + if strings.Contains(path, "anko") { + r.NotEqual(body, fixed) + } else { + r.Equal(body, fixed) + } + } +} + +func testFail(path string, info packr.File) func(*testing.T) { + return func(t *testing.T) { + r := require.New(t) + b, err := ioutil.ReadAll(info) + r.NoError(err) + + body := string(b) + _, err = Anko(body) + r.Error(err) + } +} diff --git a/fix/fixtures/pass/0001_with_raw_backticks.anko.fizz b/fix/fixtures/pass/0001_with_raw_backticks.anko.fizz new file mode 100644 index 00000000..42e54134 --- /dev/null +++ b/fix/fixtures/pass/0001_with_raw_backticks.anko.fizz @@ -0,0 +1,14 @@ +create_table("users", func(t) { + t.Column("email", "string", {}) + t.Column("twitter_handle", "string", {"size": 50}) + t.Column("age", "integer", {"default": 0}) + t.Column("admin", "bool", {"default": false}) + t.Column("company_id", "uuid", {"default_raw": "uuid_generate_v1()"}) + t.Column("bio", "text", {"null": true}) + t.Column("joined_at", "timestamp", {}) +}) + +raw(` + INSERT INTO users (email, twitter_handle, joined_at, created_at, updated_at) + VALUES ('foo@example.com', 'Soman1994', now(), now(), now()); +`) diff --git a/fix/fixtures/pass/0002_happy.plush.fizz b/fix/fixtures/pass/0002_happy.plush.fizz new file mode 100644 index 00000000..ea535869 --- /dev/null +++ b/fix/fixtures/pass/0002_happy.plush.fizz @@ -0,0 +1,14 @@ +create_table("users") { + t.Column("email", "string", {}) + t.Column("twitter_handle", "string", {"size": 50}) + t.Column("age", "integer", {"default": 0}) + t.Column("admin", "bool", {"default": false}) + t.Column("company_id", "uuid", {"default_raw": "uuid_generate_v1()"}) + t.Column("bio", "text", {"null": true}) + t.Column("joined_at", "timestamp", {}) +} + +sql(" + INSERT INTO users (email, twitter_handle, joined_at, created_at, updated_at) + VALUES ('foo@example.com', 'Soman1994', now(), now(), now()); +")