Skip to content

Commit

Permalink
Ensure mysql username and passwords aren't url encoded
Browse files Browse the repository at this point in the history
  • Loading branch information
Lauren Voswinkel committed Jun 10, 2020
1 parent 787ab27 commit 051fe58
Showing 1 changed file with 19 additions and 4 deletions.
23 changes: 19 additions & 4 deletions plugins/database/mysql/connection_producer.go
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@ import (
"database/sql"
"fmt"
"net/url"
"strings"
"sync"
"time"

Expand All @@ -31,7 +32,7 @@ type mySQLConnectionProducer struct {
Password string `json:"password" mapstructure:"password" structs:"password"`

TLSCertificateKeyData []byte `json:"tls_certificate_key" mapstructure:"tls_certificate_key" structs:"-"`
TLSCAData []byte `json:"tls_ca" mapstructure:"tls_ca" structs:"-"`
TLSCAData []byte `json:"tls_ca" mapstructure:"tls_ca" structs:"-"`

// tlsConfigName is a globally unique name that references the TLS config for this instance in the mysql driver
tlsConfigName string
Expand Down Expand Up @@ -64,6 +65,9 @@ func (c *mySQLConnectionProducer) Init(ctx context.Context, conf map[string]inte
return nil, fmt.Errorf("connection_url cannot be empty")
}

c.Type = "mysql"

// Don't escape special characters for MySQL password
password := c.Password

// QueryHelper doesn't do any SQL escaping, but if it starts to do so
Expand All @@ -73,8 +77,6 @@ func (c *mySQLConnectionProducer) Init(ctx context.Context, conf map[string]inte
"password": password,
})

c.Type = "mysql"

if c.MaxOpenConnections == 0 {
c.MaxOpenConnections = 4
}
Expand Down Expand Up @@ -155,7 +157,20 @@ func (c *mySQLConnectionProducer) Connection(ctx context.Context) (interface{},
}
uri.RawQuery = vals.Encode()

c.db, err = sql.Open(dbType, uri.String())
// This convoluted piece is to ensure we're not url encoding any username
// or password information
urlPieces := strings.Split(c.ConnectionURL, "?")
connURL := ""
for i, urlFragment := range urlPieces {
if len(urlPieces) == 1 || i != len(urlPieces) - 1 {
connURL = connURL + urlFragment
}
}
if len(vals.Encode()) > 0 {
connURL = connURL + "?" + vals.Encode()
}

c.db, err = sql.Open(dbType, connURL)
if err != nil {
return nil, err
}
Expand Down

0 comments on commit 051fe58

Please sign in to comment.