Skip to content

Commit

Permalink
Validate physical CockroachDB table config value before using it (#9191)
Browse files Browse the repository at this point in the history
* Validate table name (and database if specified) prior to using it in SQL
  • Loading branch information
pcman312 authored and andaley committed Jul 17, 2020
1 parent f62740e commit 7769184
Show file tree
Hide file tree
Showing 3 changed files with 575 additions and 10 deletions.
76 changes: 70 additions & 6 deletions physical/cockroachdb/cockroachdb.go
Original file line number Diff line number Diff line change
Expand Up @@ -8,11 +8,13 @@ import (
"strconv"
"strings"
"time"
"unicode"

metrics "github.com/armon/go-metrics"
"github.com/cockroachdb/cockroach-go/crdb"
"github.com/hashicorp/errwrap"
log "github.com/hashicorp/go-hclog"
"github.com/hashicorp/go-multierror"
"github.com/hashicorp/vault/sdk/helper/strutil"
"github.com/hashicorp/vault/sdk/physical"

Expand All @@ -21,8 +23,14 @@ import (
)

// Verify CockroachDBBackend satisfies the correct interfaces
var _ physical.Backend = (*CockroachDBBackend)(nil)
var _ physical.Transactional = (*CockroachDBBackend)(nil)
var (
_ physical.Backend = (*CockroachDBBackend)(nil)
_ physical.Transactional = (*CockroachDBBackend)(nil)
)

const (
defaultTableName = "vault_kv_store"
)

// CockroachDBBackend Backend is a physical backend that stores data
// within a CockroachDB database.
Expand All @@ -44,14 +52,18 @@ func NewCockroachDBBackend(conf map[string]string, logger log.Logger) (physical.
return nil, fmt.Errorf("missing connection_url")
}

dbTable, ok := conf["table"]
if !ok {
dbTable = "vault_kv_store"
dbTable := conf["table"]
if dbTable == "" {
dbTable = defaultTableName
}

err := validateDBTable(dbTable)
if err != nil {
return nil, errwrap.Wrapf("invalid table: {{err}}", err)
}

maxParStr, ok := conf["max_parallel"]
var maxParInt int
var err error
if ok {
maxParInt, err = strconv.Atoi(maxParStr)
if err != nil {
Expand Down Expand Up @@ -239,3 +251,55 @@ func (c *CockroachDBBackend) transaction(tx *sql.Tx, txns []*physical.TxnEntry)
}
return nil
}

// validateDBTable against the CockroachDB rules for table names:
// https://www.cockroachlabs.com/docs/stable/keywords-and-identifiers.html#identifiers
//
// - All values that accept an identifier must:
// - Begin with a Unicode letter or an underscore (_). Subsequent characters can be letters,
// - underscores, digits (0-9), or dollar signs ($).
// - Not equal any SQL keyword unless the keyword is accepted by the element's syntax. For example,
// name accepts Unreserved or Column Name keywords.
//
// The docs do state that we can bypass these rules with double quotes, however I think it
// is safer to just require these rules across the board.
func validateDBTable(dbTable string) (err error) {
// Check if this is 'database.table' formatted. If so, split them apart and check the two
// parts from each other
split := strings.SplitN(dbTable, ".", 2)
if len(split) == 2 {
merr := &multierror.Error{}
merr = multierror.Append(merr, wrapErr("invalid database: %w", validateDBTable(split[0])))
merr = multierror.Append(merr, wrapErr("invalid table name: %w", validateDBTable(split[1])))
return merr.ErrorOrNil()
}

// Disallow SQL keywords as the table name
if sqlKeywords[strings.ToUpper(dbTable)] {
return fmt.Errorf("name must not be a SQL keyword")
}

runes := []rune(dbTable)
for i, r := range runes {
if i == 0 && !unicode.IsLetter(r) && r != '_' {
return fmt.Errorf("must use a letter or an underscore as the first character")
}

if !unicode.IsLetter(r) && r != '_' && !unicode.IsDigit(r) && r != '$' {
return fmt.Errorf("must only contain letters, underscores, digits, and dollar signs")
}

if r == '`' || r == '\'' || r == '"' {
return fmt.Errorf("cannot contain backticks, single quotes, or double quotes")
}
}

return nil
}

func wrapErr(message string, err error) error {
if err == nil {
return nil
}
return fmt.Errorf(message, err)
}
69 changes: 65 additions & 4 deletions physical/cockroachdb/cockroachdb_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -18,8 +18,9 @@ import (
func prepareCockroachDBTestContainer(t *testing.T) (cleanup func(), retURL, tableName string) {
tableName = os.Getenv("CR_TABLE")
if tableName == "" {
tableName = "vault_kv_store"
tableName = defaultTableName
}
t.Logf("Table name: %s", tableName)
retURL = os.Getenv("CR_URL")
if retURL != "" {
return func() {}, retURL, tableName
Expand All @@ -45,8 +46,8 @@ func prepareCockroachDBTestContainer(t *testing.T) (cleanup func(), retURL, tabl
}

retURL = fmt.Sprintf("postgresql://root@localhost:%s/?sslmode=disable", resource.GetPort("26257/tcp"))
database := "database"
tableName = database + ".vault_kv"
database := "vault"
tableName = fmt.Sprintf("%s.%s", database, tableName)

// exponential backoff-retry
if err = pool.Retry(func() error {
Expand All @@ -56,7 +57,7 @@ func prepareCockroachDBTestContainer(t *testing.T) (cleanup func(), retURL, tabl
return err
}
defer db.Close()
_, err = db.Exec("CREATE DATABASE database")
_, err = db.Exec(fmt.Sprintf("CREATE DATABASE %s", database))
return err
}); err != nil {
cleanup()
Expand Down Expand Up @@ -99,3 +100,63 @@ func truncate(t *testing.T, b physical.Backend) {
t.Fatalf("Failed to drop table: %v", err)
}
}

func TestValidateDBTable(t *testing.T) {
type testCase struct {
table string
expectErr bool
}

tests := map[string]testCase{
"first character is letter": {"abcdef", false},
"first character is underscore": {"_bcdef", false},
"exclamation point": {"ab!def", true},
"at symbol": {"ab@def", true},
"hash": {"ab#def", true},
"percent": {"ab%def", true},
"carrot": {"ab^def", true},
"ampersand": {"ab&def", true},
"star": {"ab*def", true},
"left paren": {"ab(def", true},
"right paren": {"ab)def", true},
"dash": {"ab-def", true},
"digit": {"a123ef", false},
"dollar end": {"abcde$", false},
"dollar middle": {"ab$def", false},
"dollar start": {"$bcdef", true},
"backtick prefix": {"`bcdef", true},
"backtick middle": {"ab`def", true},
"backtick suffix": {"abcde`", true},
"single quote prefix": {"'bcdef", true},
"single quote middle": {"ab'def", true},
"single quote suffix": {"abcde'", true},
"double quote prefix": {`"bcdef`, true},
"double quote middle": {`ab"def`, true},
"double quote suffix": {`abcde"`, true},
"underscore with all runes": {"_bcd123__a__$", false},
"all runes": {"abcd123__a__$", false},
"default table name": {defaultTableName, false},
}

for name, test := range tests {
t.Run(name, func(t *testing.T) {
err := validateDBTable(test.table)
if test.expectErr && err == nil {
t.Fatalf("err expected, got nil")
}
if !test.expectErr && err != nil {
t.Fatalf("no error expected, got: %s", err)
}
})
t.Run(fmt.Sprintf("database: %s", name), func(t *testing.T) {
dbTable := fmt.Sprintf("%s.%s", test.table, test.table)
err := validateDBTable(dbTable)
if test.expectErr && err == nil {
t.Fatalf("err expected, got nil")
}
if !test.expectErr && err != nil {
t.Fatalf("no error expected, got: %s", err)
}
})
}
}
Loading

0 comments on commit 7769184

Please sign in to comment.