diff --git a/CHANGELOG.md b/CHANGELOG.md index e36f4934..ba3dceae 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -12,6 +12,7 @@ * Added `ActiveDirectoryAzCli` and `ActiveDirectoryDeviceCode` authentication types to `azuread` package * Always Encrypted encryption and decryption with 2 hour key cache (#116) * 'pfx', 'MSSQL_CERTIFICATE_STORE', and 'AZURE_KEY_VAULT' encryption key providers +* TDS8 can now be used for connections by setting encrypt="strict" ## 1.5.0 diff --git a/README.md b/README.md index 9360a50c..b4fbd55a 100644 --- a/README.md +++ b/README.md @@ -25,9 +25,10 @@ Other supported formats are listed below. * `connection timeout` - in seconds (default is 0 for no timeout), set to 0 for no timeout. Recommended to set to 0 and use context to manage query and connection timeouts. * `dial timeout` - in seconds (default is 15 times the number of registered protocols), set to 0 for no timeout. * `encrypt` + * `strict` - Data sent between client and server is encrypted E2E using [TDS8](https://learn.microsoft.com/en-us/sql/relational-databases/security/networking/tds-8?view=sql-server-ver16). * `disable` - Data send between client and server is not encrypted. - * `false` - Data sent between client and server is not encrypted beyond the login packet. (Default) - * `true` - Data sent between client and server is encrypted. + * `false`/`optional`/`no`/`0`/`f` - Data sent between client and server is not encrypted beyond the login packet. (Default) + * `true`/`mandatory`/`yes`/`1`/`t` - Data sent between client and server is encrypted. * `app name` - The application name (default is go-mssqldb) * `authenticator` - Can be used to specify use of a registered authentication provider. (e.g. ntlm, winsspi (on windows) or krb5 (on linux)) @@ -56,7 +57,7 @@ Other supported formats are listed below. * `TrustServerCertificate` * false - Server certificate is checked. Default is false if encrypt is specified. * true - Server certificate is not checked. Default is true if encrypt is not specified. If trust server certificate is true, driver accepts any certificate presented by the server and any host name in that certificate. In this mode, TLS is susceptible to man-in-the-middle attacks. This should be used only for testing. -* `certificate` - The file that contains the public key certificate of the CA that signed the SQL Server certificate. The specified certificate overrides the go platform specific CA certificates. +* `certificate` - The file that contains the public key certificate of the CA that signed the SQL Server certificate. The specified certificate overrides the go platform specific CA certificates. Currently, certificates of PEM type are supported. * `hostNameInCertificate` - Specifies the Common Name (CN) in the server certificate. Default value is the server host. * `tlsmin` - Specifies the minimum TLS version for negotiating encryption with the server. Recognized values are `1.0`, `1.1`, `1.2`, `1.3`. If not set to a recognized value the default value for the `tls` package will be used. The default is currently `1.2`. * `ServerSPN` - The kerberos SPN (Service Principal Name) for the server. Default is MSSQLSvc/host:port. @@ -470,6 +471,7 @@ Constrain the provider to an allowed list of key vaults by appending vault host * Always Encrypted - `MSSQL_CERTIFICATE_STORE` provider on Windows - `pfx` provider on Linux and Windows + ## Tests `go test` is used for testing. A running instance of MSSQL server is required. diff --git a/azuread/azuread_test.go b/azuread/azuread_test.go index c8faf81e..bcdedfcb 100644 --- a/azuread/azuread_test.go +++ b/azuread/azuread_test.go @@ -6,15 +6,17 @@ package azuread import ( "bufio" "database/sql" + "encoding/hex" "io" "os" "testing" mssql "github.com/microsoft/go-mssqldb" + "github.com/stretchr/testify/assert" ) func TestAzureSqlAuth(t *testing.T) { - mssqlConfig := testConnParams(t) + mssqlConfig := testConnParams(t, "") conn, err := newConnectorConfig(mssqlConfig) if err != nil { @@ -35,9 +37,31 @@ func TestAzureSqlAuth(t *testing.T) { } +func TestTDS8ConnWithAzureSqlAuth(t *testing.T) { + mssqlConfig := testConnParams(t, ";encrypt=strict;TrustServerCertificate=false;tlsmin=1.2") + conn, err := newConnectorConfig(mssqlConfig) + if err != nil { + t.Fatalf("Unable to get a connector: %v", err) + } + db := sql.OpenDB(conn) + row := db.QueryRow("SELECT protocol_type, CONVERT(varbinary(9),protocol_version),client_net_address from sys.dm_exec_connections where session_id=@@SPID") + if err != nil { + t.Fatal("Prepare failed:", err.Error()) + } + var protocolName string + var tdsver []byte + var clientAddress string + err = row.Scan(&protocolName, &tdsver, &clientAddress) + if err != nil { + t.Fatal("Scan failed:", err.Error()) + } + assert.Equal(t, "TSQL", protocolName, "Protocol name does not match") + assert.Equal(t, "08000000", hex.EncodeToString(tdsver)) +} + // returns parsed connection parameters derived from // environment variables -func testConnParams(t testing.TB) *azureFedAuthConfig { +func testConnParams(t testing.TB, dsnParams string) *azureFedAuthConfig { dsn := os.Getenv("AZURESERVER_DSN") const logFlags = 127 if dsn == "" { @@ -54,7 +78,7 @@ func testConnParams(t testing.TB) *azureFedAuthConfig { if dsn == "" { t.Skip("no azure database connection string. set AZURESERVER_DSN environment variable or create .azureconnstr file") } - config, err := parse(dsn) + config, err := parse(dsn + dsnParams) if err != nil { t.Skip("error parsing connection string ") } diff --git a/internal/akvkeys/utils.go b/internal/akvkeys/utils.go index bd8e2c9a..4c4da456 100644 --- a/internal/akvkeys/utils.go +++ b/internal/akvkeys/utils.go @@ -41,9 +41,13 @@ func CreateRSAKey(client *azkeys.Client) (name string, err error) { Kty: &kt, KeySize: &ks, } - i, _ := rand.Int(rand.Reader, big.NewInt(1000)) + + i, _ := rand.Int(rand.Reader, big.NewInt(1000000)) name = fmt.Sprintf("go-mssqlkey%d", i) _, err = client.CreateKey(context.TODO(), name, rsaKeyParams, nil) + if err != nil { + _, err = client.RecoverDeletedKey(context.TODO(), name, &azkeys.RecoverDeletedKeyOptions{}) + } return } diff --git a/msdsn/conn_str.go b/msdsn/conn_str.go index 544cc4db..549b3963 100644 --- a/msdsn/conn_str.go +++ b/msdsn/conn_str.go @@ -21,10 +21,17 @@ type ( BrowserMsg byte ) +const ( + DsnTypeURL = 1 + DsnTypeOdbc = 2 + DsnTypeAdo = 3 +) + const ( EncryptionOff = 0 EncryptionRequired = 1 EncryptionDisabled = 3 + EncryptionStrict = 4 ) const ( @@ -162,17 +169,19 @@ func parseTLS(params map[string]string, host string) (Encryption, *tls.Config, e var encryption Encryption = EncryptionOff encrypt, ok := params[Encrypt] if ok { - if strings.EqualFold(encrypt, "DISABLE") { + encrypt = strings.ToLower(encrypt) + switch encrypt { + case "mandatory", "yes", "1", "t", "true": + encryption = EncryptionRequired + case "disable": encryption = EncryptionDisabled - } else { - e, err := strconv.ParseBool(encrypt) - if err != nil { - f := "invalid encrypt '%s': %s" - return encryption, nil, fmt.Errorf(f, encrypt, err.Error()) - } - if e { - encryption = EncryptionRequired - } + case "strict": + encryption = EncryptionStrict + case "optional", "no", "0", "f", "false": + encryption = EncryptionOff + default: + f := "invalid encrypt '%s'" + return encryption, nil, fmt.Errorf(f, encrypt) } } else { trustServerCert = true @@ -189,6 +198,9 @@ func parseTLS(params map[string]string, host string) (Encryption, *tls.Config, e certificate := params[Certificate] if encryption != EncryptionDisabled { tlsMin := params[TLSMin] + if encrypt == "strict" { + trustServerCert = false + } tlsConfig, err := SetupTLS(certificate, trustServerCert, host, tlsMin) if err != nil { return encryption, nil, fmt.Errorf("failed to setup TLS: %w", err) @@ -200,28 +212,51 @@ func parseTLS(params map[string]string, host string) (Encryption, *tls.Config, e var skipSetup = errors.New("skip setting up TLS") -func Parse(dsn string) (Config, error) { - p := Config{ - ProtocolParameters: map[string]interface{}{}, - Protocols: []string{}, +func getDsnType(dsn string) int { + if strings.HasPrefix(dsn, "sqlserver://") { + return DsnTypeURL + } + if strings.HasPrefix(dsn, "odbc:") { + return DsnTypeOdbc } + return DsnTypeAdo +} + +func getDsnParams(dsn string) (map[string]string, error) { var params map[string]string var err error - if strings.HasPrefix(dsn, "odbc:") { + + switch getDsnType(dsn) { + case DsnTypeOdbc: params, err = splitConnectionStringOdbc(dsn[len("odbc:"):]) if err != nil { - return p, err + return params, err } - } else if strings.HasPrefix(dsn, "sqlserver://") { + case DsnTypeURL: params, err = splitConnectionStringURL(dsn) if err != nil { - return p, err + return params, err } - } else { + default: params = splitConnectionString(dsn) } + return params, nil +} +func Parse(dsn string) (Config, error) { + p := Config{ + ProtocolParameters: map[string]interface{}{}, + Protocols: []string{}, + } + + var params map[string]string + var err error + + params, err = getDsnParams(dsn) + if err != nil { + return p, err + } p.Parameters = params strlog, ok := params[LogParam] diff --git a/msdsn/conn_str_test.go b/msdsn/conn_str_test.go index 1e001385..3aab7dd1 100644 --- a/msdsn/conn_str_test.go +++ b/msdsn/conn_str_test.go @@ -62,6 +62,7 @@ func TestValidConnectionString(t *testing.T) { {"encrypt=disable", func(p Config) bool { return p.Encryption == EncryptionDisabled }}, {"encrypt=disable;tlsmin=1.1", func(p Config) bool { return p.Encryption == EncryptionDisabled && p.TLSConfig == nil }}, {"encrypt=true", func(p Config) bool { return p.Encryption == EncryptionRequired && p.TLSConfig.MinVersion == 0 }}, + {"encrypt=mandatory", func(p Config) bool { return p.Encryption == EncryptionRequired && p.TLSConfig.MinVersion == 0 }}, {"encrypt=true;tlsmin=1.0", func(p Config) bool { return p.Encryption == EncryptionRequired && p.TLSConfig.MinVersion == tls.VersionTLS10 }}, @@ -74,10 +75,15 @@ func TestValidConnectionString(t *testing.T) { {"encrypt=true;tlsmin=1.2", func(p Config) bool { return p.Encryption == EncryptionRequired && p.TLSConfig.MinVersion == tls.VersionTLS12 }}, + {"encrypt=true;tlsmin=1.3", func(p Config) bool { + return p.Encryption == EncryptionRequired && p.TLSConfig.MinVersion == tls.VersionTLS13 + }}, {"encrypt=true;tlsmin=1.4", func(p Config) bool { return p.Encryption == EncryptionRequired && p.TLSConfig.MinVersion == 0 }}, {"encrypt=false", func(p Config) bool { return p.Encryption == EncryptionOff }}, + {"encrypt=optional", func(p Config) bool { return p.Encryption == EncryptionOff }}, + {"encrypt=strict", func(p Config) bool { return p.Encryption == EncryptionStrict }}, {"connection timeout=3;dial timeout=4;keepalive=5", func(p Config) bool { return p.ConnTimeout == 3*time.Second && p.DialTimeout == 4*time.Second && p.KeepAlive == 5*time.Second }}, diff --git a/tds.go b/tds.go index db6eda14..5142df10 100644 --- a/tds.go +++ b/tds.go @@ -103,6 +103,7 @@ const ( verTDS73 = verTDS73A verTDS73B = 0x730B0003 verTDS74 = 0x74000004 + verTDS80 = 0x08000000 ) // packet types @@ -144,6 +145,7 @@ const ( encryptOn = 1 // Encryption is available and on. encryptNotSup = 2 // Encryption is not available. encryptReq = 3 // Encryption is required. + encryptStrict = 4 ) const ( @@ -1004,6 +1006,8 @@ func preparePreloginFields(p msdsn.Config, fe *featureExtFedAuth) map[uint8][]by encrypt = encryptOn case msdsn.EncryptionOff: encrypt = encryptOff + case msdsn.EncryptionStrict: + encrypt = encryptStrict } v := getDriverVersion(driverVersion) fields := map[uint8][]byte{ @@ -1050,6 +1054,12 @@ func interpretPreloginResponse(p msdsn.Config, fe *featureExtFedAuth, fields map } func prepareLogin(ctx context.Context, c *Connector, p msdsn.Config, logger ContextLogger, auth integratedauth.IntegratedAuthenticator, fe *featureExtFedAuth, packetSize uint32) (l *login, err error) { + var TDSVersion uint32 + if p.Encryption == msdsn.EncryptionStrict { + TDSVersion = verTDS80 + } else { + TDSVersion = verTDS74 + } var typeFlags uint8 if p.ReadOnlyIntent { typeFlags |= fReadOnlyIntent @@ -1062,7 +1072,7 @@ func prepareLogin(ctx context.Context, c *Connector, p msdsn.Config, logger Cont serverName = p.Host } l = &login{ - TDSVersion: verTDS74, + TDSVersion: TDSVersion, PacketSize: packetSize, Database: p.Database, OptionFlags2: fODBC, // to get unlimited TEXTSIZE @@ -1123,8 +1133,29 @@ func prepareLogin(ctx context.Context, c *Connector, p msdsn.Config, logger Cont return l, nil } -func connect(ctx context.Context, c *Connector, logger ContextLogger, p msdsn.Config) (res *tdsSession, err error) { +func getTLSConn(conn *timeoutConn, p msdsn.Config, alpnSeq string) (tlsConn *tls.Conn, err error) { + var config *tls.Config + if pc := p.TLSConfig; pc != nil { + config = pc + } + if config == nil { + config, err = msdsn.SetupTLS("", false, p.Host, "") + if err != nil { + return nil, err + } + } + //Set ALPN Sequence + config.NextProtos = []string{alpnSeq} + tlsConn = tls.Client(conn.c, config) + err = tlsConn.Handshake() + if err != nil { + return nil, fmt.Errorf("TLS Handshake failed: %w", err) + } + return tlsConn, nil +} +func connect(ctx context.Context, c *Connector, logger ContextLogger, p msdsn.Config) (res *tdsSession, err error) { + isTransportEncrypted := false // if instance is specified use instance resolution service if len(p.Instance) > 0 && p.Port != 0 && uint64(p.LogFlags)&logDebug != 0 { // both instance name and port specified @@ -1164,8 +1195,15 @@ initiate_connection: } toconn := newTimeoutConn(conn, p.ConnTimeout) - outbuf := newTdsBuffer(packetSize, toconn) + + if p.Encryption == msdsn.EncryptionStrict { + outbuf.transport, err = getTLSConn(toconn, p, "tds/8.0") + if err != nil { + return nil, err + } + isTransportEncrypted = true + } sess := tdsSession{ buf: outbuf, logger: logger, @@ -1201,42 +1239,47 @@ initiate_connection: return nil, err } - if encrypt != encryptNotSup { - var config *tls.Config - if pc := p.TLSConfig; pc != nil { - config = pc - if config.DynamicRecordSizingDisabled == false { - config = config.Clone() + //We need not perform TLS handshake if the communication channel is already encrypted (encrypt=strict) + if !isTransportEncrypted { + if encrypt != encryptNotSup { + var config *tls.Config + if pc := p.TLSConfig; pc != nil { + config = pc + if !config.DynamicRecordSizingDisabled { + config = config.Clone() + + // fix for https://github.com/microsoft/go-mssqldb/issues/166 + // Go implementation of TLS payload size heuristic algorithm splits single TDS package to multiple TCP segments, + // while SQL Server seems to expect one TCP segment per encrypted TDS package. + // Setting DynamicRecordSizingDisabled to true disables that algorithm and uses 16384 bytes per TLS package + config.DynamicRecordSizingDisabled = true + } + } + if config == nil { + config, err = msdsn.SetupTLS("", false, p.Host, "") + if err != nil { + return nil, err + } - // fix for https://github.com/microsoft/go-mssqldb/issues/166 - // Go implementation of TLS payload size heuristic algorithm splits single TDS package to multiple TCP segments, - // while SQL Server seems to expect one TCP segment per encrypted TDS package. - // Setting DynamicRecordSizingDisabled to true disables that algorithm and uses 16384 bytes per TLS package - config.DynamicRecordSizingDisabled = true } - } - if config == nil { - config, err = msdsn.SetupTLS("", false, p.Host, "") + + // setting up connection handler which will allow wrapping of TLS handshake packets inside TDS stream + handshakeConn := tlsHandshakeConn{buf: outbuf} + passthrough := passthroughConn{c: &handshakeConn} + tlsConn := tls.Client(&passthrough, config) + err = tlsConn.Handshake() + passthrough.c = toconn + outbuf.transport = tlsConn if err != nil { - return nil, err + return nil, fmt.Errorf("TLS Handshake failed: %v", err) } - } - - // setting up connection handler which will allow wrapping of TLS handshake packets inside TDS stream - handshakeConn := tlsHandshakeConn{buf: outbuf} - passthrough := passthroughConn{c: &handshakeConn} - tlsConn := tls.Client(&passthrough, config) - err = tlsConn.Handshake() - passthrough.c = toconn - outbuf.transport = tlsConn - if err != nil { - return nil, fmt.Errorf("TLS Handshake failed: %v", err) - } - if encrypt == encryptOff { - outbuf.afterFirst = func() { - outbuf.transport = toconn + if encrypt == encryptOff { + outbuf.afterFirst = func() { + outbuf.transport = toconn + } } } + } auth, err := integratedauth.GetIntegratedAuthenticator(p) diff --git a/tds_test.go b/tds_test.go index dca94788..a656231e 100644 --- a/tds_test.go +++ b/tds_test.go @@ -665,6 +665,28 @@ func TestSecureConnection(t *testing.T) { } } +func TestTDS8ConnFailure(t *testing.T) { + checkConnStr(t) + tl := testLogger{t: t} + defer tl.StopLogging() + SetLogger(&tl) + config := testConnParams(t) + dsn := config.URL() + if !strings.HasSuffix(strings.Split(dsn.Host, ":")[0], ".database.windows.net") { + t.Skip() + } + dsnParams := dsn.Query() + dsnParams.Set(msdsn.TrustServerCertificate, "true") + dsnParams.Set(msdsn.Encrypt, "strict") + dsnParams.Set(msdsn.TLSMin, "1.2") + dsn.RawQuery = dsnParams.Encode() + + _, err := sql.Open("mssql", dsn.String()) + if err == nil { + t.Fatal("Connection did not fail for unknown CA certificate with encrypt=strict") + } +} + func TestBadCredentials(t *testing.T) { params := testConnParams(t) params.Password = "padpwd"