Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Extend RabbitMQ scaler to support count unacked messages #700

Merged
merged 2 commits into from
Apr 8, 2020
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
147 changes: 120 additions & 27 deletions pkg/scalers/rabbitmq_scaler.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,13 @@ package scalers

import (
"context"
"encoding/json"
"fmt"
"io/ioutil"
"net/http"
"net/url"
"strconv"
"time"

"github.com/streadway/amqp"
v2beta1 "k8s.io/api/autoscaling/v2beta1"
Expand All @@ -17,6 +22,8 @@ import (
const (
rabbitQueueLengthMetricName = "queueLength"
rabbitMetricType = "External"
rabbitIncludeUnacked = "includeUnacked"
defaultIncludeUnacked = false
)

type rabbitMQScaler struct {
Expand All @@ -26,9 +33,17 @@ type rabbitMQScaler struct {
}

type rabbitMQMetadata struct {
queueName string
host string
queueLength int
queueName string
host string // connection string for AMQP protocol
apiHost string // connection string for management API requests
queueLength int
includeUnacked bool // if true uses HTTP API and requires apiHost, if false uses AMQP and requires host
}

type queueInfo struct {
Messages int `json:"messages"`
MessagesUnacknowledged int `json:"messages_unacknowledged"`
Name string `json:"name"`
}

var rabbitmqLog = logf.Log.WithName("rabbitmq_scaler")
Expand All @@ -40,33 +55,62 @@ func NewRabbitMQScaler(resolvedEnv, metadata, authParams map[string]string) (Sca
return nil, fmt.Errorf("error parsing rabbitmq metadata: %s", err)
}

conn, ch, err := getConnectionAndChannel(meta.host)
if err != nil {
return nil, fmt.Errorf("error establishing rabbitmq connection: %s", err)
}
if meta.includeUnacked {
return &rabbitMQScaler{metadata: meta}, nil
} else {
conn, ch, err := getConnectionAndChannel(meta.host)
if err != nil {
return nil, fmt.Errorf("error establishing rabbitmq connection: %s", err)
}

return &rabbitMQScaler{
metadata: meta,
connection: conn,
channel: ch,
}, nil
return &rabbitMQScaler{
metadata: meta,
connection: conn,
channel: ch,
}, nil
}
}

func parseRabbitMQMetadata(resolvedEnv, metadata, authParams map[string]string) (*rabbitMQMetadata, error) {
meta := rabbitMQMetadata{}

if val, ok := authParams["host"]; ok {
meta.host = val
} else if val, ok := metadata["host"]; ok {
hostSetting := val
meta.includeUnacked = defaultIncludeUnacked
if val, ok := metadata[rabbitIncludeUnacked]; ok {
includeUnacked, err := strconv.ParseBool(val)
if err != nil {
return nil, fmt.Errorf("includeUnacked parsing error %s", err.Error())
}
meta.includeUnacked = includeUnacked
}

if meta.includeUnacked {
if val, ok := authParams["apiHost"]; ok {
meta.apiHost = val
} else if val, ok := metadata["apiHost"]; ok {
hostSetting := val

if val, ok := resolvedEnv[hostSetting]; ok {
if val, ok := resolvedEnv[hostSetting]; ok {
meta.apiHost = val
}
}

if meta.apiHost == "" {
return nil, fmt.Errorf("no apiHost setting given")
}
} else {
if val, ok := authParams["host"]; ok {
meta.host = val
} else if val, ok := metadata["host"]; ok {
hostSetting := val

if val, ok := resolvedEnv[hostSetting]; ok {
meta.host = val
}
}
}

if meta.host == "" {
return nil, fmt.Errorf("no host setting given")
if meta.host == "" {
return nil, fmt.Errorf("no host setting given")
}
}

if val, ok := metadata["queueName"]; ok {
Expand Down Expand Up @@ -105,10 +149,12 @@ func getConnectionAndChannel(host string) (*amqp.Connection, *amqp.Channel, erro

// Close disposes of RabbitMQ connections
func (s *rabbitMQScaler) Close() error {
err := s.connection.Close()
if err != nil {
rabbitmqLog.Error(err, "Error closing rabbitmq connection")
return err
if s.connection != nil {
err := s.connection.Close()
if err != nil {
rabbitmqLog.Error(err, "Error closing rabbitmq connection")
return err
}
}
return nil
}
Expand All @@ -124,12 +170,59 @@ func (s *rabbitMQScaler) IsActive(ctx context.Context) (bool, error) {
}

func (s *rabbitMQScaler) getQueueMessages() (int, error) {
items, err := s.channel.QueueInspect(s.metadata.queueName)
if s.metadata.includeUnacked {
info, err := s.getQueueInfoViaHttp()
if err != nil {
return -1, err
} else {
return info.Messages + info.MessagesUnacknowledged, nil
}
} else {
items, err := s.channel.QueueInspect(s.metadata.queueName)
if err != nil {
return -1, err
} else {
return items.Messages, nil
}
}
}

func getJson(url string, target interface{}) error {
var client = &http.Client{Timeout: 5 * time.Second}
r, err := client.Get(url)
if err != nil {
return -1, err
return err
}
defer r.Body.Close()

return items.Messages, nil
if r.StatusCode == 200 {
return json.NewDecoder(r.Body).Decode(target)
} else {
body, _ := ioutil.ReadAll(r.Body)
return fmt.Errorf("error requesting rabbitMQ API status: %s, response: %s, from: %s", r.Status, body, url)
}
}

func (s *rabbitMQScaler) getQueueInfoViaHttp() (*queueInfo, error) {
parsedUrl, err := url.Parse(s.metadata.apiHost)

if err != nil {
return nil, err
}

vhost := parsedUrl.Path
parsedUrl.Path = ""

getQueueInfoManagementURI := fmt.Sprintf("%s/%s%s/%s", parsedUrl.String(), "api/queues", vhost, s.metadata.queueName)

info := queueInfo{}
err = getJson(getQueueInfoManagementURI, &info)

if err != nil {
return nil, err
} else {
return &info, nil
}
}

// GetMetricSpecForScaling returns the MetricSpec for the Horizontal Pod Autoscaler
Expand Down
77 changes: 75 additions & 2 deletions pkg/scalers/rabbitmq_scaler_test.go
Original file line number Diff line number Diff line change
@@ -1,11 +1,17 @@
package scalers

import (
"context"
"fmt"
"net/http"
"net/http/httptest"
"strings"
"testing"
)

const (
host = "myHostSecret"
host = "myHostSecret"
apiHost = "myApiHostSecret"
)

type parseRabbitMQMetadataTestData struct {
Expand All @@ -15,7 +21,8 @@ type parseRabbitMQMetadataTestData struct {
}

var sampleRabbitMqResolvedEnv = map[string]string{
host: "none",
host: "amqp://user:sercet@somehost.com:5236/vhost",
apiHost: "https://user:secret@somehost.com/vhost",
}

var testRabbitMQMetadata = []parseRabbitMQMetadataTestData{
Expand All @@ -31,6 +38,8 @@ var testRabbitMQMetadata = []parseRabbitMQMetadataTestData{
{map[string]string{"queueLength": "10", "host": host}, true, map[string]string{}},
// host defined in authParams
{map[string]string{"queueLength": "10"}, true, map[string]string{"host": host}},
// properly formed metadata with includeUnacked
{map[string]string{"queueLength": "10", "queueName": "sample", "apiHost": apiHost, "includeUnacked": "true"}, false, map[string]string{}},
}

func TestRabbitMQParseMetadata(t *testing.T) {
Expand All @@ -44,3 +53,67 @@ func TestRabbitMQParseMetadata(t *testing.T) {
}
}
}

type getQueueInfoTestData struct {
response string
responseStatus int
isActive bool
}

var testQueueInfoTestData = []getQueueInfoTestData{
{`{"messages": 4, "messages_unacknowledged": 1, "name": "evaluate_trials"}`, http.StatusOK, true},
{`{"messages": 0, "messages_unacknowledged": 1, "name": "evaluate_trials"}`, http.StatusOK, true},
{`{"messages": 1, "messages_unacknowledged": 0, "name": "evaluate_trials"}`, http.StatusOK, true},
{`{"messages": 0, "messages_unacknowledged": 0, "name": "evaluate_trials"}`, http.StatusOK, false},
{`Password is incorrect`, http.StatusUnauthorized, false},
}

func TestGetQueueInfo(t *testing.T) {
for _, testData := range testQueueInfoTestData {
var apiStub = httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
expeced_path := "/api/queues/myhost/evaluate_trials"
if r.RequestURI != expeced_path {
t.Error("Expect request path to =", expeced_path, "but it is", r.RequestURI)
}

w.WriteHeader(testData.responseStatus)
w.Write([]byte(testData.response))
}))

resolvedEnv := map[string]string{apiHost: fmt.Sprintf("%s/%s", apiStub.URL, "myhost")}

metadata := map[string]string{
"queueLength": "10",
"queueName": "evaluate_trials",
"apiHost": apiHost,
"includeUnacked": "true",
}

s, err := NewRabbitMQScaler(resolvedEnv, metadata, map[string]string{})

if err != nil {
t.Error("Expect success", err)
}

ctx := context.TODO()
active, err := s.IsActive(ctx)

if testData.responseStatus == http.StatusOK {
if err != nil {
t.Error("Expect success", err)
}

if active != testData.isActive {
if testData.isActive {
t.Error("Expect to be active")
} else {
t.Error("Expect to not be active")
}
}
} else {
if !strings.Contains(err.Error(), testData.response) {
t.Error("Expect error to be like '", testData.response, "' but it's '", err, "'")
}
}
}
}