Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add SASL SCRAM-SHA-512 and SCRAM-SHA-256 mechanismes #1295

Merged
merged 1 commit into from
Mar 5, 2019
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
123 changes: 118 additions & 5 deletions broker.go
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,10 @@ const (
SASLTypeOAuth = "OAUTHBEARER"
// SASLTypePlaintext represents the SASL/PLAIN mechanism
SASLTypePlaintext = "PLAIN"
// SASLTypeSCRAMSHA256 represents the SCRAM-SHA-256 mechanism.
SASLTypeSCRAMSHA256 = "SCRAM-SHA-256"
// SASLTypeSCRAMSHA512 represents the SCRAM-SHA-512 mechanism.
SASLTypeSCRAMSHA512 = "SCRAM-SHA-512"
// SASLHandshakeV0 is v0 of the Kafka SASL handshake protocol. Client and
// server negotiate SASL auth using opaque packets.
SASLHandshakeV0 = int16(0)
Expand Down Expand Up @@ -92,6 +96,20 @@ type AccessTokenProvider interface {
Token() (*AccessToken, error)
}

// SCRAMClient is a an interface to a SCRAM
// client implementation.
type SCRAMClient interface {
// Begin prepares the client for the SCRAM exchange
// with the server with a user name and a password
Begin(userName, password, authzID string) error
// Step steps client through the SCRAM exchange. It is
// called repeatedly until it errors or `Done` returns true.
Step(challenge string) (response string, err error)
// Done should return true when the SCRAM conversation
// is over.
Done() bool
}

type responsePromise struct {
requestTime time.Time
correlationID int32
Expand Down Expand Up @@ -793,14 +811,19 @@ func (b *Broker) responseReceiver() {
}

func (b *Broker) authenticateViaSASL() error {
if b.conf.Net.SASL.Mechanism == SASLTypeOAuth {
switch b.conf.Net.SASL.Mechanism {
case SASLTypeOAuth:
return b.sendAndReceiveSASLOAuth(b.conf.Net.SASL.TokenProvider)
case SASLTypeSCRAMSHA256, SASLTypeSCRAMSHA512:
return b.sendAndReceiveSASLSCRAMv1()
default:
return b.sendAndReceiveSASLPlainAuth()
}
return b.sendAndReceiveSASLPlainAuth()

}

func (b *Broker) sendAndReceiveSASLHandshake(saslType string, version int16) error {
rb := &SaslHandshakeRequest{Mechanism: saslType, Version: version}
func (b *Broker) sendAndReceiveSASLHandshake(saslType SASLMechanism, version int16) error {
rb := &SaslHandshakeRequest{Mechanism: string(saslType), Version: version}

req := &request{correlationID: b.correlationID, clientID: b.conf.ClientID, body: rb}
buf, err := encode(req, b.conf.MetricRegistry)
Expand Down Expand Up @@ -846,7 +869,7 @@ func (b *Broker) sendAndReceiveSASLHandshake(saslType string, version int16) err
Logger.Printf("Invalid SASL Mechanism : %s\n", res.Err.Error())
return res.Err
}
Logger.Print("Successful SASL handshake")
Logger.Print("Successful SASL handshake. Available mechanisms: ", res.EnabledMechanisms)
return nil
}

Expand Down Expand Up @@ -949,6 +972,96 @@ func (b *Broker) sendAndReceiveSASLOAuth(provider AccessTokenProvider) error {
return nil
}

func (b *Broker) sendAndReceiveSASLSCRAMv1() error {
if err := b.sendAndReceiveSASLHandshake(b.conf.Net.SASL.Mechanism, SASLHandshakeV1); err != nil {
return err
}

scramClient := b.conf.Net.SASL.SCRAMClient
if err := scramClient.Begin(b.conf.Net.SASL.User, b.conf.Net.SASL.Password, b.conf.Net.SASL.SCRAMAuthzID); err != nil {
return fmt.Errorf("failed to start SCRAM exchange with the server: %s", err.Error())
}

msg, err := scramClient.Step("")
if err != nil {
return fmt.Errorf("failed to advance the SCRAM exchange: %s", err.Error())

}

for !scramClient.Done() {
requestTime := time.Now()
correlationID := b.correlationID
bytesWritten, err := b.sendSaslAuthenticateRequest(correlationID, []byte(msg))
if err != nil {
Logger.Printf("Failed to write SASL auth header to broker %s: %s\n", b.addr, err.Error())
return err
}

b.updateOutgoingCommunicationMetrics(bytesWritten)
b.correlationID++
challenge, err := b.receiveSaslAuthenticateResponse(correlationID)
if err != nil {
Logger.Printf("Failed to read response while authenticating with SASL to broker %s: %s\n", b.addr, err.Error())
return err
}

b.updateIncomingCommunicationMetrics(len(challenge), time.Since(requestTime))
msg, err = scramClient.Step(string(challenge))
if err != nil {
Logger.Println("SASL authentication failed", err)
return err
}
}
Logger.Println("SASL authentication succeeded")
return nil
}

func (b *Broker) sendSaslAuthenticateRequest(correlationID int32, msg []byte) (int, error) {
rb := &SaslAuthenticateRequest{msg}
req := &request{correlationID: correlationID, clientID: b.conf.ClientID, body: rb}
buf, err := encode(req, b.conf.MetricRegistry)
if err != nil {
return 0, err
}
if err := b.conn.SetWriteDeadline(time.Now().Add(b.conf.Net.WriteTimeout)); err != nil {
return 0, err
}
return b.conn.Write(buf)
}

func (b *Broker) receiveSaslAuthenticateResponse(correlationID int32) ([]byte, error) {
buf := make([]byte, responseLengthSize+correlationIDSize)
bytesRead, err := io.ReadFull(b.conn, buf)
if err != nil {
return nil, err
}
header := responseHeader{}
err = decode(buf, &header)
if err != nil {
return nil, err
}
if header.correlationID != correlationID {
return nil, fmt.Errorf("correlation ID didn't match, wanted %d, got %d", b.correlationID, header.correlationID)
}
buf = make([]byte, header.length-correlationIDSize)
c, err := io.ReadFull(b.conn, buf)
bytesRead += c
if err != nil {
return nil, err
}
res := &SaslAuthenticateResponse{}
if err := versionedDecode(buf, res, 0); err != nil {
return nil, err
}
if err != nil {
return nil, err
}
if res.Err != ErrNoError {
return nil, res.Err
}
return res.SaslAuthBytes, nil
}

// Build SASL/OAUTHBEARER initial client response as described by RFC-7628
// https://tools.ietf.org/html/rfc7628
func buildClientInitialResponse(token *AccessToken) ([]byte, error) {
Expand Down
141 changes: 135 additions & 6 deletions broker_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -179,16 +179,12 @@ func TestSASLOAuthBearer(t *testing.T) {
// mockBroker mocks underlying network logic and broker responses
mockBroker := NewMockBroker(t, 0)

mockSASLAuthResponse := NewMockSaslAuthenticateResponse(t).
SetAuthBytes([]byte(`response_payload`))

mockSASLAuthResponse := NewMockSaslAuthenticateResponse(t).SetAuthBytes([]byte("response_payload"))
if test.mockAuthErr != ErrNoError {
mockSASLAuthResponse = mockSASLAuthResponse.SetError(test.mockAuthErr)
}

mockSASLHandshakeResponse := NewMockSaslHandshakeResponse(t).
SetEnabledMechanisms([]string{SASLTypeOAuth})

mockSASLHandshakeResponse := NewMockSaslHandshakeResponse(t).SetEnabledMechanisms([]string{SASLTypeOAuth})
if test.mockHandshakeErr != ErrNoError {
mockSASLHandshakeResponse = mockSASLHandshakeResponse.SetError(test.mockHandshakeErr)
}
Expand Down Expand Up @@ -248,6 +244,139 @@ func TestSASLOAuthBearer(t *testing.T) {
}
}

// A mock scram client.
type MockSCRAMClient struct {
done bool
}

func (m *MockSCRAMClient) Begin(userName, password, authzID string) (err error) {
return nil
}

func (m *MockSCRAMClient) Step(challenge string) (response string, err error) {
if challenge == "" {
return "ping", nil
}
if challenge == "pong" {
m.done = true
return "", nil
}
return "", errors.New("failed to authenticate :(")
}

func (m *MockSCRAMClient) Done() bool {
return m.done
}

var _ SCRAMClient = &MockSCRAMClient{}

func TestSASLSCRAMSHAXXX(t *testing.T) {
testTable := []struct {
name string
mockHandshakeErr KError
mockSASLAuthErr KError
expectClientErr bool
scramClient *MockSCRAMClient
scramChallengeResp string
}{
{
name: "SASL/SCRAMSHAXXX successfull authentication",
mockHandshakeErr: ErrNoError,
scramClient: &MockSCRAMClient{},
scramChallengeResp: "pong",
},
{
name: "SASL/SCRAMSHAXXX SCRAM client step error client",
mockHandshakeErr: ErrNoError,
mockSASLAuthErr: ErrNoError,
scramClient: &MockSCRAMClient{},
scramChallengeResp: "gong",
expectClientErr: true,
},
{
name: "SASL/SCRAMSHAXXX server authentication error",
mockHandshakeErr: ErrNoError,
mockSASLAuthErr: ErrSASLAuthenticationFailed,
scramClient: &MockSCRAMClient{},
scramChallengeResp: "pong",
},
{
name: "SASL/SCRAMSHAXXX unsupported SCRAM mechanism",
mockHandshakeErr: ErrUnsupportedSASLMechanism,
mockSASLAuthErr: ErrNoError,
scramClient: &MockSCRAMClient{},
scramChallengeResp: "pong",
},
}

for i, test := range testTable {

// mockBroker mocks underlying network logic and broker responses
mockBroker := NewMockBroker(t, 0)
broker := NewBroker(mockBroker.Addr())
// broker executes SASL requests against mockBroker
broker.requestRate = metrics.NilMeter{}
broker.outgoingByteRate = metrics.NilMeter{}
broker.incomingByteRate = metrics.NilMeter{}
broker.requestSize = metrics.NilHistogram{}
broker.responseSize = metrics.NilHistogram{}
broker.responseRate = metrics.NilMeter{}
broker.requestLatency = metrics.NilHistogram{}

mockSASLAuthResponse := NewMockSaslAuthenticateResponse(t).SetAuthBytes([]byte(test.scramChallengeResp))
mockSASLHandshakeResponse := NewMockSaslHandshakeResponse(t).SetEnabledMechanisms([]string{SASLTypeSCRAMSHA256, SASLTypeSCRAMSHA512})

if test.mockSASLAuthErr != ErrNoError {
mockSASLAuthResponse = mockSASLAuthResponse.SetError(test.mockSASLAuthErr)
}
if test.mockHandshakeErr != ErrNoError {
mockSASLHandshakeResponse = mockSASLHandshakeResponse.SetError(test.mockHandshakeErr)
}

mockBroker.SetHandlerByMap(map[string]MockResponse{
"SaslAuthenticateRequest": mockSASLAuthResponse,
"SaslHandshakeRequest": mockSASLHandshakeResponse,
})

conf := NewConfig()
conf.Net.SASL.Mechanism = SASLTypeSCRAMSHA512
conf.Net.SASL.SCRAMClient = test.scramClient

broker.conf = conf
dialer := net.Dialer{
Timeout: conf.Net.DialTimeout,
KeepAlive: conf.Net.KeepAlive,
LocalAddr: conf.Net.LocalAddr,
}

conn, err := dialer.Dial("tcp", mockBroker.listener.Addr().String())

if err != nil {
t.Fatal(err)
}

broker.conn = conn

err = broker.authenticateViaSASL()

if test.mockSASLAuthErr != ErrNoError {
if test.mockSASLAuthErr != err {
t.Errorf("[%d]:[%s] Expected %s SASL authentication error, got %s\n", i, test.name, test.mockHandshakeErr, err)
}
} else if test.mockHandshakeErr != ErrNoError {
if test.mockHandshakeErr != err {
t.Errorf("[%d]:[%s] Expected %s handshake error, got %s\n", i, test.name, test.mockHandshakeErr, err)
}
} else if test.expectClientErr && err == nil {
t.Errorf("[%d]:[%s] Expected a client error and got none\n", i, test.name)
} else if !test.expectClientErr && err != nil {
t.Errorf("[%d]:[%s] Unexpected error, got %s\n", i, test.name, err)
}

mockBroker.Close()
}
}

func TestBuildClientInitialResponse(t *testing.T) {

testTable := []struct {
Expand Down
36 changes: 27 additions & 9 deletions config.go
Original file line number Diff line number Diff line change
Expand Up @@ -61,9 +61,14 @@ type Config struct {
// (defaults to true). You should only set this to false if you're using
// a non-Kafka SASL proxy.
Handshake bool
//username and password for SASL/PLAIN authentication
//username and password for SASL/PLAIN or SASL/SCRAM authentication
User string
Password string
// authz id used for SASL/SCRAM authentication
SCRAMAuthzID string
// SCRAMClient is a user provided implementation of a SCRAM
// client used to perform the SCRAM exchange with the server.
SCRAMClient SCRAMClient
// TokenProvider is a user-defined callback for generating
// access tokens for SASL/OAUTHBEARER auth. See the
// AccessTokenProvider interface docs for proper implementation
Expand Down Expand Up @@ -475,22 +480,35 @@ func (c *Config) Validate() error {
case c.Net.KeepAlive < 0:
return ConfigurationError("Net.KeepAlive must be >= 0")
case c.Net.SASL.Enable:
// For backwards compatibility, empty mechanism value defaults to PLAIN
isSASLPlain := len(c.Net.SASL.Mechanism) == 0 || c.Net.SASL.Mechanism == SASLTypePlaintext
if isSASLPlain {
if c.Net.SASL.Mechanism == "" {
c.Net.SASL.Mechanism = SASLTypePlaintext
}

switch c.Net.SASL.Mechanism {
case SASLTypePlaintext:
if c.Net.SASL.User == "" {
return ConfigurationError("Net.SASL.User must not be empty when SASL is enabled")
}
if c.Net.SASL.Password == "" {
return ConfigurationError("Net.SASL.Password must not be empty when SASL is enabled")
}
} else if c.Net.SASL.Mechanism == SASLTypeOAuth {
case SASLTypeOAuth:
if c.Net.SASL.TokenProvider == nil {
return ConfigurationError("An AccessTokenProvider instance must be provided to Net.SASL.User.TokenProvider")
return ConfigurationError("An AccessTokenProvider instance must be provided to Net.SASL.TokenProvider")
}
case SASLTypeSCRAMSHA256, SASLTypeSCRAMSHA512:
if c.Net.SASL.User == "" {
return ConfigurationError("Net.SASL.User must not be empty when SASL is enabled")
}
if c.Net.SASL.Password == "" {
return ConfigurationError("Net.SASL.Password must not be empty when SASL is enabled")
}
if c.Net.SASL.SCRAMClient == nil {
return ConfigurationError("A SCRAMClient instance must be provided to Net.SASL.SCRAMClient")
}
} else {
msg := fmt.Sprintf("The SASL mechanism configuration is invalid. Possible values are `%s` and `%s`",
SASLTypeOAuth, SASLTypePlaintext)
default:
msg := fmt.Sprintf("The SASL mechanism configuration is invalid. Possible values are `%s`, `%s`, `%s` and `%s`",
SASLTypeOAuth, SASLTypePlaintext, SASLTypeSCRAMSHA256, SASLTypeSCRAMSHA512)
return ConfigurationError(msg)
}
}
Expand Down
Loading