Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Validate physical CockroachDB table config value before using it #9191

Merged
merged 3 commits into from
Jun 12, 2020
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
76 changes: 70 additions & 6 deletions physical/cockroachdb/cockroachdb.go
Original file line number Diff line number Diff line change
Expand Up @@ -8,11 +8,13 @@ import (
"strconv"
"strings"
"time"
"unicode"

metrics "github.com/armon/go-metrics"
"github.com/cockroachdb/cockroach-go/crdb"
"github.com/hashicorp/errwrap"
log "github.com/hashicorp/go-hclog"
"github.com/hashicorp/go-multierror"
"github.com/hashicorp/vault/sdk/helper/strutil"
"github.com/hashicorp/vault/sdk/physical"

Expand All @@ -21,8 +23,14 @@ import (
)

// Verify CockroachDBBackend satisfies the correct interfaces
var _ physical.Backend = (*CockroachDBBackend)(nil)
var _ physical.Transactional = (*CockroachDBBackend)(nil)
var (
_ physical.Backend = (*CockroachDBBackend)(nil)
_ physical.Transactional = (*CockroachDBBackend)(nil)
)

const (
defaultTableName = "vault_kv_store"
)

// CockroachDBBackend Backend is a physical backend that stores data
// within a CockroachDB database.
Expand All @@ -44,14 +52,18 @@ func NewCockroachDBBackend(conf map[string]string, logger log.Logger) (physical.
return nil, fmt.Errorf("missing connection_url")
}

dbTable, ok := conf["table"]
if !ok {
dbTable = "vault_kv_store"
dbTable := conf["table"]
if dbTable == "" {
dbTable = defaultTableName
}

err := validateDBTable(dbTable)
if err != nil {
return nil, errwrap.Wrapf("invalid table: {{err}}", err)
}

maxParStr, ok := conf["max_parallel"]
var maxParInt int
var err error
if ok {
maxParInt, err = strconv.Atoi(maxParStr)
if err != nil {
Expand Down Expand Up @@ -239,3 +251,55 @@ func (c *CockroachDBBackend) transaction(tx *sql.Tx, txns []*physical.TxnEntry)
}
return nil
}

// validateDBTable against the CockroachDB rules for table names:
// https://www.cockroachlabs.com/docs/stable/keywords-and-identifiers.html#identifiers
//
// - All values that accept an identifier must:
// - Begin with a Unicode letter or an underscore (_). Subsequent characters can be letters,
// - underscores, digits (0-9), or dollar signs ($).
// - Not equal any SQL keyword unless the keyword is accepted by the element's syntax. For example,
// name accepts Unreserved or Column Name keywords.
//
// The docs do state that we can bypass these rules with double quotes, however I think it
// is safer to just require these rules across the board.
func validateDBTable(dbTable string) (err error) {
// Check if this is 'database.table' formatted. If so, split them apart and check the two
// parts from each other
split := strings.SplitN(dbTable, ".", 2)
if len(split) == 2 {
merr := &multierror.Error{}
merr = multierror.Append(merr, wrapErr("invalid database: %w", validateDBTable(split[0])))
merr = multierror.Append(merr, wrapErr("invalid table name: %w", validateDBTable(split[1])))
return merr.ErrorOrNil()
}

// Disallow SQL keywords as the table name
if sqlKeywords[strings.ToUpper(dbTable)] {
return fmt.Errorf("name must not be a SQL keyword")
}

runes := []rune(dbTable)
for i, r := range runes {
if i == 0 && !unicode.IsLetter(r) && r != '_' {
return fmt.Errorf("must use a letter or an underscore as the first character")
}

if !unicode.IsLetter(r) && r != '_' && !unicode.IsDigit(r) && r != '$' {
return fmt.Errorf("must only contain letters, underscores, digits, and dollar signs")
}

if r == '`' || r == '\'' || r == '"' {
return fmt.Errorf("cannot contain backticks, single quotes, or double quotes")
}
}

return nil
}

func wrapErr(message string, err error) error {
if err == nil {
return nil
}
return fmt.Errorf(message, err)
}
69 changes: 65 additions & 4 deletions physical/cockroachdb/cockroachdb_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -18,8 +18,9 @@ import (
func prepareCockroachDBTestContainer(t *testing.T) (cleanup func(), retURL, tableName string) {
tableName = os.Getenv("CR_TABLE")
if tableName == "" {
tableName = "vault_kv_store"
tableName = defaultTableName
}
t.Logf("Table name: %s", tableName)
retURL = os.Getenv("CR_URL")
if retURL != "" {
return func() {}, retURL, tableName
Expand All @@ -45,8 +46,8 @@ func prepareCockroachDBTestContainer(t *testing.T) (cleanup func(), retURL, tabl
}

retURL = fmt.Sprintf("postgresql://root@localhost:%s/?sslmode=disable", resource.GetPort("26257/tcp"))
database := "database"
tableName = database + ".vault_kv"
database := "vault"
tableName = fmt.Sprintf("%s.%s", database, tableName)

// exponential backoff-retry
if err = pool.Retry(func() error {
Expand All @@ -56,7 +57,7 @@ func prepareCockroachDBTestContainer(t *testing.T) (cleanup func(), retURL, tabl
return err
}
defer db.Close()
_, err = db.Exec("CREATE DATABASE database")
_, err = db.Exec(fmt.Sprintf("CREATE DATABASE %s", database))
return err
}); err != nil {
cleanup()
Expand Down Expand Up @@ -99,3 +100,63 @@ func truncate(t *testing.T, b physical.Backend) {
t.Fatalf("Failed to drop table: %v", err)
}
}

func TestValidateDBTable(t *testing.T) {
type testCase struct {
table string
expectErr bool
}

tests := map[string]testCase{
"first character is letter": {"abcdef", false},
"first character is underscore": {"_bcdef", false},
"exclamation point": {"ab!def", true},
"at symbol": {"ab@def", true},
"hash": {"ab#def", true},
"percent": {"ab%def", true},
"carrot": {"ab^def", true},
"ampersand": {"ab&def", true},
"star": {"ab*def", true},
"left paren": {"ab(def", true},
"right paren": {"ab)def", true},
"dash": {"ab-def", true},
"digit": {"a123ef", false},
"dollar end": {"abcde$", false},
"dollar middle": {"ab$def", false},
"dollar start": {"$bcdef", true},
"backtick prefix": {"`bcdef", true},
"backtick middle": {"ab`def", true},
"backtick suffix": {"abcde`", true},
"single quote prefix": {"'bcdef", true},
"single quote middle": {"ab'def", true},
"single quote suffix": {"abcde'", true},
"double quote prefix": {`"bcdef`, true},
"double quote middle": {`ab"def`, true},
"double quote suffix": {`abcde"`, true},
"underscore with all runes": {"_bcd123__a__$", false},
"all runes": {"abcd123__a__$", false},
"default table name": {defaultTableName, false},
}

for name, test := range tests {
t.Run(name, func(t *testing.T) {
err := validateDBTable(test.table)
if test.expectErr && err == nil {
t.Fatalf("err expected, got nil")
}
if !test.expectErr && err != nil {
t.Fatalf("no error expected, got: %s", err)
}
})
t.Run(fmt.Sprintf("database: %s", name), func(t *testing.T) {
dbTable := fmt.Sprintf("%s.%s", test.table, test.table)
err := validateDBTable(dbTable)
if test.expectErr && err == nil {
t.Fatalf("err expected, got nil")
}
if !test.expectErr && err != nil {
t.Fatalf("no error expected, got: %s", err)
}
})
}
}
Loading