Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Validate physical MySQL database and table config values before using them #9189

Merged
merged 2 commits into from
Jun 12, 2020
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
79 changes: 74 additions & 5 deletions physical/mysql/mysql.go
Original file line number Diff line number Diff line change
Expand Up @@ -15,8 +15,10 @@ import (
"strings"
"sync"
"time"
"unicode"

log "github.com/hashicorp/go-hclog"
"github.com/hashicorp/go-multierror"

metrics "github.com/armon/go-metrics"
mysql "github.com/go-sql-driver/mysql"
Expand Down Expand Up @@ -59,15 +61,21 @@ func NewMySQLBackend(conf map[string]string, logger log.Logger) (physical.Backen
return nil, err
}

database, ok := conf["database"]
if !ok {
database := conf["database"]
if database == "" {
database = "vault"
}
table, ok := conf["table"]
if !ok {
table := conf["table"]
if table == "" {
table = "vault"
}
dbTable := "`" + database + "`.`" + table + "`"

err = validateDBTable(database, table)
if err != nil {
return nil, err
}

dbTable := fmt.Sprintf("`%s`.`%s`", database, table)

maxParStr, ok := conf["max_parallel"]
var maxParInt int
Expand Down Expand Up @@ -193,6 +201,67 @@ func NewMySQLBackend(conf map[string]string, logger log.Logger) (physical.Backen
return m, nil
}

// validateDBTable to prevent SQL injection attacks. This ensures that the database and table names only have valid
// characters in them. MySQL allows for more characters that this will allow, but there isn't an easy way of
// representing the full Unicode Basic Multilingual Plane to check against.
// https://dev.mysql.com/doc/refman/5.7/en/identifiers.html
func validateDBTable(db, table string) (err error) {
merr := &multierror.Error{}
merr = multierror.Append(merr, wrapErr("invalid database: %w", validate(db)))
merr = multierror.Append(merr, wrapErr("invalid table: %w", validate(table)))
return merr.ErrorOrNil()
}

func validate(name string) (err error) {
if name == "" {
return fmt.Errorf("missing name")
}
// From: https://dev.mysql.com/doc/refman/5.7/en/identifiers.html
// - Permitted characters in quoted identifiers include the full Unicode Basic Multilingual Plane (BMP), except U+0000:
// ASCII: U+0001 .. U+007F
// Extended: U+0080 .. U+FFFF
// - ASCII NUL (U+0000) and supplementary characters (U+10000 and higher) are not permitted in quoted or unquoted identifiers.
// - Identifiers may begin with a digit but unless quoted may not consist solely of digits.
// - Database, table, and column names cannot end with space characters.
//
// We are explicitly excluding all space characters (it's easier to deal with)
// The name will be quoted, so the all-digit requirement doesn't apply
runes := []rune(name)
validationErr := fmt.Errorf("invalid character found: can only include printable, non-space characters between [0x0001-0xFFFF]")
for _, r := range runes {
// U+0000 Explicitly disallowed
if r == 0x0000 {
return fmt.Errorf("invalid character: cannot include 0x0000")
}
// Cannot be above 0xFFFF
if r > 0xFFFF {
return fmt.Errorf("invalid character: cannot include any characters above 0xFFFF")
}
if r == '`' {
return fmt.Errorf("invalid character: cannot include '`' character")
}
if r == '\'' || r == '"' {
return fmt.Errorf("invalid character: cannot include quotes")
}
// We are excluding non-printable characters (not mentioned in the docs)
if !unicode.IsPrint(r) {
return validationErr
}
// We are excluding space characters (not mentioned in the docs)
if unicode.IsSpace(r) {
return validationErr
}
}
return nil
}

func wrapErr(message string, err error) error {
if err == nil {
return nil
}
return fmt.Errorf(message, err)
}

func NewMySQLClient(conf map[string]string, logger log.Logger) (*sql.DB, error) {
var err error

Expand Down
121 changes: 99 additions & 22 deletions physical/mysql/mysql_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -43,11 +43,11 @@ func TestMySQLPlaintextCatch(t *testing.T) {
logger := logging.NewVaultLogger(log.Debug)

NewMySQLBackend(map[string]string{
"address": address,
"database": database,
"table": table,
"username": username,
"password": password,
"address": address,
"database": database,
"table": table,
"username": username,
"password": password,
"plaintext_connection_allowed": "false",
}, logger)

Expand Down Expand Up @@ -82,11 +82,11 @@ func TestMySQLBackend(t *testing.T) {
logger := logging.NewVaultLogger(log.Debug)

b, err := NewMySQLBackend(map[string]string{
"address": address,
"database": database,
"table": table,
"username": username,
"password": password,
"address": address,
"database": database,
"table": table,
"username": username,
"password": password,
"plaintext_connection_allowed": "true",
}, logger)

Expand Down Expand Up @@ -128,12 +128,12 @@ func TestMySQLHABackend(t *testing.T) {
// Run vault tests
logger := logging.NewVaultLogger(log.Debug)
config := map[string]string{
"address": address,
"database": database,
"table": table,
"username": username,
"password": password,
"ha_enabled": "true",
"address": address,
"database": database,
"table": table,
"username": username,
"password": password,
"ha_enabled": "true",
"plaintext_connection_allowed": "true",
}

Expand Down Expand Up @@ -176,12 +176,12 @@ func TestMySQLHABackend_LockFailPanic(t *testing.T) {
table := "test"
logger := logging.NewVaultLogger(log.Debug)
config := map[string]string{
"address": cfg.Addr,
"database": cfg.DBName,
"table": table,
"username": cfg.User,
"password": cfg.Passwd,
"ha_enabled": "true",
"address": cfg.Addr,
"database": cfg.DBName,
"table": table,
"username": cfg.User,
"password": cfg.Passwd,
"ha_enabled": "true",
"plaintext_connection_allowed": "true",
}

Expand Down Expand Up @@ -265,3 +265,80 @@ func TestMySQLHABackend_LockFailPanic(t *testing.T) {
t.Fatalf("expected error, got none")
}
}

func TestValidateDBTable(t *testing.T) {
type testCase struct {
database string
table string
expectErr bool
}

tests := map[string]testCase{
"empty database & table": {"", "", true},
"empty database": {"", "a", true},
"empty table": {"a", "", true},
"ascii database": {"abcde", "a", false},
"ascii table": {"a", "abcde", false},
"ascii database & table": {"abcde", "abcde", false},
"only whitespace db": {" ", "a", true},
"only whitespace table": {"a", " ", true},
"whitespace prefix db": {" bcde", "a", true},
"whitespace middle db": {"ab de", "a", true},
"whitespace suffix db": {"abcd ", "a", true},
"whitespace prefix table": {"a", " bcde", true},
"whitespace middle table": {"a", "ab de", true},
"whitespace suffix table": {"a", "abcd ", true},
"backtick prefix db": {"`bcde", "a", true},
"backtick middle db": {"ab`de", "a", true},
"backtick suffix db": {"abcd`", "a", true},
"backtick prefix table": {"a", "`bcde", true},
"backtick middle table": {"a", "ab`de", true},
"backtick suffix table": {"a", "abcd`", true},
"single quote prefix db": {"'bcde", "a", true},
"single quote middle db": {"ab'de", "a", true},
"single quote suffix db": {"abcd'", "a", true},
"single quote prefix table": {"a", "'bcde", true},
"single quote middle table": {"a", "ab'de", true},
"single quote suffix table": {"a", "abcd'", true},
"double quote prefix db": {`"bcde`, "a", true},
"double quote middle db": {`ab"de`, "a", true},
"double quote suffix db": {`abcd"`, "a", true},
"double quote prefix table": {"a", `"bcde`, true},
"double quote middle table": {"a", `ab"de`, true},
"double quote suffix table": {"a", `abcd"`, true},
"0x0000 prefix db": {str(0x0000, 'b', 'c'), "a", true},
"0x0000 middle db": {str('a', 0x0000, 'c'), "a", true},
"0x0000 suffix db": {str('a', 'b', 0x0000), "a", true},
"0x0000 prefix table": {"a", str(0x0000, 'b', 'c'), true},
"0x0000 middle table": {"a", str('a', 0x0000, 'c'), true},
"0x0000 suffix table": {"a", str('a', 'b', 0x0000), true},
"unicode > 0xFFFF prefix db": {str(0x10000, 'b', 'c'), "a", true},
"unicode > 0xFFFF middle db": {str('a', 0x10000, 'c'), "a", true},
"unicode > 0xFFFF suffix db": {str('a', 'b', 0x10000), "a", true},
"unicode > 0xFFFF prefix table": {"a", str(0x10000, 'b', 'c'), true},
"unicode > 0xFFFF middle table": {"a", str('a', 0x10000, 'c'), true},
"unicode > 0xFFFF suffix table": {"a", str('a', 'b', 0x10000), true},
"non-printable prefix db": {str(0x0001, 'b', 'c'), "a", true},
"non-printable middle db": {str('a', 0x0001, 'c'), "a", true},
"non-printable suffix db": {str('a', 'b', 0x0001), "a", true},
"non-printable prefix table": {"a", str(0x0001, 'b', 'c'), true},
"non-printable middle table": {"a", str('a', 0x0001, 'c'), true},
"non-printable suffix table": {"a", str('a', 'b', 0x0001), true},
}

for name, test := range tests {
t.Run(name, func(t *testing.T) {
err := validateDBTable(test.database, test.table)
if test.expectErr && err == nil {
t.Fatalf("err expected, got nil")
}
if !test.expectErr && err != nil {
t.Fatalf("no error expected, got: %s", err)
}
})
}
}

func str(r ...rune) string {
return string(r)
}