diff --git a/broker.go b/broker.go index 4b3ae84d2..2d75a8cd3 100644 --- a/broker.go +++ b/broker.go @@ -1167,7 +1167,7 @@ func (b *Broker) sendAndReceiveSASLHandshake(saslType SASLMechanism, version int return res.Err } - DebugLogger.Print("Successful SASL handshake. Available mechanisms: ", res.EnabledMechanisms) + DebugLogger.Print("Completed pre-auth SASL handshake. Available mechanisms: ", res.EnabledMechanisms) return nil } @@ -1268,7 +1268,9 @@ func (b *Broker) sendAndReceiveV1SASLPlainAuth() error { // With v1 sasl we get an error message set in the response we can return if err != nil { - Logger.Printf("Error returned from broker during SASL flow %s: %s\n", b.addr, err.Error()) + Logger.Printf( + "Error returned from broker %s during SASL authentication: %v\n", + b.addr, err.Error()) return err } @@ -1579,7 +1581,11 @@ func (b *Broker) receiveSASLServerResponse(res *SaslAuthenticateResponse, correl } if !errors.Is(res.Err, ErrNoError) { - return bytesRead, res.Err + var err error = res.Err + if res.ErrorMessage != nil { + err = Wrap(res.Err, errors.New(*res.ErrorMessage)) + } + return bytesRead, err } return bytesRead, nil diff --git a/errors.go b/errors.go index ba27d38b3..507002bfa 100644 --- a/errors.go +++ b/errors.go @@ -3,6 +3,7 @@ package sarama import ( "errors" "fmt" + "strings" "github.com/hashicorp/go-multierror" ) @@ -63,8 +64,23 @@ var ErrReassignPartitions = errors.New("failed to reassign partitions for topic" // ErrDeleteRecords is the type of error returned when fail to delete the required records var ErrDeleteRecords = errors.New("kafka server: failed to delete records") -// The formatter used to format multierrors -var MultiErrorFormat multierror.ErrorFormatFunc +// MultiErrorFormat specifies the formatter applied to format multierrors. The +// default implementation is a consensed version of the hashicorp/go-multierror +// default one +var MultiErrorFormat multierror.ErrorFormatFunc = func(es []error) string { + if len(es) == 1 { + return es[0].Error() + } + + points := make([]string, len(es)) + for i, err := range es { + points[i] = fmt.Sprintf("* %s", err) + } + + return fmt.Sprintf( + "%d errors occurred:\n\t%s\n", + len(es), strings.Join(points, "\n\t")) +} type sentinelError struct { sentinel error diff --git a/errors_test.go b/errors_test.go index d10649b80..4efdd9d4e 100644 --- a/errors_test.go +++ b/errors_test.go @@ -7,12 +7,12 @@ import ( "testing" ) -func TestSentinelWithWrappedError(t *testing.T) { +func TestSentinelWithSingleWrappedError(t *testing.T) { t.Parallel() myNetError := &net.OpError{Op: "mock", Err: errors.New("op error")} error := Wrap(ErrOutOfBrokers, myNetError) - expected := fmt.Sprintf("%s: 1 error occurred:\n\t* %s\n\n", ErrOutOfBrokers, myNetError) + expected := fmt.Sprintf("%s: %s", ErrOutOfBrokers, myNetError) actual := error.Error() if actual != expected { t.Errorf("unexpected value '%s' vs '%v'", expected, actual)