Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

KIP-368 : Allow SASL Connections to Periodically Re-Authenticate #2197

Merged
merged 1 commit into from
Apr 13, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
72 changes: 56 additions & 16 deletions broker.go
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ import (
"errors"
"fmt"
"io"
"math/rand"
"net"
"sort"
"strconv"
Expand Down Expand Up @@ -52,7 +53,8 @@ type Broker struct {
brokerRequestsInFlight metrics.Counter
brokerThrottleTime metrics.Histogram

kerberosAuthenticator GSSAPIKerberosAuth
kerberosAuthenticator GSSAPIKerberosAuth
clientSessionReauthenticationTimeMs int64
}

// SASLMechanism specifies the SASL mechanism the client uses to authenticate with the broker
Expand Down Expand Up @@ -923,6 +925,13 @@ func (b *Broker) sendWithPromise(rb protocolBody, promise *responsePromise) erro
return ErrNotConnected
}

if b.clientSessionReauthenticationTimeMs > 0 && currentUnixMilli() > b.clientSessionReauthenticationTimeMs {
err := b.authenticateViaSASL()
if err != nil {
return err
}
}

if !b.conf.Version.IsAtLeast(rb.requiredVersion()) {
return ErrUnsupportedVersion
}
Expand Down Expand Up @@ -1263,7 +1272,7 @@ func (b *Broker) sendAndReceiveV1SASLPlainAuth() error {

// Will be decremented in updateIncomingCommunicationMetrics (except error)
b.addRequestInFlightMetrics(1)
bytesWritten, err := b.sendSASLPlainAuthClientResponse(correlationID)
bytesWritten, resVersion, err := b.sendSASLPlainAuthClientResponse(correlationID)
b.updateOutgoingCommunicationMetrics(bytesWritten)

if err != nil {
Expand All @@ -1274,7 +1283,8 @@ func (b *Broker) sendAndReceiveV1SASLPlainAuth() error {

b.correlationID++

bytesRead, err := b.receiveSASLServerResponse(&SaslAuthenticateResponse{}, correlationID)
res := &SaslAuthenticateResponse{}
bytesRead, err := b.receiveSASLServerResponse(res, correlationID, resVersion)
b.updateIncomingCommunicationMetrics(bytesRead, time.Since(requestTime))

// With v1 sasl we get an error message set in the response we can return
Expand All @@ -1288,6 +1298,10 @@ func (b *Broker) sendAndReceiveV1SASLPlainAuth() error {
return nil
}

func currentUnixMilli() int64 {
return time.Now().UnixNano() / int64(time.Millisecond)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The project aims at supporting the last couple of Golang version, so we could use time.Now().UnixMilli() that was introduced in 1.17. That said, it may be preferable to keep the current code for now to support older environments.

Copy link
Contributor Author

@k-wall k-wall Apr 8, 2022

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The project is still on go 1.16, so I cannot adopt this API yet. Also, if I understand the project's "two releases and two months" right it'd be too soon to bump Go right now (1.18 only came out last month (March 2022), so it is too soon for the project to adopt 1.17).

}

// sendAndReceiveSASLOAuth performs the authentication flow as described by KIP-255
// https://cwiki.apache.org/confluence/pages/viewpage.action?pageId=75968876
func (b *Broker) sendAndReceiveSASLOAuth(provider AccessTokenProvider) error {
Expand Down Expand Up @@ -1327,7 +1341,7 @@ func (b *Broker) sendClientMessage(message []byte) (bool, error) {
b.addRequestInFlightMetrics(1)
correlationID := b.correlationID

bytesWritten, err := b.sendSASLOAuthBearerClientMessage(message, correlationID)
bytesWritten, resVersion, err := b.sendSASLOAuthBearerClientMessage(message, correlationID)
b.updateOutgoingCommunicationMetrics(bytesWritten)
if err != nil {
b.addRequestInFlightMetrics(-1)
Expand All @@ -1337,7 +1351,7 @@ func (b *Broker) sendClientMessage(message []byte) (bool, error) {
b.correlationID++

res := &SaslAuthenticateResponse{}
bytesRead, err := b.receiveSASLServerResponse(res, correlationID)
bytesRead, err := b.receiveSASLServerResponse(res, correlationID, resVersion)

requestLatency := time.Since(requestTime)
b.updateIncomingCommunicationMetrics(bytesRead, requestLatency)
Expand Down Expand Up @@ -1464,7 +1478,7 @@ func (b *Broker) sendAndReceiveSASLSCRAMv1() error {
}

func (b *Broker) sendSaslAuthenticateRequest(correlationID int32, msg []byte) (int, error) {
rb := &SaslAuthenticateRequest{msg}
rb := b.createSaslAuthenticateRequest(msg)
req := &request{correlationID: correlationID, clientID: b.conf.ClientID, body: rb}
buf, err := encode(req, b.conf.MetricRegistry)
if err != nil {
Expand All @@ -1474,6 +1488,15 @@ func (b *Broker) sendSaslAuthenticateRequest(correlationID int32, msg []byte) (i
return b.write(buf)
}

func (b *Broker) createSaslAuthenticateRequest(msg []byte) *SaslAuthenticateRequest {
authenticateRequest := SaslAuthenticateRequest{SaslAuthBytes: msg}
if b.conf.Version.IsAtLeast(V2_2_0_0) {
authenticateRequest.Version = 1
}

return &authenticateRequest
}

func (b *Broker) receiveSaslAuthenticateResponse(correlationID int32) ([]byte, error) {
buf := make([]byte, responseLengthSize+correlationIDSize)
_, err := b.readFull(buf)
Expand Down Expand Up @@ -1538,32 +1561,34 @@ func mapToString(extensions map[string]string, keyValSep string, elemSep string)
return strings.Join(buf, elemSep)
}

func (b *Broker) sendSASLPlainAuthClientResponse(correlationID int32) (int, error) {
func (b *Broker) sendSASLPlainAuthClientResponse(correlationID int32) (int, int16, error) {
authBytes := []byte(b.conf.Net.SASL.AuthIdentity + "\x00" + b.conf.Net.SASL.User + "\x00" + b.conf.Net.SASL.Password)
rb := &SaslAuthenticateRequest{authBytes}
rb := b.createSaslAuthenticateRequest(authBytes)
req := &request{correlationID: correlationID, clientID: b.conf.ClientID, body: rb}
buf, err := encode(req, b.conf.MetricRegistry)
if err != nil {
return 0, err
return 0, rb.Version, err
}

return b.write(buf)
write, err := b.write(buf)
return write, rb.Version, err
}

func (b *Broker) sendSASLOAuthBearerClientMessage(initialResp []byte, correlationID int32) (int, error) {
rb := &SaslAuthenticateRequest{initialResp}
func (b *Broker) sendSASLOAuthBearerClientMessage(initialResp []byte, correlationID int32) (int, int16, error) {
rb := b.createSaslAuthenticateRequest(initialResp)

req := &request{correlationID: correlationID, clientID: b.conf.ClientID, body: rb}

buf, err := encode(req, b.conf.MetricRegistry)
if err != nil {
return 0, err
return 0, rb.version(), err
}

return b.write(buf)
write, err := b.write(buf)
return write, rb.version(), err
}

func (b *Broker) receiveSASLServerResponse(res *SaslAuthenticateResponse, correlationID int32) (int, error) {
func (b *Broker) receiveSASLServerResponse(res *SaslAuthenticateResponse, correlationID int32, resVersion int16) (int, error) {
buf := make([]byte, responseLengthSize+correlationIDSize)
bytesRead, err := b.readFull(buf)
if err != nil {
Expand All @@ -1587,7 +1612,7 @@ func (b *Broker) receiveSASLServerResponse(res *SaslAuthenticateResponse, correl
return bytesRead, err
}

if err := versionedDecode(buf, res, 0); err != nil {
if err := versionedDecode(buf, res, resVersion); err != nil {
return bytesRead, err
}

Expand All @@ -1599,6 +1624,21 @@ func (b *Broker) receiveSASLServerResponse(res *SaslAuthenticateResponse, correl
return bytesRead, err
}

if res.SessionLifetimeMs > 0 {
// Follows the Java Kafka implementation from SaslClientAuthenticator.ReauthInfo#setAuthenticationEndAndSessionReauthenticationTimes
// pick a random percentage between 85% and 95% for session re-authentication
positiveSessionLifetimeMs := res.SessionLifetimeMs
authenticationEndMs := currentUnixMilli()
pctWindowFactorToTakeNetworkLatencyAndClockDriftIntoAccount := 0.85
pctWindowJitterToAvoidReauthenticationStormAcrossManyChannelsSimultaneously := 0.10
pctToUse := pctWindowFactorToTakeNetworkLatencyAndClockDriftIntoAccount + rand.Float64()*pctWindowJitterToAvoidReauthenticationStormAcrossManyChannelsSimultaneously
sessionLifetimeMsToUse := int64(float64(positiveSessionLifetimeMs) * pctToUse)
DebugLogger.Printf("Session expiration in %d ms and session re-authentication on or after %d ms", positiveSessionLifetimeMs, sessionLifetimeMsToUse)
b.clientSessionReauthenticationTimeMs = authenticationEndMs + sessionLifetimeMsToUse
} else {
b.clientSessionReauthenticationTimeMs = 0
}

return bytesRead, nil
}

Expand Down
159 changes: 159 additions & 0 deletions broker_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -828,6 +828,165 @@ func TestBuildClientFirstMessage(t *testing.T) {
}
}

func TestKip368ReAuthenticationSuccess(t *testing.T) {
sessionLifetimeMs := int64(100)

mockBroker := NewMockBroker(t, 0)

countSaslAuthRequests := func() (count int) {
for _, rr := range mockBroker.History() {
switch rr.Request.(type) {
case *SaslAuthenticateRequest:
count++
}
}
return
}

mockSASLAuthResponse := NewMockSaslAuthenticateResponse(t).
SetAuthBytes([]byte(`response_payload`)).
SetSessionLifetimeMs(sessionLifetimeMs)

mockSASLHandshakeResponse := NewMockSaslHandshakeResponse(t).
SetEnabledMechanisms([]string{SASLTypePlaintext})

mockApiVersions := NewMockApiVersionsResponse(t)

mockBroker.SetHandlerByMap(map[string]MockResponse{
"SaslAuthenticateRequest": mockSASLAuthResponse,
"SaslHandshakeRequest": mockSASLHandshakeResponse,
"ApiVersionsRequest": mockApiVersions,
})

broker := NewBroker(mockBroker.Addr())

conf := NewTestConfig()
conf.Net.SASL.Enable = true
conf.Net.SASL.Mechanism = SASLTypePlaintext
conf.Net.SASL.Version = SASLHandshakeV1
conf.Net.SASL.AuthIdentity = "authid"
conf.Net.SASL.User = "token"
conf.Net.SASL.Password = "password"

broker.conf = conf
broker.conf.Version = V2_2_0_0

err := broker.Open(conf)
if err != nil {
t.Fatal(err)
}
t.Cleanup(func() { _ = broker.Close() })

connected, err := broker.Connected()
if err != nil || !connected {
t.Fatal(err)
}

actualSaslAuthRequests := countSaslAuthRequests()
if actualSaslAuthRequests != 1 {
t.Fatalf("unexpected number of SaslAuthRequests during initial authentication: %d", actualSaslAuthRequests)
}

timeout := time.After(time.Duration(sessionLifetimeMs) * time.Millisecond)

loop:
for actualSaslAuthRequests < 2 {
select {
case <-timeout:
break loop
default:
time.Sleep(10 * time.Millisecond)
// put some traffic on the wire
_, err = broker.ApiVersions(&ApiVersionsRequest{})
if err != nil {
t.Fatal(err)
}
actualSaslAuthRequests = countSaslAuthRequests()
}
}

if actualSaslAuthRequests < 2 {
t.Fatalf("sasl reauth has not occurred within expected timeframe")
}

mockBroker.Close()
}

func TestKip368ReAuthenticationFailure(t *testing.T) {
sessionLifetimeMs := int64(100)

mockBroker := NewMockBroker(t, 0)

mockSASLAuthResponse := NewMockSaslAuthenticateResponse(t).
SetAuthBytes([]byte(`response_payload`)).
SetSessionLifetimeMs(sessionLifetimeMs)

mockSASLAuthErrorResponse := NewMockSaslAuthenticateResponse(t).
SetError(ErrSASLAuthenticationFailed)

mockSASLHandshakeResponse := NewMockSaslHandshakeResponse(t).
SetEnabledMechanisms([]string{SASLTypePlaintext})

mockApiVersions := NewMockApiVersionsResponse(t)

mockBroker.SetHandlerByMap(map[string]MockResponse{
"SaslAuthenticateRequest": mockSASLAuthResponse,
"SaslHandshakeRequest": mockSASLHandshakeResponse,
"ApiVersionsRequest": mockApiVersions,
})

broker := NewBroker(mockBroker.Addr())

conf := NewTestConfig()
conf.Net.SASL.Enable = true
conf.Net.SASL.Mechanism = SASLTypePlaintext
conf.Net.SASL.Version = SASLHandshakeV1
conf.Net.SASL.AuthIdentity = "authid"
conf.Net.SASL.User = "token"
conf.Net.SASL.Password = "password"

broker.conf = conf
broker.conf.Version = V2_2_0_0

err := broker.Open(conf)
if err != nil {
t.Fatal(err)
}
t.Cleanup(func() { _ = broker.Close() })

connected, err := broker.Connected()
if err != nil || !connected {
t.Fatal(err)
}

mockBroker.SetHandlerByMap(map[string]MockResponse{
"SaslAuthenticateRequest": mockSASLAuthErrorResponse,
"SaslHandshakeRequest": mockSASLHandshakeResponse,
"ApiVersionsRequest": mockApiVersions,
})

timeout := time.After(time.Duration(sessionLifetimeMs) * time.Millisecond)

var apiVersionError error
loop:
for apiVersionError == nil {
select {
case <-timeout:
break loop
default:
time.Sleep(10 * time.Millisecond)
// put some traffic on the wire
_, apiVersionError = broker.ApiVersions(&ApiVersionsRequest{})
}
}

if !errors.Is(apiVersionError, ErrSASLAuthenticationFailed) {
t.Fatalf("sasl reauth has not failed in the expected way %v", apiVersionError)
}

mockBroker.Close()
}

// We're not testing encoding/decoding here, so most of the requests/responses will be empty for simplicity's sake
var brokerTestTable = []struct {
version KafkaVersion
Expand Down
15 changes: 12 additions & 3 deletions mockresponses.go
Original file line number Diff line number Diff line change
Expand Up @@ -1057,19 +1057,23 @@ func (mr *MockListAclsResponse) For(reqBody versionedDecoder) encoderWithHeader
}

type MockSaslAuthenticateResponse struct {
t TestReporter
kerror KError
saslAuthBytes []byte
t TestReporter
kerror KError
saslAuthBytes []byte
sessionLifetimeMs int64
}

func NewMockSaslAuthenticateResponse(t TestReporter) *MockSaslAuthenticateResponse {
return &MockSaslAuthenticateResponse{t: t}
}

func (msar *MockSaslAuthenticateResponse) For(reqBody versionedDecoder) encoderWithHeader {
req := reqBody.(*SaslAuthenticateRequest)
res := &SaslAuthenticateResponse{}
res.Version = req.Version
res.Err = msar.kerror
res.SaslAuthBytes = msar.saslAuthBytes
res.SessionLifetimeMs = msar.sessionLifetimeMs
return res
}

Expand All @@ -1083,6 +1087,11 @@ func (msar *MockSaslAuthenticateResponse) SetAuthBytes(saslAuthBytes []byte) *Mo
return msar
}

func (msar *MockSaslAuthenticateResponse) SetSessionLifetimeMs(sessionLifetimeMs int64) *MockSaslAuthenticateResponse {
msar.sessionLifetimeMs = sessionLifetimeMs
return msar
}

type MockDeleteAclsResponse struct {
t TestReporter
}
Expand Down
Loading