From 3cc2463cd7ba0de8a2a252246ca0fc477548d52e Mon Sep 17 00:00:00 2001 From: Apoorv Deshmukh Date: Thu, 6 Jul 2023 19:55:30 +0530 Subject: [PATCH 01/12] Accept additional values for encrypt --- msdsn/conn_str.go | 24 ++++++++++++++---------- msdsn/conn_str_test.go | 3 +++ 2 files changed, 17 insertions(+), 10 deletions(-) diff --git a/msdsn/conn_str.go b/msdsn/conn_str.go index 4f71453d..bd06ac47 100644 --- a/msdsn/conn_str.go +++ b/msdsn/conn_str.go @@ -25,6 +25,7 @@ const ( EncryptionOff = 0 EncryptionRequired = 1 EncryptionDisabled = 3 + EncryptionStrict = 4 ) const ( @@ -130,21 +131,24 @@ func parseTLS(params map[string]string, host string) (Encryption, *tls.Config, e var encryption Encryption = EncryptionOff encrypt, ok := params["encrypt"] if ok { - if strings.EqualFold(encrypt, "DISABLE") { + encrypt = strings.ToLower(encrypt) + switch encrypt { + case "mandatory", "yes", "1", "t", "true": + encryption = EncryptionRequired + case "disable": encryption = EncryptionDisabled - } else { - e, err := strconv.ParseBool(encrypt) - if err != nil { - f := "invalid encrypt '%s': %s" - return encryption, nil, fmt.Errorf(f, encrypt, err.Error()) - } - if e { - encryption = EncryptionRequired - } + case "strict": + encryption = EncryptionStrict + case "optional", "no", "0", "f", "false": + encryption = EncryptionOff + default: + f := "invalid encrypt '%s'" + return encryption, nil, fmt.Errorf(f, encrypt) } } else { trustServerCert = true } + trust, ok := params["trustservercertificate"] if ok { var err error diff --git a/msdsn/conn_str_test.go b/msdsn/conn_str_test.go index 5fa1a0ed..31a5f3c1 100644 --- a/msdsn/conn_str_test.go +++ b/msdsn/conn_str_test.go @@ -62,6 +62,7 @@ func TestValidConnectionString(t *testing.T) { {"encrypt=disable", func(p Config) bool { return p.Encryption == EncryptionDisabled }}, {"encrypt=disable;tlsmin=1.1", func(p Config) bool { return p.Encryption == EncryptionDisabled && p.TLSConfig == nil }}, {"encrypt=true", func(p Config) bool { return p.Encryption == EncryptionRequired && p.TLSConfig.MinVersion == 0 }}, + {"encrypt=mandatory", func(p Config) bool { return p.Encryption == EncryptionRequired && p.TLSConfig.MinVersion == 0 }}, {"encrypt=true;tlsmin=1.0", func(p Config) bool { return p.Encryption == EncryptionRequired && p.TLSConfig.MinVersion == tls.VersionTLS10 }}, @@ -78,6 +79,8 @@ func TestValidConnectionString(t *testing.T) { return p.Encryption == EncryptionRequired && p.TLSConfig.MinVersion == 0 }}, {"encrypt=false", func(p Config) bool { return p.Encryption == EncryptionOff }}, + {"encrypt=optional", func(p Config) bool { return p.Encryption == EncryptionOff }}, + {"encrypt=strict", func(p Config) bool { return p.Encryption == EncryptionStrict }}, {"connection timeout=3;dial timeout=4;keepalive=5", func(p Config) bool { return p.ConnTimeout == 3*time.Second && p.DialTimeout == 4*time.Second && p.KeepAlive == 5*time.Second }}, From 386a3fc74dc319289ba3759cd2367a3982db4aed Mon Sep 17 00:00:00 2001 From: Apoorv Deshmukh Date: Fri, 14 Jul 2023 15:26:04 +0530 Subject: [PATCH 02/12] Enable TDS 8.0 is encrypt is strict --- tds.go | 101 +++++++++++++++++++++++++++++++++++++++------------------ 1 file changed, 69 insertions(+), 32 deletions(-) diff --git a/tds.go b/tds.go index d10f9c6b..2fae4b18 100644 --- a/tds.go +++ b/tds.go @@ -143,6 +143,7 @@ const ( encryptOn = 1 // Encryption is available and on. encryptNotSup = 2 // Encryption is not available. encryptReq = 3 // Encryption is required. + encryptStrict = 4 ) const ( @@ -977,6 +978,8 @@ func preparePreloginFields(p msdsn.Config, fe *featureExtFedAuth) map[uint8][]by encrypt = encryptOn case msdsn.EncryptionOff: encrypt = encryptOff + case msdsn.EncryptionStrict: + encrypt = encryptStrict } v := getDriverVersion(driverVersion) fields := map[uint8][]byte{ @@ -1133,8 +1136,38 @@ initiate_connection: } toconn := newTimeoutConn(conn, p.ConnTimeout) - outbuf := newTdsBuffer(packetSize, toconn) + + if p.Encryption == msdsn.EncryptionStrict { + var config *tls.Config + if pc := p.TLSConfig; pc != nil { + config = pc + if config.DynamicRecordSizingDisabled == false { + config = config.Clone() + + // fix for https://github.com/microsoft/go-mssqldb/issues/166 + // Go implementation of TLS payload size heuristic algorithm splits single TDS package to multiple TCP segments, + // while SQL Server seems to expect one TCP segment per encrypted TDS package. + // Setting DynamicRecordSizingDisabled to true disables that algorithm and uses 16384 bytes per TLS package + config.DynamicRecordSizingDisabled = true + } + } + if config == nil { + config, err = msdsn.SetupTLS("", false, p.Host, "") + if err != nil { + return nil, err + } + } + //Set ALPN Sequence + config.NextProtos = []string{"tds/8.0"} + + tlsConn := tls.Client(toconn.c, config) + err = tlsConn.Handshake() + outbuf.transport = tlsConn + if err != nil { + return nil, fmt.Errorf("TLS Handshake failed: %v", err) + } + } sess := tdsSession{ buf: outbuf, logger: logger, @@ -1166,43 +1199,47 @@ initiate_connection: return nil, err } - if encrypt != encryptNotSup { - var config *tls.Config - if pc := p.TLSConfig; pc != nil { - config = pc - if config.DynamicRecordSizingDisabled == false { - config = config.Clone() + if p.Encryption != msdsn.EncryptionStrict { + if encrypt != encryptNotSup { + var config *tls.Config + if pc := p.TLSConfig; pc != nil { + config = pc + if config.DynamicRecordSizingDisabled == false { + config = config.Clone() - // fix for https://github.com/microsoft/go-mssqldb/issues/166 - // Go implementation of TLS payload size heuristic algorithm splits single TDS package to multiple TCP segments, - // while SQL Server seems to expect one TCP segment per encrypted TDS package. - // Setting DynamicRecordSizingDisabled to true disables that algorithm and uses 16384 bytes per TLS package - config.DynamicRecordSizingDisabled = true + // fix for https://github.com/microsoft/go-mssqldb/issues/166 + // Go implementation of TLS payload size heuristic algorithm splits single TDS package to multiple TCP segments, + // while SQL Server seems to expect one TCP segment per encrypted TDS package. + // Setting DynamicRecordSizingDisabled to true disables that algorithm and uses 16384 bytes per TLS package + config.DynamicRecordSizingDisabled = true + } } - } - if config == nil { - config, err = msdsn.SetupTLS("", false, p.Host, "") - if err != nil { - return nil, err + if config == nil { + config, err = msdsn.SetupTLS("", false, p.Host, "") + if err != nil { + return nil, err + } + } - } - // setting up connection handler which will allow wrapping of TLS handshake packets inside TDS stream - handshakeConn := tlsHandshakeConn{buf: outbuf} - passthrough := passthroughConn{c: &handshakeConn} - tlsConn := tls.Client(&passthrough, config) - err = tlsConn.Handshake() - passthrough.c = toconn - outbuf.transport = tlsConn - if err != nil { - return nil, fmt.Errorf("TLS Handshake failed: %v", err) - } - if encrypt == encryptOff { - outbuf.afterFirst = func() { - outbuf.transport = toconn + // setting up connection handler which will allow wrapping of TLS handshake packets inside TDS stream + handshakeConn := tlsHandshakeConn{buf: outbuf} + passthrough := passthroughConn{c: &handshakeConn} + tlsConn := tls.Client(&passthrough, config) + err = tlsConn.Handshake() + passthrough.c = toconn + outbuf.transport = tlsConn + if err != nil { + return nil, fmt.Errorf("TLS Handshake failed: %v", err) + } + if encrypt == encryptOff { + outbuf.afterFirst = func() { + outbuf.transport = toconn + } } } - } + + } //p.Encryption != msdsn.EncryptionStrict auth, err := integratedauth.GetIntegratedAuthenticator(p) if err != nil { From c789659d8770f30ee5d06168b22d7d61d68b40b5 Mon Sep 17 00:00:00 2001 From: Srdjan Bozovic Date: Sat, 15 Jul 2023 20:59:04 +0200 Subject: [PATCH 03/12] fix: protocol version --- tds.go | 9 ++++++++- 1 file changed, 8 insertions(+), 1 deletion(-) diff --git a/tds.go b/tds.go index 2fae4b18..2f04e90d 100644 --- a/tds.go +++ b/tds.go @@ -102,6 +102,7 @@ const ( verTDS73 = verTDS73A verTDS73B = 0x730B0003 verTDS74 = 0x74000004 + verTDS80 = 0x08000000 ) // packet types @@ -1026,6 +1027,12 @@ func interpretPreloginResponse(p msdsn.Config, fe *featureExtFedAuth, fields map } func prepareLogin(ctx context.Context, c *Connector, p msdsn.Config, logger ContextLogger, auth integratedauth.IntegratedAuthenticator, fe *featureExtFedAuth, packetSize uint32) (l *login, err error) { + var TDSVersion uint32 + if(p.Encryption == msdsn.EncryptionStrict) { + TDSVersion = verTDS80 + } else { + TDSVersion = verTDS74 + } var typeFlags uint8 if p.ReadOnlyIntent { typeFlags |= fReadOnlyIntent @@ -1038,7 +1045,7 @@ func prepareLogin(ctx context.Context, c *Connector, p msdsn.Config, logger Cont serverName = p.Host } l = &login{ - TDSVersion: verTDS74, + TDSVersion: TDSVersion, PacketSize: packetSize, Database: p.Database, OptionFlags2: fODBC, // to get unlimited TEXTSIZE From adee951529846a6e3c881407bcadff0887a3a86d Mon Sep 17 00:00:00 2001 From: Apoorv Deshmukh Date: Thu, 27 Jul 2023 19:40:53 +0530 Subject: [PATCH 04/12] Keep certificate in PEM format explicitly to support other formats --- msdsn/conn_str.go | 16 +++++++++++++++- 1 file changed, 15 insertions(+), 1 deletion(-) diff --git a/msdsn/conn_str.go b/msdsn/conn_str.go index bd06ac47..84c603ca 100644 --- a/msdsn/conn_str.go +++ b/msdsn/conn_str.go @@ -3,6 +3,7 @@ package msdsn import ( "crypto/tls" "crypto/x509" + "encoding/pem" "errors" "fmt" "io/ioutil" @@ -91,6 +92,19 @@ type Config struct { BrowserMessage BrowserMsg } +// GetPEMCertificate returns PEM formatted certificate +func GetPEMCertificate(certificate string) ([]byte, error) { + cerData, ok := ioutil.ReadFile(certificate) + if ok != nil { + return nil, fmt.Errorf("cannot read certificate %q: %w", certificate, ok) + } + pemData := pem.EncodeToMemory(&pem.Block{ + Type: "CERTIFICATE", + Bytes: cerData, + }) + return pemData, nil +} + // Build a tls.Config object from the supplied certificate. func SetupTLS(certificate string, insecureSkipVerify bool, hostInCertificate string, minTLSVersion string) (*tls.Config, error) { config := tls.Config{ @@ -108,7 +122,7 @@ func SetupTLS(certificate string, insecureSkipVerify bool, hostInCertificate str if len(certificate) == 0 { return &config, nil } - pem, err := ioutil.ReadFile(certificate) + pem, err := GetPEMCertificate(certificate) if err != nil { return nil, fmt.Errorf("cannot read certificate %q: %w", certificate, err) } From 1e53bf7ea7eb354ebabbb4abddd35ba8c882d26b Mon Sep 17 00:00:00 2001 From: Apoorv Deshmukh Date: Tue, 22 Aug 2023 19:33:40 +0530 Subject: [PATCH 05/12] Add TDS8 testcase --- msdsn/conn_str.go | 16 +-------------- tds.go | 51 ++++++++++++++++++++++------------------------- tds_test.go | 44 ++++++++++++++++++++++++++++++++++++++++ 3 files changed, 69 insertions(+), 42 deletions(-) diff --git a/msdsn/conn_str.go b/msdsn/conn_str.go index f1079a4c..c957657d 100644 --- a/msdsn/conn_str.go +++ b/msdsn/conn_str.go @@ -3,7 +3,6 @@ package msdsn import ( "crypto/tls" "crypto/x509" - "encoding/pem" "errors" "fmt" "io/ioutil" @@ -94,19 +93,6 @@ type Config struct { ColumnEncryption bool } -// GetPEMCertificate returns PEM formatted certificate -func GetPEMCertificate(certificate string) ([]byte, error) { - cerData, ok := ioutil.ReadFile(certificate) - if ok != nil { - return nil, fmt.Errorf("cannot read certificate %q: %w", certificate, ok) - } - pemData := pem.EncodeToMemory(&pem.Block{ - Type: "CERTIFICATE", - Bytes: cerData, - }) - return pemData, nil -} - // Build a tls.Config object from the supplied certificate. func SetupTLS(certificate string, insecureSkipVerify bool, hostInCertificate string, minTLSVersion string) (*tls.Config, error) { config := tls.Config{ @@ -124,7 +110,7 @@ func SetupTLS(certificate string, insecureSkipVerify bool, hostInCertificate str if len(certificate) == 0 { return &config, nil } - pem, err := GetPEMCertificate(certificate) + pem, err := ioutil.ReadFile(certificate) if err != nil { return nil, fmt.Errorf("cannot read certificate %q: %w", certificate, err) } diff --git a/tds.go b/tds.go index 3af87ae8..705aec78 100644 --- a/tds.go +++ b/tds.go @@ -1052,7 +1052,7 @@ func interpretPreloginResponse(p msdsn.Config, fe *featureExtFedAuth, fields map func prepareLogin(ctx context.Context, c *Connector, p msdsn.Config, logger ContextLogger, auth integratedauth.IntegratedAuthenticator, fe *featureExtFedAuth, packetSize uint32) (l *login, err error) { var TDSVersion uint32 - if(p.Encryption == msdsn.EncryptionStrict) { + if p.Encryption == msdsn.EncryptionStrict { TDSVersion = verTDS80 } else { TDSVersion = verTDS74 @@ -1129,6 +1129,27 @@ func prepareLogin(ctx context.Context, c *Connector, p msdsn.Config, logger Cont return l, nil } +func getTLSConn(conn *timeoutConn, p msdsn.Config) (tlsConn *tls.Conn, err error) { + var config *tls.Config + if pc := p.TLSConfig; pc != nil { + config = pc + } + if config == nil { + config, err = msdsn.SetupTLS("", false, p.Host, "") + if err != nil { + return nil, err + } + } + //Set ALPN Sequence + config.NextProtos = []string{"tds/8.0"} + tlsConn = tls.Client(conn.c, config) + err = tlsConn.Handshake() + if err != nil { + return nil, fmt.Errorf("TLS Handshake failed: %v", err) + } + return tlsConn, nil +} + func connect(ctx context.Context, c *Connector, logger ContextLogger, p msdsn.Config) (res *tdsSession, err error) { // if instance is specified use instance resolution service @@ -1173,33 +1194,9 @@ initiate_connection: outbuf := newTdsBuffer(packetSize, toconn) if p.Encryption == msdsn.EncryptionStrict { - var config *tls.Config - if pc := p.TLSConfig; pc != nil { - config = pc - if config.DynamicRecordSizingDisabled == false { - config = config.Clone() - - // fix for https://github.com/microsoft/go-mssqldb/issues/166 - // Go implementation of TLS payload size heuristic algorithm splits single TDS package to multiple TCP segments, - // while SQL Server seems to expect one TCP segment per encrypted TDS package. - // Setting DynamicRecordSizingDisabled to true disables that algorithm and uses 16384 bytes per TLS package - config.DynamicRecordSizingDisabled = true - } - } - if config == nil { - config, err = msdsn.SetupTLS("", false, p.Host, "") - if err != nil { - return nil, err - } - } - //Set ALPN Sequence - config.NextProtos = []string{"tds/8.0"} - - tlsConn := tls.Client(toconn.c, config) - err = tlsConn.Handshake() - outbuf.transport = tlsConn + outbuf.transport, err = getTLSConn(toconn, p) if err != nil { - return nil, fmt.Errorf("TLS Handshake failed: %v", err) + return nil, err } } sess := tdsSession{ diff --git a/tds_test.go b/tds_test.go index daabd714..6acdc81e 100644 --- a/tds_test.go +++ b/tds_test.go @@ -662,6 +662,50 @@ func TestSecureConnection(t *testing.T) { } } +func TestTDS8Connection(t *testing.T) { + checkConnStr(t) + tl := testLogger{t: t} + defer tl.StopLogging() + SetLogger(&tl) + + dsn := makeConnStr(t) + if !strings.HasSuffix(strings.Split(dsn.Host, ":")[0], ".database.windows.net") { + t.Skip() + } + dsnParams := dsn.Query() + dsnParams.Set("encrypt", "strict") + dsnParams.Set("TrustServerCertificate", "false") + dsnParams.Set("tlsmin", "1.2") + dsn.RawQuery = dsnParams.Encode() + + conn, err := sql.Open("mssql", dsn.String()) + if err != nil { + t.Fatal("Open connection failed:", err.Error()) + } + defer conn.Close() + stmt, err := conn.Prepare("SELECT protocol_type, CONVERT(varbinary(9),protocol_version),client_net_address from sys.dm_exec_connections where session_id=@@SPID") + if err != nil { + t.Fatal("Prepare failed:", err.Error()) + } + defer stmt.Close() + row := stmt.QueryRow() + var protocolName string + var tdsver []byte + var clientAddress string + err = row.Scan(&protocolName, &tdsver, &clientAddress) + if err != nil { + t.Fatal("Scan failed:", err.Error()) + } + assertEqual(t, "TSQL", protocolName) + assertEqual(t, "0x08000000", hex.EncodeToString(tdsver)) +} + +func assertEqual(t *testing.T, expected interface{}, actual interface{}) { + if expected != actual { + t.Fatalf("Expected %v, got %v", expected, actual) + } +} + func TestBadCredentials(t *testing.T) { params := testConnParams(t) params.Password = "padpwd" From 291c8eb6815b53456b091f8fabc8ed9d26b45f98 Mon Sep 17 00:00:00 2001 From: Apoorv Deshmukh Date: Tue, 22 Aug 2023 19:44:59 +0530 Subject: [PATCH 06/12] Address review suggestions --- tds.go | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tds.go b/tds.go index 79a9e9d5..cb5c62f7 100644 --- a/tds.go +++ b/tds.go @@ -1243,7 +1243,7 @@ initiate_connection: var config *tls.Config if pc := p.TLSConfig; pc != nil { config = pc - if config.DynamicRecordSizingDisabled == false { + if !config.DynamicRecordSizingDisabled { config = config.Clone() // fix for https://github.com/microsoft/go-mssqldb/issues/166 From 18ea4dc00dfc13de384c12fbd7985db0a8d0edda Mon Sep 17 00:00:00 2001 From: Apoorv Deshmukh Date: Thu, 31 Aug 2023 18:47:16 +0530 Subject: [PATCH 07/12] Modify testcases and update README --- CHANGELOG.md | 1 + README.md | 8 ++++--- azuread/azuread_test.go | 35 ++++++++++++++++++++++++++--- msdsn/conn_str.go | 50 +++++++++++++++++++++++++++++++++-------- tds.go | 16 +++++++------ tds_test.go | 40 ++++++++------------------------- 6 files changed, 97 insertions(+), 53 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index ceabdbb8..f2e78459 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -11,6 +11,7 @@ * Always Encrypted encryption and decryption with 2 hour key cache (#116) * 'pfx', 'MSSQL_CERTIFICATE_STORE', and 'AZURE_KEY_VAULT' encryption key providers +* TDS8 can now be used for connections by setting encrypt="strict" ## 1.5.0 diff --git a/README.md b/README.md index 5b5f363d..69fc3710 100644 --- a/README.md +++ b/README.md @@ -25,9 +25,10 @@ Other supported formats are listed below. * `connection timeout` - in seconds (default is 0 for no timeout), set to 0 for no timeout. Recommended to set to 0 and use context to manage query and connection timeouts. * `dial timeout` - in seconds (default is 15 times the number of registered protocols), set to 0 for no timeout. * `encrypt` + * `strict` - Data sent between client and server is encrypted E2E using [TDS8](https://learn.microsoft.com/en-us/sql/relational-databases/security/networking/tds-8?view=sql-server-ver16). * `disable` - Data send between client and server is not encrypted. - * `false` - Data sent between client and server is not encrypted beyond the login packet. (Default) - * `true` - Data sent between client and server is encrypted. + * `false`/`optional`/`no`/`0`/`f` - Data sent between client and server is not encrypted beyond the login packet. (Default) + * `true`/`mandatory`/`yes`/`1`/`t` - Data sent between client and server is encrypted. * `app name` - The application name (default is go-mssqldb) * `authenticator` - Can be used to specify use of a registered authentication provider. (e.g. ntlm, winsspi (on windows) or krb5 (on linux)) @@ -56,7 +57,7 @@ Other supported formats are listed below. * `TrustServerCertificate` * false - Server certificate is checked. Default is false if encrypt is specified. * true - Server certificate is not checked. Default is true if encrypt is not specified. If trust server certificate is true, driver accepts any certificate presented by the server and any host name in that certificate. In this mode, TLS is susceptible to man-in-the-middle attacks. This should be used only for testing. -* `certificate` - The file that contains the public key certificate of the CA that signed the SQL Server certificate. The specified certificate overrides the go platform specific CA certificates. +* `certificate` - The file that contains the public key certificate of the CA that signed the SQL Server certificate. The specified certificate overrides the go platform specific CA certificates. Currently, certificates of PEM type are supported. * `hostNameInCertificate` - Specifies the Common Name (CN) in the server certificate. Default value is the server host. * `tlsmin` - Specifies the minimum TLS version for negotiating encryption with the server. Recognized values are `1.0`, `1.1`, `1.2`, `1.3`. If not set to a recognized value the default value for the `tls` package will be used. The default is currently `1.2`. * `ServerSPN` - The kerberos SPN (Service Principal Name) for the server. Default is MSSQLSvc/host:port. @@ -468,6 +469,7 @@ Constrain the provider to an allowed list of key vaults by appending vault host * Always Encrypted - `MSSQL_CERTIFICATE_STORE` provider on Windows - `pfx` provider on Linux and Windows + ## Tests `go test` is used for testing. A running instance of MSSQL server is required. diff --git a/azuread/azuread_test.go b/azuread/azuread_test.go index c8faf81e..75590305 100644 --- a/azuread/azuread_test.go +++ b/azuread/azuread_test.go @@ -6,6 +6,7 @@ package azuread import ( "bufio" "database/sql" + "encoding/hex" "io" "os" "testing" @@ -14,7 +15,7 @@ import ( ) func TestAzureSqlAuth(t *testing.T) { - mssqlConfig := testConnParams(t) + mssqlConfig := testConnParams(t, "") conn, err := newConnectorConfig(mssqlConfig) if err != nil { @@ -35,9 +36,31 @@ func TestAzureSqlAuth(t *testing.T) { } +func TestTDS8ConnWithAzureSqlAuth(t *testing.T) { + mssqlConfig := testConnParams(t, ";encrypt=strict;TrustServerCertificate=false;tlsmin=1.2") + conn, err := newConnectorConfig(mssqlConfig) + if err != nil { + t.Fatalf("Unable to get a connector: %v", err) + } + db := sql.OpenDB(conn) + row := db.QueryRow("SELECT protocol_type, CONVERT(varbinary(9),protocol_version),client_net_address from sys.dm_exec_connections where session_id=@@SPID") + if err != nil { + t.Fatal("Prepare failed:", err.Error()) + } + var protocolName string + var tdsver []byte + var clientAddress string + err = row.Scan(&protocolName, &tdsver, &clientAddress) + if err != nil { + t.Fatal("Scan failed:", err.Error()) + } + assertEqual(t, "TSQL", protocolName) + assertEqual(t, "0x08000000", hex.EncodeToString(tdsver)) +} + // returns parsed connection parameters derived from // environment variables -func testConnParams(t testing.TB) *azureFedAuthConfig { +func testConnParams(t testing.TB, dsnParams string) *azureFedAuthConfig { dsn := os.Getenv("AZURESERVER_DSN") const logFlags = 127 if dsn == "" { @@ -54,7 +77,7 @@ func testConnParams(t testing.TB) *azureFedAuthConfig { if dsn == "" { t.Skip("no azure database connection string. set AZURESERVER_DSN environment variable or create .azureconnstr file") } - config, err := parse(dsn) + config, err := parse(dsn + dsnParams) if err != nil { t.Skip("error parsing connection string ") } @@ -64,3 +87,9 @@ func testConnParams(t testing.TB) *azureFedAuthConfig { config.mssqlConfig.LogFlags = logFlags return config } + +func assertEqual(t *testing.T, expected interface{}, actual interface{}) { + if expected != actual { + t.Fatalf("Expected %v, got %v", expected, actual) + } +} diff --git a/msdsn/conn_str.go b/msdsn/conn_str.go index 39027bed..02e29ab1 100644 --- a/msdsn/conn_str.go +++ b/msdsn/conn_str.go @@ -21,6 +21,12 @@ type ( BrowserMsg byte ) +const ( + DsnTypeUrl = 1 + DsnTypeOdbc = 2 + DsnTypeAdo = 3 +) + const ( EncryptionOff = 0 EncryptionRequired = 1 @@ -192,6 +198,9 @@ func parseTLS(params map[string]string, host string) (Encryption, *tls.Config, e certificate := params[Certificate] if encryption != EncryptionDisabled { tlsMin := params[TLSMin] + if encrypt == "strict" { + trustServerCert = false + } tlsConfig, err := SetupTLS(certificate, trustServerCert, host, tlsMin) if err != nil { return encryption, nil, fmt.Errorf("failed to setup TLS: %w", err) @@ -203,28 +212,51 @@ func parseTLS(params map[string]string, host string) (Encryption, *tls.Config, e var skipSetup = errors.New("skip setting up TLS") -func Parse(dsn string) (Config, error) { - p := Config{ - ProtocolParameters: map[string]interface{}{}, - Protocols: []string{}, +func getDsnType(dsn string) int { + if strings.HasPrefix(dsn, "sqlserver://") { + return DsnTypeUrl + } + if strings.HasPrefix(dsn, "odbc:") { + return DsnTypeOdbc } + return DsnTypeAdo +} + +func getDsnParams(dsn string) (map[string]string, error) { var params map[string]string var err error - if strings.HasPrefix(dsn, "odbc:") { + + switch getDsnType(dsn) { + case DsnTypeOdbc: params, err = splitConnectionStringOdbc(dsn[len("odbc:"):]) if err != nil { - return p, err + return params, err } - } else if strings.HasPrefix(dsn, "sqlserver://") { + case DsnTypeUrl: params, err = splitConnectionStringURL(dsn) if err != nil { - return p, err + return params, err } - } else { + default: params = splitConnectionString(dsn) } + return params, nil +} + +func Parse(dsn string) (Config, error) { + p := Config{ + ProtocolParameters: map[string]interface{}{}, + Protocols: []string{}, + } + + var params map[string]string + var err error + params, err = getDsnParams(dsn) + if err != nil { + return p, err + } p.Parameters = params strlog, ok := params[LogParam] diff --git a/tds.go b/tds.go index cb5c62f7..dd6be6b5 100644 --- a/tds.go +++ b/tds.go @@ -1133,7 +1133,7 @@ func prepareLogin(ctx context.Context, c *Connector, p msdsn.Config, logger Cont return l, nil } -func getTLSConn(conn *timeoutConn, p msdsn.Config) (tlsConn *tls.Conn, err error) { +func getTLSConn(conn *timeoutConn, p msdsn.Config, alpnSeq string) (tlsConn *tls.Conn, err error) { var config *tls.Config if pc := p.TLSConfig; pc != nil { config = pc @@ -1145,17 +1145,17 @@ func getTLSConn(conn *timeoutConn, p msdsn.Config) (tlsConn *tls.Conn, err error } } //Set ALPN Sequence - config.NextProtos = []string{"tds/8.0"} + config.NextProtos = []string{alpnSeq} tlsConn = tls.Client(conn.c, config) err = tlsConn.Handshake() if err != nil { - return nil, fmt.Errorf("TLS Handshake failed: %v", err) + return nil, fmt.Errorf("TLS Handshake failed: %w", err) } return tlsConn, nil } func connect(ctx context.Context, c *Connector, logger ContextLogger, p msdsn.Config) (res *tdsSession, err error) { - + isTransportEncrypted := false // if instance is specified use instance resolution service if len(p.Instance) > 0 && p.Port != 0 && uint64(p.LogFlags)&logDebug != 0 { // both instance name and port specified @@ -1198,10 +1198,11 @@ initiate_connection: outbuf := newTdsBuffer(packetSize, toconn) if p.Encryption == msdsn.EncryptionStrict { - outbuf.transport, err = getTLSConn(toconn, p) + outbuf.transport, err = getTLSConn(toconn, p, "tds/8.0") if err != nil { return nil, err } + isTransportEncrypted = true } sess := tdsSession{ buf: outbuf, @@ -1238,7 +1239,8 @@ initiate_connection: return nil, err } - if p.Encryption != msdsn.EncryptionStrict { + //We need not perform TLS handshake if the communication channel is already encrypted (encrypt=strict) + if isTransportEncrypted { if encrypt != encryptNotSup { var config *tls.Config if pc := p.TLSConfig; pc != nil { @@ -1278,7 +1280,7 @@ initiate_connection: } } - } //p.Encryption != msdsn.EncryptionStrict + } auth, err := integratedauth.GetIntegratedAuthenticator(p) if err != nil { diff --git a/tds_test.go b/tds_test.go index 97933e66..a656231e 100644 --- a/tds_test.go +++ b/tds_test.go @@ -665,47 +665,25 @@ func TestSecureConnection(t *testing.T) { } } -func TestTDS8Connection(t *testing.T) { +func TestTDS8ConnFailure(t *testing.T) { checkConnStr(t) tl := testLogger{t: t} defer tl.StopLogging() SetLogger(&tl) - - dsn := makeConnStr(t) + config := testConnParams(t) + dsn := config.URL() if !strings.HasSuffix(strings.Split(dsn.Host, ":")[0], ".database.windows.net") { t.Skip() } dsnParams := dsn.Query() - dsnParams.Set("encrypt", "strict") - dsnParams.Set("TrustServerCertificate", "false") - dsnParams.Set("tlsmin", "1.2") + dsnParams.Set(msdsn.TrustServerCertificate, "true") + dsnParams.Set(msdsn.Encrypt, "strict") + dsnParams.Set(msdsn.TLSMin, "1.2") dsn.RawQuery = dsnParams.Encode() - conn, err := sql.Open("mssql", dsn.String()) - if err != nil { - t.Fatal("Open connection failed:", err.Error()) - } - defer conn.Close() - stmt, err := conn.Prepare("SELECT protocol_type, CONVERT(varbinary(9),protocol_version),client_net_address from sys.dm_exec_connections where session_id=@@SPID") - if err != nil { - t.Fatal("Prepare failed:", err.Error()) - } - defer stmt.Close() - row := stmt.QueryRow() - var protocolName string - var tdsver []byte - var clientAddress string - err = row.Scan(&protocolName, &tdsver, &clientAddress) - if err != nil { - t.Fatal("Scan failed:", err.Error()) - } - assertEqual(t, "TSQL", protocolName) - assertEqual(t, "0x08000000", hex.EncodeToString(tdsver)) -} - -func assertEqual(t *testing.T, expected interface{}, actual interface{}) { - if expected != actual { - t.Fatalf("Expected %v, got %v", expected, actual) + _, err := sql.Open("mssql", dsn.String()) + if err == nil { + t.Fatal("Connection did not fail for unknown CA certificate with encrypt=strict") } } From e91b8c12ca08cee45ec8764cd50c30d252f616f5 Mon Sep 17 00:00:00 2001 From: Apoorv Deshmukh Date: Thu, 31 Aug 2023 19:34:04 +0530 Subject: [PATCH 08/12] Change const --- msdsn/conn_str.go | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/msdsn/conn_str.go b/msdsn/conn_str.go index 02e29ab1..549b3963 100644 --- a/msdsn/conn_str.go +++ b/msdsn/conn_str.go @@ -22,7 +22,7 @@ type ( ) const ( - DsnTypeUrl = 1 + DsnTypeURL = 1 DsnTypeOdbc = 2 DsnTypeAdo = 3 ) @@ -214,7 +214,7 @@ var skipSetup = errors.New("skip setting up TLS") func getDsnType(dsn string) int { if strings.HasPrefix(dsn, "sqlserver://") { - return DsnTypeUrl + return DsnTypeURL } if strings.HasPrefix(dsn, "odbc:") { return DsnTypeOdbc @@ -233,7 +233,7 @@ func getDsnParams(dsn string) (map[string]string, error) { if err != nil { return params, err } - case DsnTypeUrl: + case DsnTypeURL: params, err = splitConnectionStringURL(dsn) if err != nil { return params, err From 7da34add1f31c1221afd1e58f2a0cda45654d437 Mon Sep 17 00:00:00 2001 From: Apoorv Deshmukh Date: Thu, 31 Aug 2023 20:12:15 +0530 Subject: [PATCH 09/12] Correct transport check --- tds.go | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tds.go b/tds.go index dd6be6b5..5142df10 100644 --- a/tds.go +++ b/tds.go @@ -1240,7 +1240,7 @@ initiate_connection: } //We need not perform TLS handshake if the communication channel is already encrypted (encrypt=strict) - if isTransportEncrypted { + if !isTransportEncrypted { if encrypt != encryptNotSup { var config *tls.Config if pc := p.TLSConfig; pc != nil { From d090602eae7fb3bd4cf20348f289431b96797ea7 Mon Sep 17 00:00:00 2001 From: Apoorv Deshmukh Date: Thu, 31 Aug 2023 21:12:28 +0530 Subject: [PATCH 10/12] Use testify for asserts --- azuread/azuread_test.go | 11 +++-------- msdsn/conn_str_test.go | 3 +++ 2 files changed, 6 insertions(+), 8 deletions(-) diff --git a/azuread/azuread_test.go b/azuread/azuread_test.go index 75590305..0b5a42e0 100644 --- a/azuread/azuread_test.go +++ b/azuread/azuread_test.go @@ -12,6 +12,7 @@ import ( "testing" mssql "github.com/microsoft/go-mssqldb" + "github.com/stretchr/testify/assert" ) func TestAzureSqlAuth(t *testing.T) { @@ -54,8 +55,8 @@ func TestTDS8ConnWithAzureSqlAuth(t *testing.T) { if err != nil { t.Fatal("Scan failed:", err.Error()) } - assertEqual(t, "TSQL", protocolName) - assertEqual(t, "0x08000000", hex.EncodeToString(tdsver)) + assert.Equal(t, "TSQL", protocolName, "Protocol name does not match") + assert.Equal(t, "0x08000000", hex.EncodeToString(tdsver)) } // returns parsed connection parameters derived from @@ -87,9 +88,3 @@ func testConnParams(t testing.TB, dsnParams string) *azureFedAuthConfig { config.mssqlConfig.LogFlags = logFlags return config } - -func assertEqual(t *testing.T, expected interface{}, actual interface{}) { - if expected != actual { - t.Fatalf("Expected %v, got %v", expected, actual) - } -} diff --git a/msdsn/conn_str_test.go b/msdsn/conn_str_test.go index 329caf8f..3aab7dd1 100644 --- a/msdsn/conn_str_test.go +++ b/msdsn/conn_str_test.go @@ -75,6 +75,9 @@ func TestValidConnectionString(t *testing.T) { {"encrypt=true;tlsmin=1.2", func(p Config) bool { return p.Encryption == EncryptionRequired && p.TLSConfig.MinVersion == tls.VersionTLS12 }}, + {"encrypt=true;tlsmin=1.3", func(p Config) bool { + return p.Encryption == EncryptionRequired && p.TLSConfig.MinVersion == tls.VersionTLS13 + }}, {"encrypt=true;tlsmin=1.4", func(p Config) bool { return p.Encryption == EncryptionRequired && p.TLSConfig.MinVersion == 0 }}, From 36bf0bc770a4a07e744ca50d7955fbfa4a417da9 Mon Sep 17 00:00:00 2001 From: Apoorv Deshmukh Date: Thu, 31 Aug 2023 21:50:54 +0530 Subject: [PATCH 11/12] Correct test expectation --- azuread/azuread_test.go | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/azuread/azuread_test.go b/azuread/azuread_test.go index 0b5a42e0..bcdedfcb 100644 --- a/azuread/azuread_test.go +++ b/azuread/azuread_test.go @@ -56,7 +56,7 @@ func TestTDS8ConnWithAzureSqlAuth(t *testing.T) { t.Fatal("Scan failed:", err.Error()) } assert.Equal(t, "TSQL", protocolName, "Protocol name does not match") - assert.Equal(t, "0x08000000", hex.EncodeToString(tdsver)) + assert.Equal(t, "08000000", hex.EncodeToString(tdsver)) } // returns parsed connection parameters derived from From 7a47b912d222eca34a95faba29939bc71a5e664e Mon Sep 17 00:00:00 2001 From: davidshi Date: Thu, 31 Aug 2023 11:36:08 -0500 Subject: [PATCH 12/12] handle AKV key name collision --- internal/akvkeys/utils.go | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/internal/akvkeys/utils.go b/internal/akvkeys/utils.go index bd8e2c9a..4c4da456 100644 --- a/internal/akvkeys/utils.go +++ b/internal/akvkeys/utils.go @@ -41,9 +41,13 @@ func CreateRSAKey(client *azkeys.Client) (name string, err error) { Kty: &kt, KeySize: &ks, } - i, _ := rand.Int(rand.Reader, big.NewInt(1000)) + + i, _ := rand.Int(rand.Reader, big.NewInt(1000000)) name = fmt.Sprintf("go-mssqlkey%d", i) _, err = client.CreateKey(context.TODO(), name, rsaKeyParams, nil) + if err != nil { + _, err = client.RecoverDeletedKey(context.TODO(), name, &azkeys.RecoverDeletedKeyOptions{}) + } return }