Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

test(endtoend): Re-use databases when possible #3315

Merged
merged 3 commits into from
Apr 4, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 2 additions & 9 deletions internal/endtoend/endtoend_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -120,15 +120,14 @@ func TestReplay(t *testing.T) {
"managed-db": {
Mutate: func(t *testing.T, path string) func(*config.Config) {
return func(c *config.Config) {
c.Cloud.Project = "01HAQMMECEYQYKFJN8MP16QC41" // TODO: Read from environment
for i := range c.SQL {
files := []string{}
for _, s := range c.SQL[i].Schema {
files = append(files, filepath.Join(path, s))
}
switch c.SQL[i].Engine {
case config.EnginePostgreSQL:
uri := local.PostgreSQL(t, files)
uri := local.ReadOnlyPostgreSQL(t, files)
c.SQL[i].Database = &config.Database{
URI: uri,
}
Expand All @@ -138,18 +137,12 @@ func TestReplay(t *testing.T) {
// URI: uri,
// }
default:
c.SQL[i].Database = &config.Database{
Managed: true,
}
// pass
}
}
}
},
Enabled: func() bool {
// Return false if no auth token exists
if len(os.Getenv("SQLC_AUTH_TOKEN")) == 0 {
return false
}
if len(os.Getenv("POSTGRESQL_SERVER_URI")) == 0 {
return false
}
Expand Down
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
CREATE TABLE foo(
CREATE TABLE foo (
bar_id text,
site_url text
);
Original file line number Diff line number Diff line change
@@ -1,3 +1,3 @@
CREATE TABLE foo(
CREATE TABLE foo (
bar text
);
Original file line number Diff line number Diff line change
@@ -1,3 +1,2 @@
CREATE SCHEMA foo;
CREATE TABLE foo.bar (id serial not null);

32 changes: 32 additions & 0 deletions internal/pgx/poolcache/poolcache.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,32 @@
package poolcache

import (
"context"
"sync"

"github.com/jackc/pgx/v5/pgxpool"
)

var lock sync.RWMutex
var pools = map[string]*pgxpool.Pool{}

func New(ctx context.Context, uri string) (*pgxpool.Pool, error) {
lock.RLock()
existing, found := pools[uri]
lock.RUnlock()

if found {
return existing, nil
}

pool, err := pgxpool.New(ctx, uri)
if err != nil {
return nil, err
}

lock.Lock()
pools[uri] = pool
lock.Unlock()

return pool, nil
}
100 changes: 63 additions & 37 deletions internal/sqltest/local/postgres.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,23 +3,31 @@ package local
import (
"context"
"fmt"
"hash/fnv"
"net/url"
"os"
"strings"
"sync"
"testing"

"github.com/jackc/pgx/v5"
"github.com/jackc/pgx/v5/pgxpool"
"golang.org/x/sync/singleflight"

migrate "github.com/sqlc-dev/sqlc/internal/migrations"
"github.com/sqlc-dev/sqlc/internal/pgx/poolcache"
"github.com/sqlc-dev/sqlc/internal/sql/sqlpath"
)

var postgresPool *pgxpool.Pool
var postgresSync sync.Once
var flight singleflight.Group

func PostgreSQL(t *testing.T, migrations []string) string {
return postgreSQL(t, migrations, true)
}

func ReadOnlyPostgreSQL(t *testing.T, migrations []string) string {
return postgreSQL(t, migrations, false)
}

func postgreSQL(t *testing.T, migrations []string, rw bool) string {
ctx := context.Background()
t.Helper()

Expand All @@ -28,65 +36,83 @@ func PostgreSQL(t *testing.T, migrations []string) string {
t.Skip("POSTGRESQL_SERVER_URI is empty")
}

postgresSync.Do(func() {
pool, err := pgxpool.New(ctx, dburi)
if err != nil {
t.Fatal(err)
}
postgresPool = pool
})

if postgresPool == nil {
t.Fatalf("PostgreSQL pool creation failed")
postgresPool, err := poolcache.New(ctx, dburi)
if err != nil {
t.Fatalf("PostgreSQL pool creation failed: %s", err)
}

var seed []string
files, err := sqlpath.Glob(migrations)
if err != nil {
t.Fatal(err)
}

h := fnv.New64()
for _, f := range files {
blob, err := os.ReadFile(f)
if err != nil {
t.Fatal(err)
}
h.Write(blob)
seed = append(seed, migrate.RemoveRollbackStatements(string(blob)))
}

uri, err := url.Parse(dburi)
if err != nil {
t.Fatal(err)
var name string
if rw {
name = fmt.Sprintf("sqlc_test_%s", id())
} else {
name = fmt.Sprintf("sqlc_test_%x", h.Sum(nil))
}

name := fmt.Sprintf("sqlc_test_%s", id())

if _, err := postgresPool.Exec(ctx, fmt.Sprintf(`CREATE DATABASE "%s"`, name)); err != nil {
uri, err := url.Parse(dburi)
if err != nil {
t.Fatal(err)
}

uri.Path = name
dropQuery := fmt.Sprintf(`DROP DATABASE IF EXISTS "%s" WITH (FORCE)`, name)

t.Cleanup(func() {
if _, err := postgresPool.Exec(ctx, dropQuery); err != nil {
t.Fatal(err)
key := uri.String()

_, err, _ = flight.Do(key, func() (interface{}, error) {
row := postgresPool.QueryRow(ctx,
fmt.Sprintf(`SELECT datname FROM pg_database WHERE datname = '%s'`, name))

var datname string
if err := row.Scan(&datname); err == nil {
t.Logf("database exists: %s", name)
return nil, nil
}
})

conn, err := pgx.Connect(ctx, uri.String())
if err != nil {
t.Fatalf("connect %s: %s", name, err)
}
defer conn.Close(ctx)
t.Logf("creating database: %s", name)
if _, err := postgresPool.Exec(ctx, fmt.Sprintf(`CREATE DATABASE "%s"`, name)); err != nil {
return nil, err
}

for _, q := range seed {
if len(strings.TrimSpace(q)) == 0 {
continue
conn, err := pgx.Connect(ctx, uri.String())
if err != nil {
return nil, fmt.Errorf("connect %s: %s", name, err)
}
if _, err := conn.Exec(ctx, q); err != nil {
t.Fatalf("%s: %s", q, err)
defer conn.Close(ctx)

for _, q := range seed {
if len(strings.TrimSpace(q)) == 0 {
continue
}
if _, err := conn.Exec(ctx, q); err != nil {
return nil, fmt.Errorf("%s: %s", q, err)
}
}
return nil, nil
})
if rw || err != nil {
t.Cleanup(func() {
if _, err := postgresPool.Exec(ctx, dropQuery); err != nil {
t.Fatalf("failed cleaning up: %s", err)
}
})
}

return uri.String()
if err != nil {
t.Fatalf("create db: %s", err)
}
return key
}
55 changes: 55 additions & 0 deletions scripts/cleanup-test-dbs/main.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,55 @@
package main

import (
"context"
"fmt"
"log"
"os"

"github.com/jackc/pgx/v5"
"github.com/jackc/pgx/v5/pgxpool"
)

func main() {
if err := run(); err != nil {
log.Fatal(err)
}
}

const query = `
SELECT datname
FROM pg_database
WHERE datname LIKE 'sqlc_test_%'
`

func run() error {
ctx := context.Background()
dburi := os.Getenv("POSTGRESQL_SERVER_URI")
if dburi == "" {
return fmt.Errorf("POSTGRESQL_SERVER_URI is empty")
}
pool, err := pgxpool.New(ctx, dburi)
if err != nil {
return err
}

rows, err := pool.Query(ctx, query)
if err != nil {
return err
}

names, err := pgx.CollectRows(rows, pgx.RowTo[string])
if err != nil {
return err
}

for _, name := range names {
drop := fmt.Sprintf(`DROP DATABASE IF EXISTS "%s" WITH (FORCE)`, name)
if _, err := pool.Exec(ctx, drop); err != nil {
return err
}
log.Println("dropping database", name)
}

return nil
}
Loading