diff --git a/broker.go b/broker.go index dd01e4ef1..d30840b1e 100644 --- a/broker.go +++ b/broker.go @@ -941,7 +941,7 @@ func (b *Broker) authenticateViaSASL() error { case SASLTypeOAuth: return b.sendAndReceiveSASLOAuth(b.conf.Net.SASL.TokenProvider) case SASLTypeSCRAMSHA256, SASLTypeSCRAMSHA512: - return b.sendAndReceiveSASLSCRAMv1() + return b.sendAndReceiveSASLSCRAM() case SASLTypeGSSAPI: return b.sendAndReceiveKerberos() default: @@ -1180,6 +1180,70 @@ func (b *Broker) sendClientMessage(message []byte) (bool, error) { return isChallenge, err } +func (b *Broker) sendAndReceiveSASLSCRAM() error { + if b.conf.Net.SASL.Version == SASLHandshakeV0 { + return b.sendAndReceiveSASLSCRAMv0() + } + return b.sendAndReceiveSASLSCRAMv1() +} + +func (b *Broker) sendAndReceiveSASLSCRAMv0() error { + if err := b.sendAndReceiveSASLHandshake(b.conf.Net.SASL.Mechanism, SASLHandshakeV0); err != nil { + return err + } + + scramClient := b.conf.Net.SASL.SCRAMClientGeneratorFunc() + if err := scramClient.Begin(b.conf.Net.SASL.User, b.conf.Net.SASL.Password, b.conf.Net.SASL.SCRAMAuthzID); err != nil { + return fmt.Errorf("failed to start SCRAM exchange with the server: %s", err.Error()) + } + + msg, err := scramClient.Step("") + if err != nil { + return fmt.Errorf("failed to advance the SCRAM exchange: %s", err.Error()) + } + + for !scramClient.Done() { + requestTime := time.Now() + // Will be decremented in updateIncomingCommunicationMetrics (except error) + b.addRequestInFlightMetrics(1) + length := len(msg) + authBytes := make([]byte, length+4) //4 byte length header + auth data + binary.BigEndian.PutUint32(authBytes, uint32(length)) + copy(authBytes[4:], []byte(msg)) + _, err := b.write(authBytes) + b.updateOutgoingCommunicationMetrics(length + 4) + if err != nil { + b.addRequestInFlightMetrics(-1) + Logger.Printf("Failed to write SASL auth header to broker %s: %s\n", b.addr, err.Error()) + return err + } + b.correlationID++ + header := make([]byte, 4) + _, err = b.readFull(header) + if err != nil { + b.addRequestInFlightMetrics(-1) + Logger.Printf("Failed to read response header while authenticating with SASL to broker %s: %s\n", b.addr, err.Error()) + return err + } + payload := make([]byte, int32(binary.BigEndian.Uint32(header))) + n, err := b.readFull(payload) + if err != nil { + b.addRequestInFlightMetrics(-1) + Logger.Printf("Failed to read response payload while authenticating with SASL to broker %s: %s\n", b.addr, err.Error()) + return err + } + b.updateIncomingCommunicationMetrics(n+4, time.Since(requestTime)) + msg, err = scramClient.Step(string(payload)) + if err != nil { + Logger.Println("SASL authentication failed", err) + return err + } + } + + Logger.Println("SASL authentication succeeded") + return nil +} + func (b *Broker) sendAndReceiveSASLSCRAMv1() error { if err := b.sendAndReceiveSASLHandshake(b.conf.Net.SASL.Mechanism, SASLHandshakeV1); err != nil { return err diff --git a/broker_test.go b/broker_test.go index 2fa40ceb4..a91e35254 100644 --- a/broker_test.go +++ b/broker_test.go @@ -359,9 +359,11 @@ func TestSASLSCRAMSHAXXX(t *testing.T) { conf := NewTestConfig() conf.Net.SASL.Mechanism = SASLTypeSCRAMSHA512 + conf.Net.SASL.Version = SASLHandshakeV1 conf.Net.SASL.SCRAMClientGeneratorFunc = func() SCRAMClient { return test.scramClient } broker.conf = conf + broker.conf.Version = V1_0_0_0 dialer := net.Dialer{ Timeout: conf.Net.DialTimeout, KeepAlive: conf.Net.KeepAlive,