Skip to content

Commit

Permalink
Handle GuidConversion property
Browse files Browse the repository at this point in the history
  • Loading branch information
Andrea Magnetto authored and Andrea Magnetto committed Sep 17, 2024
1 parent 8dbf84e commit 236ed76
Show file tree
Hide file tree
Showing 18 changed files with 194 additions and 59 deletions.
2 changes: 1 addition & 1 deletion alwaysencrypted_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -213,7 +213,7 @@ func TestAlwaysEncryptedE2E(t *testing.T) {
func testProviderErrorHandling(t *testing.T, name string, provider aecmk.ColumnEncryptionKeyProvider, sel string, insert string, insertArgs []interface{}) {
t.Helper()
testProvider := &testKeyProvider{fallback: provider}
connector, _ := getTestConnector(t)
connector, _ := getTestConnector(t, false /*guidConversion*/)
connector.RegisterCekProvider(name, testProvider)
conn := sql.OpenDB(connector)
defer conn.Close()
Expand Down
2 changes: 1 addition & 1 deletion bulkcopy.go
Original file line number Diff line number Diff line change
Expand Up @@ -264,7 +264,7 @@ func (b *Bulk) createColMetadata() []byte {
}
binary.Write(buf, binary.LittleEndian, uint16(col.Flags))

writeTypeInfo(buf, &b.bulkColumns[i].ti, false)
writeTypeInfo(buf, &b.bulkColumns[i].ti, false, b.cn.sess.encoding)

if col.ti.TypeId == typeNText ||
col.ti.TypeId == typeText ||
Expand Down
12 changes: 10 additions & 2 deletions bulkcopy_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -111,9 +111,9 @@ func TestBulkcopyWithInvalidNullableType(t *testing.T) {
}
}

func TestBulkcopy(t *testing.T) {
func testBulkcopy(t *testing.T, guidConversion bool) {
// TDS level Bulk Insert is not supported on Azure SQL Server.
if dsn := makeConnStr(t); strings.HasSuffix(strings.Split(dsn.Host, ":")[0], ".database.windows.net") {
if dsn := makeConnStrSettingGuidConversion(t, guidConversion); strings.HasSuffix(strings.Split(dsn.Host, ":")[0], ".database.windows.net") {
t.Skip("TDS level bulk copy is not supported on Azure SQL Server")
}
type testValue struct {
Expand Down Expand Up @@ -300,6 +300,14 @@ func TestBulkcopy(t *testing.T) {
}
}

func TestBulkcopyWithGuidConversion(t *testing.T) {
testBulkcopy(t, true /*guidConversion*/)
}

func TestBulkcopyWithoutGuidConversion(t *testing.T) {
testBulkcopy(t, false /*guidConversion*/)
}

func compareValue(a interface{}, expected interface{}) bool {
if got, ok := a.([]uint8); ok {
if _, ok := expected.([]uint8); !ok {
Expand Down
25 changes: 25 additions & 0 deletions msdsn/conn_str.go
Original file line number Diff line number Diff line change
Expand Up @@ -79,8 +79,14 @@ const (
DialTimeout = "dial timeout"
Pipe = "pipe"
MultiSubnetFailover = "multisubnetfailover"
GuidConversion = "guid conversion"
)

type EncodeParameters struct {
// Properly convert GUIDs, using correct byte endianness
GuidConversion bool
}

type Config struct {
Port uint64
Host string
Expand Down Expand Up @@ -131,6 +137,8 @@ type Config struct {
ColumnEncryption bool
// Attempt to connect to all IPs in parallel when MultiSubnetFailover is true
MultiSubnetFailover bool
// Parameters related to type encoding
Encoding EncodeParameters
}

func readDERFile(filename string) ([]byte, error) {
Expand Down Expand Up @@ -504,6 +512,20 @@ func Parse(dsn string) (Config, error) {
// Defaulting to true to prevent breaking change although other client libraries default to false
p.MultiSubnetFailover = true
}

guidConversion, ok := params[GuidConversion]
if ok {
var err error
p.Encoding.GuidConversion, err = strconv.ParseBool(guidConversion)
if err != nil {
f := "invalid guid conversion '%s': %s"
return p, fmt.Errorf(f, guidConversion, err.Error())
}
} else {
// set to false for backward compatibility
p.Encoding.GuidConversion = false
}

return p, nil
}

Expand Down Expand Up @@ -564,6 +586,9 @@ func (p Config) URL() *url.URL {
if p.ColumnEncryption {
q.Add("columnencryption", "true")
}

q.Add(GuidConversion, strconv.FormatBool(p.Encoding.GuidConversion))

if len(q) > 0 {
res.RawQuery = q.Encode()
}
Expand Down
5 changes: 4 additions & 1 deletion msdsn/conn_str_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -191,7 +191,10 @@ func TestValidConnectionString(t *testing.T) {
return p.Host == "somehost" && p.Port == 0 && p.Instance == "" && p.User == "someuser" && p.Password == "" && p.ConnTimeout == 30*time.Second && p.DisableRetry && !p.ColumnEncryption
}},
{"sqlserver://somehost?encrypt=true&tlsmin=1.1&columnencryption=1", func(p Config) bool {
return p.Host == "somehost" && p.Encryption == EncryptionRequired && p.TLSConfig.MinVersion == tls.VersionTLS11 && p.ColumnEncryption
return p.Host == "somehost" && p.Encryption == EncryptionRequired && p.TLSConfig.MinVersion == tls.VersionTLS11 && p.ColumnEncryption && !p.Encoding.GuidConversion
}},
{"sqlserver://somehost?encrypt=true&tlsmin=1.1&columnencryption=1&guid+conversion=true", func(p Config) bool {
return p.Host == "somehost" && p.Encryption == EncryptionRequired && p.TLSConfig.MinVersion == tls.VersionTLS11 && p.ColumnEncryption && p.Encoding.GuidConversion
}},
}
for _, ts := range connStrings {
Expand Down
2 changes: 1 addition & 1 deletion mssql.go
Original file line number Diff line number Diff line change
Expand Up @@ -566,7 +566,7 @@ func (s *Stmt) sendQuery(ctx context.Context, args []namedValue) (err error) {
params[0] = makeStrParam(s.query)
params[1] = makeStrParam(strings.Join(decls, ","))
}
if err = sendRpc(conn.sess.buf, headers, proc, 0, params, reset); err != nil {
if err = sendRpc(conn.sess.buf, headers, proc, 0, params, reset, conn.sess.encoding); err != nil {
if conn.sess.logFlags&logErrors != 0 {
conn.sess.logger.Log(ctx, msdsn.LogErrors, fmt.Sprintf("Failed to send Rpc with %v", err))
}
Expand Down
2 changes: 1 addition & 1 deletion mssql_go19.go
Original file line number Diff line number Diff line change
Expand Up @@ -206,7 +206,7 @@ func (s *Stmt) makeParamExtra(val driver.Value) (res param, err error) {
err = errCalTypes
return
}
res.buffer, err = val.encode(schema, name, columnStr, tvpFieldIndexes)
res.buffer, err = val.encode(schema, name, columnStr, tvpFieldIndexes, s.c.sess.encoding)
if err != nil {
return
}
Expand Down
22 changes: 16 additions & 6 deletions queries_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -27,8 +27,8 @@ func driverWithProcess(t *testing.T, tl Logger) *Driver {
}
}

func TestSelect(t *testing.T) {
conn, logger := open(t)
func testSelect(t *testing.T, guidConversion bool) {
conn, logger := openSettingGuidConversion(t, guidConversion)
defer conn.Close()
defer logger.StopLogging()

Expand All @@ -39,6 +39,10 @@ func TestSelect(t *testing.T) {
}

longstr := strings.Repeat("x", 10000)
expectedGuid := []byte{0x6F, 0x96, 0x19, 0xFF, 0x8B, 0x86, 0xD0, 0x11, 0xB4, 0x2D, 0x00, 0xC0, 0x4F, 0xC9, 0x64, 0xFF}
if guidConversion {
expectedGuid = []byte{0xFF, 0x19, 0x96, 0x6F, 0x86, 0x8B, 0x11, 0xD0, 0xB4, 0x2D, 0x00, 0xC0, 0x4F, 0xC9, 0x64, 0xFF}
}

values := []testStruct{
{"1", int64(1)},
Expand Down Expand Up @@ -83,8 +87,7 @@ func TestSelect(t *testing.T) {
{"cast('2079-06-06T23:59:00' as smalldatetime)",
time.Date(2079, 6, 6, 23, 59, 0, 0, time.UTC)},
{"cast(NULL as smalldatetime)", nil},
{"cast(0x6F9619FF8B86D011B42D00C04FC964FF as uniqueidentifier)",
[]byte{0xFF, 0x19, 0x96, 0x6F, 0x86, 0x8B, 0x11, 0xD0, 0xB4, 0x2D, 0x00, 0xC0, 0x4F, 0xC9, 0x64, 0xFF}},
{"cast(0x6F9619FF8B86D011B42D00C04FC964FF as uniqueidentifier)", expectedGuid},
{"cast(NULL as uniqueidentifier)", nil},
{"cast(0x1234 as varbinary(2))", []byte{0x12, 0x34}},
{"cast(N'abc' as nvarchar(max))", "abc"},
Expand Down Expand Up @@ -114,8 +117,7 @@ func TestSelect(t *testing.T) {
{"cast(cast(N'chào' as nvarchar(max)) collate Vietnamese_CI_AI as varchar(max))", "chào"}, // cp1258
{fmt.Sprintf("cast(N'%s' as nvarchar(max))", longstr), longstr},
{"cast(NULL as sql_variant)", nil},
{"cast(cast(0x6F9619FF8B86D011B42D00C04FC964FF as uniqueidentifier) as sql_variant)",
[]byte{0xFF, 0x19, 0x96, 0x6F, 0x86, 0x8B, 0x11, 0xD0, 0xB4, 0x2D, 0x00, 0xC0, 0x4F, 0xC9, 0x64, 0xFF}},
{"cast(cast(0x6F9619FF8B86D011B42D00C04FC964FF as uniqueidentifier) as sql_variant)", expectedGuid},
{"cast(cast(1 as bit) as sql_variant)", true},
{"cast(cast(10 as tinyint) as sql_variant)", int64(10)},
{"cast(cast(-10 as smallint) as sql_variant)", int64(-10)},
Expand Down Expand Up @@ -213,6 +215,14 @@ func TestSelect(t *testing.T) {
})
}

func TestSelectWithGuidConversion(t *testing.T) {
testSelect(t, true /*guidConversion*/)
}

func TestSelectWithoutGuidConversion(t *testing.T) {
testSelect(t, false /*guidConversion*/)
}

func TestSelectDateTimeOffset(t *testing.T) {
type testStruct struct {
sql string
Expand Down
8 changes: 5 additions & 3 deletions rpc.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,8 @@ package mssql

import (
"encoding/binary"

"github.com/microsoft/go-mssqldb/msdsn"
)

type procId struct {
Expand Down Expand Up @@ -43,7 +45,7 @@ var (
)

// http://msdn.microsoft.com/en-us/library/dd357576.aspx
func sendRpc(buf *tdsBuffer, headers []headerStruct, proc procId, flags uint16, params []param, resetSession bool) (err error) {
func sendRpc(buf *tdsBuffer, headers []headerStruct, proc procId, flags uint16, params []param, resetSession bool, encoding msdsn.EncodeParameters) (err error) {
buf.BeginPacket(packRPCRequest, resetSession)
writeAllHeaders(buf, headers)
if len(proc.name) == 0 {
Expand Down Expand Up @@ -73,7 +75,7 @@ func sendRpc(buf *tdsBuffer, headers []headerStruct, proc procId, flags uint16,
if err = binary.Write(buf, binary.LittleEndian, param.Flags); err != nil {
return
}
err = writeTypeInfo(buf, &param.ti, (param.Flags&fByRevValue) != 0)
err = writeTypeInfo(buf, &param.ti, (param.Flags&fByRevValue) != 0, encoding)
if err != nil {
return
}
Expand All @@ -82,7 +84,7 @@ func sendRpc(buf *tdsBuffer, headers []headerStruct, proc procId, flags uint16,
return
}
if (param.Flags & fEncrypted) == fEncrypted {
err = writeTypeInfo(buf, &param.tiOriginal, false)
err = writeTypeInfo(buf, &param.tiOriginal, false, encoding)
if err != nil {
return
}
Expand Down
2 changes: 2 additions & 0 deletions tds.go
Original file line number Diff line number Diff line change
Expand Up @@ -172,6 +172,7 @@ type tdsSession struct {
routedPort uint16
alwaysEncrypted bool
aeSettings *alwaysEncryptedSettings
encoding msdsn.EncodeParameters
}

type alwaysEncryptedSettings struct {
Expand Down Expand Up @@ -1209,6 +1210,7 @@ initiate_connection:
logger: logger,
logFlags: uint64(p.LogFlags),
aeSettings: &alwaysEncryptedSettings{keyProviders: aecmk.GetGlobalCekProviders()},
encoding: p.Encoding,
}

for i, p := range c.keyProviders {
Expand Down
14 changes: 10 additions & 4 deletions tds_go110_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -8,16 +8,22 @@ import (
"testing"
)

func open(t testing.TB) (*sql.DB, *testLogger) {
connector, logger := getTestConnector(t)
func openSettingGuidConversion(t testing.TB, guidConversion bool) (*sql.DB, *testLogger) {
connector, logger := getTestConnector(t, guidConversion)
conn := sql.OpenDB(connector)
return conn, logger
}

func getTestConnector(t testing.TB) (*Connector, *testLogger) {
func open(t testing.TB) (*sql.DB, *testLogger) {
return openSettingGuidConversion(t, false /*guidConversion*/)
}

func getTestConnector(t testing.TB, guidConversion bool) (*Connector, *testLogger) {
tl := testLogger{t: t}
SetLogger(&tl)
connector, err := NewConnector(makeConnStr(t).String())

connectionString := makeConnStrSettingGuidConversion(t, guidConversion).String()
connector, err := NewConnector(connectionString)
if err != nil {
t.Error("Open connection failed:", err.Error())
return nil, &tl
Expand Down
9 changes: 7 additions & 2 deletions tds_go110pre_test.go
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
//go:build !go1.10
// +build !go1.10

package mssql
Expand All @@ -7,14 +8,18 @@ import (
"testing"
)

func open(t *testing.T) (*sql.DB, *testLogger) {
func openSettingGuidConversion(t *testing.T, guidConversion bool) (*sql.DB, *testLogger) {
tl := testLogger{t: t}
SetLogger(&tl)
checkConnStr(t)
conn, err := sql.Open("sqlserver", makeConnStr(t).String())
conn, err := sql.Open("sqlserver", makeConnStrSettingGuidConversion(t, guidConversion).String())
if err != nil {
t.Error("Open connection failed:", err.Error())
return nil, &tl
}
return conn, &tl
}

func open(t *testing.T) (*sql.DB, *testLogger) {
return openSettingGuidConversion(t, false /*guidConversion*/)
}
6 changes: 6 additions & 0 deletions tds_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -364,6 +364,12 @@ func makeConnStr(t testing.TB) *url.URL {
return testConnParams(t).URL()
}

func makeConnStrSettingGuidConversion(t testing.TB, guidConversion bool) *url.URL {
config := testConnParams(t)
config.Encoding.GuidConversion = guidConversion
return config.URL()
}

type testLogger struct {
t testing.TB
mu sync.Mutex
Expand Down
12 changes: 6 additions & 6 deletions token.go
Original file line number Diff line number Diff line change
Expand Up @@ -616,7 +616,7 @@ func parseColMetadata72(r *tdsBuffer, s *tdsSession) (columns []columnStruct) {
for i := range columns {
column := &columns[i]
baseTi := getBaseTypeInfo(r, true)
typeInfo := readTypeInfo(r, baseTi.TypeId, column.cryptoMeta)
typeInfo := readTypeInfo(r, baseTi.TypeId, column.cryptoMeta, s.encoding)
typeInfo.UserType = baseTi.UserType
typeInfo.Flags = baseTi.Flags
typeInfo.TypeId = baseTi.TypeId
Expand All @@ -627,7 +627,7 @@ func parseColMetadata72(r *tdsBuffer, s *tdsSession) (columns []columnStruct) {

if column.isEncrypted() && s.alwaysEncrypted {
// Read Crypto Metadata
cryptoMeta := parseCryptoMetadata(r, cekTable)
cryptoMeta := parseCryptoMetadata(r, cekTable, s.encoding)
cryptoMeta.typeInfo.Flags = baseTi.Flags
column.cryptoMeta = &cryptoMeta
} else {
Expand Down Expand Up @@ -663,14 +663,14 @@ type cryptoMetadata struct {
typeInfo typeInfo
}

func parseCryptoMetadata(r *tdsBuffer, cekTable *cekTable) cryptoMetadata {
func parseCryptoMetadata(r *tdsBuffer, cekTable *cekTable, encoding msdsn.EncodeParameters) cryptoMetadata {
ordinal := uint16(0)
if cekTable != nil {
ordinal = r.uint16()
}

typeInfo := getBaseTypeInfo(r, false)
ti := readTypeInfo(r, typeInfo.TypeId, nil)
ti := readTypeInfo(r, typeInfo.TypeId, nil, encoding)
ti.UserType = typeInfo.UserType
ti.Flags = typeInfo.Flags
ti.TypeId = typeInfo.TypeId
Expand Down Expand Up @@ -935,11 +935,11 @@ func parseReturnValue(r *tdsBuffer, s *tdsSession) (nv namedValue) {

var cryptoMetadata *cryptoMetadata = nil
if s.alwaysEncrypted && (ti.Flags&fEncrypted) == fEncrypted {
cm := parseCryptoMetadata(r, nil) // CryptoMetadata
cm := parseCryptoMetadata(r, nil, s.encoding) // CryptoMetadata
cryptoMetadata = &cm
}

ti2 := readTypeInfo(r, ti.TypeId, cryptoMetadata)
ti2 := readTypeInfo(r, ti.TypeId, cryptoMetadata, s.encoding)
nv.Value = ti2.Reader(&ti2, r, cryptoMetadata)

return
Expand Down
6 changes: 4 additions & 2 deletions tvp_go19.go
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,8 @@ import (
"reflect"
"strings"
"time"

"github.com/microsoft/go-mssqldb/msdsn"
)

const (
Expand Down Expand Up @@ -62,7 +64,7 @@ func (tvp TVP) check() error {
return nil
}

func (tvp TVP) encode(schema, name string, columnStr []columnStruct, tvpFieldIndexes []int) ([]byte, error) {
func (tvp TVP) encode(schema, name string, columnStr []columnStruct, tvpFieldIndexes []int, encoding msdsn.EncodeParameters) ([]byte, error) {
if len(columnStr) != len(tvpFieldIndexes) {
return nil, ErrorWrongTyping
}
Expand All @@ -80,7 +82,7 @@ func (tvp TVP) encode(schema, name string, columnStr []columnStruct, tvpFieldInd
for i, column := range columnStr {
binary.Write(buf, binary.LittleEndian, column.UserType)
binary.Write(buf, binary.LittleEndian, column.Flags)
writeTypeInfo(buf, &columnStr[i].ti, false)
writeTypeInfo(buf, &columnStr[i].ti, false, encoding)
writeBVarChar(buf, "")
}
// The returned error is always nil
Expand Down
Loading

0 comments on commit 236ed76

Please sign in to comment.