diff --git a/pkg/scalers/rabbitmq_scaler.go b/pkg/scalers/rabbitmq_scaler.go index 5b43aa67049..32a059bf41d 100644 --- a/pkg/scalers/rabbitmq_scaler.go +++ b/pkg/scalers/rabbitmq_scaler.go @@ -2,8 +2,13 @@ package scalers import ( "context" + "encoding/json" "fmt" + "io/ioutil" + "net/http" + "net/url" "strconv" + "time" "github.com/streadway/amqp" v2beta2 "k8s.io/api/autoscaling/v2beta2" @@ -17,6 +22,8 @@ import ( const ( rabbitQueueLengthMetricName = "queueLength" rabbitMetricType = "External" + rabbitIncludeUnacked = "includeUnacked" + defaultIncludeUnacked = false ) type rabbitMQScaler struct { @@ -26,9 +33,17 @@ type rabbitMQScaler struct { } type rabbitMQMetadata struct { - queueName string - host string - queueLength int + queueName string + host string // connection string for AMQP protocol + apiHost string // connection string for management API requests + queueLength int + includeUnacked bool // if true uses HTTP API and requires apiHost, if false uses AMQP and requires host +} + +type queueInfo struct { + Messages int `json:"messages"` + MessagesUnacknowledged int `json:"messages_unacknowledged"` + Name string `json:"name"` } var rabbitmqLog = logf.Log.WithName("rabbitmq_scaler") @@ -40,33 +55,62 @@ func NewRabbitMQScaler(resolvedEnv, metadata, authParams map[string]string) (Sca return nil, fmt.Errorf("error parsing rabbitmq metadata: %s", err) } - conn, ch, err := getConnectionAndChannel(meta.host) - if err != nil { - return nil, fmt.Errorf("error establishing rabbitmq connection: %s", err) - } + if meta.includeUnacked { + return &rabbitMQScaler{metadata: meta}, nil + } else { + conn, ch, err := getConnectionAndChannel(meta.host) + if err != nil { + return nil, fmt.Errorf("error establishing rabbitmq connection: %s", err) + } - return &rabbitMQScaler{ - metadata: meta, - connection: conn, - channel: ch, - }, nil + return &rabbitMQScaler{ + metadata: meta, + connection: conn, + channel: ch, + }, nil + } } func parseRabbitMQMetadata(resolvedEnv, metadata, authParams map[string]string) (*rabbitMQMetadata, error) { meta := rabbitMQMetadata{} - if val, ok := authParams["host"]; ok { - meta.host = val - } else if val, ok := metadata["host"]; ok { - hostSetting := val + meta.includeUnacked = defaultIncludeUnacked + if val, ok := metadata[rabbitIncludeUnacked]; ok { + includeUnacked, err := strconv.ParseBool(val) + if err != nil { + return nil, fmt.Errorf("includeUnacked parsing error %s", err.Error()) + } + meta.includeUnacked = includeUnacked + } + + if meta.includeUnacked { + if val, ok := authParams["apiHost"]; ok { + meta.apiHost = val + } else if val, ok := metadata["apiHost"]; ok { + hostSetting := val - if val, ok := resolvedEnv[hostSetting]; ok { + if val, ok := resolvedEnv[hostSetting]; ok { + meta.apiHost = val + } + } + + if meta.apiHost == "" { + return nil, fmt.Errorf("no apiHost setting given") + } + } else { + if val, ok := authParams["host"]; ok { meta.host = val + } else if val, ok := metadata["host"]; ok { + hostSetting := val + + if val, ok := resolvedEnv[hostSetting]; ok { + meta.host = val + } } - } - if meta.host == "" { - return nil, fmt.Errorf("no host setting given") + if meta.host == "" { + return nil, fmt.Errorf("no host setting given") + } } if val, ok := metadata["queueName"]; ok { @@ -105,10 +149,12 @@ func getConnectionAndChannel(host string) (*amqp.Connection, *amqp.Channel, erro // Close disposes of RabbitMQ connections func (s *rabbitMQScaler) Close() error { - err := s.connection.Close() - if err != nil { - rabbitmqLog.Error(err, "Error closing rabbitmq connection") - return err + if s.connection != nil { + err := s.connection.Close() + if err != nil { + rabbitmqLog.Error(err, "Error closing rabbitmq connection") + return err + } } return nil } @@ -124,12 +170,59 @@ func (s *rabbitMQScaler) IsActive(ctx context.Context) (bool, error) { } func (s *rabbitMQScaler) getQueueMessages() (int, error) { - items, err := s.channel.QueueInspect(s.metadata.queueName) + if s.metadata.includeUnacked { + info, err := s.getQueueInfoViaHttp() + if err != nil { + return -1, err + } else { + return info.Messages + info.MessagesUnacknowledged, nil + } + } else { + items, err := s.channel.QueueInspect(s.metadata.queueName) + if err != nil { + return -1, err + } else { + return items.Messages, nil + } + } +} + +func getJson(url string, target interface{}) error { + var client = &http.Client{Timeout: 5 * time.Second} + r, err := client.Get(url) if err != nil { - return -1, err + return err } + defer r.Body.Close() - return items.Messages, nil + if r.StatusCode == 200 { + return json.NewDecoder(r.Body).Decode(target) + } else { + body, _ := ioutil.ReadAll(r.Body) + return fmt.Errorf("error requesting rabbitMQ API status: %s, response: %s, from: %s", r.Status, body, url) + } +} + +func (s *rabbitMQScaler) getQueueInfoViaHttp() (*queueInfo, error) { + parsedUrl, err := url.Parse(s.metadata.apiHost) + + if err != nil { + return nil, err + } + + vhost := parsedUrl.Path + parsedUrl.Path = "" + + getQueueInfoManagementURI := fmt.Sprintf("%s/%s%s/%s", parsedUrl.String(), "api/queues", vhost, s.metadata.queueName) + + info := queueInfo{} + err = getJson(getQueueInfoManagementURI, &info) + + if err != nil { + return nil, err + } else { + return &info, nil + } } // GetMetricSpecForScaling returns the MetricSpec for the Horizontal Pod Autoscaler diff --git a/pkg/scalers/rabbitmq_scaler_test.go b/pkg/scalers/rabbitmq_scaler_test.go index 157f9c9cecb..0a05a6751a7 100644 --- a/pkg/scalers/rabbitmq_scaler_test.go +++ b/pkg/scalers/rabbitmq_scaler_test.go @@ -1,11 +1,17 @@ package scalers import ( + "context" + "fmt" + "net/http" + "net/http/httptest" + "strings" "testing" ) const ( - host = "myHostSecret" + host = "myHostSecret" + apiHost = "myApiHostSecret" ) type parseRabbitMQMetadataTestData struct { @@ -15,7 +21,8 @@ type parseRabbitMQMetadataTestData struct { } var sampleRabbitMqResolvedEnv = map[string]string{ - host: "none", + host: "amqp://user:sercet@somehost.com:5236/vhost", + apiHost: "https://user:secret@somehost.com/vhost", } var testRabbitMQMetadata = []parseRabbitMQMetadataTestData{ @@ -31,6 +38,8 @@ var testRabbitMQMetadata = []parseRabbitMQMetadataTestData{ {map[string]string{"queueLength": "10", "host": host}, true, map[string]string{}}, // host defined in authParams {map[string]string{"queueLength": "10"}, true, map[string]string{"host": host}}, + // properly formed metadata with includeUnacked + {map[string]string{"queueLength": "10", "queueName": "sample", "apiHost": apiHost, "includeUnacked": "true"}, false, map[string]string{}}, } func TestRabbitMQParseMetadata(t *testing.T) { @@ -44,3 +53,67 @@ func TestRabbitMQParseMetadata(t *testing.T) { } } } + +type getQueueInfoTestData struct { + response string + responseStatus int + isActive bool +} + +var testQueueInfoTestData = []getQueueInfoTestData{ + {`{"messages": 4, "messages_unacknowledged": 1, "name": "evaluate_trials"}`, http.StatusOK, true}, + {`{"messages": 0, "messages_unacknowledged": 1, "name": "evaluate_trials"}`, http.StatusOK, true}, + {`{"messages": 1, "messages_unacknowledged": 0, "name": "evaluate_trials"}`, http.StatusOK, true}, + {`{"messages": 0, "messages_unacknowledged": 0, "name": "evaluate_trials"}`, http.StatusOK, false}, + {`Password is incorrect`, http.StatusUnauthorized, false}, +} + +func TestGetQueueInfo(t *testing.T) { + for _, testData := range testQueueInfoTestData { + var apiStub = httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + expeced_path := "/api/queues/myhost/evaluate_trials" + if r.RequestURI != expeced_path { + t.Error("Expect request path to =", expeced_path, "but it is", r.RequestURI) + } + + w.WriteHeader(testData.responseStatus) + w.Write([]byte(testData.response)) + })) + + resolvedEnv := map[string]string{apiHost: fmt.Sprintf("%s/%s", apiStub.URL, "myhost")} + + metadata := map[string]string{ + "queueLength": "10", + "queueName": "evaluate_trials", + "apiHost": apiHost, + "includeUnacked": "true", + } + + s, err := NewRabbitMQScaler(resolvedEnv, metadata, map[string]string{}) + + if err != nil { + t.Error("Expect success", err) + } + + ctx := context.TODO() + active, err := s.IsActive(ctx) + + if testData.responseStatus == http.StatusOK { + if err != nil { + t.Error("Expect success", err) + } + + if active != testData.isActive { + if testData.isActive { + t.Error("Expect to be active") + } else { + t.Error("Expect to not be active") + } + } + } else { + if !strings.Contains(err.Error(), testData.response) { + t.Error("Expect error to be like '", testData.response, "' but it's '", err, "'") + } + } + } +}