Skip to content

Commit

Permalink
Merge pull request #371 from nirrattner/auth-credential-function
Browse files Browse the repository at this point in the history
Add option for an AUTH credential function
  • Loading branch information
rueian committed Sep 22, 2023
2 parents 2f2f4da + eb45194 commit 0608d8b
Show file tree
Hide file tree
Showing 3 changed files with 95 additions and 8 deletions.
31 changes: 23 additions & 8 deletions pipe.go
Original file line number Diff line number Diff line change
Expand Up @@ -152,11 +152,26 @@ func _newPipe(connFn func() (net.Conn, error), option *ClientOption, r2ps bool)
p.pshks.Store(emptypshks)
p.clhks.Store(emptyclhks)

username := option.Username
password := option.Password
if option.AuthCredentialsFn != nil {
authCredentialsContext := AuthCredentialsContext{
Address: conn.RemoteAddr(),
}
authCredentials, err := option.AuthCredentialsFn(authCredentialsContext)
if err != nil {
p.Close()
return nil, err
}
username = authCredentials.Username
password = authCredentials.Password
}

helloCmd := []string{"HELLO", "3"}
if option.Password != "" && option.Username == "" {
helloCmd = append(helloCmd, "AUTH", "default", option.Password)
} else if option.Username != "" {
helloCmd = append(helloCmd, "AUTH", option.Username, option.Password)
if password != "" && username == "" {
helloCmd = append(helloCmd, "AUTH", "default", password)
} else if username != "" {
helloCmd = append(helloCmd, "AUTH", username, password)
}
if option.ClientName != "" {
helloCmd = append(helloCmd, "SETNAME", option.ClientName)
Expand Down Expand Up @@ -244,10 +259,10 @@ func _newPipe(connFn func() (net.Conn, error), option *ClientOption, r2ps bool)
return nil, ErrNoCache
}
init = init[:0]
if option.Password != "" && option.Username == "" {
init = append(init, []string{"AUTH", option.Password})
} else if option.Username != "" {
init = append(init, []string{"AUTH", option.Username, option.Password})
if password != "" && username == "" {
init = append(init, []string{"AUTH", password})
} else if username != "" {
init = append(init, []string{"AUTH", username, password})
}
if option.ClientName != "" {
init = append(init, []string{"CLIENT", "SETNAME", option.ClientName})
Expand Down
57 changes: 57 additions & 0 deletions pipe_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ import (
"bufio"
"context"
"errors"
"fmt"
"io"
"net"
"runtime"
Expand Down Expand Up @@ -294,6 +295,44 @@ func TestNewPipe(t *testing.T) {
n1.Close()
n2.Close()
})
t.Run("Auth with Credentials Function", func(t *testing.T) {
n1, n2 := net.Pipe()
mock := &redisMock{buf: bufio.NewReader(n2), conn: n2, t: t}
go func() {
mock.Expect("HELLO", "3", "AUTH", "ua", "pa", "SETNAME", "cn").
Reply(RedisMessage{
typ: '%',
values: []RedisMessage{
{typ: '+', string: "proto"},
{typ: ':', integer: 3},
},
})
mock.Expect("CLIENT", "TRACKING", "ON", "OPTIN").
ReplyString("OK")
mock.Expect("SELECT", "1").
ReplyString("OK")
mock.Expect("CLIENT", "SETINFO", "LIB-NAME", LIB_NAME, "LIB-VER", LIB_VER).
ReplyError("UNKNOWN COMMAND")
}()
p, err := newPipe(func() (net.Conn, error) { return n1, nil }, &ClientOption{
SelectDB: 1,
AuthCredentialsFn: func(context AuthCredentialsContext) (AuthCredentials, error) {
return AuthCredentials{
Username: "ua",
Password: "pa",
}, nil
},
ClientName: "cn",
})
if err != nil {
t.Fatalf("pipe setup failed: %v", err)
}
go func() { mock.Expect("QUIT").ReplyString("OK") }()
p.Close()
mock.Close()
n1.Close()
n2.Close()
})
t.Run("With ClientSideTrackingOptions", func(t *testing.T) {
n1, n2 := net.Pipe()
mock := &redisMock{buf: bufio.NewReader(n2), conn: n2, t: t}
Expand Down Expand Up @@ -405,6 +444,24 @@ func TestNewPipe(t *testing.T) {
t.Fatalf("pipe setup should failed with io.ErrClosedPipe, but got %v", err)
}
})
t.Run("Auth Credentials Function Error", func(t *testing.T) {
n1, n2 := net.Pipe()
mock := &redisMock{buf: bufio.NewReader(n2), conn: n2, t: t}
go func() { mock.Expect("QUIT").ReplyString("OK") }()
_, err := newPipe(func() (net.Conn, error) { return n1, nil }, &ClientOption{
SelectDB: 1,
AuthCredentialsFn: func(context AuthCredentialsContext) (AuthCredentials, error) {
return AuthCredentials{}, fmt.Errorf("auth credential failure")
},
ClientName: "cn",
})
if err.Error() != "auth credential failure" {
t.Fatalf("pipe setup failed: %v", err)
}
mock.Close()
n1.Close()
n2.Close()
})
}

func TestNewRESP2Pipe(t *testing.T) {
Expand Down
15 changes: 15 additions & 0 deletions rueidis.go
Original file line number Diff line number Diff line change
Expand Up @@ -77,6 +77,10 @@ type ClientOption struct {
Password string
ClientName string

// AuthCredentialsFn allows for setting the AUTH username and password dynamically on each connection attempt to
// support rotating credentials
AuthCredentialsFn func(AuthCredentialsContext) (AuthCredentials, error)

// ClientSetInfo will assign various info attributes to the current connection
ClientSetInfo []string

Expand Down Expand Up @@ -264,6 +268,17 @@ type CacheableTTL struct {
TTL time.Duration
}

// AuthCredentialsContext is the parameter container of AuthCredentialsFn
type AuthCredentialsContext struct {
Address net.Addr
}

// AuthCredentials is the output of AuthCredentialsFn
type AuthCredentials struct {
Username string
Password string
}

// NewClient uses ClientOption to initialize the Client for both cluster client and single client.
// It will first try to connect as cluster client. If the len(ClientOption.InitAddress) == 1 and
// the address does not enable cluster mode, the NewClient() will use single client instead.
Expand Down

0 comments on commit 0608d8b

Please sign in to comment.