diff --git a/physical/mysql/mysql.go b/physical/mysql/mysql.go index fce3f026488f..eb83923dd97f 100644 --- a/physical/mysql/mysql.go +++ b/physical/mysql/mysql.go @@ -15,8 +15,10 @@ import ( "strings" "sync" "time" + "unicode" log "github.com/hashicorp/go-hclog" + "github.com/hashicorp/go-multierror" metrics "github.com/armon/go-metrics" mysql "github.com/go-sql-driver/mysql" @@ -59,15 +61,21 @@ func NewMySQLBackend(conf map[string]string, logger log.Logger) (physical.Backen return nil, err } - database, ok := conf["database"] - if !ok { + database := conf["database"] + if database == "" { database = "vault" } - table, ok := conf["table"] - if !ok { + table := conf["table"] + if table == "" { table = "vault" } - dbTable := "`" + database + "`.`" + table + "`" + + err = validateDBTable(database, table) + if err != nil { + return nil, err + } + + dbTable := fmt.Sprintf("`%s`.`%s`", database, table) maxParStr, ok := conf["max_parallel"] var maxParInt int @@ -193,6 +201,67 @@ func NewMySQLBackend(conf map[string]string, logger log.Logger) (physical.Backen return m, nil } +// validateDBTable to prevent SQL injection attacks. This ensures that the database and table names only have valid +// characters in them. MySQL allows for more characters that this will allow, but there isn't an easy way of +// representing the full Unicode Basic Multilingual Plane to check against. +// https://dev.mysql.com/doc/refman/5.7/en/identifiers.html +func validateDBTable(db, table string) (err error) { + merr := &multierror.Error{} + merr = multierror.Append(merr, wrapErr("invalid database: %w", validate(db))) + merr = multierror.Append(merr, wrapErr("invalid table: %w", validate(table))) + return merr.ErrorOrNil() +} + +func validate(name string) (err error) { + if name == "" { + return fmt.Errorf("missing name") + } + // From: https://dev.mysql.com/doc/refman/5.7/en/identifiers.html + // - Permitted characters in quoted identifiers include the full Unicode Basic Multilingual Plane (BMP), except U+0000: + // ASCII: U+0001 .. U+007F + // Extended: U+0080 .. U+FFFF + // - ASCII NUL (U+0000) and supplementary characters (U+10000 and higher) are not permitted in quoted or unquoted identifiers. + // - Identifiers may begin with a digit but unless quoted may not consist solely of digits. + // - Database, table, and column names cannot end with space characters. + // + // We are explicitly excluding all space characters (it's easier to deal with) + // The name will be quoted, so the all-digit requirement doesn't apply + runes := []rune(name) + validationErr := fmt.Errorf("invalid character found: can only include printable, non-space characters between [0x0001-0xFFFF]") + for _, r := range runes { + // U+0000 Explicitly disallowed + if r == 0x0000 { + return fmt.Errorf("invalid character: cannot include 0x0000") + } + // Cannot be above 0xFFFF + if r > 0xFFFF { + return fmt.Errorf("invalid character: cannot include any characters above 0xFFFF") + } + if r == '`' { + return fmt.Errorf("invalid character: cannot include '`' character") + } + if r == '\'' || r == '"' { + return fmt.Errorf("invalid character: cannot include quotes") + } + // We are excluding non-printable characters (not mentioned in the docs) + if !unicode.IsPrint(r) { + return validationErr + } + // We are excluding space characters (not mentioned in the docs) + if unicode.IsSpace(r) { + return validationErr + } + } + return nil +} + +func wrapErr(message string, err error) error { + if err == nil { + return nil + } + return fmt.Errorf(message, err) +} + func NewMySQLClient(conf map[string]string, logger log.Logger) (*sql.DB, error) { var err error diff --git a/physical/mysql/mysql_test.go b/physical/mysql/mysql_test.go index 51222639a24d..75d220b9ae4d 100644 --- a/physical/mysql/mysql_test.go +++ b/physical/mysql/mysql_test.go @@ -43,11 +43,11 @@ func TestMySQLPlaintextCatch(t *testing.T) { logger := logging.NewVaultLogger(log.Debug) NewMySQLBackend(map[string]string{ - "address": address, - "database": database, - "table": table, - "username": username, - "password": password, + "address": address, + "database": database, + "table": table, + "username": username, + "password": password, "plaintext_connection_allowed": "false", }, logger) @@ -82,11 +82,11 @@ func TestMySQLBackend(t *testing.T) { logger := logging.NewVaultLogger(log.Debug) b, err := NewMySQLBackend(map[string]string{ - "address": address, - "database": database, - "table": table, - "username": username, - "password": password, + "address": address, + "database": database, + "table": table, + "username": username, + "password": password, "plaintext_connection_allowed": "true", }, logger) @@ -128,12 +128,12 @@ func TestMySQLHABackend(t *testing.T) { // Run vault tests logger := logging.NewVaultLogger(log.Debug) config := map[string]string{ - "address": address, - "database": database, - "table": table, - "username": username, - "password": password, - "ha_enabled": "true", + "address": address, + "database": database, + "table": table, + "username": username, + "password": password, + "ha_enabled": "true", "plaintext_connection_allowed": "true", } @@ -176,12 +176,12 @@ func TestMySQLHABackend_LockFailPanic(t *testing.T) { table := "test" logger := logging.NewVaultLogger(log.Debug) config := map[string]string{ - "address": cfg.Addr, - "database": cfg.DBName, - "table": table, - "username": cfg.User, - "password": cfg.Passwd, - "ha_enabled": "true", + "address": cfg.Addr, + "database": cfg.DBName, + "table": table, + "username": cfg.User, + "password": cfg.Passwd, + "ha_enabled": "true", "plaintext_connection_allowed": "true", } @@ -265,3 +265,80 @@ func TestMySQLHABackend_LockFailPanic(t *testing.T) { t.Fatalf("expected error, got none") } } + +func TestValidateDBTable(t *testing.T) { + type testCase struct { + database string + table string + expectErr bool + } + + tests := map[string]testCase{ + "empty database & table": {"", "", true}, + "empty database": {"", "a", true}, + "empty table": {"a", "", true}, + "ascii database": {"abcde", "a", false}, + "ascii table": {"a", "abcde", false}, + "ascii database & table": {"abcde", "abcde", false}, + "only whitespace db": {" ", "a", true}, + "only whitespace table": {"a", " ", true}, + "whitespace prefix db": {" bcde", "a", true}, + "whitespace middle db": {"ab de", "a", true}, + "whitespace suffix db": {"abcd ", "a", true}, + "whitespace prefix table": {"a", " bcde", true}, + "whitespace middle table": {"a", "ab de", true}, + "whitespace suffix table": {"a", "abcd ", true}, + "backtick prefix db": {"`bcde", "a", true}, + "backtick middle db": {"ab`de", "a", true}, + "backtick suffix db": {"abcd`", "a", true}, + "backtick prefix table": {"a", "`bcde", true}, + "backtick middle table": {"a", "ab`de", true}, + "backtick suffix table": {"a", "abcd`", true}, + "single quote prefix db": {"'bcde", "a", true}, + "single quote middle db": {"ab'de", "a", true}, + "single quote suffix db": {"abcd'", "a", true}, + "single quote prefix table": {"a", "'bcde", true}, + "single quote middle table": {"a", "ab'de", true}, + "single quote suffix table": {"a", "abcd'", true}, + "double quote prefix db": {`"bcde`, "a", true}, + "double quote middle db": {`ab"de`, "a", true}, + "double quote suffix db": {`abcd"`, "a", true}, + "double quote prefix table": {"a", `"bcde`, true}, + "double quote middle table": {"a", `ab"de`, true}, + "double quote suffix table": {"a", `abcd"`, true}, + "0x0000 prefix db": {str(0x0000, 'b', 'c'), "a", true}, + "0x0000 middle db": {str('a', 0x0000, 'c'), "a", true}, + "0x0000 suffix db": {str('a', 'b', 0x0000), "a", true}, + "0x0000 prefix table": {"a", str(0x0000, 'b', 'c'), true}, + "0x0000 middle table": {"a", str('a', 0x0000, 'c'), true}, + "0x0000 suffix table": {"a", str('a', 'b', 0x0000), true}, + "unicode > 0xFFFF prefix db": {str(0x10000, 'b', 'c'), "a", true}, + "unicode > 0xFFFF middle db": {str('a', 0x10000, 'c'), "a", true}, + "unicode > 0xFFFF suffix db": {str('a', 'b', 0x10000), "a", true}, + "unicode > 0xFFFF prefix table": {"a", str(0x10000, 'b', 'c'), true}, + "unicode > 0xFFFF middle table": {"a", str('a', 0x10000, 'c'), true}, + "unicode > 0xFFFF suffix table": {"a", str('a', 'b', 0x10000), true}, + "non-printable prefix db": {str(0x0001, 'b', 'c'), "a", true}, + "non-printable middle db": {str('a', 0x0001, 'c'), "a", true}, + "non-printable suffix db": {str('a', 'b', 0x0001), "a", true}, + "non-printable prefix table": {"a", str(0x0001, 'b', 'c'), true}, + "non-printable middle table": {"a", str('a', 0x0001, 'c'), true}, + "non-printable suffix table": {"a", str('a', 'b', 0x0001), true}, + } + + for name, test := range tests { + t.Run(name, func(t *testing.T) { + err := validateDBTable(test.database, test.table) + if test.expectErr && err == nil { + t.Fatalf("err expected, got nil") + } + if !test.expectErr && err != nil { + t.Fatalf("no error expected, got: %s", err) + } + }) + } +} + +func str(r ...rune) string { + return string(r) +}