Skip to content

Commit

Permalink
postgres/rdspostgres: fix code and tests so tests pass; add URL tests (
Browse files Browse the repository at this point in the history
  • Loading branch information
vangent committed May 29, 2019
1 parent 7c95872 commit 7ab9ebe
Show file tree
Hide file tree
Showing 3 changed files with 72 additions and 11 deletions.
1 change: 0 additions & 1 deletion internal/testing/prereleasechecks.sh
Original file line number Diff line number Diff line change
Expand Up @@ -126,7 +126,6 @@ case "$op" in
terraform init && terraform apply -var region="us-west-1" -auto-approve
;;
run)
# TODO: The TestOpenBadValues test fails.
go test -mod=readonly
;;
cleanup)
Expand Down
10 changes: 5 additions & 5 deletions postgres/rdspostgres/rdspostgres.go
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@ import (
"io"
"net"
"net/url"
"strings"
"time"

"contrib.go.opencensus.io/integrations/ocsql"
Expand Down Expand Up @@ -72,15 +73,14 @@ func init() {
// OpenPostgresURL opens a new RDS database connection wrapped with OpenCensus instrumentation.
func (uo *URLOpener) OpenPostgresURL(ctx context.Context, u *url.URL) (*sql.DB, error) {
cf := new(rds.CertFetcher)
vals := u.Query()
u.RawQuery = vals.Encode()

database := strings.TrimPrefix(u.EscapedPath(), "/")
password, _ := u.User.Password()
params := Params{
Endpoint: u.Host,
User: u.User.Username(),
Password: password,
Database: u.RawPath,
Database: database,
Values: u.Query(),
TraceOpts: uo.TraceOpts,
}
Expand All @@ -96,9 +96,9 @@ func (uo *URLOpener) OpenPostgresURL(ctx context.Context, u *url.URL) (*sql.DB,
func Open(ctx context.Context, provider rds.CertPoolProvider, params *Params) (*sql.DB, func(), error) {
vals := make(url.Values)
for k, v := range params.Values {
// Only permit parameters that do not conflict with other behavior.
// Forbid SSL-related parameters.
if k == "sslmode" || k == "sslcert" || k == "sslkey" || k == "sslrootcert" {
return nil, nil, fmt.Errorf("rdspostgres: open: extra parameter %s not allowed; use Params fields instead", k)
return nil, nil, fmt.Errorf("rdspostgres: open: parameter %q not allowed; sslmode must be disabled because the underlying dialer is already providing TLS", k)
}
vals[k] = v
}
Expand Down
72 changes: 67 additions & 5 deletions postgres/rdspostgres/rdspostgres_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -16,12 +16,14 @@ package rdspostgres

import (
"context"
"fmt"
"net/url"
"strings"
"testing"

"gocloud.dev/aws/rds"
"gocloud.dev/internal/testing/terraform"
"gocloud.dev/postgres"
)

func TestOpen(t *testing.T) {
Expand Down Expand Up @@ -62,6 +64,68 @@ func TestOpen(t *testing.T) {
}
}

func TestOpenWithURL(t *testing.T) {
// This test will be skipped unless the project is set up with Terraform.
// Before running go test, run in this directory:
//
// terraform init
// terraform apply

tfOut, err := terraform.ReadOutput(".")
if err != nil {
t.Skipf("Could not obtain harness info: %v", err)
}
endpoint, _ := tfOut["endpoint"].Value.(string)
username, _ := tfOut["username"].Value.(string)
password, _ := tfOut["password"].Value.(string)
databaseName, _ := tfOut["database"].Value.(string)
if endpoint == "" || username == "" || databaseName == "" {
t.Fatalf("Missing one or more required Terraform outputs; got endpoint=%q username=%q database=%q", endpoint, username, databaseName)
}

ctx := context.Background()

tests := []struct {
urlstr string
wantErr bool
wantPingErr bool
}{
// OK.
{fmt.Sprintf("rdspostgres://%s:%s@%s/%s", username, password, endpoint, databaseName), false, false},
// Invalid URL parameters: db creation fails.
{fmt.Sprintf("rdspostgres://%s:%s@%s/%s?sslcert=foo", username, password, endpoint, databaseName), true, false},
{fmt.Sprintf("rdspostgres://%s:%s@%s/%s?sslkey=foo", username, password, endpoint, databaseName), true, false},
{fmt.Sprintf("rdspostgres://%s:%s@%s/%s?sslrootcert=foo", username, password, endpoint, databaseName), true, false},
{fmt.Sprintf("rdspostgres://%s:%s@%s/%s?sslmode=require", username, password, endpoint, databaseName), true, false},
// Invalid connection info: db is created, but Ping fails.
{fmt.Sprintf("rdspostgres://%s:badpwd@%s/%s", username, endpoint, databaseName), false, true},
{fmt.Sprintf("rdspostgres://badusername:%s@%s/%s", password, endpoint, databaseName), false, true},
{fmt.Sprintf("rdspostgres://%s:%s@localhost:9999/%s", username, password, databaseName), false, true},
{fmt.Sprintf("rdspostgres://%s:%s@%s/wrongdbname", username, password, endpoint), false, true},
{fmt.Sprintf("rdspostgres://%s:%s@%s/%s?foo=bar", username, password, endpoint, databaseName), false, true},
}
for _, test := range tests {
t.Run(test.urlstr, func(t *testing.T) {
db, err := postgres.Open(ctx, test.urlstr)
if err != nil != test.wantErr {
t.Fatalf("got err %v, wanted error? %v", err, test.wantErr)
}
if err != nil {
return
}
defer func() {
if err := db.Close(); err != nil {
t.Error("Close:", err)
}
}()
err = db.Ping()
if err != nil != test.wantPingErr {
t.Errorf("ping got err %v, wanted error? %v", err, test.wantPingErr)
}
})
}
}

func TestOpenBadValues(t *testing.T) {
// This test will be skipped unless the project is set up with Terraform.

Expand All @@ -82,12 +146,10 @@ func TestOpenBadValues(t *testing.T) {
tests := []struct {
name, value string
}{
{"user", "foo"},
{"password", "foo"},
{"dbname", "foo"},
{"host", "localhost"},
{"port", "1234"},
{"sslmode", "require"},
{"sslcert", "foo"},
{"sslkey", "foo"},
{"sslrootcert", "foo"},
}
for _, test := range tests {
t.Run(test.name+"="+test.value, func(t *testing.T) {
Expand Down

0 comments on commit 7ab9ebe

Please sign in to comment.