Skip to content

Commit

Permalink
update aws related scalers to reuse the aws clients (#2255)
Browse files Browse the repository at this point in the history
Signed-off-by: Xiayang Wu <xwu@rippling.com>
  • Loading branch information
fivesheep committed Nov 8, 2021
1 parent 89daa2f commit ca8d5ea
Show file tree
Hide file tree
Showing 7 changed files with 351 additions and 83 deletions.
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,7 @@
- Improve metric name creation to be unique using scaler index inside the scaler ([#2161](https://github.com/kedacore/keda/pull/2161))
- Improve error message if `IdleReplicaCount` are equal to `MinReplicaCount` to be the same as the check ([#2212](https://github.com/kedacore/keda/pull/2212))
- Improve Cloudwatch Scaler metric exporting logic ([#2243](https://github.com/kedacore/keda/pull/2243))
- Refactor aws related scalers to reuse the aws clients instead of creating a new one for every GetMetrics call([#2255](https://github.com/kedacore/keda/pull/2255))

### Breaking Changes

Expand Down
64 changes: 38 additions & 26 deletions pkg/scalers/aws_cloudwatch_scaler.go
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@ import (
"github.com/aws/aws-sdk-go/aws/credentials/stscreds"
"github.com/aws/aws-sdk-go/aws/session"
"github.com/aws/aws-sdk-go/service/cloudwatch"
"github.com/aws/aws-sdk-go/service/cloudwatch/cloudwatchiface"
"k8s.io/api/autoscaling/v2beta2"
"k8s.io/apimachinery/pkg/api/resource"
metav1 "k8s.io/apimachinery/pkg/apis/meta/v1"
Expand All @@ -32,6 +33,7 @@ const (

type awsCloudwatchScaler struct {
metadata *awsCloudwatchMetadata
cwClient cloudwatchiface.CloudWatchAPI
}

type awsCloudwatchMetadata struct {
Expand Down Expand Up @@ -67,6 +69,7 @@ func NewAwsCloudwatchScaler(config *ScalerConfig) (Scaler, error) {

return &awsCloudwatchScaler{
metadata: meta,
cwClient: createCloudwatchClient(meta),
}, nil
}

Expand Down Expand Up @@ -102,6 +105,32 @@ func getFloatMetadataValue(metadata map[string]string, key string, required bool
return defaultValue, nil
}

func createCloudwatchClient(metadata *awsCloudwatchMetadata) *cloudwatch.CloudWatch {
sess := session.Must(session.NewSession(&aws.Config{
Region: aws.String(metadata.awsRegion),
}))

var cloudwatchClient *cloudwatch.CloudWatch
if metadata.awsAuthorization.podIdentityOwner {
creds := credentials.NewStaticCredentials(metadata.awsAuthorization.awsAccessKeyID, metadata.awsAuthorization.awsSecretAccessKey, "")

if metadata.awsAuthorization.awsRoleArn != "" {
creds = stscreds.NewCredentials(sess, metadata.awsAuthorization.awsRoleArn)
}

cloudwatchClient = cloudwatch.New(sess, &aws.Config{
Region: aws.String(metadata.awsRegion),
Credentials: creds,
})
} else {
cloudwatchClient = cloudwatch.New(sess, &aws.Config{
Region: aws.String(metadata.awsRegion),
})
}

return cloudwatchClient
}

func parseAwsCloudwatchMetadata(config *ScalerConfig) (*awsCloudwatchMetadata, error) {
var err error
meta := awsCloudwatchMetadata{}
Expand Down Expand Up @@ -236,6 +265,12 @@ func checkMetricStatPeriod(period int64) error {
return nil
}

func computeQueryWindow(current time.Time, metricPeriodSec, metricEndTimeOffsetSec, metricCollectionTimeSec int64) (startTime, endTime time.Time) {
endTime = current.Add(time.Second * -1 * time.Duration(metricEndTimeOffsetSec)).Truncate(time.Duration(metricPeriodSec) * time.Second)
startTime = endTime.Add(time.Second * -1 * time.Duration(metricCollectionTimeSec))
return
}

func (c *awsCloudwatchScaler) GetMetrics(ctx context.Context, metricName string, metricSelector labels.Selector) ([]external_metrics.ExternalMetricValue, error) {
metricValue, err := c.GetCloudwatchMetrics()

Expand Down Expand Up @@ -283,28 +318,6 @@ func (c *awsCloudwatchScaler) Close(context.Context) error {
}

func (c *awsCloudwatchScaler) GetCloudwatchMetrics() (float64, error) {
sess := session.Must(session.NewSession(&aws.Config{
Region: aws.String(c.metadata.awsRegion),
}))

var cloudwatchClient *cloudwatch.CloudWatch
if c.metadata.awsAuthorization.podIdentityOwner {
creds := credentials.NewStaticCredentials(c.metadata.awsAuthorization.awsAccessKeyID, c.metadata.awsAuthorization.awsSecretAccessKey, "")

if c.metadata.awsAuthorization.awsRoleArn != "" {
creds = stscreds.NewCredentials(sess, c.metadata.awsAuthorization.awsRoleArn)
}

cloudwatchClient = cloudwatch.New(sess, &aws.Config{
Region: aws.String(c.metadata.awsRegion),
Credentials: creds,
})
} else {
cloudwatchClient = cloudwatch.New(sess, &aws.Config{
Region: aws.String(c.metadata.awsRegion),
})
}

dimensions := []*cloudwatch.Dimension{}
for i := range c.metadata.dimensionName {
dimensions = append(dimensions, &cloudwatch.Dimension{
Expand All @@ -313,8 +326,7 @@ func (c *awsCloudwatchScaler) GetCloudwatchMetrics() (float64, error) {
})
}

endTime := time.Now().Add(time.Second * -1 * time.Duration(c.metadata.metricEndTimeOffset)).Truncate(time.Duration(c.metadata.metricStatPeriod) * time.Second)
startTime := endTime.Add(time.Second * -1 * time.Duration(c.metadata.metricCollectionTime))
startTime, endTime := computeQueryWindow(time.Now(), c.metadata.metricStatPeriod, c.metadata.metricEndTimeOffset, c.metadata.metricCollectionTime)

var metricUnit *string
if c.metadata.metricUnit != "" {
Expand Down Expand Up @@ -343,7 +355,7 @@ func (c *awsCloudwatchScaler) GetCloudwatchMetrics() (float64, error) {
},
}

output, err := cloudwatchClient.GetMetricData(&input)
output, err := c.cwClient.GetMetricData(&input)

if err != nil {
cloudwatchLog.Error(err, "Failed to get output")
Expand All @@ -352,7 +364,7 @@ func (c *awsCloudwatchScaler) GetCloudwatchMetrics() (float64, error) {

cloudwatchLog.V(1).Info("Received Metric Data", "data", output)
var metricValue float64
if output.MetricDataResults[0].Values != nil {
if len(output.MetricDataResults) > 0 && len(output.MetricDataResults[0].Values) > 0 {
metricValue = *output.MetricDataResults[0].Values[0]
} else {
return -1, fmt.Errorf("metric data not received")
Expand Down
153 changes: 148 additions & 5 deletions pkg/scalers/aws_cloudwatch_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,13 +2,24 @@ package scalers

import (
"context"
"errors"
"testing"
)
"time"

var testAWSCloudwatchRoleArn = "none"
"github.com/aws/aws-sdk-go/aws"
"github.com/aws/aws-sdk-go/service/cloudwatch"
"github.com/aws/aws-sdk-go/service/cloudwatch/cloudwatchiface"
"github.com/stretchr/testify/assert"
"k8s.io/apimachinery/pkg/labels"
)

var testAWSCloudwatchAccessKeyID = "none"
var testAWSCloudwatchSecretAccessKey = "none"
const (
testAWSCloudwatchRoleArn = "none"
testAWSCloudwatchAccessKeyID = "none"
testAWSCloudwatchSecretAccessKey = "none"
testAWSCloudwatchErrorMetric = "Error"
testAWSCloudwatchNoValueMetric = "NoValue"
)

var testAWSCloudwatchResolvedEnv = map[string]string{
"AWS_ACCESS_KEY": "none",
Expand Down Expand Up @@ -314,6 +325,79 @@ var awsCloudwatchMetricIdentifiers = []awsCloudwatchMetricIdentifier{
{&testAWSCloudwatchMetadata[1], 3, "s3-aws-cloudwatch-AWS-SQS-QueueName-keda"},
}

var awsCloudwatchGetMetricTestData = []awsCloudwatchMetadata{
{
namespace: "Custom",
metricsName: "HasData",
dimensionName: []string{"DIM"},
dimensionValue: []string{"DIM_VALUE"},
targetMetricValue: 100,
minMetricValue: 0,
metricCollectionTime: 60,
metricStat: "Average",
metricUnit: "SampleCount",
metricStatPeriod: 60,
metricEndTimeOffset: 60,
awsRegion: "us-west-2",
awsAuthorization: awsAuthorizationMetadata{podIdentityOwner: false},
scalerIndex: 0,
},
{
namespace: "Custom",
metricsName: "HasDataNoUnit",
dimensionName: []string{"DIM"},
dimensionValue: []string{"DIM_VALUE"},
targetMetricValue: 100,
minMetricValue: 0,
metricCollectionTime: 60,
metricStat: "Average",
metricUnit: "",
metricStatPeriod: 60,
metricEndTimeOffset: 60,
awsRegion: "us-west-2",
awsAuthorization: awsAuthorizationMetadata{podIdentityOwner: false},
scalerIndex: 0,
},
{
namespace: "Custom",
metricsName: "Error",
dimensionName: []string{"DIM"},
dimensionValue: []string{"DIM_VALUE"},
targetMetricValue: 100,
minMetricValue: 0,
metricCollectionTime: 60,
metricStat: "Average",
metricUnit: "",
metricStatPeriod: 60,
metricEndTimeOffset: 60,
awsRegion: "us-west-2",
awsAuthorization: awsAuthorizationMetadata{podIdentityOwner: false},
scalerIndex: 0,
},
}

type mockCloudwatch struct {
cloudwatchiface.CloudWatchAPI
}

func (m *mockCloudwatch) GetMetricData(input *cloudwatch.GetMetricDataInput) (*cloudwatch.GetMetricDataOutput, error) {
switch *input.MetricDataQueries[0].MetricStat.Metric.MetricName {
case testAWSCloudwatchErrorMetric:
return nil, errors.New("error")
case testAWSCloudwatchNoValueMetric:
return &cloudwatch.GetMetricDataOutput{
MetricDataResults: []*cloudwatch.MetricDataResult{},
}, nil
}
return &cloudwatch.GetMetricDataOutput{
MetricDataResults: []*cloudwatch.MetricDataResult{
{
Values: []*float64{aws.Float64(10)},
},
},
}, nil
}

func TestCloudwatchParseMetadata(t *testing.T) {
for _, testData := range testAWSCloudwatchMetadata {
_, err := parseAwsCloudwatchMetadata(&ScalerConfig{TriggerMetadata: testData.metadata, ResolvedEnv: testAWSCloudwatchResolvedEnv, AuthParams: testData.authParams})
Expand All @@ -333,7 +417,7 @@ func TestAWSCloudwatchGetMetricSpecForScaling(t *testing.T) {
if err != nil {
t.Fatal("Could not parse metadata:", err)
}
mockAWSCloudwatchScaler := awsCloudwatchScaler{meta}
mockAWSCloudwatchScaler := awsCloudwatchScaler{meta, &mockCloudwatch{}}

metricSpec := mockAWSCloudwatchScaler.GetMetricSpecForScaling(ctx)
metricName := metricSpec[0].External.Metric.Name
Expand All @@ -342,3 +426,62 @@ func TestAWSCloudwatchGetMetricSpecForScaling(t *testing.T) {
}
}
}

func TestAWSCloudwatchScalerGetMetrics(t *testing.T) {
var selector labels.Selector
for _, meta := range awsCloudwatchGetMetricTestData {
mockAWSCloudwatchScaler := awsCloudwatchScaler{&meta, &mockCloudwatch{}}
value, err := mockAWSCloudwatchScaler.GetMetrics(context.Background(), meta.metricsName, selector)
switch meta.metricsName {
case testAWSCloudwatchErrorMetric:
assert.Error(t, err, "expect error because of cloudwatch api error")
case testAWSCloudwatchNoValueMetric:
assert.Error(t, err, "expect error because of no data return from cloudwatch")
default:
assert.EqualValues(t, int64(10.0), value[0].Value.Value())
}
}
}

type computeQueryWindowTestArgs struct {
name string
current string
metricPeriodSec int64
metricEndTimeOffsetSec int64
metricCollectionTimeSec int64
expectedStartTime string
expectedEndTime string
}

var awsCloudwatchComputeQueryWindowTestData = []computeQueryWindowTestArgs{
{
name: "normal",
current: "2021-11-07T15:04:05.999Z",
metricPeriodSec: 60,
metricEndTimeOffsetSec: 0,
metricCollectionTimeSec: 60,
expectedStartTime: "2021-11-07T15:03:00Z",
expectedEndTime: "2021-11-07T15:04:00Z",
},
{
name: "normal with offset",
current: "2021-11-07T15:04:05.999Z",
metricPeriodSec: 60,
metricEndTimeOffsetSec: 30,
metricCollectionTimeSec: 60,
expectedStartTime: "2021-11-07T15:02:00Z",
expectedEndTime: "2021-11-07T15:03:00Z",
},
}

func TestComputeQueryWindow(t *testing.T) {
for _, testData := range awsCloudwatchComputeQueryWindowTestData {
current, err := time.Parse(time.RFC3339Nano, testData.current)
if err != nil {
t.Errorf("unexpected input datetime format: %v", err)
}
startTime, endTime := computeQueryWindow(current, testData.metricPeriodSec, testData.metricEndTimeOffsetSec, testData.metricCollectionTimeSec)
assert.Equal(t, testData.expectedStartTime, startTime.UTC().Format(time.RFC3339Nano), "unexpected startTime", "name", testData.name)
assert.Equal(t, testData.expectedEndTime, endTime.UTC().Format(time.RFC3339Nano), "unexpected endTime", "name", testData.name)
}
}
Loading

0 comments on commit ca8d5ea

Please sign in to comment.