diff --git a/CHANGELOG.md b/CHANGELOG.md index c27d049acf8..6451ee44b6c 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -51,6 +51,7 @@ - Improve metric name creation to be unique using scaler index inside the scaler ([#2161](https://github.com/kedacore/keda/pull/2161)) - Improve error message if `IdleReplicaCount` are equal to `MinReplicaCount` to be the same as the check ([#2212](https://github.com/kedacore/keda/pull/2212)) - Improve Cloudwatch Scaler metric exporting logic ([#2243](https://github.com/kedacore/keda/pull/2243)) +- Refactor aws related scalers to reuse the aws clients instead of creating a new one for every GetMetrics call([#2255](https://github.com/kedacore/keda/pull/2255)) ### Breaking Changes diff --git a/pkg/scalers/aws_cloudwatch_scaler.go b/pkg/scalers/aws_cloudwatch_scaler.go index d8b3100d784..50eff5a9c30 100644 --- a/pkg/scalers/aws_cloudwatch_scaler.go +++ b/pkg/scalers/aws_cloudwatch_scaler.go @@ -12,6 +12,7 @@ import ( "github.com/aws/aws-sdk-go/aws/credentials/stscreds" "github.com/aws/aws-sdk-go/aws/session" "github.com/aws/aws-sdk-go/service/cloudwatch" + "github.com/aws/aws-sdk-go/service/cloudwatch/cloudwatchiface" "k8s.io/api/autoscaling/v2beta2" "k8s.io/apimachinery/pkg/api/resource" metav1 "k8s.io/apimachinery/pkg/apis/meta/v1" @@ -32,6 +33,7 @@ const ( type awsCloudwatchScaler struct { metadata *awsCloudwatchMetadata + cwClient cloudwatchiface.CloudWatchAPI } type awsCloudwatchMetadata struct { @@ -67,6 +69,7 @@ func NewAwsCloudwatchScaler(config *ScalerConfig) (Scaler, error) { return &awsCloudwatchScaler{ metadata: meta, + cwClient: createCloudwatchClient(meta), }, nil } @@ -102,6 +105,32 @@ func getFloatMetadataValue(metadata map[string]string, key string, required bool return defaultValue, nil } +func createCloudwatchClient(metadata *awsCloudwatchMetadata) *cloudwatch.CloudWatch { + sess := session.Must(session.NewSession(&aws.Config{ + Region: aws.String(metadata.awsRegion), + })) + + var cloudwatchClient *cloudwatch.CloudWatch + if metadata.awsAuthorization.podIdentityOwner { + creds := credentials.NewStaticCredentials(metadata.awsAuthorization.awsAccessKeyID, metadata.awsAuthorization.awsSecretAccessKey, "") + + if metadata.awsAuthorization.awsRoleArn != "" { + creds = stscreds.NewCredentials(sess, metadata.awsAuthorization.awsRoleArn) + } + + cloudwatchClient = cloudwatch.New(sess, &aws.Config{ + Region: aws.String(metadata.awsRegion), + Credentials: creds, + }) + } else { + cloudwatchClient = cloudwatch.New(sess, &aws.Config{ + Region: aws.String(metadata.awsRegion), + }) + } + + return cloudwatchClient +} + func parseAwsCloudwatchMetadata(config *ScalerConfig) (*awsCloudwatchMetadata, error) { var err error meta := awsCloudwatchMetadata{} @@ -236,6 +265,12 @@ func checkMetricStatPeriod(period int64) error { return nil } +func computeQueryWindow(current time.Time, metricPeriodSec, metricEndTimeOffsetSec, metricCollectionTimeSec int64) (startTime, endTime time.Time) { + endTime = current.Add(time.Second * -1 * time.Duration(metricEndTimeOffsetSec)).Truncate(time.Duration(metricPeriodSec) * time.Second) + startTime = endTime.Add(time.Second * -1 * time.Duration(metricCollectionTimeSec)) + return +} + func (c *awsCloudwatchScaler) GetMetrics(ctx context.Context, metricName string, metricSelector labels.Selector) ([]external_metrics.ExternalMetricValue, error) { metricValue, err := c.GetCloudwatchMetrics() @@ -283,28 +318,6 @@ func (c *awsCloudwatchScaler) Close(context.Context) error { } func (c *awsCloudwatchScaler) GetCloudwatchMetrics() (float64, error) { - sess := session.Must(session.NewSession(&aws.Config{ - Region: aws.String(c.metadata.awsRegion), - })) - - var cloudwatchClient *cloudwatch.CloudWatch - if c.metadata.awsAuthorization.podIdentityOwner { - creds := credentials.NewStaticCredentials(c.metadata.awsAuthorization.awsAccessKeyID, c.metadata.awsAuthorization.awsSecretAccessKey, "") - - if c.metadata.awsAuthorization.awsRoleArn != "" { - creds = stscreds.NewCredentials(sess, c.metadata.awsAuthorization.awsRoleArn) - } - - cloudwatchClient = cloudwatch.New(sess, &aws.Config{ - Region: aws.String(c.metadata.awsRegion), - Credentials: creds, - }) - } else { - cloudwatchClient = cloudwatch.New(sess, &aws.Config{ - Region: aws.String(c.metadata.awsRegion), - }) - } - dimensions := []*cloudwatch.Dimension{} for i := range c.metadata.dimensionName { dimensions = append(dimensions, &cloudwatch.Dimension{ @@ -313,8 +326,7 @@ func (c *awsCloudwatchScaler) GetCloudwatchMetrics() (float64, error) { }) } - endTime := time.Now().Add(time.Second * -1 * time.Duration(c.metadata.metricEndTimeOffset)).Truncate(time.Duration(c.metadata.metricStatPeriod) * time.Second) - startTime := endTime.Add(time.Second * -1 * time.Duration(c.metadata.metricCollectionTime)) + startTime, endTime := computeQueryWindow(time.Now(), c.metadata.metricStatPeriod, c.metadata.metricEndTimeOffset, c.metadata.metricCollectionTime) var metricUnit *string if c.metadata.metricUnit != "" { @@ -343,7 +355,7 @@ func (c *awsCloudwatchScaler) GetCloudwatchMetrics() (float64, error) { }, } - output, err := cloudwatchClient.GetMetricData(&input) + output, err := c.cwClient.GetMetricData(&input) if err != nil { cloudwatchLog.Error(err, "Failed to get output") @@ -352,7 +364,7 @@ func (c *awsCloudwatchScaler) GetCloudwatchMetrics() (float64, error) { cloudwatchLog.V(1).Info("Received Metric Data", "data", output) var metricValue float64 - if output.MetricDataResults[0].Values != nil { + if len(output.MetricDataResults) > 0 && len(output.MetricDataResults[0].Values) > 0 { metricValue = *output.MetricDataResults[0].Values[0] } else { return -1, fmt.Errorf("metric data not received") diff --git a/pkg/scalers/aws_cloudwatch_test.go b/pkg/scalers/aws_cloudwatch_test.go index 81f40a369d1..ff77b0623d7 100644 --- a/pkg/scalers/aws_cloudwatch_test.go +++ b/pkg/scalers/aws_cloudwatch_test.go @@ -2,13 +2,24 @@ package scalers import ( "context" + "errors" "testing" -) + "time" -var testAWSCloudwatchRoleArn = "none" + "github.com/aws/aws-sdk-go/aws" + "github.com/aws/aws-sdk-go/service/cloudwatch" + "github.com/aws/aws-sdk-go/service/cloudwatch/cloudwatchiface" + "github.com/stretchr/testify/assert" + "k8s.io/apimachinery/pkg/labels" +) -var testAWSCloudwatchAccessKeyID = "none" -var testAWSCloudwatchSecretAccessKey = "none" +const ( + testAWSCloudwatchRoleArn = "none" + testAWSCloudwatchAccessKeyID = "none" + testAWSCloudwatchSecretAccessKey = "none" + testAWSCloudwatchErrorMetric = "Error" + testAWSCloudwatchNoValueMetric = "NoValue" +) var testAWSCloudwatchResolvedEnv = map[string]string{ "AWS_ACCESS_KEY": "none", @@ -314,6 +325,79 @@ var awsCloudwatchMetricIdentifiers = []awsCloudwatchMetricIdentifier{ {&testAWSCloudwatchMetadata[1], 3, "s3-aws-cloudwatch-AWS-SQS-QueueName-keda"}, } +var awsCloudwatchGetMetricTestData = []awsCloudwatchMetadata{ + { + namespace: "Custom", + metricsName: "HasData", + dimensionName: []string{"DIM"}, + dimensionValue: []string{"DIM_VALUE"}, + targetMetricValue: 100, + minMetricValue: 0, + metricCollectionTime: 60, + metricStat: "Average", + metricUnit: "SampleCount", + metricStatPeriod: 60, + metricEndTimeOffset: 60, + awsRegion: "us-west-2", + awsAuthorization: awsAuthorizationMetadata{podIdentityOwner: false}, + scalerIndex: 0, + }, + { + namespace: "Custom", + metricsName: "HasDataNoUnit", + dimensionName: []string{"DIM"}, + dimensionValue: []string{"DIM_VALUE"}, + targetMetricValue: 100, + minMetricValue: 0, + metricCollectionTime: 60, + metricStat: "Average", + metricUnit: "", + metricStatPeriod: 60, + metricEndTimeOffset: 60, + awsRegion: "us-west-2", + awsAuthorization: awsAuthorizationMetadata{podIdentityOwner: false}, + scalerIndex: 0, + }, + { + namespace: "Custom", + metricsName: "Error", + dimensionName: []string{"DIM"}, + dimensionValue: []string{"DIM_VALUE"}, + targetMetricValue: 100, + minMetricValue: 0, + metricCollectionTime: 60, + metricStat: "Average", + metricUnit: "", + metricStatPeriod: 60, + metricEndTimeOffset: 60, + awsRegion: "us-west-2", + awsAuthorization: awsAuthorizationMetadata{podIdentityOwner: false}, + scalerIndex: 0, + }, +} + +type mockCloudwatch struct { + cloudwatchiface.CloudWatchAPI +} + +func (m *mockCloudwatch) GetMetricData(input *cloudwatch.GetMetricDataInput) (*cloudwatch.GetMetricDataOutput, error) { + switch *input.MetricDataQueries[0].MetricStat.Metric.MetricName { + case testAWSCloudwatchErrorMetric: + return nil, errors.New("error") + case testAWSCloudwatchNoValueMetric: + return &cloudwatch.GetMetricDataOutput{ + MetricDataResults: []*cloudwatch.MetricDataResult{}, + }, nil + } + return &cloudwatch.GetMetricDataOutput{ + MetricDataResults: []*cloudwatch.MetricDataResult{ + { + Values: []*float64{aws.Float64(10)}, + }, + }, + }, nil +} + func TestCloudwatchParseMetadata(t *testing.T) { for _, testData := range testAWSCloudwatchMetadata { _, err := parseAwsCloudwatchMetadata(&ScalerConfig{TriggerMetadata: testData.metadata, ResolvedEnv: testAWSCloudwatchResolvedEnv, AuthParams: testData.authParams}) @@ -333,7 +417,7 @@ func TestAWSCloudwatchGetMetricSpecForScaling(t *testing.T) { if err != nil { t.Fatal("Could not parse metadata:", err) } - mockAWSCloudwatchScaler := awsCloudwatchScaler{meta} + mockAWSCloudwatchScaler := awsCloudwatchScaler{meta, &mockCloudwatch{}} metricSpec := mockAWSCloudwatchScaler.GetMetricSpecForScaling(ctx) metricName := metricSpec[0].External.Metric.Name @@ -342,3 +426,62 @@ func TestAWSCloudwatchGetMetricSpecForScaling(t *testing.T) { } } } + +func TestAWSCloudwatchScalerGetMetrics(t *testing.T) { + var selector labels.Selector + for _, meta := range awsCloudwatchGetMetricTestData { + mockAWSCloudwatchScaler := awsCloudwatchScaler{&meta, &mockCloudwatch{}} + value, err := mockAWSCloudwatchScaler.GetMetrics(context.Background(), meta.metricsName, selector) + switch meta.metricsName { + case testAWSCloudwatchErrorMetric: + assert.Error(t, err, "expect error because of cloudwatch api error") + case testAWSCloudwatchNoValueMetric: + assert.Error(t, err, "expect error because of no data return from cloudwatch") + default: + assert.EqualValues(t, int64(10.0), value[0].Value.Value()) + } + } +} + +type computeQueryWindowTestArgs struct { + name string + current string + metricPeriodSec int64 + metricEndTimeOffsetSec int64 + metricCollectionTimeSec int64 + expectedStartTime string + expectedEndTime string +} + +var awsCloudwatchComputeQueryWindowTestData = []computeQueryWindowTestArgs{ + { + name: "normal", + current: "2021-11-07T15:04:05.999Z", + metricPeriodSec: 60, + metricEndTimeOffsetSec: 0, + metricCollectionTimeSec: 60, + expectedStartTime: "2021-11-07T15:03:00Z", + expectedEndTime: "2021-11-07T15:04:00Z", + }, + { + name: "normal with offset", + current: "2021-11-07T15:04:05.999Z", + metricPeriodSec: 60, + metricEndTimeOffsetSec: 30, + metricCollectionTimeSec: 60, + expectedStartTime: "2021-11-07T15:02:00Z", + expectedEndTime: "2021-11-07T15:03:00Z", + }, +} + +func TestComputeQueryWindow(t *testing.T) { + for _, testData := range awsCloudwatchComputeQueryWindowTestData { + current, err := time.Parse(time.RFC3339Nano, testData.current) + if err != nil { + t.Errorf("unexpected input datetime format: %v", err) + } + startTime, endTime := computeQueryWindow(current, testData.metricPeriodSec, testData.metricEndTimeOffsetSec, testData.metricCollectionTimeSec) + assert.Equal(t, testData.expectedStartTime, startTime.UTC().Format(time.RFC3339Nano), "unexpected startTime", "name", testData.name) + assert.Equal(t, testData.expectedEndTime, endTime.UTC().Format(time.RFC3339Nano), "unexpected endTime", "name", testData.name) + } +} diff --git a/pkg/scalers/aws_kinesis_stream_scaler.go b/pkg/scalers/aws_kinesis_stream_scaler.go index 3aa82a20f21..448c70c170a 100644 --- a/pkg/scalers/aws_kinesis_stream_scaler.go +++ b/pkg/scalers/aws_kinesis_stream_scaler.go @@ -11,6 +11,7 @@ import ( "github.com/aws/aws-sdk-go/aws" "github.com/aws/aws-sdk-go/aws/session" "github.com/aws/aws-sdk-go/service/kinesis" + "github.com/aws/aws-sdk-go/service/kinesis/kinesisiface" v2beta2 "k8s.io/api/autoscaling/v2beta2" "k8s.io/apimachinery/pkg/api/resource" metav1 "k8s.io/apimachinery/pkg/apis/meta/v1" @@ -26,7 +27,8 @@ const ( ) type awsKinesisStreamScaler struct { - metadata *awsKinesisStreamMetadata + metadata *awsKinesisStreamMetadata + kinesisClient kinesisiface.KinesisAPI } type awsKinesisStreamMetadata struct { @@ -47,7 +49,8 @@ func NewAwsKinesisStreamScaler(config *ScalerConfig) (Scaler, error) { } return &awsKinesisStreamScaler{ - metadata: meta, + metadata: meta, + kinesisClient: createKinesisClient(meta), }, nil } @@ -89,6 +92,31 @@ func parseAwsKinesisStreamMetadata(config *ScalerConfig) (*awsKinesisStreamMetad return &meta, nil } +func createKinesisClient(metadata *awsKinesisStreamMetadata) *kinesis.Kinesis { + sess := session.Must(session.NewSession(&aws.Config{ + Region: aws.String(metadata.awsRegion), + })) + + var kinesisClinent *kinesis.Kinesis + if metadata.awsAuthorization.podIdentityOwner { + creds := credentials.NewStaticCredentials(metadata.awsAuthorization.awsAccessKeyID, metadata.awsAuthorization.awsSecretAccessKey, "") + + if metadata.awsAuthorization.awsRoleArn != "" { + creds = stscreds.NewCredentials(sess, metadata.awsAuthorization.awsRoleArn) + } + + kinesisClinent = kinesis.New(sess, &aws.Config{ + Region: aws.String(metadata.awsRegion), + Credentials: creds, + }) + } else { + kinesisClinent = kinesis.New(sess, &aws.Config{ + Region: aws.String(metadata.awsRegion), + }) + } + return kinesisClinent +} + // IsActive determines if we need to scale from zero func (s *awsKinesisStreamScaler) IsActive(ctx context.Context) (bool, error) { count, err := s.GetAwsKinesisOpenShardCount() @@ -143,29 +171,7 @@ func (s *awsKinesisStreamScaler) GetAwsKinesisOpenShardCount() (int64, error) { StreamName: &s.metadata.streamName, } - sess := session.Must(session.NewSession(&aws.Config{ - Region: aws.String(s.metadata.awsRegion), - })) - - var kinesisClinent *kinesis.Kinesis - if s.metadata.awsAuthorization.podIdentityOwner { - creds := credentials.NewStaticCredentials(s.metadata.awsAuthorization.awsAccessKeyID, s.metadata.awsAuthorization.awsSecretAccessKey, "") - - if s.metadata.awsAuthorization.awsRoleArn != "" { - creds = stscreds.NewCredentials(sess, s.metadata.awsAuthorization.awsRoleArn) - } - - kinesisClinent = kinesis.New(sess, &aws.Config{ - Region: aws.String(s.metadata.awsRegion), - Credentials: creds, - }) - } else { - kinesisClinent = kinesis.New(sess, &aws.Config{ - Region: aws.String(s.metadata.awsRegion), - }) - } - - output, err := kinesisClinent.DescribeStreamSummary(input) + output, err := s.kinesisClient.DescribeStreamSummary(input) if err != nil { return -1, err } diff --git a/pkg/scalers/aws_kinesis_stream_test.go b/pkg/scalers/aws_kinesis_stream_test.go index 84a02498f34..ce4f45e4d6d 100644 --- a/pkg/scalers/aws_kinesis_stream_test.go +++ b/pkg/scalers/aws_kinesis_stream_test.go @@ -2,8 +2,15 @@ package scalers import ( "context" + "errors" "reflect" "testing" + + "github.com/aws/aws-sdk-go/aws" + "github.com/aws/aws-sdk-go/service/kinesis" + "github.com/aws/aws-sdk-go/service/kinesis/kinesisiface" + "github.com/stretchr/testify/assert" + "k8s.io/apimachinery/pkg/labels" ) const ( @@ -12,6 +19,7 @@ const ( testAWSKinesisSecretAccessKey = "none" testAWSKinesisStreamName = "test" testAWSRegion = "eu-west-1" + testAWSKinesisErrorStream = "Error" ) var testAWSKinesisAuthentication = map[string]string{ @@ -34,6 +42,22 @@ type awsKinesisMetricIdentifier struct { name string } +type mockKinesis struct { + kinesisiface.KinesisAPI +} + +func (m *mockKinesis) DescribeStreamSummary(input *kinesis.DescribeStreamSummaryInput) (*kinesis.DescribeStreamSummaryOutput, error) { + if *input.StreamName == "Error" { + return nil, errors.New("some error") + } + + return &kinesis.DescribeStreamSummaryOutput{ + StreamDescriptionSummary: &kinesis.StreamDescriptionSummary{ + OpenShardCount: aws.Int64(100), + }, + }, nil +} + var testAWSKinesisMetadata = []parseAWSKinesisMetadataTestData{ { metadata: map[string]string{}, @@ -200,6 +224,11 @@ var awsKinesisMetricIdentifiers = []awsKinesisMetricIdentifier{ {&testAWSKinesisMetadata[1], 1, "s1-AWS-Kinesis-Stream-test"}, } +var awsKinesisGetMetricTestData = []*awsKinesisStreamMetadata{ + {streamName: "Good"}, + {streamName: testAWSKinesisErrorStream}, +} + func TestKinesisParseMetadata(t *testing.T) { for _, testData := range testAWSKinesisMetadata { result, err := parseAwsKinesisStreamMetadata(&ScalerConfig{TriggerMetadata: testData.metadata, ResolvedEnv: testAWSKinesisAuthentication, AuthParams: testData.authParams, ScalerIndex: testData.scalerIndex}) @@ -223,7 +252,7 @@ func TestAWSKinesisGetMetricSpecForScaling(t *testing.T) { if err != nil { t.Fatal("Could not parse metadata:", err) } - mockAWSKinesisStreamScaler := awsKinesisStreamScaler{meta} + mockAWSKinesisStreamScaler := awsKinesisStreamScaler{meta, &mockKinesis{}} metricSpec := mockAWSKinesisStreamScaler.GetMetricSpecForScaling(ctx) metricName := metricSpec[0].External.Metric.Name @@ -232,3 +261,17 @@ func TestAWSKinesisGetMetricSpecForScaling(t *testing.T) { } } } + +func TestAWSKinesisStreamScalerGetMetrics(t *testing.T) { + var selector labels.Selector + for _, meta := range awsKinesisGetMetricTestData { + scaler := awsKinesisStreamScaler{meta, &mockKinesis{}} + value, err := scaler.GetMetrics(context.Background(), "MetricName", selector) + switch meta.streamName { + case testAWSKinesisErrorStream: + assert.Error(t, err, "expect error because of kinesis api error") + default: + assert.EqualValues(t, int64(100.0), value[0].Value.Value()) + } + } +} diff --git a/pkg/scalers/aws_sqs_queue_scaler.go b/pkg/scalers/aws_sqs_queue_scaler.go index 08e0747cd2b..8aaf016f004 100644 --- a/pkg/scalers/aws_sqs_queue_scaler.go +++ b/pkg/scalers/aws_sqs_queue_scaler.go @@ -13,6 +13,7 @@ import ( "github.com/aws/aws-sdk-go/aws" "github.com/aws/aws-sdk-go/aws/session" "github.com/aws/aws-sdk-go/service/sqs" + "github.com/aws/aws-sdk-go/service/sqs/sqsiface" v2beta2 "k8s.io/api/autoscaling/v2beta2" "k8s.io/apimachinery/pkg/api/resource" metav1 "k8s.io/apimachinery/pkg/apis/meta/v1" @@ -36,7 +37,8 @@ var ( ) type awsSqsQueueScaler struct { - metadata *awsSqsQueueMetadata + metadata *awsSqsQueueMetadata + sqsClient sqsiface.SQSAPI } type awsSqsQueueMetadata struct { @@ -56,7 +58,8 @@ func NewAwsSqsQueueScaler(config *ScalerConfig) (Scaler, error) { } return &awsSqsQueueScaler{ - metadata: meta, + metadata: meta, + sqsClient: createSqsClient(meta), }, nil } @@ -111,6 +114,31 @@ func parseAwsSqsQueueMetadata(config *ScalerConfig) (*awsSqsQueueMetadata, error return &meta, nil } +func createSqsClient(metadata *awsSqsQueueMetadata) *sqs.SQS { + sess := session.Must(session.NewSession(&aws.Config{ + Region: aws.String(metadata.awsRegion), + })) + + var sqsClient *sqs.SQS + if metadata.awsAuthorization.podIdentityOwner { + creds := credentials.NewStaticCredentials(metadata.awsAuthorization.awsAccessKeyID, metadata.awsAuthorization.awsSecretAccessKey, "") + + if metadata.awsAuthorization.awsRoleArn != "" { + creds = stscreds.NewCredentials(sess, metadata.awsAuthorization.awsRoleArn) + } + + sqsClient = sqs.New(sess, &aws.Config{ + Region: aws.String(metadata.awsRegion), + Credentials: creds, + }) + } else { + sqsClient = sqs.New(sess, &aws.Config{ + Region: aws.String(metadata.awsRegion), + }) + } + return sqsClient +} + // IsActive determines if we need to scale from zero func (s *awsSqsQueueScaler) IsActive(ctx context.Context) (bool, error) { length, err := s.GetAwsSqsQueueLength() @@ -166,29 +194,7 @@ func (s *awsSqsQueueScaler) GetAwsSqsQueueLength() (int32, error) { QueueUrl: aws.String(s.metadata.queueURL), } - sess := session.Must(session.NewSession(&aws.Config{ - Region: aws.String(s.metadata.awsRegion), - })) - - var sqsClient *sqs.SQS - if s.metadata.awsAuthorization.podIdentityOwner { - creds := credentials.NewStaticCredentials(s.metadata.awsAuthorization.awsAccessKeyID, s.metadata.awsAuthorization.awsSecretAccessKey, "") - - if s.metadata.awsAuthorization.awsRoleArn != "" { - creds = stscreds.NewCredentials(sess, s.metadata.awsAuthorization.awsRoleArn) - } - - sqsClient = sqs.New(sess, &aws.Config{ - Region: aws.String(s.metadata.awsRegion), - Credentials: creds, - }) - } else { - sqsClient = sqs.New(sess, &aws.Config{ - Region: aws.String(s.metadata.awsRegion), - }) - } - - output, err := sqsClient.GetQueueAttributes(input) + output, err := s.sqsClient.GetQueueAttributes(input) if err != nil { return -1, err } diff --git a/pkg/scalers/aws_sqs_queue_test.go b/pkg/scalers/aws_sqs_queue_test.go index 726ee357bc9..4e629b17966 100644 --- a/pkg/scalers/aws_sqs_queue_test.go +++ b/pkg/scalers/aws_sqs_queue_test.go @@ -2,7 +2,14 @@ package scalers import ( "context" + "errors" "testing" + + "github.com/aws/aws-sdk-go/aws" + "github.com/aws/aws-sdk-go/service/sqs" + "github.com/aws/aws-sdk-go/service/sqs/sqsiface" + "github.com/stretchr/testify/assert" + "k8s.io/apimachinery/pkg/labels" ) const ( @@ -13,6 +20,9 @@ const ( testAWSSQSProperQueueURL = "https://sqs.eu-west-1.amazonaws.com/account_id/DeleteArtifactQ" testAWSSQSImproperQueueURL1 = "https://sqs.eu-west-1.amazonaws.com/account_id" testAWSSQSImproperQueueURL2 = "https://sqs.eu-west-1.amazonaws.com" + + testAWSSQSErrorQueueURL = "https://sqs.eu-west-1.amazonaws.com/account_id/Error" + testAWSSQSBadDataQueueURL = "https://sqs.eu-west-1.amazonaws.com/account_id/BadData" ) var testAWSSQSAuthentication = map[string]string{ @@ -33,6 +43,31 @@ type awsSQSMetricIdentifier struct { name string } +type mockSqs struct { + sqsiface.SQSAPI +} + +func (m *mockSqs) GetQueueAttributes(input *sqs.GetQueueAttributesInput) (*sqs.GetQueueAttributesOutput, error) { + switch *input.QueueUrl { + case testAWSSQSErrorQueueURL: + return nil, errors.New("some error") + case testAWSSQSBadDataQueueURL: + return &sqs.GetQueueAttributesOutput{ + Attributes: map[string]*string{ + "ApproximateNumberOfMessages": aws.String("NotInt"), + "ApproximateNumberOfMessagesNotVisible": aws.String("NotInt"), + }, + }, nil + } + + return &sqs.GetQueueAttributesOutput{ + Attributes: map[string]*string{ + "ApproximateNumberOfMessages": aws.String("200"), + "ApproximateNumberOfMessagesNotVisible": aws.String("100"), + }, + }, nil +} + var testAWSSQSMetadata = []parseAWSSQSMetadataTestData{ {map[string]string{}, testAWSSQSAuthentication, @@ -137,6 +172,12 @@ var awsSQSMetricIdentifiers = []awsSQSMetricIdentifier{ {&testAWSSQSMetadata[1], 1, "s1-AWS-SQS-Queue-DeleteArtifactQ"}, } +var awsSQSGetMetricTestData = []*awsSqsQueueMetadata{ + {queueURL: testAWSSQSProperQueueURL}, + {queueURL: testAWSSQSErrorQueueURL}, + {queueURL: testAWSSQSBadDataQueueURL}, +} + func TestSQSParseMetadata(t *testing.T) { for _, testData := range testAWSSQSMetadata { _, err := parseAwsSqsQueueMetadata(&ScalerConfig{TriggerMetadata: testData.metadata, ResolvedEnv: testAWSSQSAuthentication, AuthParams: testData.authParams}) @@ -156,7 +197,7 @@ func TestAWSSQSGetMetricSpecForScaling(t *testing.T) { if err != nil { t.Fatal("Could not parse metadata:", err) } - mockAWSSQSScaler := awsSqsQueueScaler{meta} + mockAWSSQSScaler := awsSqsQueueScaler{meta, &mockSqs{}} metricSpec := mockAWSSQSScaler.GetMetricSpecForScaling(ctx) metricName := metricSpec[0].External.Metric.Name @@ -165,3 +206,19 @@ func TestAWSSQSGetMetricSpecForScaling(t *testing.T) { } } } + +func TestAWSSQSScalerGetMetrics(t *testing.T) { + var selector labels.Selector + for _, meta := range awsSQSGetMetricTestData { + scaler := awsSqsQueueScaler{meta, &mockSqs{}} + value, err := scaler.GetMetrics(context.Background(), "MetricName", selector) + switch meta.queueURL { + case testAWSSQSErrorQueueURL: + assert.Error(t, err, "expect error because of sqs api error") + case testAWSSQSBadDataQueueURL: + assert.Error(t, err, "expect error because of bad data return from sqs") + default: + assert.EqualValues(t, int64(300.0), value[0].Value.Value()) + } + } +}