diff --git a/CHANGELOG.md b/CHANGELOG.md index 35aa92e8896..e252892b6e9 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -22,7 +22,6 @@ ### New -- TODO ([#XXX](https://github.com/kedacore/keda/pull/XXX)) - ScaledJob: introduce MultipleScalersCalculation ([#2016](https://github.com/kedacore/keda/pull/2016)) - Add Graphite Scaler ([#1628](https://github.com/kedacore/keda/pull/2092)) @@ -53,6 +52,7 @@ ### Other +- Ensure that `context.Context` values are passed down the stack from all scaler gRPC handler implementation to scaler implementation code ([#2202](https://github.com/kedacore/keda/pull/2202)) - Migrate to Kubebuilder v3 ([#2082](https://github.com/kedacore/keda/pull/2082)) - API path has been changed: `github.com/kedacore/keda/v2/api/v1alpha1` -> `github.com/kedacore/keda/v2/apis/keda/v1alpha1` - Use Patch to set FallbackCondition on ScaledObject.Status ([#2037](https://github.com/kedacore/keda/pull/2037)) diff --git a/apis/keda/v1alpha1/zz_generated.deepcopy.go b/apis/keda/v1alpha1/zz_generated.deepcopy.go index d6e286bf9e8..53d00a01491 100644 --- a/apis/keda/v1alpha1/zz_generated.deepcopy.go +++ b/apis/keda/v1alpha1/zz_generated.deepcopy.go @@ -1,4 +1,3 @@ -//go:build !ignore_autogenerated // +build !ignore_autogenerated /* diff --git a/controllers/keda/hpa.go b/controllers/keda/hpa.go index d076d4d290c..3342f27b00e 100644 --- a/controllers/keda/hpa.go +++ b/controllers/keda/hpa.go @@ -40,10 +40,10 @@ const ( ) // createAndDeployNewHPA creates and deploy HPA in the cluster for specified ScaledObject -func (r *ScaledObjectReconciler) createAndDeployNewHPA(logger logr.Logger, scaledObject *kedav1alpha1.ScaledObject, gvkr *kedav1alpha1.GroupVersionKindResource) error { +func (r *ScaledObjectReconciler) createAndDeployNewHPA(ctx context.Context, logger logr.Logger, scaledObject *kedav1alpha1.ScaledObject, gvkr *kedav1alpha1.GroupVersionKindResource) error { hpaName := getHPAName(scaledObject) logger.Info("Creating a new HPA", "HPA.Namespace", scaledObject.Namespace, "HPA.Name", hpaName) - hpa, err := r.newHPAForScaledObject(logger, scaledObject, gvkr) + hpa, err := r.newHPAForScaledObject(ctx, logger, scaledObject, gvkr) if err != nil { logger.Error(err, "Failed to create new HPA resource", "HPA.Namespace", scaledObject.Namespace, "HPA.Name", hpaName) return err @@ -59,8 +59,8 @@ func (r *ScaledObjectReconciler) createAndDeployNewHPA(logger logr.Logger, scale } // newHPAForScaledObject returns HPA as it is specified in ScaledObject -func (r *ScaledObjectReconciler) newHPAForScaledObject(logger logr.Logger, scaledObject *kedav1alpha1.ScaledObject, gvkr *kedav1alpha1.GroupVersionKindResource) (*autoscalingv2beta2.HorizontalPodAutoscaler, error) { - scaledObjectMetricSpecs, err := r.getScaledObjectMetricSpecs(logger, scaledObject) +func (r *ScaledObjectReconciler) newHPAForScaledObject(ctx context.Context, logger logr.Logger, scaledObject *kedav1alpha1.ScaledObject, gvkr *kedav1alpha1.GroupVersionKindResource) (*autoscalingv2beta2.HorizontalPodAutoscaler, error) { + scaledObjectMetricSpecs, err := r.getScaledObjectMetricSpecs(ctx, logger, scaledObject) if err != nil { return nil, err } @@ -120,8 +120,8 @@ func (r *ScaledObjectReconciler) newHPAForScaledObject(logger logr.Logger, scale } // updateHPAIfNeeded checks whether update of HPA is needed -func (r *ScaledObjectReconciler) updateHPAIfNeeded(logger logr.Logger, scaledObject *kedav1alpha1.ScaledObject, foundHpa *autoscalingv2beta2.HorizontalPodAutoscaler, gvkr *kedav1alpha1.GroupVersionKindResource) error { - hpa, err := r.newHPAForScaledObject(logger, scaledObject, gvkr) +func (r *ScaledObjectReconciler) updateHPAIfNeeded(ctx context.Context, logger logr.Logger, scaledObject *kedav1alpha1.ScaledObject, foundHpa *autoscalingv2beta2.HorizontalPodAutoscaler, gvkr *kedav1alpha1.GroupVersionKindResource) error { + hpa, err := r.newHPAForScaledObject(ctx, logger, scaledObject, gvkr) if err != nil { logger.Error(err, "Failed to create new HPA resource", "HPA.Namespace", scaledObject.Namespace, "HPA.Name", getHPAName(scaledObject)) return err @@ -155,19 +155,19 @@ func (r *ScaledObjectReconciler) updateHPAIfNeeded(logger logr.Logger, scaledObj } // getScaledObjectMetricSpecs returns MetricSpec for HPA, generater from Triggers defitinion in ScaledObject -func (r *ScaledObjectReconciler) getScaledObjectMetricSpecs(logger logr.Logger, scaledObject *kedav1alpha1.ScaledObject) ([]autoscalingv2beta2.MetricSpec, error) { +func (r *ScaledObjectReconciler) getScaledObjectMetricSpecs(ctx context.Context, logger logr.Logger, scaledObject *kedav1alpha1.ScaledObject) ([]autoscalingv2beta2.MetricSpec, error) { var scaledObjectMetricSpecs []autoscalingv2beta2.MetricSpec var externalMetricNames []string var resourceMetricNames []string - scalers, err := r.scaleHandler.GetScalers(scaledObject) + scalers, err := r.scaleHandler.GetScalers(ctx, scaledObject) if err != nil { logger.Error(err, "Error getting scalers") return nil, err } for _, scaler := range scalers { - metricSpecs := scaler.GetMetricSpecForScaling() + metricSpecs := scaler.GetMetricSpecForScaling(ctx) for _, metricSpec := range metricSpecs { if metricSpec.Resource != nil { @@ -187,7 +187,7 @@ func (r *ScaledObjectReconciler) getScaledObjectMetricSpecs(logger logr.Logger, } } scaledObjectMetricSpecs = append(scaledObjectMetricSpecs, metricSpecs...) - scaler.Close() + scaler.Close(ctx) } // sort metrics in ScaledObject, this way we always check the same resource in Reconcile loop and we can prevent unnecessary HPA updates, diff --git a/controllers/keda/hpa_test.go b/controllers/keda/hpa_test.go index 4e266f0f842..51a092f4863 100644 --- a/controllers/keda/hpa_test.go +++ b/controllers/keda/hpa_test.go @@ -17,6 +17,8 @@ limitations under the License. package keda import ( + "context" + "github.com/go-logr/logr" "github.com/golang/mock/gomock" . "github.com/onsi/ginkgo" @@ -75,7 +77,7 @@ var _ = Describe("hpa", func() { capturedScaledObject = *scaledObject }) - _, err := reconciler.getScaledObjectMetricSpecs(logger, scaledObject) + _, err := reconciler.getScaledObjectMetricSpecs(context.Background(), logger, scaledObject) Expect(err).ToNot(HaveOccurred()) Expect(capturedScaledObject.Status.Health).To(BeEmpty()) @@ -102,7 +104,7 @@ var _ = Describe("hpa", func() { capturedScaledObject = *scaledObject }) - _, err := reconciler.getScaledObjectMetricSpecs(logger, scaledObject) + _, err := reconciler.getScaledObjectMetricSpecs(context.Background(), logger, scaledObject) expectedHealth := make(map[string]v1alpha1.HealthStatus) expectedHealth["some metric name"] = v1alpha1.HealthStatus{ @@ -136,9 +138,10 @@ func setupTest(health map[string]v1alpha1.HealthStatus, scaler *mock_scalers.Moc }, } metricSpecs := []v2beta2.MetricSpec{metricSpec} - scaler.EXPECT().GetMetricSpecForScaling().Return(metricSpecs) - scaler.EXPECT().Close() - scaleHandler.EXPECT().GetScalers(gomock.Eq(scaledObject)).Return(scalers, nil) + ctx := context.Background() + scaler.EXPECT().GetMetricSpecForScaling(ctx).Return(metricSpecs) + scaler.EXPECT().Close(ctx) + scaleHandler.EXPECT().GetScalers(context.Background(), gomock.Eq(scaledObject)).Return(scalers, nil) return scaledObject } diff --git a/controllers/keda/scaledjob_controller.go b/controllers/keda/scaledjob_controller.go index 37b546fd7a9..11544f8373b 100644 --- a/controllers/keda/scaledjob_controller.go +++ b/controllers/keda/scaledjob_controller.go @@ -110,7 +110,7 @@ func (r *ScaledJobReconciler) Reconcile(ctx context.Context, req ctrl.Request) ( reqLogger.Error(err, "scaledJob.spec.jobTargetRef not found") return ctrl.Result{}, err } - msg, err := r.reconcileScaledJob(reqLogger, scaledJob) + msg, err := r.reconcileScaledJob(ctx, reqLogger, scaledJob) conditions := scaledJob.Status.Conditions.DeepCopy() if err != nil { reqLogger.Error(err, msg) @@ -133,14 +133,14 @@ func (r *ScaledJobReconciler) Reconcile(ctx context.Context, req ctrl.Request) ( } // reconcileScaledJob implements reconciler logic for K8s Jobs based ScaledJob -func (r *ScaledJobReconciler) reconcileScaledJob(logger logr.Logger, scaledJob *kedav1alpha1.ScaledJob) (string, error) { +func (r *ScaledJobReconciler) reconcileScaledJob(ctx context.Context, logger logr.Logger, scaledJob *kedav1alpha1.ScaledJob) (string, error) { msg, err := r.deletePreviousVersionScaleJobs(logger, scaledJob) if err != nil { return msg, err } // Check ScaledJob is Ready or not - _, err = r.scaleHandler.GetScalers(scaledJob) + _, err = r.scaleHandler.GetScalers(ctx, scaledJob) if err != nil { logger.Error(err, "Error getting scalers") return "Failed to ensure ScaledJob is correctly created", err diff --git a/controllers/keda/scaledobject_controller.go b/controllers/keda/scaledobject_controller.go index db92b0adfb0..50eb0d3c8e1 100644 --- a/controllers/keda/scaledobject_controller.go +++ b/controllers/keda/scaledobject_controller.go @@ -175,7 +175,7 @@ func (r *ScaledObjectReconciler) Reconcile(ctx context.Context, req ctrl.Request } // reconcile ScaledObject and set status appropriately - msg, err := r.reconcileScaledObject(reqLogger, scaledObject) + msg, err := r.reconcileScaledObject(ctx, reqLogger, scaledObject) conditions := scaledObject.Status.Conditions.DeepCopy() if err != nil { reqLogger.Error(err, msg) @@ -199,7 +199,7 @@ func (r *ScaledObjectReconciler) Reconcile(ctx context.Context, req ctrl.Request } // reconcileScaledObject implements reconciler logic for ScaleObject -func (r *ScaledObjectReconciler) reconcileScaledObject(logger logr.Logger, scaledObject *kedav1alpha1.ScaledObject) (string, error) { +func (r *ScaledObjectReconciler) reconcileScaledObject(ctx context.Context, logger logr.Logger, scaledObject *kedav1alpha1.ScaledObject) (string, error) { // Check scale target Name is specified if scaledObject.Spec.ScaleTargetRef.Name == "" { err := fmt.Errorf("ScaledObject.spec.scaleTargetRef.name is missing") @@ -224,7 +224,7 @@ func (r *ScaledObjectReconciler) reconcileScaledObject(logger logr.Logger, scale } // Create a new HPA or update existing one according to ScaledObject - newHPACreated, err := r.ensureHPAForScaledObjectExists(logger, scaledObject, &gvkr) + newHPACreated, err := r.ensureHPAForScaledObjectExists(ctx, logger, scaledObject, &gvkr) if err != nil { return "Failed to ensure HPA is correctly created for ScaledObject", err } @@ -349,14 +349,14 @@ func (r *ScaledObjectReconciler) checkReplicaCountBoundsAreValid(scaledObject *k } // ensureHPAForScaledObjectExists ensures that in cluster exist up-to-date HPA for specified ScaledObject, returns true if a new HPA was created -func (r *ScaledObjectReconciler) ensureHPAForScaledObjectExists(logger logr.Logger, scaledObject *kedav1alpha1.ScaledObject, gvkr *kedav1alpha1.GroupVersionKindResource) (bool, error) { +func (r *ScaledObjectReconciler) ensureHPAForScaledObjectExists(ctx context.Context, logger logr.Logger, scaledObject *kedav1alpha1.ScaledObject, gvkr *kedav1alpha1.GroupVersionKindResource) (bool, error) { hpaName := getHPAName(scaledObject) foundHpa := &autoscalingv2beta2.HorizontalPodAutoscaler{} // Check if HPA for this ScaledObject already exists err := r.Client.Get(context.TODO(), types.NamespacedName{Name: hpaName, Namespace: scaledObject.Namespace}, foundHpa) if err != nil && errors.IsNotFound(err) { // HPA wasn't found -> let's create a new one - err = r.createAndDeployNewHPA(logger, scaledObject, gvkr) + err = r.createAndDeployNewHPA(ctx, logger, scaledObject, gvkr) if err != nil { return false, err } @@ -372,7 +372,7 @@ func (r *ScaledObjectReconciler) ensureHPAForScaledObjectExists(logger logr.Logg } // HPA was found -> let's check if we need to update it - err = r.updateHPAIfNeeded(logger, scaledObject, foundHpa, gvkr) + err = r.updateHPAIfNeeded(ctx, logger, scaledObject, foundHpa, gvkr) if err != nil { logger.Error(err, "Failed to check HPA for possible update") return false, err diff --git a/controllers/keda/scaledobject_controller_test.go b/controllers/keda/scaledobject_controller_test.go index d3ca4e5da84..6b18aec6744 100644 --- a/controllers/keda/scaledobject_controller_test.go +++ b/controllers/keda/scaledobject_controller_test.go @@ -100,7 +100,7 @@ var _ = Describe("ScaledObjectController", func() { } testScalers = append(testScalers, s) - for _, metricSpec := range s.GetMetricSpecForScaling() { + for _, metricSpec := range s.GetMetricSpecForScaling(context.Background()) { if metricSpec.External != nil { expectedExternalMetricNames = append(expectedExternalMetricNames, metricSpec.External.Metric.Name) } @@ -108,12 +108,12 @@ var _ = Describe("ScaledObjectController", func() { } // Set up expectations - mockScaleHandler.EXPECT().GetScalers(uniquelyNamedScaledObject).Return(testScalers, nil) + mockScaleHandler.EXPECT().GetScalers(context.Background(), uniquelyNamedScaledObject).Return(testScalers, nil) mockClient.EXPECT().Status().Return(mockStatusWriter) mockStatusWriter.EXPECT().Patch(gomock.Any(), gomock.Any(), gomock.Any()) // Call function to be tested - metricSpecs, err := metricNameTestReconciler.getScaledObjectMetricSpecs(testLogger, uniquelyNamedScaledObject) + metricSpecs, err := metricNameTestReconciler.getScaledObjectMetricSpecs(context.Background(), testLogger, uniquelyNamedScaledObject) // Test that the status was updated with metric names Ω(uniquelyNamedScaledObject.Status.ExternalMetricNames).Should(Equal(expectedExternalMetricNames)) @@ -139,19 +139,19 @@ var _ = Describe("ScaledObjectController", func() { if err != nil { Fail(err.Error()) } - for _, metricSpec := range s.GetMetricSpecForScaling() { + for _, metricSpec := range s.GetMetricSpecForScaling(context.Background()) { if metricSpec.External != nil { expectedExternalMetricNames = append(expectedExternalMetricNames, metricSpec.External.Metric.Name) } } // Set up expectations - mockScaleHandler.EXPECT().GetScalers(uniquelyNamedScaledObject).Return([]scalers.Scaler{s}, nil) + mockScaleHandler.EXPECT().GetScalers(context.Background(), uniquelyNamedScaledObject).Return([]scalers.Scaler{s}, nil) mockClient.EXPECT().Status().Return(mockStatusWriter) mockStatusWriter.EXPECT().Patch(gomock.Any(), gomock.Any(), gomock.Any()) // Call function to be tested - metricSpecs, err := metricNameTestReconciler.getScaledObjectMetricSpecs(testLogger, uniquelyNamedScaledObject) + metricSpecs, err := metricNameTestReconciler.getScaledObjectMetricSpecs(context.Background(), testLogger, uniquelyNamedScaledObject) // Test that the status was updated Ω(uniquelyNamedScaledObject.Status.ExternalMetricNames).Should(Equal(expectedExternalMetricNames)) @@ -186,10 +186,10 @@ var _ = Describe("ScaledObjectController", func() { } // Set up expectations - mockScaleHandler.EXPECT().GetScalers(duplicateNamedScaledObject).Return(testScalers, nil) + mockScaleHandler.EXPECT().GetScalers(context.Background(), duplicateNamedScaledObject).Return(testScalers, nil) // Call function tobe tested - metricSpecs, err := metricNameTestReconciler.getScaledObjectMetricSpecs(testLogger, duplicateNamedScaledObject) + metricSpecs, err := metricNameTestReconciler.getScaledObjectMetricSpecs(context.Background(), testLogger, duplicateNamedScaledObject) // Test that the status was not updated Ω(duplicateNamedScaledObject.Status.ExternalMetricNames).Should(BeNil()) diff --git a/pkg/mock/mock_client/mock_interfaces.go b/pkg/mock/mock_client/mock_interfaces.go index 16e08888025..b33495dcdf2 100644 --- a/pkg/mock/mock_client/mock_interfaces.go +++ b/pkg/mock/mock_client/mock_interfaces.go @@ -1,5 +1,5 @@ // Code generated by MockGen. DO NOT EDIT. -// Source: /go/pkg/mod/sigs.k8s.io/controller-runtime@v0.10.0/pkg/client/interfaces.go +// Source: /go/pkg/mod/sigs.k8s.io/controller-runtime@v0.10.2/pkg/client/interfaces.go // Package mock_client is a generated GoMock package. package mock_client diff --git a/pkg/mock/mock_scale/mock_interfaces.go b/pkg/mock/mock_scale/mock_interfaces.go index cdfe5255add..4cb4a636375 100644 --- a/pkg/mock/mock_scale/mock_interfaces.go +++ b/pkg/mock/mock_scale/mock_interfaces.go @@ -1,5 +1,5 @@ // Code generated by MockGen. DO NOT EDIT. -// Source: /go/pkg/mod/k8s.io/client-go@v0.22.1/scale/interfaces.go +// Source: /home/ecomaz/go/pkg/mod/k8s.io/client-go@v0.22.2/scale/interfaces.go // Package mock_scale is a generated GoMock package. package mock_scale diff --git a/pkg/mock/mock_scaler/mock_scaler.go b/pkg/mock/mock_scaler/mock_scaler.go index 917d39b3cdc..cfee6c54623 100644 --- a/pkg/mock/mock_scaler/mock_scaler.go +++ b/pkg/mock/mock_scaler/mock_scaler.go @@ -38,31 +38,31 @@ func (m *MockScaler) EXPECT() *MockScalerMockRecorder { } // Close mocks base method. -func (m *MockScaler) Close() error { +func (m *MockScaler) Close(ctx context.Context) error { m.ctrl.T.Helper() - ret := m.ctrl.Call(m, "Close") + ret := m.ctrl.Call(m, "Close", ctx) ret0, _ := ret[0].(error) return ret0 } // Close indicates an expected call of Close. -func (mr *MockScalerMockRecorder) Close() *gomock.Call { +func (mr *MockScalerMockRecorder) Close(ctx interface{}) *gomock.Call { mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Close", reflect.TypeOf((*MockScaler)(nil).Close)) + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Close", reflect.TypeOf((*MockScaler)(nil).Close), ctx) } // GetMetricSpecForScaling mocks base method. -func (m *MockScaler) GetMetricSpecForScaling() []v2beta2.MetricSpec { +func (m *MockScaler) GetMetricSpecForScaling(ctx context.Context) []v2beta2.MetricSpec { m.ctrl.T.Helper() - ret := m.ctrl.Call(m, "GetMetricSpecForScaling") + ret := m.ctrl.Call(m, "GetMetricSpecForScaling", ctx) ret0, _ := ret[0].([]v2beta2.MetricSpec) return ret0 } // GetMetricSpecForScaling indicates an expected call of GetMetricSpecForScaling. -func (mr *MockScalerMockRecorder) GetMetricSpecForScaling() *gomock.Call { +func (mr *MockScalerMockRecorder) GetMetricSpecForScaling(ctx interface{}) *gomock.Call { mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetMetricSpecForScaling", reflect.TypeOf((*MockScaler)(nil).GetMetricSpecForScaling)) + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetMetricSpecForScaling", reflect.TypeOf((*MockScaler)(nil).GetMetricSpecForScaling), ctx) } // GetMetrics mocks base method. @@ -119,31 +119,31 @@ func (m *MockPushScaler) EXPECT() *MockPushScalerMockRecorder { } // Close mocks base method. -func (m *MockPushScaler) Close() error { +func (m *MockPushScaler) Close(ctx context.Context) error { m.ctrl.T.Helper() - ret := m.ctrl.Call(m, "Close") + ret := m.ctrl.Call(m, "Close", ctx) ret0, _ := ret[0].(error) return ret0 } // Close indicates an expected call of Close. -func (mr *MockPushScalerMockRecorder) Close() *gomock.Call { +func (mr *MockPushScalerMockRecorder) Close(ctx interface{}) *gomock.Call { mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Close", reflect.TypeOf((*MockPushScaler)(nil).Close)) + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Close", reflect.TypeOf((*MockPushScaler)(nil).Close), ctx) } // GetMetricSpecForScaling mocks base method. -func (m *MockPushScaler) GetMetricSpecForScaling() []v2beta2.MetricSpec { +func (m *MockPushScaler) GetMetricSpecForScaling(ctx context.Context) []v2beta2.MetricSpec { m.ctrl.T.Helper() - ret := m.ctrl.Call(m, "GetMetricSpecForScaling") + ret := m.ctrl.Call(m, "GetMetricSpecForScaling", ctx) ret0, _ := ret[0].([]v2beta2.MetricSpec) return ret0 } // GetMetricSpecForScaling indicates an expected call of GetMetricSpecForScaling. -func (mr *MockPushScalerMockRecorder) GetMetricSpecForScaling() *gomock.Call { +func (mr *MockPushScalerMockRecorder) GetMetricSpecForScaling(ctx interface{}) *gomock.Call { mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetMetricSpecForScaling", reflect.TypeOf((*MockPushScaler)(nil).GetMetricSpecForScaling)) + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetMetricSpecForScaling", reflect.TypeOf((*MockPushScaler)(nil).GetMetricSpecForScaling), ctx) } // GetMetrics mocks base method. diff --git a/pkg/mock/mock_scaling/mock_interface.go b/pkg/mock/mock_scaling/mock_interface.go index 52cadd3266b..f398a17fc12 100644 --- a/pkg/mock/mock_scaling/mock_interface.go +++ b/pkg/mock/mock_scaling/mock_interface.go @@ -5,6 +5,7 @@ package mock_scaling import ( + context "context" reflect "reflect" gomock "github.com/golang/mock/gomock" @@ -49,18 +50,18 @@ func (mr *MockScaleHandlerMockRecorder) DeleteScalableObject(scalableObject inte } // GetScalers mocks base method. -func (m *MockScaleHandler) GetScalers(scalableObject interface{}) ([]scalers.Scaler, error) { +func (m *MockScaleHandler) GetScalers(ctx context.Context, scalableObject interface{}) ([]scalers.Scaler, error) { m.ctrl.T.Helper() - ret := m.ctrl.Call(m, "GetScalers", scalableObject) + ret := m.ctrl.Call(m, "GetScalers", ctx, scalableObject) ret0, _ := ret[0].([]scalers.Scaler) ret1, _ := ret[1].(error) return ret0, ret1 } // GetScalers indicates an expected call of GetScalers. -func (mr *MockScaleHandlerMockRecorder) GetScalers(scalableObject interface{}) *gomock.Call { +func (mr *MockScaleHandlerMockRecorder) GetScalers(ctx, scalableObject interface{}) *gomock.Call { mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetScalers", reflect.TypeOf((*MockScaleHandler)(nil).GetScalers), scalableObject) + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetScalers", reflect.TypeOf((*MockScaleHandler)(nil).GetScalers), ctx, scalableObject) } // HandleScalableObject mocks base method. diff --git a/pkg/provider/provider.go b/pkg/provider/provider.go index cdb130ca87a..2b2b663df09 100644 --- a/pkg/provider/provider.go +++ b/pkg/provider/provider.go @@ -93,14 +93,14 @@ func (p *KedaProvider) GetExternalMetric(ctx context.Context, namespace string, scaledObject := &scaledObjects.Items[0] matchingMetrics := []external_metrics.ExternalMetricValue{} - scalers, err := p.scaleHandler.GetScalers(scaledObject) + scalers, err := p.scaleHandler.GetScalers(ctx, scaledObject) metricsServer.RecordScalerObjectError(scaledObject.Namespace, scaledObject.Name, err) if err != nil { return nil, fmt.Errorf("error when getting scalers %s", err) } for scalerIndex, scaler := range scalers { - metricSpecs := scaler.GetMetricSpecForScaling() + metricSpecs := scaler.GetMetricSpecForScaling(ctx) scalerName := strings.Replace(fmt.Sprintf("%T", scaler), "*scalers.", "", 1) for _, metricSpec := range metricSpecs { @@ -124,7 +124,7 @@ func (p *KedaProvider) GetExternalMetric(ctx context.Context, namespace string, metricsServer.RecordHPAScalerError(namespace, scaledObject.Name, scalerName, scalerIndex, info.Metric, err) } } - scaler.Close() + scaler.Close(ctx) } if len(matchingMetrics) == 0 { diff --git a/pkg/scalers/artemis_scaler.go b/pkg/scalers/artemis_scaler.go index a9c1fc7c16e..78283457818 100644 --- a/pkg/scalers/artemis_scaler.go +++ b/pkg/scalers/artemis_scaler.go @@ -162,7 +162,7 @@ func parseArtemisMetadata(config *ScalerConfig) (*artemisMetadata, error) { // IsActive determines if we need to scale from zero func (s *artemisScaler) IsActive(ctx context.Context) (bool, error) { - messages, err := s.getQueueMessageCount() + messages, err := s.getQueueMessageCount(ctx) if err != nil { artemisLog.Error(err, "Unable to access the artemis management endpoint", "managementEndpoint", s.metadata.managementEndpoint) return false, err @@ -214,14 +214,14 @@ func (s *artemisScaler) getMonitoringEndpoint() string { return monitoringEndpoint } -func (s *artemisScaler) getQueueMessageCount() (int, error) { +func (s *artemisScaler) getQueueMessageCount(ctx context.Context) (int, error) { var monitoringInfo *artemisMonitoring messageCount := 0 client := s.httpClient url := s.getMonitoringEndpoint() - req, err := http.NewRequest("GET", url, nil) + req, err := http.NewRequestWithContext(ctx, "GET", url, nil) req.SetBasicAuth(s.metadata.username, s.metadata.password) req.Header.Set("Origin", s.metadata.corsHeader) @@ -250,7 +250,7 @@ func (s *artemisScaler) getQueueMessageCount() (int, error) { return messageCount, nil } -func (s *artemisScaler) GetMetricSpecForScaling() []v2beta2.MetricSpec { +func (s *artemisScaler) GetMetricSpecForScaling(ctx context.Context) []v2beta2.MetricSpec { targetMetricValue := resource.NewQuantity(int64(s.metadata.queueLength), resource.DecimalSI) externalMetric := &v2beta2.ExternalMetricSource{ Metric: v2beta2.MetricIdentifier{ @@ -267,7 +267,7 @@ func (s *artemisScaler) GetMetricSpecForScaling() []v2beta2.MetricSpec { // GetMetrics returns value for a supported metric and an error if there is a problem getting the metric func (s *artemisScaler) GetMetrics(ctx context.Context, metricName string, metricSelector labels.Selector) ([]external_metrics.ExternalMetricValue, error) { - messages, err := s.getQueueMessageCount() + messages, err := s.getQueueMessageCount(ctx) if err != nil { artemisLog.Error(err, "Unable to access the artemis management endpoint", "managementEndpoint", s.metadata.managementEndpoint) @@ -284,6 +284,6 @@ func (s *artemisScaler) GetMetrics(ctx context.Context, metricName string, metri } // Nothing to close here. -func (s *artemisScaler) Close() error { +func (s *artemisScaler) Close(context.Context) error { return nil } diff --git a/pkg/scalers/artemis_scaler_test.go b/pkg/scalers/artemis_scaler_test.go index 7f127ddd6e2..6f5e3533e84 100644 --- a/pkg/scalers/artemis_scaler_test.go +++ b/pkg/scalers/artemis_scaler_test.go @@ -1,6 +1,7 @@ package scalers import ( + "context" "net/http" "testing" ) @@ -143,6 +144,7 @@ func TestArtemisParseMetadata(t *testing.T) { func TestArtemisGetMetricSpecForScaling(t *testing.T) { for _, testData := range artemisMetricIdentifiers { + ctx := context.Background() meta, err := parseArtemisMetadata(&ScalerConfig{ResolvedEnv: sampleArtemisResolvedEnv, TriggerMetadata: testData.metadataTestData.metadata, AuthParams: nil, ScalerIndex: testData.scalerIndex}) if err != nil { t.Fatal("Could not parse metadata:", err) @@ -152,7 +154,7 @@ func TestArtemisGetMetricSpecForScaling(t *testing.T) { httpClient: http.DefaultClient, } - metricSpec := mockArtemisScaler.GetMetricSpecForScaling() + metricSpec := mockArtemisScaler.GetMetricSpecForScaling(ctx) metricName := metricSpec[0].External.Metric.Name if metricName != testData.name { t.Error("Wrong External metric source name:", metricName) diff --git a/pkg/scalers/aws_cloudwatch_scaler.go b/pkg/scalers/aws_cloudwatch_scaler.go index 3bba2e2ae9f..5f1d9fc3da7 100644 --- a/pkg/scalers/aws_cloudwatch_scaler.go +++ b/pkg/scalers/aws_cloudwatch_scaler.go @@ -213,7 +213,7 @@ func (c *awsCloudwatchScaler) GetMetrics(ctx context.Context, metricName string, return append([]external_metrics.ExternalMetricValue{}, metric), nil } -func (c *awsCloudwatchScaler) GetMetricSpecForScaling() []v2beta2.MetricSpec { +func (c *awsCloudwatchScaler) GetMetricSpecForScaling(context.Context) []v2beta2.MetricSpec { targetMetricValue := resource.NewQuantity(int64(c.metadata.targetMetricValue), resource.DecimalSI) externalMetric := &v2beta2.ExternalMetricSource{ Metric: v2beta2.MetricIdentifier{ @@ -238,7 +238,7 @@ func (c *awsCloudwatchScaler) IsActive(ctx context.Context) (bool, error) { return val > c.metadata.minMetricValue, nil } -func (c *awsCloudwatchScaler) Close() error { +func (c *awsCloudwatchScaler) Close(context.Context) error { return nil } diff --git a/pkg/scalers/aws_cloudwatch_test.go b/pkg/scalers/aws_cloudwatch_test.go index d39e57977ff..2bfb1f2adaa 100644 --- a/pkg/scalers/aws_cloudwatch_test.go +++ b/pkg/scalers/aws_cloudwatch_test.go @@ -1,6 +1,7 @@ package scalers import ( + "context" "testing" ) @@ -252,13 +253,14 @@ func TestCloudwatchParseMetadata(t *testing.T) { func TestAWSCloudwatchGetMetricSpecForScaling(t *testing.T) { for _, testData := range awsCloudwatchMetricIdentifiers { + ctx := context.Background() meta, err := parseAwsCloudwatchMetadata(&ScalerConfig{TriggerMetadata: testData.metadataTestData.metadata, ResolvedEnv: testAWSCloudwatchResolvedEnv, AuthParams: testData.metadataTestData.authParams, ScalerIndex: testData.scalerIndex}) if err != nil { t.Fatal("Could not parse metadata:", err) } mockAWSCloudwatchScaler := awsCloudwatchScaler{meta} - metricSpec := mockAWSCloudwatchScaler.GetMetricSpecForScaling() + metricSpec := mockAWSCloudwatchScaler.GetMetricSpecForScaling(ctx) metricName := metricSpec[0].External.Metric.Name if metricName != testData.name { t.Error("Wrong External metric source name:", metricName) diff --git a/pkg/scalers/aws_kinesis_stream_scaler.go b/pkg/scalers/aws_kinesis_stream_scaler.go index 4bd3482c40a..3aa82a20f21 100644 --- a/pkg/scalers/aws_kinesis_stream_scaler.go +++ b/pkg/scalers/aws_kinesis_stream_scaler.go @@ -100,11 +100,11 @@ func (s *awsKinesisStreamScaler) IsActive(ctx context.Context) (bool, error) { return count > 0, nil } -func (s *awsKinesisStreamScaler) Close() error { +func (s *awsKinesisStreamScaler) Close(context.Context) error { return nil } -func (s *awsKinesisStreamScaler) GetMetricSpecForScaling() []v2beta2.MetricSpec { +func (s *awsKinesisStreamScaler) GetMetricSpecForScaling(context.Context) []v2beta2.MetricSpec { targetShardCountQty := resource.NewQuantity(int64(s.metadata.targetShardCount), resource.DecimalSI) externalMetric := &v2beta2.ExternalMetricSource{ Metric: v2beta2.MetricIdentifier{ diff --git a/pkg/scalers/aws_kinesis_stream_test.go b/pkg/scalers/aws_kinesis_stream_test.go index 232d0d2755f..84a02498f34 100644 --- a/pkg/scalers/aws_kinesis_stream_test.go +++ b/pkg/scalers/aws_kinesis_stream_test.go @@ -1,6 +1,7 @@ package scalers import ( + "context" "reflect" "testing" ) @@ -217,13 +218,14 @@ func TestKinesisParseMetadata(t *testing.T) { func TestAWSKinesisGetMetricSpecForScaling(t *testing.T) { for _, testData := range awsKinesisMetricIdentifiers { + ctx := context.Background() meta, err := parseAwsKinesisStreamMetadata(&ScalerConfig{TriggerMetadata: testData.metadataTestData.metadata, ResolvedEnv: testAWSKinesisAuthentication, AuthParams: testData.metadataTestData.authParams, ScalerIndex: testData.scalerIndex}) if err != nil { t.Fatal("Could not parse metadata:", err) } mockAWSKinesisStreamScaler := awsKinesisStreamScaler{meta} - metricSpec := mockAWSKinesisStreamScaler.GetMetricSpecForScaling() + metricSpec := mockAWSKinesisStreamScaler.GetMetricSpecForScaling(ctx) metricName := metricSpec[0].External.Metric.Name if metricName != testData.name { t.Error("Wrong External metric source name:", metricName) diff --git a/pkg/scalers/aws_sqs_queue_scaler.go b/pkg/scalers/aws_sqs_queue_scaler.go index 2f4bc31d129..08e0747cd2b 100644 --- a/pkg/scalers/aws_sqs_queue_scaler.go +++ b/pkg/scalers/aws_sqs_queue_scaler.go @@ -122,11 +122,11 @@ func (s *awsSqsQueueScaler) IsActive(ctx context.Context) (bool, error) { return length > 0, nil } -func (s *awsSqsQueueScaler) Close() error { +func (s *awsSqsQueueScaler) Close(context.Context) error { return nil } -func (s *awsSqsQueueScaler) GetMetricSpecForScaling() []v2beta2.MetricSpec { +func (s *awsSqsQueueScaler) GetMetricSpecForScaling(context.Context) []v2beta2.MetricSpec { targetQueueLengthQty := resource.NewQuantity(int64(s.metadata.targetQueueLength), resource.DecimalSI) externalMetric := &v2beta2.ExternalMetricSource{ Metric: v2beta2.MetricIdentifier{ diff --git a/pkg/scalers/aws_sqs_queue_test.go b/pkg/scalers/aws_sqs_queue_test.go index 1fd34870fda..726ee357bc9 100644 --- a/pkg/scalers/aws_sqs_queue_test.go +++ b/pkg/scalers/aws_sqs_queue_test.go @@ -1,6 +1,7 @@ package scalers import ( + "context" "testing" ) @@ -150,13 +151,14 @@ func TestSQSParseMetadata(t *testing.T) { func TestAWSSQSGetMetricSpecForScaling(t *testing.T) { for _, testData := range awsSQSMetricIdentifiers { + ctx := context.Background() meta, err := parseAwsSqsQueueMetadata(&ScalerConfig{TriggerMetadata: testData.metadataTestData.metadata, ResolvedEnv: testAWSSQSAuthentication, AuthParams: testData.metadataTestData.authParams, ScalerIndex: testData.scalerIndex}) if err != nil { t.Fatal("Could not parse metadata:", err) } mockAWSSQSScaler := awsSqsQueueScaler{meta} - metricSpec := mockAWSSQSScaler.GetMetricSpecForScaling() + metricSpec := mockAWSSQSScaler.GetMetricSpecForScaling(ctx) metricName := metricSpec[0].External.Metric.Name if metricName != testData.name { t.Error("Wrong External metric source name:", metricName) diff --git a/pkg/scalers/azure/azure_aad_podidentity.go b/pkg/scalers/azure/azure_aad_podidentity.go index ffee9159403..8a4f2566b8b 100644 --- a/pkg/scalers/azure/azure_aad_podidentity.go +++ b/pkg/scalers/azure/azure_aad_podidentity.go @@ -1,6 +1,7 @@ package azure import ( + "context" "encoding/json" "errors" "fmt" @@ -16,11 +17,11 @@ const ( ) // GetAzureADPodIdentityToken returns the AADToken for resource -func GetAzureADPodIdentityToken(httpClient util.HTTPDoer, audience string) (AADToken, error) { +func GetAzureADPodIdentityToken(ctx context.Context, httpClient util.HTTPDoer, audience string) (AADToken, error) { var token AADToken urlStr := fmt.Sprintf(msiURL, url.QueryEscape(audience)) - req, err := http.NewRequest("GET", urlStr, nil) + req, err := http.NewRequestWithContext(ctx, "GET", urlStr, nil) if err != nil { return token, err } diff --git a/pkg/scalers/azure/azure_blob.go b/pkg/scalers/azure/azure_blob.go index fdded6fe8a4..0902afe35c2 100644 --- a/pkg/scalers/azure/azure_blob.go +++ b/pkg/scalers/azure/azure_blob.go @@ -27,7 +27,7 @@ import ( // GetAzureBlobListLength returns the count of the blobs in blob container in int func GetAzureBlobListLength(ctx context.Context, httpClient util.HTTPDoer, podIdentity kedav1alpha1.PodIdentityProvider, connectionString, blobContainerName string, accountName string, blobDelimiter string, blobPrefix string, endpointSuffix string) (int, error) { - credential, endpoint, err := ParseAzureStorageBlobConnection(httpClient, podIdentity, connectionString, accountName, endpointSuffix) + credential, endpoint, err := ParseAzureStorageBlobConnection(ctx, httpClient, podIdentity, connectionString, accountName, endpointSuffix) if err != nil { return -1, err } diff --git a/pkg/scalers/azure/azure_eventhub_checkpoint.go b/pkg/scalers/azure/azure_eventhub_checkpoint.go index a6a60e8e88c..3fccadf938d 100644 --- a/pkg/scalers/azure/azure_eventhub_checkpoint.go +++ b/pkg/scalers/azure/azure_eventhub_checkpoint.go @@ -214,7 +214,7 @@ func (checkpointer *defaultCheckpointer) extractCheckpoint(get *azblob.DownloadR } func getCheckpoint(ctx context.Context, httpClient util.HTTPDoer, info EventHubInfo, checkpointer checkpointer) (Checkpoint, error) { - blobCreds, storageEndpoint, err := ParseAzureStorageBlobConnection(httpClient, kedav1alpha1.PodIdentityProviderNone, info.StorageConnection, "", "") + blobCreds, storageEndpoint, err := ParseAzureStorageBlobConnection(ctx, httpClient, kedav1alpha1.PodIdentityProviderNone, info.StorageConnection, "", "") if err != nil { return Checkpoint{}, err } diff --git a/pkg/scalers/azure/azure_eventhub_test.go b/pkg/scalers/azure/azure_eventhub_test.go index fe242e04e70..c3e7cf9950a 100644 --- a/pkg/scalers/azure/azure_eventhub_test.go +++ b/pkg/scalers/azure/azure_eventhub_test.go @@ -292,10 +292,11 @@ func TestShouldParseCheckpointForGoSdk(t *testing.T) { } func createNewCheckpointInStorage(urlPath string, containerName string, partitionID string, checkpoint string, metadata map[string]string) (context.Context, error) { - credential, endpoint, _ := ParseAzureStorageBlobConnection(http.DefaultClient, "none", StorageConnectionString, "", "") + ctx := context.Background() + + credential, endpoint, _ := ParseAzureStorageBlobConnection(ctx, http.DefaultClient, "none", StorageConnectionString, "", "") // Create container - ctx := context.Background() path, _ := url.Parse(containerName) url := endpoint.ResolveReference(path) containerURL := azblob.NewContainerURL(*url, azblob.NewPipeline(credential, azblob.PipelineOptions{})) diff --git a/pkg/scalers/azure/azure_queue.go b/pkg/scalers/azure/azure_queue.go index 6c89a7f7c8d..df549648165 100644 --- a/pkg/scalers/azure/azure_queue.go +++ b/pkg/scalers/azure/azure_queue.go @@ -27,7 +27,7 @@ import ( // GetAzureQueueLength returns the length of a queue in int func GetAzureQueueLength(ctx context.Context, httpClient util.HTTPDoer, podIdentity kedav1alpha1.PodIdentityProvider, connectionString, queueName, accountName, endpointSuffix string) (int32, error) { - credential, endpoint, err := ParseAzureStorageQueueConnection(httpClient, podIdentity, connectionString, accountName, endpointSuffix) + credential, endpoint, err := ParseAzureStorageQueueConnection(ctx, httpClient, podIdentity, connectionString, accountName, endpointSuffix) if err != nil { return -1, err } @@ -40,7 +40,7 @@ func GetAzureQueueLength(ctx context.Context, httpClient util.HTTPDoer, podIdent return -1, err } - visibleMessageCount, err := getVisibleCount(&queueURL, 32) + visibleMessageCount, err := getVisibleCount(ctx, &queueURL, 32) if err != nil { return -1, err } @@ -53,9 +53,8 @@ func GetAzureQueueLength(ctx context.Context, httpClient util.HTTPDoer, podIdent return visibleMessageCount, nil } -func getVisibleCount(queueURL *azqueue.QueueURL, maxCount int32) (int32, error) { +func getVisibleCount(ctx context.Context, queueURL *azqueue.QueueURL, maxCount int32) (int32, error) { messagesURL := queueURL.NewMessagesURL() - ctx := context.Background() queue, err := messagesURL.Peek(ctx, maxCount) if err != nil { return 0, err diff --git a/pkg/scalers/azure/azure_storage.go b/pkg/scalers/azure/azure_storage.go index 4af2b97c3bb..d71ef08f1cc 100644 --- a/pkg/scalers/azure/azure_storage.go +++ b/pkg/scalers/azure/azure_storage.go @@ -17,6 +17,7 @@ limitations under the License. package azure import ( + "context" "errors" "fmt" "net/url" @@ -77,10 +78,10 @@ func ParseAzureStorageEndpointSuffix(metadata map[string]string, endpointType St } // ParseAzureStorageQueueConnection parses queue connection string and returns credential and resource url -func ParseAzureStorageQueueConnection(httpClient util.HTTPDoer, podIdentity kedav1alpha1.PodIdentityProvider, connectionString, accountName, endpointSuffix string) (azqueue.Credential, *url.URL, error) { +func ParseAzureStorageQueueConnection(ctx context.Context, httpClient util.HTTPDoer, podIdentity kedav1alpha1.PodIdentityProvider, connectionString, accountName, endpointSuffix string) (azqueue.Credential, *url.URL, error) { switch podIdentity { case kedav1alpha1.PodIdentityProviderAzure: - token, endpoint, err := parseAcessTokenAndEndpoint(httpClient, accountName, endpointSuffix) + token, endpoint, err := parseAcessTokenAndEndpoint(ctx, httpClient, accountName, endpointSuffix) if err != nil { return nil, nil, err } @@ -105,10 +106,10 @@ func ParseAzureStorageQueueConnection(httpClient util.HTTPDoer, podIdentity keda } // ParseAzureStorageBlobConnection parses blob connection string and returns credential and resource url -func ParseAzureStorageBlobConnection(httpClient util.HTTPDoer, podIdentity kedav1alpha1.PodIdentityProvider, connectionString, accountName, endpointSuffix string) (azblob.Credential, *url.URL, error) { +func ParseAzureStorageBlobConnection(ctx context.Context, httpClient util.HTTPDoer, podIdentity kedav1alpha1.PodIdentityProvider, connectionString, accountName, endpointSuffix string) (azblob.Credential, *url.URL, error) { switch podIdentity { case kedav1alpha1.PodIdentityProviderAzure: - token, endpoint, err := parseAcessTokenAndEndpoint(httpClient, accountName, endpointSuffix) + token, endpoint, err := parseAcessTokenAndEndpoint(ctx, httpClient, accountName, endpointSuffix) if err != nil { return nil, nil, err } @@ -189,9 +190,9 @@ func parseAzureStorageConnectionString(connectionString string, endpointType Sto return u, name, key, nil } -func parseAcessTokenAndEndpoint(httpClient util.HTTPDoer, accountName string, endpointSuffix string) (string, *url.URL, error) { +func parseAcessTokenAndEndpoint(ctx context.Context, httpClient util.HTTPDoer, accountName string, endpointSuffix string) (string, *url.URL, error) { // Azure storage resource is "https://storage.azure.com/" in all cloud environments - token, err := GetAzureADPodIdentityToken(httpClient, "https://storage.azure.com/") + token, err := GetAzureADPodIdentityToken(ctx, httpClient, "https://storage.azure.com/") if err != nil { return "", nil, err } diff --git a/pkg/scalers/azure_blob_scaler.go b/pkg/scalers/azure_blob_scaler.go index 764cb0c7b67..11efc79712f 100644 --- a/pkg/scalers/azure_blob_scaler.go +++ b/pkg/scalers/azure_blob_scaler.go @@ -180,11 +180,11 @@ func (s *azureBlobScaler) IsActive(ctx context.Context) (bool, error) { return length > 0, nil } -func (s *azureBlobScaler) Close() error { +func (s *azureBlobScaler) Close(context.Context) error { return nil } -func (s *azureBlobScaler) GetMetricSpecForScaling() []v2beta2.MetricSpec { +func (s *azureBlobScaler) GetMetricSpecForScaling(context.Context) []v2beta2.MetricSpec { targetBlobCount := resource.NewQuantity(int64(s.metadata.targetBlobCount), resource.DecimalSI) externalMetric := &v2beta2.ExternalMetricSource{ Metric: v2beta2.MetricIdentifier{ diff --git a/pkg/scalers/azure_blob_scaler_test.go b/pkg/scalers/azure_blob_scaler_test.go index 2b4a96613c2..ebe489c3aef 100644 --- a/pkg/scalers/azure_blob_scaler_test.go +++ b/pkg/scalers/azure_blob_scaler_test.go @@ -17,6 +17,7 @@ limitations under the License. package scalers import ( + "context" "net/http" "testing" @@ -95,6 +96,7 @@ func TestAzBlobParseMetadata(t *testing.T) { func TestAzBlobGetMetricSpecForScaling(t *testing.T) { for _, testData := range azBlobMetricIdentifiers { + ctx := context.Background() meta, podIdentity, err := parseAzureBlobMetadata(&ScalerConfig{TriggerMetadata: testData.metadataTestData.metadata, ResolvedEnv: testData.metadataTestData.resolvedEnv, AuthParams: testData.metadataTestData.authParams, PodIdentity: testData.metadataTestData.podIdentity, ScalerIndex: testData.scalerIndex}) if err != nil { t.Fatal("Could not parse metadata:", err) @@ -105,7 +107,7 @@ func TestAzBlobGetMetricSpecForScaling(t *testing.T) { httpClient: http.DefaultClient, } - metricSpec := mockAzBlobScaler.GetMetricSpecForScaling() + metricSpec := mockAzBlobScaler.GetMetricSpecForScaling(ctx) metricName := metricSpec[0].External.Metric.Name if metricName != testData.name { t.Error("Wrong External metric source name:", metricName) diff --git a/pkg/scalers/azure_eventhub_scaler.go b/pkg/scalers/azure_eventhub_scaler.go index 90633b7071a..566e8ee96d1 100644 --- a/pkg/scalers/azure_eventhub_scaler.go +++ b/pkg/scalers/azure_eventhub_scaler.go @@ -251,7 +251,7 @@ func (scaler *azureEventHubScaler) IsActive(ctx context.Context) (bool, error) { } // GetMetricSpecForScaling returns metric spec -func (scaler *azureEventHubScaler) GetMetricSpecForScaling() []v2beta2.MetricSpec { +func (scaler *azureEventHubScaler) GetMetricSpecForScaling(context.Context) []v2beta2.MetricSpec { targetMetricVal := resource.NewQuantity(scaler.metadata.threshold, resource.DecimalSI) externalMetric := &v2beta2.ExternalMetricSource{ Metric: v2beta2.MetricIdentifier{ @@ -319,7 +319,7 @@ func getTotalLagRelatedToPartitionAmount(unprocessedEventsCount int64, partition } // Close closes Azure Event Hub Scaler -func (scaler *azureEventHubScaler) Close() error { +func (scaler *azureEventHubScaler) Close(context.Context) error { if scaler.client != nil { err := scaler.client.Close(context.TODO()) if err != nil { diff --git a/pkg/scalers/azure_eventhub_scaler_test.go b/pkg/scalers/azure_eventhub_scaler_test.go index a7e0748007e..de76a33af5c 100644 --- a/pkg/scalers/azure_eventhub_scaler_test.go +++ b/pkg/scalers/azure_eventhub_scaler_test.go @@ -105,6 +105,7 @@ func TestParseEventHubMetadata(t *testing.T) { } func TestGetUnprocessedEventCountInPartition(t *testing.T) { + ctx := context.Background() t.Log("This test will use the environment variable EVENTHUB_CONNECTION_STRING and STORAGE_CONNECTION_STRING if it is set.") t.Log("If set, it will connect to the storage account and event hub to determine how many messages are in the event hub.") t.Logf("EventHub has 1 message in partition 0 and 0 messages in partition 1") @@ -114,7 +115,7 @@ func TestGetUnprocessedEventCountInPartition(t *testing.T) { if eventHubKey != "" && storageConnectionString != "" { eventHubConnectionString := fmt.Sprintf("Endpoint=sb://%s.servicebus.windows.net/;SharedAccessKeyName=RootManageSharedAccessKey;SharedAccessKey=%s;EntityPath=%s", testEventHubNamespace, eventHubKey, testEventHubName) - storageCredentials, endpoint, err := azure.ParseAzureStorageBlobConnection(http.DefaultClient, "none", storageConnectionString, "", "") + storageCredentials, endpoint, err := azure.ParseAzureStorageBlobConnection(ctx, http.DefaultClient, "none", storageConnectionString, "", "") if err != nil { t.Error(err) t.FailNow() @@ -447,7 +448,7 @@ func TestEventHubGetMetricSpecForScaling(t *testing.T) { httpClient: http.DefaultClient, } - metricSpec := mockEventHubScaler.GetMetricSpecForScaling() + metricSpec := mockEventHubScaler.GetMetricSpecForScaling(context.Background()) metricName := metricSpec[0].External.Metric.Name if metricName != testData.name { t.Error("Wrong External metric source name:", metricName) diff --git a/pkg/scalers/azure_log_analytics_scaler.go b/pkg/scalers/azure_log_analytics_scaler.go index 6d69067be62..91ee3a160c0 100644 --- a/pkg/scalers/azure_log_analytics_scaler.go +++ b/pkg/scalers/azure_log_analytics_scaler.go @@ -206,7 +206,7 @@ func getParameterFromConfig(config *ScalerConfig, parameter string, checkAuthPar // IsActive determines if we need to scale from zero func (s *azureLogAnalyticsScaler) IsActive(ctx context.Context) (bool, error) { - err := s.updateCache() + err := s.updateCache(ctx) if err != nil { return false, fmt.Errorf("failed to execute IsActive function. Scaled object: %s. Namespace: %s. Inner Error: %v", s.name, s.namespace, err) @@ -215,8 +215,8 @@ func (s *azureLogAnalyticsScaler) IsActive(ctx context.Context) (bool, error) { return s.cache.metricValue > 0, nil } -func (s *azureLogAnalyticsScaler) GetMetricSpecForScaling() []v2beta2.MetricSpec { - err := s.updateCache() +func (s *azureLogAnalyticsScaler) GetMetricSpecForScaling(ctx context.Context) []v2beta2.MetricSpec { + err := s.updateCache(ctx) if err != nil { logAnalyticsLog.V(1).Info("failed to get metric spec.", "Scaled object", s.name, "Namespace", s.namespace, "Inner Error", err) @@ -238,7 +238,7 @@ func (s *azureLogAnalyticsScaler) GetMetricSpecForScaling() []v2beta2.MetricSpec // GetMetrics returns value for a supported metric and an error if there is a problem getting the metric func (s *azureLogAnalyticsScaler) GetMetrics(ctx context.Context, metricName string, metricSelector labels.Selector) ([]external_metrics.ExternalMetricValue, error) { - receivedMetric, err := s.getMetricData() + receivedMetric, err := s.getMetricData(ctx) if err != nil { return []external_metrics.ExternalMetricValue{}, fmt.Errorf("failed to get metrics. Scaled object: %s. Namespace: %s. Inner Error: %v", s.name, s.namespace, err) @@ -253,13 +253,13 @@ func (s *azureLogAnalyticsScaler) GetMetrics(ctx context.Context, metricName str return append([]external_metrics.ExternalMetricValue{}, metric), nil } -func (s *azureLogAnalyticsScaler) Close() error { +func (s *azureLogAnalyticsScaler) Close(context.Context) error { return nil } -func (s *azureLogAnalyticsScaler) updateCache() error { +func (s *azureLogAnalyticsScaler) updateCache(ctx context.Context) error { if s.cache.metricValue < 0 { - receivedMetric, err := s.getMetricData() + receivedMetric, err := s.getMetricData(ctx) if err != nil { return err @@ -277,13 +277,13 @@ func (s *azureLogAnalyticsScaler) updateCache() error { return nil } -func (s *azureLogAnalyticsScaler) getMetricData() (metricsData, error) { - tokenInfo, err := s.getAccessToken() +func (s *azureLogAnalyticsScaler) getMetricData(ctx context.Context) (metricsData, error) { + tokenInfo, err := s.getAccessToken(ctx) if err != nil { return metricsData{}, err } - metricsInfo, err := s.executeQuery(s.metadata.query, tokenInfo) + metricsInfo, err := s.executeQuery(ctx, s.metadata.query, tokenInfo) if err != nil { return metricsData{}, err } @@ -293,7 +293,7 @@ func (s *azureLogAnalyticsScaler) getMetricData() (metricsData, error) { return metricsInfo, nil } -func (s *azureLogAnalyticsScaler) getAccessToken() (tokenData, error) { +func (s *azureLogAnalyticsScaler) getAccessToken(ctx context.Context) (tokenData, error) { // if there is no token yet or it will be expired in less, that 30 secs currentTimeSec := time.Now().Unix() tokenInfo := tokenData{} @@ -305,7 +305,7 @@ func (s *azureLogAnalyticsScaler) getAccessToken() (tokenData, error) { } if currentTimeSec+30 > tokenInfo.ExpiresOn { - newTokenInfo, err := s.refreshAccessToken() + newTokenInfo, err := s.refreshAccessToken(ctx) if err != nil { return tokenData{}, err } @@ -323,17 +323,17 @@ func (s *azureLogAnalyticsScaler) getAccessToken() (tokenData, error) { return tokenInfo, nil } -func (s *azureLogAnalyticsScaler) executeQuery(query string, tokenInfo tokenData) (metricsData, error) { +func (s *azureLogAnalyticsScaler) executeQuery(ctx context.Context, query string, tokenInfo tokenData) (metricsData, error) { queryData := queryResult{} var body []byte var statusCode int var err error - body, statusCode, err = s.executeLogAnalyticsREST(query, tokenInfo) + body, statusCode, err = s.executeLogAnalyticsREST(ctx, query, tokenInfo) // Handle expired token if statusCode == 403 || (len(body) > 0 && strings.Contains(string(body), "TokenExpired")) { - tokenInfo, err = s.refreshAccessToken() + tokenInfo, err = s.refreshAccessToken(ctx) if err != nil { return metricsData{}, err } @@ -347,7 +347,7 @@ func (s *azureLogAnalyticsScaler) executeQuery(query string, tokenInfo tokenData } if err == nil { - body, statusCode, err = s.executeLogAnalyticsREST(query, tokenInfo) + body, statusCode, err = s.executeLogAnalyticsREST(ctx, query, tokenInfo) } else { return metricsData{}, err } @@ -431,8 +431,8 @@ func parseTableValueToInt64(value interface{}, dataType string) (int64, error) { return 0, fmt.Errorf("error validating Log Analytics request. Details: value is empty, check your query") } -func (s *azureLogAnalyticsScaler) refreshAccessToken() (tokenData, error) { - tokenInfo, err := s.getAuthorizationToken() +func (s *azureLogAnalyticsScaler) refreshAccessToken(ctx context.Context) (tokenData, error) { + tokenInfo, err := s.getAuthorizationToken(ctx) if err != nil { return tokenData{}, err @@ -453,16 +453,16 @@ func (s *azureLogAnalyticsScaler) refreshAccessToken() (tokenData, error) { return tokenInfo, nil } -func (s *azureLogAnalyticsScaler) getAuthorizationToken() (tokenData, error) { +func (s *azureLogAnalyticsScaler) getAuthorizationToken(ctx context.Context) (tokenData, error) { var body []byte var statusCode int var err error var tokenInfo tokenData if s.metadata.podIdentity == "" { - body, statusCode, err = s.executeAADApicall() + body, statusCode, err = s.executeAADApicall(ctx) } else { - body, statusCode, err = s.executeIMDSApicall() + body, statusCode, err = s.executeIMDSApicall(ctx) } if err != nil { @@ -483,7 +483,7 @@ func (s *azureLogAnalyticsScaler) getAuthorizationToken() (tokenData, error) { return tokenData{}, fmt.Errorf("error getting access token. Details: unknown error. HTTP code: %d. Body: %s", statusCode, string(body)) } -func (s *azureLogAnalyticsScaler) executeLogAnalyticsREST(query string, tokenInfo tokenData) ([]byte, int, error) { +func (s *azureLogAnalyticsScaler) executeLogAnalyticsREST(ctx context.Context, query string, tokenInfo tokenData) ([]byte, int, error) { m := map[string]interface{}{"query": query} jsonBytes, err := json.Marshal(m) @@ -491,7 +491,7 @@ func (s *azureLogAnalyticsScaler) executeLogAnalyticsREST(query string, tokenInf return nil, 0, fmt.Errorf("can't construct JSON for request to Log Analytics API. Inner Error: %v", err) } - request, err := http.NewRequest(http.MethodPost, fmt.Sprintf(laQueryEndpoint, s.metadata.workspaceID), bytes.NewBuffer(jsonBytes)) // URL-encoded payload + request, err := http.NewRequestWithContext(ctx, http.MethodPost, fmt.Sprintf(laQueryEndpoint, s.metadata.workspaceID), bytes.NewBuffer(jsonBytes)) // URL-encoded payload if err != nil { return nil, 0, fmt.Errorf("can't construct HTTP request to Log Analytics API. Inner Error: %v", err) } @@ -503,7 +503,7 @@ func (s *azureLogAnalyticsScaler) executeLogAnalyticsREST(query string, tokenInf return s.runHTTP(request, "Log Analytics REST api") } -func (s *azureLogAnalyticsScaler) executeAADApicall() ([]byte, int, error) { +func (s *azureLogAnalyticsScaler) executeAADApicall(ctx context.Context) ([]byte, int, error) { data := url.Values{ "grant_type": {"client_credentials"}, "client_id": {s.metadata.clientID}, @@ -512,7 +512,7 @@ func (s *azureLogAnalyticsScaler) executeAADApicall() ([]byte, int, error) { "client_secret": {s.metadata.clientSecret}, } - request, err := http.NewRequest(http.MethodPost, fmt.Sprintf(aadTokenEndpoint, s.metadata.tenantID), strings.NewReader(data.Encode())) // URL-encoded payload + request, err := http.NewRequestWithContext(ctx, http.MethodPost, fmt.Sprintf(aadTokenEndpoint, s.metadata.tenantID), strings.NewReader(data.Encode())) // URL-encoded payload if err != nil { return nil, 0, fmt.Errorf("can't construct HTTP request to Azure Active Directory. Inner Error: %v", err) } @@ -523,8 +523,8 @@ func (s *azureLogAnalyticsScaler) executeAADApicall() ([]byte, int, error) { return s.runHTTP(request, "AAD") } -func (s *azureLogAnalyticsScaler) executeIMDSApicall() ([]byte, int, error) { - request, err := http.NewRequest(http.MethodGet, miEndpoint, nil) +func (s *azureLogAnalyticsScaler) executeIMDSApicall(ctx context.Context) ([]byte, int, error) { + request, err := http.NewRequestWithContext(ctx, http.MethodGet, miEndpoint, nil) if err != nil { return nil, 0, fmt.Errorf("can't construct HTTP request to Azure Instance Metadata service. Inner Error: %v", err) } diff --git a/pkg/scalers/azure_log_analytics_scaler_test.go b/pkg/scalers/azure_log_analytics_scaler_test.go index bb93d7af4da..090ef6ae795 100644 --- a/pkg/scalers/azure_log_analytics_scaler_test.go +++ b/pkg/scalers/azure_log_analytics_scaler_test.go @@ -17,6 +17,7 @@ limitations under the License. package scalers import ( + "context" "net/http" "testing" @@ -174,7 +175,7 @@ func TestLogAnalyticsGetMetricSpecForScaling(t *testing.T) { httpClient: http.DefaultClient, } - metricSpec := mockLogAnalyticsScaler.GetMetricSpecForScaling() + metricSpec := mockLogAnalyticsScaler.GetMetricSpecForScaling(context.Background()) metricName := metricSpec[0].External.Metric.Name if metricName != testData.name { t.Error("Wrong External metric source name:", metricName) diff --git a/pkg/scalers/azure_monitor_scaler.go b/pkg/scalers/azure_monitor_scaler.go index 2f75c12edcd..9d9fd313aeb 100644 --- a/pkg/scalers/azure_monitor_scaler.go +++ b/pkg/scalers/azure_monitor_scaler.go @@ -186,11 +186,11 @@ func (s *azureMonitorScaler) IsActive(ctx context.Context) (bool, error) { return val > 0, nil } -func (s *azureMonitorScaler) Close() error { +func (s *azureMonitorScaler) Close(context.Context) error { return nil } -func (s *azureMonitorScaler) GetMetricSpecForScaling() []v2beta2.MetricSpec { +func (s *azureMonitorScaler) GetMetricSpecForScaling(context.Context) []v2beta2.MetricSpec { targetMetricVal := resource.NewQuantity(int64(s.metadata.targetValue), resource.DecimalSI) externalMetric := &v2beta2.ExternalMetricSource{ Metric: v2beta2.MetricIdentifier{ diff --git a/pkg/scalers/azure_monitor_scaler_test.go b/pkg/scalers/azure_monitor_scaler_test.go index 8c20e532370..c7a9f42fe48 100644 --- a/pkg/scalers/azure_monitor_scaler_test.go +++ b/pkg/scalers/azure_monitor_scaler_test.go @@ -17,6 +17,7 @@ limitations under the License. package scalers import ( + "context" "testing" kedav1alpha1 "github.com/kedacore/keda/v2/apis/keda/v1alpha1" @@ -105,7 +106,7 @@ func TestAzMonitorGetMetricSpecForScaling(t *testing.T) { } mockAzMonitorScaler := azureMonitorScaler{meta, testData.metadataTestData.podIdentity} - metricSpec := mockAzMonitorScaler.GetMetricSpecForScaling() + metricSpec := mockAzMonitorScaler.GetMetricSpecForScaling(context.Background()) metricName := metricSpec[0].External.Metric.Name if metricName != testData.name { t.Error("Wrong External metric source name:", metricName) diff --git a/pkg/scalers/azure_pipelines_scaler.go b/pkg/scalers/azure_pipelines_scaler.go index 7700c54b1e0..4b78ea9467f 100644 --- a/pkg/scalers/azure_pipelines_scaler.go +++ b/pkg/scalers/azure_pipelines_scaler.go @@ -120,7 +120,7 @@ func (s *azurePipelinesScaler) GetMetrics(ctx context.Context, metricName string func (s *azurePipelinesScaler) GetAzurePipelinesQueueLength(ctx context.Context) (int, error) { url := fmt.Sprintf("%s/_apis/distributedtask/pools/%s/jobrequests", s.metadata.organizationURL, s.metadata.poolID) - req, err := http.NewRequest("GET", url, nil) + req, err := http.NewRequestWithContext(ctx, "GET", url, nil) if err != nil { return -1, err } @@ -165,7 +165,7 @@ func (s *azurePipelinesScaler) GetAzurePipelinesQueueLength(ctx context.Context) return count, err } -func (s *azurePipelinesScaler) GetMetricSpecForScaling() []v2beta2.MetricSpec { +func (s *azurePipelinesScaler) GetMetricSpecForScaling(context.Context) []v2beta2.MetricSpec { targetPipelinesQueueLengthQty := resource.NewQuantity(int64(s.metadata.targetPipelinesQueueLength), resource.DecimalSI) externalMetric := &v2beta2.ExternalMetricSource{ Metric: v2beta2.MetricIdentifier{ @@ -191,6 +191,6 @@ func (s *azurePipelinesScaler) IsActive(ctx context.Context) (bool, error) { return queuelen > 0, nil } -func (s *azurePipelinesScaler) Close() error { +func (s *azurePipelinesScaler) Close(context.Context) error { return nil } diff --git a/pkg/scalers/azure_pipelines_scaler_test.go b/pkg/scalers/azure_pipelines_scaler_test.go index 50842d0d8c9..4eb3cad85ad 100644 --- a/pkg/scalers/azure_pipelines_scaler_test.go +++ b/pkg/scalers/azure_pipelines_scaler_test.go @@ -1,6 +1,7 @@ package scalers import ( + "context" "net/http" "testing" ) @@ -66,7 +67,7 @@ func TestAzurePipelinesGetMetricSpecForScaling(t *testing.T) { httpClient: http.DefaultClient, } - metricSpec := mockAzurePipelinesScaler.GetMetricSpecForScaling() + metricSpec := mockAzurePipelinesScaler.GetMetricSpecForScaling(context.Background()) metricName := metricSpec[0].External.Metric.Name if metricName != testData.name { t.Error("Wrong External metric source name:", metricName) diff --git a/pkg/scalers/azure_queue_scaler.go b/pkg/scalers/azure_queue_scaler.go index f930eb3f840..573460d5035 100644 --- a/pkg/scalers/azure_queue_scaler.go +++ b/pkg/scalers/azure_queue_scaler.go @@ -158,11 +158,11 @@ func (s *azureQueueScaler) IsActive(ctx context.Context) (bool, error) { return length > 0, nil } -func (s *azureQueueScaler) Close() error { +func (s *azureQueueScaler) Close(context.Context) error { return nil } -func (s *azureQueueScaler) GetMetricSpecForScaling() []v2beta2.MetricSpec { +func (s *azureQueueScaler) GetMetricSpecForScaling(context.Context) []v2beta2.MetricSpec { targetQueueLengthQty := resource.NewQuantity(int64(s.metadata.targetQueueLength), resource.DecimalSI) externalMetric := &v2beta2.ExternalMetricSource{ Metric: v2beta2.MetricIdentifier{ diff --git a/pkg/scalers/azure_queue_scaler_test.go b/pkg/scalers/azure_queue_scaler_test.go index 3eb5ed0ca86..45a8cca6fcf 100644 --- a/pkg/scalers/azure_queue_scaler_test.go +++ b/pkg/scalers/azure_queue_scaler_test.go @@ -17,6 +17,7 @@ limitations under the License. package scalers import ( + "context" "net/http" "testing" @@ -108,7 +109,7 @@ func TestAzQueueGetMetricSpecForScaling(t *testing.T) { httpClient: http.DefaultClient, } - metricSpec := mockAzQueueScaler.GetMetricSpecForScaling() + metricSpec := mockAzQueueScaler.GetMetricSpecForScaling(context.Background()) metricName := metricSpec[0].External.Metric.Name if metricName != testData.name { t.Error("Wrong External metric source name:", metricName) diff --git a/pkg/scalers/azure_servicebus_scaler.go b/pkg/scalers/azure_servicebus_scaler.go index bf66910bb9e..fba48780102 100755 --- a/pkg/scalers/azure_servicebus_scaler.go +++ b/pkg/scalers/azure_servicebus_scaler.go @@ -173,12 +173,12 @@ func (s *azureServiceBusScaler) IsActive(ctx context.Context) (bool, error) { } // Close - nothing to close for SB -func (s *azureServiceBusScaler) Close() error { +func (s *azureServiceBusScaler) Close(context.Context) error { return nil } // Returns the metric spec to be used by the HPA -func (s *azureServiceBusScaler) GetMetricSpecForScaling() []v2beta2.MetricSpec { +func (s *azureServiceBusScaler) GetMetricSpecForScaling(context.Context) []v2beta2.MetricSpec { targetLengthQty := resource.NewQuantity(int64(s.metadata.targetLength), resource.DecimalSI) namespace, err := s.getServiceBusNamespace() if err != nil { @@ -229,8 +229,9 @@ type azureTokenProvider struct { // GetToken implements TokenProvider interface for azureTokenProvider func (a azureTokenProvider) GetToken(uri string) (*auth.Token, error) { + ctx := context.Background() // Service bus resource id is "https://servicebus.azure.net/" in all cloud environments - token, err := azure.GetAzureADPodIdentityToken(a.httpClient, "https://servicebus.azure.net/") + token, err := azure.GetAzureADPodIdentityToken(ctx, a.httpClient, "https://servicebus.azure.net/") if err != nil { return nil, err } diff --git a/pkg/scalers/azure_servicebus_scaler_test.go b/pkg/scalers/azure_servicebus_scaler_test.go index b9d6b85e046..65757057d81 100755 --- a/pkg/scalers/azure_servicebus_scaler_test.go +++ b/pkg/scalers/azure_servicebus_scaler_test.go @@ -200,7 +200,7 @@ func TestAzServiceBusGetMetricSpecForScaling(t *testing.T) { httpClient: http.DefaultClient, } - metricSpec := mockAzServiceBusScalerScaler.GetMetricSpecForScaling() + metricSpec := mockAzServiceBusScalerScaler.GetMetricSpecForScaling(context.Background()) metricName := metricSpec[0].External.Metric.Name if metricName != testData.name { t.Error("Wrong External metric source name:", metricName) diff --git a/pkg/scalers/cpu_memory_scaler.go b/pkg/scalers/cpu_memory_scaler.go index 5f3c52e63f1..813cfe3d22b 100644 --- a/pkg/scalers/cpu_memory_scaler.go +++ b/pkg/scalers/cpu_memory_scaler.go @@ -76,12 +76,12 @@ func (s *cpuMemoryScaler) IsActive(ctx context.Context) (bool, error) { } // Close no need for cpuMemory scaler -func (s *cpuMemoryScaler) Close() error { +func (s *cpuMemoryScaler) Close(context.Context) error { return nil } // GetMetricSpecForScaling returns the metric spec for the HPA -func (s *cpuMemoryScaler) GetMetricSpecForScaling() []v2beta2.MetricSpec { +func (s *cpuMemoryScaler) GetMetricSpecForScaling(context.Context) []v2beta2.MetricSpec { cpuMemoryMetric := &v2beta2.ResourceMetricSource{ Name: s.resourceName, Target: v2beta2.MetricTarget{ diff --git a/pkg/scalers/cpu_memory_scaler_test.go b/pkg/scalers/cpu_memory_scaler_test.go index e6c32b3d46e..9866b8b21d4 100644 --- a/pkg/scalers/cpu_memory_scaler_test.go +++ b/pkg/scalers/cpu_memory_scaler_test.go @@ -1,6 +1,7 @@ package scalers import ( + "context" "testing" "github.com/stretchr/testify/assert" @@ -49,7 +50,7 @@ func TestGetMetricSpecForScaling(t *testing.T) { TriggerMetadata: validCPUMemoryMetadata, } scaler, _ := NewCPUMemoryScaler(v1.ResourceCPU, config) - metricSpec := scaler.GetMetricSpecForScaling() + metricSpec := scaler.GetMetricSpecForScaling(context.Background()) assert.Equal(t, metricSpec[0].Type, v2beta2.ResourceMetricSourceType) assert.Equal(t, metricSpec[0].Resource.Name, v1.ResourceCPU) diff --git a/pkg/scalers/cron_scaler.go b/pkg/scalers/cron_scaler.go index be749e0a362..99eea10f2b0 100644 --- a/pkg/scalers/cron_scaler.go +++ b/pkg/scalers/cron_scaler.go @@ -139,7 +139,7 @@ func (s *cronScaler) IsActive(ctx context.Context) (bool, error) { } } -func (s *cronScaler) Close() error { +func (s *cronScaler) Close(context.Context) error { return nil } @@ -152,7 +152,7 @@ func parseCronTimeFormat(s string) string { } // GetMetricSpecForScaling returns the metric spec for the HPA -func (s *cronScaler) GetMetricSpecForScaling() []v2beta2.MetricSpec { +func (s *cronScaler) GetMetricSpecForScaling(context.Context) []v2beta2.MetricSpec { specReplicas := 1 targetMetricValue := resource.NewQuantity(int64(specReplicas), resource.DecimalSI) externalMetric := &v2beta2.ExternalMetricSource{ diff --git a/pkg/scalers/cron_scaler_test.go b/pkg/scalers/cron_scaler_test.go index 30efad9cbf8..fa525cb13b6 100644 --- a/pkg/scalers/cron_scaler_test.go +++ b/pkg/scalers/cron_scaler_test.go @@ -120,7 +120,7 @@ func TestCronGetMetricSpecForScaling(t *testing.T) { } mockCronScaler := cronScaler{meta} - metricSpec := mockCronScaler.GetMetricSpecForScaling() + metricSpec := mockCronScaler.GetMetricSpecForScaling(context.Background()) metricName := metricSpec[0].External.Metric.Name if metricName != testData.name { t.Error("Wrong External metric source name:", metricName) diff --git a/pkg/scalers/external_scaler.go b/pkg/scalers/external_scaler.go index 914a18906fb..e25a03358d2 100644 --- a/pkg/scalers/external_scaler.go +++ b/pkg/scalers/external_scaler.go @@ -133,12 +133,12 @@ func (s *externalScaler) IsActive(ctx context.Context) (bool, error) { return response.Result, nil } -func (s *externalScaler) Close() error { +func (s *externalScaler) Close(context.Context) error { return nil } // GetMetricSpecForScaling returns the metric spec for the HPA -func (s *externalScaler) GetMetricSpecForScaling() []v2beta2.MetricSpec { +func (s *externalScaler) GetMetricSpecForScaling(context.Context) []v2beta2.MetricSpec { var result []v2beta2.MetricSpec grpcClient, done, err := getClientForConnectionPool(s.metadata) diff --git a/pkg/scalers/gcp_pub_sub_scaler.go b/pkg/scalers/gcp_pub_sub_scaler.go index 2beb2baaccc..00c88290ad9 100644 --- a/pkg/scalers/gcp_pub_sub_scaler.go +++ b/pkg/scalers/gcp_pub_sub_scaler.go @@ -95,7 +95,7 @@ func (s *pubsubScaler) IsActive(ctx context.Context) (bool, error) { return size > 0, nil } -func (s *pubsubScaler) Close() error { +func (s *pubsubScaler) Close(context.Context) error { if s.client != nil { err := s.client.metricsClient.Close() s.client = nil @@ -108,7 +108,7 @@ func (s *pubsubScaler) Close() error { } // GetMetricSpecForScaling returns the metric spec for the HPA -func (s *pubsubScaler) GetMetricSpecForScaling() []v2beta2.MetricSpec { +func (s *pubsubScaler) GetMetricSpecForScaling(context.Context) []v2beta2.MetricSpec { // Construct the target subscription size as a quantity targetSubscriptionSizeQty := resource.NewQuantity(int64(s.metadata.targetSubscriptionSize), resource.DecimalSI) diff --git a/pkg/scalers/gcp_pubsub_scaler_test.go b/pkg/scalers/gcp_pubsub_scaler_test.go index 655e0c869cb..b3d50e49f77 100644 --- a/pkg/scalers/gcp_pubsub_scaler_test.go +++ b/pkg/scalers/gcp_pubsub_scaler_test.go @@ -1,6 +1,7 @@ package scalers import ( + "context" "testing" ) @@ -61,7 +62,7 @@ func TestGcpPubSubGetMetricSpecForScaling(t *testing.T) { } mockGcpPubSubScaler := pubsubScaler{nil, meta} - metricSpec := mockGcpPubSubScaler.GetMetricSpecForScaling() + metricSpec := mockGcpPubSubScaler.GetMetricSpecForScaling(context.Background()) metricName := metricSpec[0].External.Metric.Name if metricName != testData.name { t.Error("Wrong External metric source name:", metricName) diff --git a/pkg/scalers/graphite_scaler.go b/pkg/scalers/graphite_scaler.go index beb82d35fa5..31b6f04425e 100644 --- a/pkg/scalers/graphite_scaler.go +++ b/pkg/scalers/graphite_scaler.go @@ -130,7 +130,7 @@ func parseGraphiteMetadata(config *ScalerConfig) (*graphiteMetadata, error) { } func (s *graphiteScaler) IsActive(ctx context.Context) (bool, error) { - val, err := s.ExecuteGrapQuery() + val, err := s.ExecuteGrapQuery(ctx) if err != nil { graphiteLog.Error(err, "error executing graphite query") return false, err @@ -139,11 +139,11 @@ func (s *graphiteScaler) IsActive(ctx context.Context) (bool, error) { return val > 0, nil } -func (s *graphiteScaler) Close() error { +func (s *graphiteScaler) Close(context.Context) error { return nil } -func (s *graphiteScaler) GetMetricSpecForScaling() []v2beta2.MetricSpec { +func (s *graphiteScaler) GetMetricSpecForScaling(context.Context) []v2beta2.MetricSpec { targetMetricValue := resource.NewQuantity(int64(s.metadata.threshold), resource.DecimalSI) metricName := kedautil.NormalizeString(fmt.Sprintf("%s-%s", "graphite", s.metadata.metricName)) externalMetric := &v2beta2.ExternalMetricSource{ @@ -161,10 +161,10 @@ func (s *graphiteScaler) GetMetricSpecForScaling() []v2beta2.MetricSpec { return []v2beta2.MetricSpec{metricSpec} } -func (s *graphiteScaler) ExecuteGrapQuery() (float64, error) { +func (s *graphiteScaler) ExecuteGrapQuery(ctx context.Context) (float64, error) { queryEscaped := url_pkg.QueryEscape(s.metadata.query) url := fmt.Sprintf("%s/render?from=%s&target=%s&format=json", s.metadata.serverAddress, s.metadata.from, queryEscaped) - req, err := http.NewRequest("GET", url, nil) + req, err := http.NewRequestWithContext(ctx, "GET", url, nil) if err != nil { return -1, err } @@ -201,7 +201,7 @@ func (s *graphiteScaler) ExecuteGrapQuery() (float64, error) { } func (s *graphiteScaler) GetMetrics(ctx context.Context, metricName string, metricSelector labels.Selector) ([]external_metrics.ExternalMetricValue, error) { - val, err := s.ExecuteGrapQuery() + val, err := s.ExecuteGrapQuery(ctx) if err != nil { graphiteLog.Error(err, "error executing graphite query") return []external_metrics.ExternalMetricValue{}, err diff --git a/pkg/scalers/graphite_scaler_test.go b/pkg/scalers/graphite_scaler_test.go index 8f903d3edd1..6e22230a217 100644 --- a/pkg/scalers/graphite_scaler_test.go +++ b/pkg/scalers/graphite_scaler_test.go @@ -1,6 +1,7 @@ package scalers import ( + "context" "strings" "testing" ) @@ -66,6 +67,7 @@ func TestGraphiteParseMetadata(t *testing.T) { func TestGraphiteGetMetricSpecForScaling(t *testing.T) { for _, testData := range graphiteMetricIdentifiers { + ctx := context.Background() meta, err := parseGraphiteMetadata(&ScalerConfig{TriggerMetadata: testData.metadataTestData.metadata, ScalerIndex: testData.scalerIndex}) if err != nil { t.Fatal("Could not parse metadata:", err) @@ -74,7 +76,7 @@ func TestGraphiteGetMetricSpecForScaling(t *testing.T) { metadata: meta, } - metricSpec := mockGraphiteScaler.GetMetricSpecForScaling() + metricSpec := mockGraphiteScaler.GetMetricSpecForScaling(ctx) metricName := metricSpec[0].External.Metric.Name if metricName != testData.name { t.Error("Wrong External metric source name:", metricName) diff --git a/pkg/scalers/huawei_cloudeye_scaler.go b/pkg/scalers/huawei_cloudeye_scaler.go index a91af257d6e..ab54df36e98 100644 --- a/pkg/scalers/huawei_cloudeye_scaler.go +++ b/pkg/scalers/huawei_cloudeye_scaler.go @@ -241,7 +241,7 @@ func (h *huaweiCloudeyeScaler) GetMetrics(ctx context.Context, metricName string return append([]external_metrics.ExternalMetricValue{}, metric), nil } -func (h *huaweiCloudeyeScaler) GetMetricSpecForScaling() []v2beta2.MetricSpec { +func (h *huaweiCloudeyeScaler) GetMetricSpecForScaling(context.Context) []v2beta2.MetricSpec { targetMetricValue := resource.NewQuantity(int64(h.metadata.targetMetricValue), resource.DecimalSI) externalMetric := &v2beta2.ExternalMetricSource{ Metric: v2beta2.MetricIdentifier{ @@ -269,7 +269,7 @@ func (h *huaweiCloudeyeScaler) IsActive(ctx context.Context) (bool, error) { return val > h.metadata.minMetricValue, nil } -func (h *huaweiCloudeyeScaler) Close() error { +func (h *huaweiCloudeyeScaler) Close(context.Context) error { return nil } diff --git a/pkg/scalers/huawei_cloudeye_test.go b/pkg/scalers/huawei_cloudeye_test.go index 89b12ee0443..6fda5e4c0ab 100644 --- a/pkg/scalers/huawei_cloudeye_test.go +++ b/pkg/scalers/huawei_cloudeye_test.go @@ -1,6 +1,7 @@ package scalers import ( + "context" "testing" ) @@ -165,7 +166,7 @@ func TestHuaweiCloudeyeGetMetricSpecForScaling(t *testing.T) { } mockHuaweiCloudeyeScaler := huaweiCloudeyeScaler{meta} - metricSpec := mockHuaweiCloudeyeScaler.GetMetricSpecForScaling() + metricSpec := mockHuaweiCloudeyeScaler.GetMetricSpecForScaling(context.Background()) metricName := metricSpec[0].External.Metric.Name if metricName != testData.name { t.Error("Wrong External metric source name:", metricName, "wanted:", testData.name) diff --git a/pkg/scalers/ibmmq_scaler.go b/pkg/scalers/ibmmq_scaler.go index 57c445bbaa7..a45c462fc35 100644 --- a/pkg/scalers/ibmmq_scaler.go +++ b/pkg/scalers/ibmmq_scaler.go @@ -75,7 +75,7 @@ func NewIBMMQScaler(config *ScalerConfig) (Scaler, error) { } // Close closes and returns nil -func (s *IBMMQScaler) Close() error { +func (s *IBMMQScaler) Close(context.Context) error { return nil } @@ -150,7 +150,7 @@ func parseIBMMQMetadata(config *ScalerConfig) (*IBMMQMetadata, error) { // IsActive returns true if there are messages to be processed/if we need to scale from zero func (s *IBMMQScaler) IsActive(ctx context.Context) (bool, error) { - queueDepth, err := s.getQueueDepthViaHTTP() + queueDepth, err := s.getQueueDepthViaHTTP(ctx) if err != nil { return false, fmt.Errorf("error inspecting IBM MQ queue depth: %s", err) } @@ -158,12 +158,12 @@ func (s *IBMMQScaler) IsActive(ctx context.Context) (bool, error) { } // getQueueDepthViaHTTP returns the depth of the MQ Queue from the Admin endpoint -func (s *IBMMQScaler) getQueueDepthViaHTTP() (int, error) { +func (s *IBMMQScaler) getQueueDepthViaHTTP(ctx context.Context) (int, error) { queue := s.metadata.queueName url := s.metadata.host var requestJSON = []byte(`{"type": "runCommandJSON", "command": "display", "qualifier": "qlocal", "name": "` + queue + `", "responseParameters" : ["CURDEPTH"]}`) - req, err := http.NewRequest("POST", url, bytes.NewBuffer(requestJSON)) + req, err := http.NewRequestWithContext(ctx, "POST", url, bytes.NewBuffer(requestJSON)) if err != nil { return 0, fmt.Errorf("failed to request queue depth: %s", err) } @@ -201,7 +201,7 @@ func (s *IBMMQScaler) getQueueDepthViaHTTP() (int, error) { } // GetMetricSpecForScaling returns the MetricSpec for the Horizontal Pod Autoscaler -func (s *IBMMQScaler) GetMetricSpecForScaling() []v2beta2.MetricSpec { +func (s *IBMMQScaler) GetMetricSpecForScaling(context.Context) []v2beta2.MetricSpec { targetQueueLengthQty := resource.NewQuantity(int64(s.metadata.targetQueueDepth), resource.DecimalSI) externalMetric := &v2beta2.ExternalMetricSource{ Metric: v2beta2.MetricIdentifier{ @@ -218,7 +218,7 @@ func (s *IBMMQScaler) GetMetricSpecForScaling() []v2beta2.MetricSpec { // GetMetrics returns value for a supported metric and an error if there is a problem getting the metric func (s *IBMMQScaler) GetMetrics(ctx context.Context, metricName string, metricSelector labels.Selector) ([]external_metrics.ExternalMetricValue, error) { - queueDepth, err := s.getQueueDepthViaHTTP() + queueDepth, err := s.getQueueDepthViaHTTP(ctx) if err != nil { return []external_metrics.ExternalMetricValue{}, fmt.Errorf("error inspecting IBM MQ queue depth: %s", err) } diff --git a/pkg/scalers/ibmmq_scaler_test.go b/pkg/scalers/ibmmq_scaler_test.go index 13ec97c4c47..46ffd97a422 100644 --- a/pkg/scalers/ibmmq_scaler_test.go +++ b/pkg/scalers/ibmmq_scaler_test.go @@ -1,6 +1,7 @@ package scalers import ( + "context" "fmt" "testing" "time" @@ -115,7 +116,7 @@ func TestIBMMQGetMetricSpecForScaling(t *testing.T) { metadata: metadata, defaultHTTPTimeout: httpTimeout, } - metricSpec := mockIBMMQScaler.GetMetricSpecForScaling() + metricSpec := mockIBMMQScaler.GetMetricSpecForScaling(context.Background()) metricName := metricSpec[0].External.Metric.Name if metricName != testData.name { diff --git a/pkg/scalers/influxdb_scaler.go b/pkg/scalers/influxdb_scaler.go index 1661e011d2e..f050ac39438 100644 --- a/pkg/scalers/influxdb_scaler.go +++ b/pkg/scalers/influxdb_scaler.go @@ -162,7 +162,7 @@ func parseInfluxDBMetadata(config *ScalerConfig) (*influxDBMetadata, error) { func (s *influxDBScaler) IsActive(ctx context.Context) (bool, error) { queryAPI := s.client.QueryAPI(s.metadata.organizationName) - value, err := queryInfluxDB(queryAPI, s.metadata.query) + value, err := queryInfluxDB(ctx, queryAPI, s.metadata.query) if err != nil { return false, err } @@ -171,7 +171,7 @@ func (s *influxDBScaler) IsActive(ctx context.Context) (bool, error) { } // Close closes the connection of the client to the server -func (s *influxDBScaler) Close() error { +func (s *influxDBScaler) Close(context.Context) error { s.client.Close() return nil } @@ -179,8 +179,8 @@ func (s *influxDBScaler) Close() error { // queryInfluxDB runs the query against the associated influxdb database // there is an implicit assumption here that the first value returned from the iterator // will be the value of interest -func queryInfluxDB(queryAPI api.QueryAPI, query string) (float64, error) { - result, err := queryAPI.Query(context.Background(), query) +func queryInfluxDB(ctx context.Context, queryAPI api.QueryAPI, query string) (float64, error) { + result, err := queryAPI.Query(ctx, query) if err != nil { return 0, err } @@ -205,7 +205,7 @@ func (s *influxDBScaler) GetMetrics(ctx context.Context, metricName string, metr // Grab QueryAPI to make queries to influxdb instance queryAPI := s.client.QueryAPI(s.metadata.organizationName) - value, err := queryInfluxDB(queryAPI, s.metadata.query) + value, err := queryInfluxDB(ctx, queryAPI, s.metadata.query) if err != nil { return []external_metrics.ExternalMetricValue{}, err } @@ -220,7 +220,7 @@ func (s *influxDBScaler) GetMetrics(ctx context.Context, metricName string, metr } // GetMetricSpecForScaling returns the metric spec for the Horizontal Pod Autoscaler -func (s *influxDBScaler) GetMetricSpecForScaling() []v2beta2.MetricSpec { +func (s *influxDBScaler) GetMetricSpecForScaling(context.Context) []v2beta2.MetricSpec { targetMetricValue := resource.NewQuantity(int64(s.metadata.thresholdValue), resource.DecimalSI) externalMetric := &v2beta2.ExternalMetricSource{ Metric: v2beta2.MetricIdentifier{ diff --git a/pkg/scalers/influxdb_scaler_test.go b/pkg/scalers/influxdb_scaler_test.go index d528b7e19ea..42350d6b691 100644 --- a/pkg/scalers/influxdb_scaler_test.go +++ b/pkg/scalers/influxdb_scaler_test.go @@ -1,6 +1,7 @@ package scalers import ( + "context" "testing" influxdb2 "github.com/influxdata/influxdb-client-go/v2" @@ -73,7 +74,7 @@ func TestInfluxDBGetMetricSpecForScaling(t *testing.T) { } mockInfluxDBScaler := influxDBScaler{influxdb2.NewClient("https://influxdata.com", "myToken"), meta} - metricSpec := mockInfluxDBScaler.GetMetricSpecForScaling() + metricSpec := mockInfluxDBScaler.GetMetricSpecForScaling(context.Background()) metricName := metricSpec[0].External.Metric.Name if metricName != testData.name { t.Errorf("Wrong External metric source name: %s, expected: %s", metricName, testData.name) diff --git a/pkg/scalers/kafka_scaler.go b/pkg/scalers/kafka_scaler.go index b5186e95c24..950c61d90e6 100644 --- a/pkg/scalers/kafka_scaler.go +++ b/pkg/scalers/kafka_scaler.go @@ -339,7 +339,7 @@ func (s *kafkaScaler) getLagForPartition(partition int32, offsets *sarama.Offset } // Close closes the kafka admin and client -func (s *kafkaScaler) Close() error { +func (s *kafkaScaler) Close(context.Context) error { // underlying client will also be closed on admin's Close() call err := s.admin.Close() if err != nil { @@ -349,7 +349,7 @@ func (s *kafkaScaler) Close() error { return nil } -func (s *kafkaScaler) GetMetricSpecForScaling() []v2beta2.MetricSpec { +func (s *kafkaScaler) GetMetricSpecForScaling(context.Context) []v2beta2.MetricSpec { targetMetricValue := resource.NewQuantity(s.metadata.lagThreshold, resource.DecimalSI) externalMetric := &v2beta2.ExternalMetricSource{ Metric: v2beta2.MetricIdentifier{ diff --git a/pkg/scalers/kafka_scaler_test.go b/pkg/scalers/kafka_scaler_test.go index 8c19bf26500..58a5b940c28 100644 --- a/pkg/scalers/kafka_scaler_test.go +++ b/pkg/scalers/kafka_scaler_test.go @@ -1,6 +1,7 @@ package scalers import ( + "context" "reflect" "testing" ) @@ -197,7 +198,7 @@ func TestKafkaGetMetricSpecForScaling(t *testing.T) { } mockKafkaScaler := kafkaScaler{meta, nil, nil} - metricSpec := mockKafkaScaler.GetMetricSpecForScaling() + metricSpec := mockKafkaScaler.GetMetricSpecForScaling(context.Background()) metricName := metricSpec[0].External.Metric.Name if metricName != testData.name { t.Error("Wrong External metric source name:", metricName) diff --git a/pkg/scalers/kubernetes_workload_scaler.go b/pkg/scalers/kubernetes_workload_scaler.go index 513177e8bcd..79afcfe3126 100644 --- a/pkg/scalers/kubernetes_workload_scaler.go +++ b/pkg/scalers/kubernetes_workload_scaler.go @@ -76,12 +76,12 @@ func (s *kubernetesWorkloadScaler) IsActive(ctx context.Context) (bool, error) { } // Close no need for kubernetes workload scaler -func (s *kubernetesWorkloadScaler) Close() error { +func (s *kubernetesWorkloadScaler) Close(context.Context) error { return nil } // GetMetricSpecForScaling returns the metric spec for the HPA -func (s *kubernetesWorkloadScaler) GetMetricSpecForScaling() []v2beta2.MetricSpec { +func (s *kubernetesWorkloadScaler) GetMetricSpecForScaling(context.Context) []v2beta2.MetricSpec { targetMetricValue := resource.NewQuantity(s.metadata.value, resource.DecimalSI) externalMetric := &v2beta2.ExternalMetricSource{ Metric: v2beta2.MetricIdentifier{ diff --git a/pkg/scalers/kubernetes_workload_scaler_test.go b/pkg/scalers/kubernetes_workload_scaler_test.go index b7c20194b1d..3a6f368b88d 100644 --- a/pkg/scalers/kubernetes_workload_scaler_test.go +++ b/pkg/scalers/kubernetes_workload_scaler_test.go @@ -113,7 +113,7 @@ func TestWorkloadGetMetricSpecForScaling(t *testing.T) { ScalerIndex: testData.scalerIndex, }, ) - metric := s.GetMetricSpecForScaling() + metric := s.GetMetricSpecForScaling(context.Background()) if metric[0].External.Metric.Name != testData.name { t.Errorf("Expected '%s' as metric name and got '%s'", testData.name, metric[0].External.Metric.Name) diff --git a/pkg/scalers/liiklus_scaler.go b/pkg/scalers/liiklus_scaler.go index 0743102daa7..4a7b7fc8195 100644 --- a/pkg/scalers/liiklus_scaler.go +++ b/pkg/scalers/liiklus_scaler.go @@ -82,7 +82,7 @@ func (s *liiklusScaler) GetMetrics(ctx context.Context, metricName string, metri }, nil } -func (s *liiklusScaler) GetMetricSpecForScaling() []v2beta2.MetricSpec { +func (s *liiklusScaler) GetMetricSpecForScaling(context.Context) []v2beta2.MetricSpec { targetMetricValue := resource.NewQuantity(s.metadata.lagThreshold, resource.DecimalSI) externalMetric := &v2beta2.ExternalMetricSource{ Metric: v2beta2.MetricIdentifier{ @@ -97,7 +97,7 @@ func (s *liiklusScaler) GetMetricSpecForScaling() []v2beta2.MetricSpec { return []v2beta2.MetricSpec{metricSpec} } -func (s *liiklusScaler) Close() error { +func (s *liiklusScaler) Close(context.Context) error { err := s.connection.Close() if err != nil { return err @@ -129,7 +129,7 @@ func (s *liiklusScaler) getLag(ctx context.Context) (uint64, map[uint32]uint64, return 0, nil, err } - ctx2, cancel2 := context.WithTimeout(context.Background(), 10*time.Second) + ctx2, cancel2 := context.WithTimeout(ctx, 10*time.Second) defer cancel2() geor, err := s.client.GetEndOffsets(ctx2, &liiklus_service.GetEndOffsetsRequest{ Topic: s.metadata.topic, diff --git a/pkg/scalers/liiklus_scaler_test.go b/pkg/scalers/liiklus_scaler_test.go index 95454485807..20271b6f242 100644 --- a/pkg/scalers/liiklus_scaler_test.go +++ b/pkg/scalers/liiklus_scaler_test.go @@ -173,7 +173,7 @@ func TestLiiklusGetMetricSpecForScaling(t *testing.T) { } mockLiiklusScaler := liiklusScaler{meta, nil, nil} - metricSpec := mockLiiklusScaler.GetMetricSpecForScaling() + metricSpec := mockLiiklusScaler.GetMetricSpecForScaling(context.Background()) metricName := metricSpec[0].External.Metric.Name if metricName != testData.name { t.Error("Wrong External metric source name:", metricName) diff --git a/pkg/scalers/metrics_api_scaler.go b/pkg/scalers/metrics_api_scaler.go index a04231f2d8a..a05360be0c7 100644 --- a/pkg/scalers/metrics_api_scaler.go +++ b/pkg/scalers/metrics_api_scaler.go @@ -200,8 +200,8 @@ func GetValueFromResponse(body []byte, valueLocation string) (*resource.Quantity return resource.NewQuantity(int64(r.Num), resource.DecimalSI), nil } -func (s *metricsAPIScaler) getMetricValue() (*resource.Quantity, error) { - request, err := getMetricAPIServerRequest(s.metadata) +func (s *metricsAPIScaler) getMetricValue(ctx context.Context) (*resource.Quantity, error) { + request, err := getMetricAPIServerRequest(ctx, s.metadata) if err != nil { return nil, err } @@ -229,13 +229,13 @@ func (s *metricsAPIScaler) getMetricValue() (*resource.Quantity, error) { } // Close does nothing in case of metricsAPIScaler -func (s *metricsAPIScaler) Close() error { +func (s *metricsAPIScaler) Close(context.Context) error { return nil } // IsActive returns true if there are pending messages to be processed func (s *metricsAPIScaler) IsActive(ctx context.Context) (bool, error) { - v, err := s.getMetricValue() + v, err := s.getMetricValue(ctx) if err != nil { httpLog.Error(err, fmt.Sprintf("Error when checking metric value: %s", err)) return false, err @@ -245,7 +245,7 @@ func (s *metricsAPIScaler) IsActive(ctx context.Context) (bool, error) { } // GetMetricSpecForScaling returns the MetricSpec for the Horizontal Pod Autoscaler -func (s *metricsAPIScaler) GetMetricSpecForScaling() []v2beta2.MetricSpec { +func (s *metricsAPIScaler) GetMetricSpecForScaling(context.Context) []v2beta2.MetricSpec { targetValue := resource.NewQuantity(int64(s.metadata.targetValue), resource.DecimalSI) metricName := kedautil.NormalizeString(fmt.Sprintf("%s-%s-%s", "http", s.metadata.url, s.metadata.valueLocation)) externalMetric := &v2beta2.ExternalMetricSource{ @@ -265,7 +265,7 @@ func (s *metricsAPIScaler) GetMetricSpecForScaling() []v2beta2.MetricSpec { // GetMetrics returns value for a supported metric and an error if there is a problem getting the metric func (s *metricsAPIScaler) GetMetrics(ctx context.Context, metricName string, metricSelector labels.Selector) ([]external_metrics.ExternalMetricValue, error) { - v, err := s.getMetricValue() + v, err := s.getMetricValue(ctx) if err != nil { return []external_metrics.ExternalMetricValue{}, fmt.Errorf("error requesting metrics endpoint: %s", err) } @@ -279,7 +279,7 @@ func (s *metricsAPIScaler) GetMetrics(ctx context.Context, metricName string, me return append([]external_metrics.ExternalMetricValue{}, metric), nil } -func getMetricAPIServerRequest(meta *metricsAPIScalerMetadata) (*http.Request, error) { +func getMetricAPIServerRequest(ctx context.Context, meta *metricsAPIScalerMetadata) (*http.Request, error) { var req *http.Request var err error @@ -295,13 +295,13 @@ func getMetricAPIServerRequest(meta *metricsAPIScalerMetadata) (*http.Request, e } url.RawQuery = queryString.Encode() - req, err = http.NewRequest("GET", url.String(), nil) + req, err = http.NewRequestWithContext(ctx, "GET", url.String(), nil) if err != nil { return nil, err } } else { // default behaviour is to use header method - req, err = http.NewRequest("GET", meta.url, nil) + req, err = http.NewRequestWithContext(ctx, "GET", meta.url, nil) if err != nil { return nil, err } @@ -313,20 +313,20 @@ func getMetricAPIServerRequest(meta *metricsAPIScalerMetadata) (*http.Request, e } } case meta.enableBaseAuth: - req, err = http.NewRequest("GET", meta.url, nil) + req, err = http.NewRequestWithContext(ctx, "GET", meta.url, nil) if err != nil { return nil, err } req.SetBasicAuth(meta.username, meta.password) case meta.enableBearerAuth: - req, err = http.NewRequest("GET", meta.url, nil) + req, err = http.NewRequestWithContext(ctx, "GET", meta.url, nil) if err != nil { return nil, err } req.Header.Add("Authorization", fmt.Sprintf("Bearer %s", meta.bearerToken)) default: - req, err = http.NewRequest("GET", meta.url, nil) + req, err = http.NewRequestWithContext(ctx, "GET", meta.url, nil) if err != nil { return nil, err } diff --git a/pkg/scalers/mongo_scaler.go b/pkg/scalers/mongo_scaler.go index 3066515c56b..6387f6d6f83 100644 --- a/pkg/scalers/mongo_scaler.go +++ b/pkg/scalers/mongo_scaler.go @@ -75,8 +75,8 @@ const ( var mongoDBLog = logf.Log.WithName("mongodb_scaler") // NewMongoDBScaler creates a new mongoDB scaler -func NewMongoDBScaler(config *ScalerConfig) (Scaler, error) { - ctx, cancel := context.WithTimeout(context.Background(), mongoDBDefaultTimeOut) +func NewMongoDBScaler(ctx context.Context, config *ScalerConfig) (Scaler, error) { + ctx, cancel := context.WithTimeout(ctx, mongoDBDefaultTimeOut) defer cancel() meta, connStr, err := parseMongoDBMetadata(config) @@ -190,7 +190,7 @@ func parseMongoDBMetadata(config *ScalerConfig) (*mongoDBMetadata, string, error } func (s *mongoDBScaler) IsActive(ctx context.Context) (bool, error) { - result, err := s.getQueryResult() + result, err := s.getQueryResult(ctx) if err != nil { mongoDBLog.Error(err, fmt.Sprintf("failed to get query result by mongoDB, because of %v", err)) return false, err @@ -199,7 +199,7 @@ func (s *mongoDBScaler) IsActive(ctx context.Context) (bool, error) { } // Close disposes of mongoDB connections -func (s *mongoDBScaler) Close() error { +func (s *mongoDBScaler) Close(context.Context) error { if s.client != nil { err := s.client.Disconnect(context.TODO()) if err != nil { @@ -212,8 +212,8 @@ func (s *mongoDBScaler) Close() error { } // getQueryResult query mongoDB by meta.query -func (s *mongoDBScaler) getQueryResult() (int, error) { - ctx, cancel := context.WithTimeout(context.Background(), mongoDBDefaultTimeOut) +func (s *mongoDBScaler) getQueryResult(ctx context.Context) (int, error) { + ctx, cancel := context.WithTimeout(ctx, mongoDBDefaultTimeOut) defer cancel() filter, err := json2BsonDoc(s.metadata.query) @@ -233,7 +233,7 @@ func (s *mongoDBScaler) getQueryResult() (int, error) { // GetMetrics query from mongoDB,and return to external metrics func (s *mongoDBScaler) GetMetrics(ctx context.Context, metricName string, metricSelector labels.Selector) ([]external_metrics.ExternalMetricValue, error) { - num, err := s.getQueryResult() + num, err := s.getQueryResult(ctx) if err != nil { return []external_metrics.ExternalMetricValue{}, fmt.Errorf("failed to inspect momgoDB, because of %v", err) } @@ -248,7 +248,7 @@ func (s *mongoDBScaler) GetMetrics(ctx context.Context, metricName string, metri } // GetMetricSpecForScaling get the query value for scaling -func (s *mongoDBScaler) GetMetricSpecForScaling() []v2beta2.MetricSpec { +func (s *mongoDBScaler) GetMetricSpecForScaling(context.Context) []v2beta2.MetricSpec { targetQueryValue := resource.NewQuantity(int64(s.metadata.queryValue), resource.DecimalSI) externalMetric := &v2beta2.ExternalMetricSource{ diff --git a/pkg/scalers/mongo_scaler_test.go b/pkg/scalers/mongo_scaler_test.go index 10a481eded1..06447349129 100644 --- a/pkg/scalers/mongo_scaler_test.go +++ b/pkg/scalers/mongo_scaler_test.go @@ -1,6 +1,7 @@ package scalers import ( + "context" "testing" "go.mongodb.org/mongo-driver/mongo" @@ -80,7 +81,7 @@ func TestMongoDBGetMetricSpecForScaling(t *testing.T) { } mockMongoDBScaler := mongoDBScaler{meta, &mongo.Client{}} - metricSpec := mockMongoDBScaler.GetMetricSpecForScaling() + metricSpec := mockMongoDBScaler.GetMetricSpecForScaling(context.Background()) metricName := metricSpec[0].External.Metric.Name if metricName != testData.name { t.Error("Wrong External metric source name:", metricName) diff --git a/pkg/scalers/mssql_scaler.go b/pkg/scalers/mssql_scaler.go index 7167d7f1327..7aa8cca0ed9 100644 --- a/pkg/scalers/mssql_scaler.go +++ b/pkg/scalers/mssql_scaler.go @@ -215,7 +215,7 @@ func getMSSQLConnectionString(meta *mssqlMetadata) string { } // GetMetricSpecForScaling returns the MetricSpec for the Horizontal Pod Autoscaler -func (s *mssqlScaler) GetMetricSpecForScaling() []v2beta2.MetricSpec { +func (s *mssqlScaler) GetMetricSpecForScaling(context.Context) []v2beta2.MetricSpec { targetQueryValue := resource.NewQuantity(int64(s.metadata.targetValue), resource.DecimalSI) externalMetric := &v2beta2.ExternalMetricSource{ Metric: v2beta2.MetricIdentifier{ @@ -236,7 +236,7 @@ func (s *mssqlScaler) GetMetricSpecForScaling() []v2beta2.MetricSpec { // GetMetrics returns a value for a supported metric or an error if there is a problem getting the metric func (s *mssqlScaler) GetMetrics(ctx context.Context, metricName string, metricSelector labels.Selector) ([]external_metrics.ExternalMetricValue, error) { - num, err := s.getQueryResult() + num, err := s.getQueryResult(ctx) if err != nil { return []external_metrics.ExternalMetricValue{}, fmt.Errorf("error inspecting mssql: %s", err) } @@ -251,9 +251,9 @@ func (s *mssqlScaler) GetMetrics(ctx context.Context, metricName string, metricS } // getQueryResult returns the result of the scaler query -func (s *mssqlScaler) getQueryResult() (int, error) { +func (s *mssqlScaler) getQueryResult(ctx context.Context) (int, error) { var value int - err := s.connection.QueryRow(s.metadata.query).Scan(&value) + err := s.connection.QueryRowContext(ctx, s.metadata.query).Scan(&value) switch { case err == sql.ErrNoRows: value = 0 @@ -267,7 +267,7 @@ func (s *mssqlScaler) getQueryResult() (int, error) { // IsActive returns true if there are pending events to be processed func (s *mssqlScaler) IsActive(ctx context.Context) (bool, error) { - messages, err := s.getQueryResult() + messages, err := s.getQueryResult(ctx) if err != nil { return false, fmt.Errorf("error inspecting mssql: %s", err) } @@ -276,7 +276,7 @@ func (s *mssqlScaler) IsActive(ctx context.Context) (bool, error) { } // Close closes the mssql database connections -func (s *mssqlScaler) Close() error { +func (s *mssqlScaler) Close(context.Context) error { err := s.connection.Close() if err != nil { mssqlLog.Error(err, "Error closing mssql connection") diff --git a/pkg/scalers/mysql_scaler.go b/pkg/scalers/mysql_scaler.go index 149f0d9de67..0bd04435b8f 100644 --- a/pkg/scalers/mysql_scaler.go +++ b/pkg/scalers/mysql_scaler.go @@ -167,7 +167,7 @@ func parseMySQLDbNameFromConnectionStr(connectionString string) string { } // Close disposes of MySQL connections -func (s *mySQLScaler) Close() error { +func (s *mySQLScaler) Close(context.Context) error { err := s.connection.Close() if err != nil { mySQLLog.Error(err, "Error closing MySQL connection") @@ -178,7 +178,7 @@ func (s *mySQLScaler) Close() error { // IsActive returns true if there are pending messages to be processed func (s *mySQLScaler) IsActive(ctx context.Context) (bool, error) { - messages, err := s.getQueryResult() + messages, err := s.getQueryResult(ctx) if err != nil { mySQLLog.Error(err, fmt.Sprintf("Error inspecting MySQL: %s", err)) return false, err @@ -187,9 +187,9 @@ func (s *mySQLScaler) IsActive(ctx context.Context) (bool, error) { } // getQueryResult returns result of the scaler query -func (s *mySQLScaler) getQueryResult() (int, error) { +func (s *mySQLScaler) getQueryResult(ctx context.Context) (int, error) { var value int - err := s.connection.QueryRow(s.metadata.query).Scan(&value) + err := s.connection.QueryRowContext(ctx, s.metadata.query).Scan(&value) if err != nil { mySQLLog.Error(err, fmt.Sprintf("Could not query MySQL database: %s", err)) return 0, err @@ -198,7 +198,7 @@ func (s *mySQLScaler) getQueryResult() (int, error) { } // GetMetricSpecForScaling returns the MetricSpec for the Horizontal Pod Autoscaler -func (s *mySQLScaler) GetMetricSpecForScaling() []v2beta2.MetricSpec { +func (s *mySQLScaler) GetMetricSpecForScaling(context.Context) []v2beta2.MetricSpec { targetQueryValue := resource.NewQuantity(int64(s.metadata.queryValue), resource.DecimalSI) externalMetric := &v2beta2.ExternalMetricSource{ @@ -218,7 +218,7 @@ func (s *mySQLScaler) GetMetricSpecForScaling() []v2beta2.MetricSpec { // GetMetrics returns value for a supported metric and an error if there is a problem getting the metric func (s *mySQLScaler) GetMetrics(ctx context.Context, metricName string, metricSelector labels.Selector) ([]external_metrics.ExternalMetricValue, error) { - num, err := s.getQueryResult() + num, err := s.getQueryResult(ctx) if err != nil { return []external_metrics.ExternalMetricValue{}, fmt.Errorf("error inspecting MySQL: %s", err) } diff --git a/pkg/scalers/openstack/keystone_authentication.go b/pkg/scalers/openstack/keystone_authentication.go index 94956a93182..c613b28d92f 100644 --- a/pkg/scalers/openstack/keystone_authentication.go +++ b/pkg/scalers/openstack/keystone_authentication.go @@ -2,6 +2,7 @@ package openstack import ( "bytes" + "context" "encoding/json" "fmt" "io/ioutil" @@ -100,7 +101,7 @@ type endpoint struct { } // IsTokenValid checks if a authentication token is valid -func (client *Client) IsTokenValid() (bool, error) { +func (client *Client) IsTokenValid(ctx context.Context) (bool, error) { var token = client.Token if token == "" { @@ -115,7 +116,7 @@ func (client *Client) IsTokenValid() (bool, error) { tokenURL.Path = path.Join(tokenURL.Path, tokensEndpoint) - checkTokenRequest, err := http.NewRequest("HEAD", tokenURL.String(), nil) + checkTokenRequest, err := http.NewRequestWithContext(ctx, "HEAD", tokenURL.String(), nil) checkTokenRequest.Header.Set("X-Subject-Token", token) checkTokenRequest.Header.Set("X-Auth-Token", token) @@ -139,8 +140,8 @@ func (client *Client) IsTokenValid() (bool, error) { } // RenewToken retrives another token from Keystone -func (client *Client) RenewToken() error { - token, err := client.authMetadata.getToken() +func (client *Client) RenewToken(ctx context.Context) error { + token, err := client.authMetadata.getToken(ctx) if err != nil { return err @@ -218,13 +219,13 @@ func NewAppCredentialsAuth(authURL string, id string, secret string, httpTimeout // If an OpenStack project name is provided as first parameter, it will try to retrieve its API URL using the current credentials. // If an OpenStack region or availability zone is provided as second parameter, it will retrieve the service API URL for that region. // Otherwise, if the service API URL was found, it retrieves the first public URL for that service. -func (keystone *KeystoneAuthRequest) RequestClient(projectProps ...string) (Client, error) { +func (keystone *KeystoneAuthRequest) RequestClient(ctx context.Context, projectProps ...string) (Client, error) { var client = Client{ HTTPClient: kedautil.CreateHTTPClient(keystone.HTTPClientTimeout), authMetadata: keystone, } - token, err := keystone.getToken() + token, err := keystone.getToken(ctx) if err != nil { return client, err @@ -236,9 +237,9 @@ func (keystone *KeystoneAuthRequest) RequestClient(projectProps ...string) (Clie switch len(projectProps) { case 2: - serviceURL, err = keystone.getServiceURL(token, projectProps[0], projectProps[1]) + serviceURL, err = keystone.getServiceURL(ctx, token, projectProps[0], projectProps[1]) case 1: - serviceURL, err = keystone.getServiceURL(token, projectProps[0], "") + serviceURL, err = keystone.getServiceURL(ctx, token, projectProps[0], "") default: serviceURL = "" } @@ -252,7 +253,7 @@ func (keystone *KeystoneAuthRequest) RequestClient(projectProps ...string) (Clie return client, nil } -func (keystone *KeystoneAuthRequest) getToken() (string, error) { +func (keystone *KeystoneAuthRequest) getToken(ctx context.Context) (string, error) { var httpClient = kedautil.CreateHTTPClient(keystone.HTTPClientTimeout) jsonBody, err := json.Marshal(keystone) @@ -271,7 +272,7 @@ func (keystone *KeystoneAuthRequest) getToken() (string, error) { tokenURL.Path = path.Join(tokenURL.Path, tokensEndpoint) - tokenRequest, err := http.NewRequest("POST", tokenURL.String(), jsonBodyReader) + tokenRequest, err := http.NewRequestWithContext(ctx, "POST", tokenURL.String(), jsonBodyReader) if err != nil { return "", err @@ -299,7 +300,7 @@ func (keystone *KeystoneAuthRequest) getToken() (string, error) { } // getCatalog retrives the OpenStack catalog according to the current authorization -func (keystone *KeystoneAuthRequest) getCatalog(token string) ([]service, error) { +func (keystone *KeystoneAuthRequest) getCatalog(ctx context.Context, token string) ([]service, error) { var httpClient = kedautil.CreateHTTPClient(keystone.HTTPClientTimeout) catalogURL, err := url.Parse(keystone.AuthURL) @@ -310,7 +311,7 @@ func (keystone *KeystoneAuthRequest) getCatalog(token string) ([]service, error) catalogURL.Path = path.Join(catalogURL.Path, catalogEndpoint) - getCatalog, err := http.NewRequest("GET", catalogURL.String(), nil) + getCatalog, err := http.NewRequestWithContext(ctx, "GET", catalogURL.String(), nil) if err != nil { return nil, err @@ -348,14 +349,14 @@ func (keystone *KeystoneAuthRequest) getCatalog(token string) ([]service, error) } // getServiceURL retrieves a public URL for an OpenStack project from the OpenStack catalog -func (keystone *KeystoneAuthRequest) getServiceURL(token string, projectName string, region string) (string, error) { - serviceTypes, err := openstackutil.GetServiceTypes(projectName) +func (keystone *KeystoneAuthRequest) getServiceURL(ctx context.Context, token string, projectName string, region string) (string, error) { + serviceTypes, err := openstackutil.GetServiceTypes(ctx, projectName) if err != nil { return "", err } - serviceCatalog, err := keystone.getCatalog(token) + serviceCatalog, err := keystone.getCatalog(ctx, token) if err != nil { return "", err diff --git a/pkg/scalers/openstack/utils/serviceTypes.go b/pkg/scalers/openstack/utils/serviceTypes.go index ba63f074791..27508ecb161 100644 --- a/pkg/scalers/openstack/utils/serviceTypes.go +++ b/pkg/scalers/openstack/utils/serviceTypes.go @@ -1,6 +1,7 @@ package utils import ( + "context" "encoding/json" "fmt" "net/http" @@ -33,14 +34,14 @@ type serviceMapping struct { } // GetServiceTypes retrieves all historical OpenStack Service Types for a given OpenStack project -func GetServiceTypes(projectName string) ([]string, error) { +func GetServiceTypes(ctx context.Context, projectName string) ([]string, error) { var serviceTypesRequest serviceTypesRequest var httpClient = kedautil.CreateHTTPClient(defaultHTTPClientTimeout * time.Second) var url = serviceTypesAuthorityEndpoint - getServiceTypes, err := http.NewRequest("GET", url, nil) + getServiceTypes, err := http.NewRequestWithContext(ctx, "GET", url, nil) if err != nil { return []string{}, err diff --git a/pkg/scalers/openstack_metrics_scaler.go b/pkg/scalers/openstack_metrics_scaler.go index 27f1a66e34a..cf012d41d9a 100644 --- a/pkg/scalers/openstack_metrics_scaler.go +++ b/pkg/scalers/openstack_metrics_scaler.go @@ -60,7 +60,7 @@ type measureResult struct { var openstackMetricLog = logf.Log.WithName("openstack_metric_scaler") // NewOpenstackMetricScaler creates new openstack metrics scaler instance -func NewOpenstackMetricScaler(config *ScalerConfig) (Scaler, error) { +func NewOpenstackMetricScaler(ctx context.Context, config *ScalerConfig) (Scaler, error) { var keystoneAuth *openstack.KeystoneAuthRequest var metricsClient openstack.Client @@ -96,7 +96,7 @@ func NewOpenstackMetricScaler(config *ScalerConfig) (Scaler, error) { } } - metricsClient, err = keystoneAuth.RequestClient() + metricsClient, err = keystoneAuth.RequestClient(ctx) if err != nil { openstackMetricLog.Error(err, "Fail to retrieve new keystone clinet for openstack metrics scaler") return nil, err @@ -195,7 +195,7 @@ func parseOpenstackMetricAuthenticationMetadata(config *ScalerConfig) (openstack return authMeta, nil } -func (a *openstackMetricScaler) GetMetricSpecForScaling() []v2beta2.MetricSpec { +func (a *openstackMetricScaler) GetMetricSpecForScaling(context.Context) []v2beta2.MetricSpec { targetMetricVal := resource.NewQuantity(int64(a.metadata.threshold), resource.DecimalSI) metricName := kedautil.NormalizeString(fmt.Sprintf("openstack-metric-%s-%s-%s", a.metadata.metricID, strconv.FormatFloat(a.metadata.threshold, 'f', 0, 32), a.metadata.aggregationMethod)) @@ -218,7 +218,7 @@ func (a *openstackMetricScaler) GetMetricSpecForScaling() []v2beta2.MetricSpec { } func (a *openstackMetricScaler) GetMetrics(ctx context.Context, metricName string, metricSelector labels.Selector) ([]external_metrics.ExternalMetricValue, error) { - val, err := a.readOpenstackMetrics() + val, err := a.readOpenstackMetrics(ctx) if err != nil { openstackMetricLog.Error(err, "Error collecting metric value") @@ -235,7 +235,7 @@ func (a *openstackMetricScaler) GetMetrics(ctx context.Context, metricName strin } func (a *openstackMetricScaler) IsActive(ctx context.Context) (bool, error) { - val, err := a.readOpenstackMetrics() + val, err := a.readOpenstackMetrics(ctx) if err != nil { return false, err @@ -244,15 +244,15 @@ func (a *openstackMetricScaler) IsActive(ctx context.Context) (bool, error) { return val > 0, nil } -func (a *openstackMetricScaler) Close() error { +func (a *openstackMetricScaler) Close(context.Context) error { return nil } // Gets measureament from API as float64, converts it to int and return the value. -func (a *openstackMetricScaler) readOpenstackMetrics() (float64, error) { +func (a *openstackMetricScaler) readOpenstackMetrics(ctx context.Context) (float64, error) { var metricURL = a.metadata.metricsURL - isValid, validationError := a.metricClient.IsTokenValid() + isValid, validationError := a.metricClient.IsTokenValid(ctx) if validationError != nil { openstackMetricLog.Error(validationError, "Unable to check token validity.") @@ -260,7 +260,7 @@ func (a *openstackMetricScaler) readOpenstackMetrics() (float64, error) { } if !isValid { - tokenRequestError := a.metricClient.RenewToken() + tokenRequestError := a.metricClient.RenewToken(ctx) if tokenRequestError != nil { openstackMetricLog.Error(tokenRequestError, "The token being used is invalid") return defaultValueWhenError, tokenRequestError @@ -304,7 +304,7 @@ func (a *openstackMetricScaler) readOpenstackMetrics() (float64, error) { openstackMetricsURL.RawQuery = queryParameter.Encode() - openstackMetricRequest, newReqErr := http.NewRequest("GET", openstackMetricsURL.String(), nil) + openstackMetricRequest, newReqErr := http.NewRequestWithContext(ctx, "GET", openstackMetricsURL.String(), nil) if newReqErr != nil { openstackMetricLog.Error(newReqErr, "Could not build metrics request", nil) } diff --git a/pkg/scalers/openstack_metrics_scaler_test.go b/pkg/scalers/openstack_metrics_scaler_test.go index 59ce658cd97..5ca8dd1a853 100644 --- a/pkg/scalers/openstack_metrics_scaler_test.go +++ b/pkg/scalers/openstack_metrics_scaler_test.go @@ -1,6 +1,7 @@ package scalers import ( + "context" "testing" "github.com/kedacore/keda/v2/pkg/scalers/openstack" @@ -114,7 +115,7 @@ func TestOpenstackMetricsGetMetricsForSpecScaling(t *testing.T) { } mockMetricsScaler := openstackMetricScaler{meta, openstack.Client{}} - metricsSpec := mockMetricsScaler.GetMetricSpecForScaling() + metricsSpec := mockMetricsScaler.GetMetricSpecForScaling(context.Background()) metricName := metricsSpec[0].External.Metric.Name if metricName != testData.name { diff --git a/pkg/scalers/openstack_swift_scaler.go b/pkg/scalers/openstack_swift_scaler.go index d1e965ccf2a..b691f279043 100644 --- a/pkg/scalers/openstack_swift_scaler.go +++ b/pkg/scalers/openstack_swift_scaler.go @@ -59,11 +59,11 @@ type openstackSwiftScaler struct { var openstackSwiftLog = logf.Log.WithName("openstack_swift_scaler") -func (s *openstackSwiftScaler) getOpenstackSwiftContainerObjectCount() (int, error) { +func (s *openstackSwiftScaler) getOpenstackSwiftContainerObjectCount(ctx context.Context) (int, error) { var containerName = s.metadata.containerName var swiftURL = s.metadata.swiftURL - isValid, err := s.swiftClient.IsTokenValid() + isValid, err := s.swiftClient.IsTokenValid(ctx) if err != nil { openstackSwiftLog.Error(err, "scaler could not validate the token for authentication") @@ -71,7 +71,7 @@ func (s *openstackSwiftScaler) getOpenstackSwiftContainerObjectCount() (int, err } if !isValid { - err := s.swiftClient.RenewToken() + err := s.swiftClient.RenewToken(ctx) if err != nil { openstackSwiftLog.Error(err, "error requesting token for authentication") @@ -90,7 +90,7 @@ func (s *openstackSwiftScaler) getOpenstackSwiftContainerObjectCount() (int, err swiftContainerURL.Path = path.Join(swiftContainerURL.Path, containerName) - swiftRequest, _ := http.NewRequest("GET", swiftContainerURL.String(), nil) + swiftRequest, _ := http.NewRequestWithContext(ctx, "GET", swiftContainerURL.String(), nil) swiftRequest.Header.Set("X-Auth-Token", token) @@ -177,7 +177,7 @@ func (s *openstackSwiftScaler) getOpenstackSwiftContainerObjectCount() (int, err } // NewOpenstackSwiftScaler creates a new OpenStack Swift scaler -func NewOpenstackSwiftScaler(config *ScalerConfig) (Scaler, error) { +func NewOpenstackSwiftScaler(ctx context.Context, config *ScalerConfig) (Scaler, error) { var authRequest *openstack.KeystoneAuthRequest var swiftClient openstack.Client @@ -214,7 +214,7 @@ func NewOpenstackSwiftScaler(config *ScalerConfig) (Scaler, error) { if openstackSwiftMetadata.swiftURL == "" { // Request a Client with a token and the Swift API endpoint - swiftClient, err = authRequest.RequestClient("swift", authMetadata.regionName) + swiftClient, err = authRequest.RequestClient(ctx, "swift", authMetadata.regionName) if err != nil { return nil, fmt.Errorf("swiftURL was not provided and the scaler could not retrieve it dinamically using the OpenStack catalog: %s", err.Error()) @@ -223,7 +223,7 @@ func NewOpenstackSwiftScaler(config *ScalerConfig) (Scaler, error) { openstackSwiftMetadata.swiftURL = swiftClient.URL } else { // Request a Client with a token, but not the Swift API endpoint - swiftClient, err = authRequest.RequestClient() + swiftClient, err = authRequest.RequestClient(ctx) if err != nil { return nil, err @@ -351,7 +351,7 @@ func parseOpenstackSwiftAuthenticationMetadata(config *ScalerConfig) (*openstack } func (s *openstackSwiftScaler) IsActive(ctx context.Context) (bool, error) { - objectCount, err := s.getOpenstackSwiftContainerObjectCount() + objectCount, err := s.getOpenstackSwiftContainerObjectCount(ctx) if err != nil { return false, err @@ -360,12 +360,12 @@ func (s *openstackSwiftScaler) IsActive(ctx context.Context) (bool, error) { return objectCount > 0, nil } -func (s *openstackSwiftScaler) Close() error { +func (s *openstackSwiftScaler) Close(context.Context) error { return nil } func (s *openstackSwiftScaler) GetMetrics(ctx context.Context, metricName string, metricSelector labels.Selector) ([]external_metrics.ExternalMetricValue, error) { - objectCount, err := s.getOpenstackSwiftContainerObjectCount() + objectCount, err := s.getOpenstackSwiftContainerObjectCount(ctx) if err != nil { openstackSwiftLog.Error(err, "error getting objectCount") @@ -381,7 +381,7 @@ func (s *openstackSwiftScaler) GetMetrics(ctx context.Context, metricName string return append([]external_metrics.ExternalMetricValue{}, metric), nil } -func (s *openstackSwiftScaler) GetMetricSpecForScaling() []v2beta2.MetricSpec { +func (s *openstackSwiftScaler) GetMetricSpecForScaling(context.Context) []v2beta2.MetricSpec { targetObjectCount := resource.NewQuantity(int64(s.metadata.objectCount), resource.DecimalSI) var metricName string diff --git a/pkg/scalers/openstack_swift_scaler_test.go b/pkg/scalers/openstack_swift_scaler_test.go index 1167d6ea650..75a41245c84 100644 --- a/pkg/scalers/openstack_swift_scaler_test.go +++ b/pkg/scalers/openstack_swift_scaler_test.go @@ -1,6 +1,7 @@ package scalers import ( + "context" "testing" "github.com/kedacore/keda/v2/pkg/scalers/openstack" @@ -110,7 +111,7 @@ func TestOpenstackSwiftGetMetricSpecForScaling(t *testing.T) { mockSwiftScaler := openstackSwiftScaler{meta, openstack.Client{}} - metricSpec := mockSwiftScaler.GetMetricSpecForScaling() + metricSpec := mockSwiftScaler.GetMetricSpecForScaling(context.Background()) metricName := metricSpec[0].External.Metric.Name diff --git a/pkg/scalers/postgresql_scaler.go b/pkg/scalers/postgresql_scaler.go index 223b5d569e4..1ffb7a4d4f3 100644 --- a/pkg/scalers/postgresql_scaler.go +++ b/pkg/scalers/postgresql_scaler.go @@ -162,7 +162,7 @@ func getConnection(meta *postgreSQLMetadata) (*sql.DB, error) { } // Close disposes of postgres connections -func (s *postgreSQLScaler) Close() error { +func (s *postgreSQLScaler) Close(context.Context) error { err := s.connection.Close() if err != nil { postgreSQLLog.Error(err, "Error closing postgreSQL connection") @@ -173,7 +173,7 @@ func (s *postgreSQLScaler) Close() error { // IsActive returns true if there are pending messages to be processed func (s *postgreSQLScaler) IsActive(ctx context.Context) (bool, error) { - messages, err := s.getActiveNumber() + messages, err := s.getActiveNumber(ctx) if err != nil { return false, fmt.Errorf("error inspecting postgreSQL: %s", err) } @@ -181,9 +181,9 @@ func (s *postgreSQLScaler) IsActive(ctx context.Context) (bool, error) { return messages > 0, nil } -func (s *postgreSQLScaler) getActiveNumber() (int, error) { +func (s *postgreSQLScaler) getActiveNumber(ctx context.Context) (int, error) { var id int - err := s.connection.QueryRow(s.metadata.query).Scan(&id) + err := s.connection.QueryRowContext(ctx, s.metadata.query).Scan(&id) if err != nil { postgreSQLLog.Error(err, fmt.Sprintf("could not query postgreSQL: %s", err)) return 0, fmt.Errorf("could not query postgreSQL: %s", err) @@ -192,7 +192,7 @@ func (s *postgreSQLScaler) getActiveNumber() (int, error) { } // GetMetricSpecForScaling returns the MetricSpec for the Horizontal Pod Autoscaler -func (s *postgreSQLScaler) GetMetricSpecForScaling() []v2beta2.MetricSpec { +func (s *postgreSQLScaler) GetMetricSpecForScaling(context.Context) []v2beta2.MetricSpec { targetQueryValue := resource.NewQuantity(int64(s.metadata.targetQueryValue), resource.DecimalSI) externalMetric := &v2beta2.ExternalMetricSource{ @@ -212,7 +212,7 @@ func (s *postgreSQLScaler) GetMetricSpecForScaling() []v2beta2.MetricSpec { // GetMetrics returns value for a supported metric and an error if there is a problem getting the metric func (s *postgreSQLScaler) GetMetrics(ctx context.Context, metricName string, metricSelector labels.Selector) ([]external_metrics.ExternalMetricValue, error) { - num, err := s.getActiveNumber() + num, err := s.getActiveNumber(ctx) if err != nil { return []external_metrics.ExternalMetricValue{}, fmt.Errorf("error inspecting postgreSQL: %s", err) } diff --git a/pkg/scalers/postgresql_scaler_test.go b/pkg/scalers/postgresql_scaler_test.go index a86d2517786..f03cd6f636b 100644 --- a/pkg/scalers/postgresql_scaler_test.go +++ b/pkg/scalers/postgresql_scaler_test.go @@ -1,6 +1,7 @@ package scalers import ( + "context" "testing" ) @@ -48,7 +49,7 @@ func TestPosgresSQLGetMetricSpecForScaling(t *testing.T) { } mockPostgresSQLScaler := postgreSQLScaler{meta, nil} - metricSpec := mockPostgresSQLScaler.GetMetricSpecForScaling() + metricSpec := mockPostgresSQLScaler.GetMetricSpecForScaling(context.Background()) metricName := metricSpec[0].External.Metric.Name if metricName != testData.name { t.Error("Wrong External metric source name:", metricName) diff --git a/pkg/scalers/prometheus_scaler.go b/pkg/scalers/prometheus_scaler.go index 1a0d9e45a0a..0a049a9b196 100644 --- a/pkg/scalers/prometheus_scaler.go +++ b/pkg/scalers/prometheus_scaler.go @@ -187,7 +187,7 @@ func parsePrometheusMetadata(config *ScalerConfig) (*prometheusMetadata, error) } func (s *prometheusScaler) IsActive(ctx context.Context) (bool, error) { - val, err := s.ExecutePromQuery() + val, err := s.ExecutePromQuery(ctx) if err != nil { prometheusLog.Error(err, "error executing prometheus query") return false, err @@ -196,11 +196,11 @@ func (s *prometheusScaler) IsActive(ctx context.Context) (bool, error) { return val > 0, nil } -func (s *prometheusScaler) Close() error { +func (s *prometheusScaler) Close(context.Context) error { return nil } -func (s *prometheusScaler) GetMetricSpecForScaling() []v2beta2.MetricSpec { +func (s *prometheusScaler) GetMetricSpecForScaling(context.Context) []v2beta2.MetricSpec { targetMetricValue := resource.NewQuantity(int64(s.metadata.threshold), resource.DecimalSI) metricName := kedautil.NormalizeString(fmt.Sprintf("%s-%s", "prometheus", s.metadata.metricName)) externalMetric := &v2beta2.ExternalMetricSource{ @@ -218,11 +218,11 @@ func (s *prometheusScaler) GetMetricSpecForScaling() []v2beta2.MetricSpec { return []v2beta2.MetricSpec{metricSpec} } -func (s *prometheusScaler) ExecutePromQuery() (float64, error) { +func (s *prometheusScaler) ExecutePromQuery(ctx context.Context) (float64, error) { t := time.Now().UTC().Format(time.RFC3339) queryEscaped := url_pkg.QueryEscape(s.metadata.query) url := fmt.Sprintf("%s/api/v1/query?query=%s&time=%s", s.metadata.serverAddress, queryEscaped, t) - req, err := http.NewRequest("GET", url, nil) + req, err := http.NewRequestWithContext(ctx, "GET", url, nil) if err != nil { return -1, err } @@ -277,7 +277,7 @@ func (s *prometheusScaler) ExecutePromQuery() (float64, error) { } func (s *prometheusScaler) GetMetrics(ctx context.Context, metricName string, metricSelector labels.Selector) ([]external_metrics.ExternalMetricValue, error) { - val, err := s.ExecutePromQuery() + val, err := s.ExecutePromQuery(ctx) if err != nil { prometheusLog.Error(err, "error executing prometheus query") return []external_metrics.ExternalMetricValue{}, err diff --git a/pkg/scalers/prometheus_scaler_test.go b/pkg/scalers/prometheus_scaler_test.go index 6f641db7eae..8ec064c9dde 100644 --- a/pkg/scalers/prometheus_scaler_test.go +++ b/pkg/scalers/prometheus_scaler_test.go @@ -1,6 +1,7 @@ package scalers import ( + "context" "net/http" "strings" "testing" @@ -90,7 +91,7 @@ func TestPrometheusGetMetricSpecForScaling(t *testing.T) { httpClient: http.DefaultClient, } - metricSpec := mockPrometheusScaler.GetMetricSpecForScaling() + metricSpec := mockPrometheusScaler.GetMetricSpecForScaling(context.Background()) metricName := metricSpec[0].External.Metric.Name if metricName != testData.name { t.Error("Wrong External metric source name:", metricName) diff --git a/pkg/scalers/rabbitmq_scaler.go b/pkg/scalers/rabbitmq_scaler.go index 81e51c2aa15..6b89bbb4a83 100644 --- a/pkg/scalers/rabbitmq_scaler.go +++ b/pkg/scalers/rabbitmq_scaler.go @@ -336,7 +336,7 @@ func getConnectionAndChannel(host string) (*amqp.Connection, *amqp.Channel, erro } // Close disposes of RabbitMQ connections -func (s *rabbitMQScaler) Close() error { +func (s *rabbitMQScaler) Close(context.Context) error { if s.connection != nil { err := s.connection.Close() if err != nil { @@ -446,7 +446,7 @@ func (s *rabbitMQScaler) getQueueInfoViaHTTP() (*queueInfo, error) { } // GetMetricSpecForScaling returns the MetricSpec for the Horizontal Pod Autoscaler -func (s *rabbitMQScaler) GetMetricSpecForScaling() []v2beta2.MetricSpec { +func (s *rabbitMQScaler) GetMetricSpecForScaling(context.Context) []v2beta2.MetricSpec { metricValue := resource.NewQuantity(int64(s.metadata.value), resource.DecimalSI) externalMetric := &v2beta2.ExternalMetricSource{ Metric: v2beta2.MetricIdentifier{ diff --git a/pkg/scalers/rabbitmq_scaler_test.go b/pkg/scalers/rabbitmq_scaler_test.go index 2677851aab0..d41dc3feef6 100644 --- a/pkg/scalers/rabbitmq_scaler_test.go +++ b/pkg/scalers/rabbitmq_scaler_test.go @@ -460,7 +460,7 @@ func TestRabbitMQGetMetricSpecForScaling(t *testing.T) { httpClient: http.DefaultClient, } - metricSpec := mockRabbitMQScaler.GetMetricSpecForScaling() + metricSpec := mockRabbitMQScaler.GetMetricSpecForScaling(context.Background()) metricName := metricSpec[0].External.Metric.Name if metricName != testData.name { t.Error("Wrong External metric source name:", metricName, "wanted:", testData.name) diff --git a/pkg/scalers/redis_scaler.go b/pkg/scalers/redis_scaler.go index 27903b87b80..d38ec74274e 100644 --- a/pkg/scalers/redis_scaler.go +++ b/pkg/scalers/redis_scaler.go @@ -29,7 +29,7 @@ type redisAddressParser func(metadata, resolvedEnv, authParams map[string]string type redisScaler struct { metadata *redisMetadata closeFn func() error - getListLengthFn func() (int64, error) + getListLengthFn func(context.Context) (int64, error) } type redisConnectionInfo struct { @@ -51,7 +51,7 @@ type redisMetadata struct { var redisLog = logf.Log.WithName("redis_scaler") // NewRedisScaler creates a new redisScaler -func NewRedisScaler(isClustered bool, config *ScalerConfig) (Scaler, error) { +func NewRedisScaler(ctx context.Context, isClustered bool, config *ScalerConfig) (Scaler, error) { luaScript := ` local listName = KEYS[1] local listType = redis.call('type', listName).ok @@ -70,17 +70,17 @@ func NewRedisScaler(isClustered bool, config *ScalerConfig) (Scaler, error) { if err != nil { return nil, fmt.Errorf("error parsing redis metadata: %s", err) } - return createClusteredRedisScaler(meta, luaScript) + return createClusteredRedisScaler(ctx, meta, luaScript) } meta, err := parseRedisMetadata(config, parseRedisAddress) if err != nil { return nil, fmt.Errorf("error parsing redis metadata: %s", err) } - return createRedisScaler(meta, luaScript) + return createRedisScaler(ctx, meta, luaScript) } -func createClusteredRedisScaler(meta *redisMetadata, script string) (Scaler, error) { - client, err := getRedisClusterClient(meta.connectionInfo) +func createClusteredRedisScaler(ctx context.Context, meta *redisMetadata, script string) (Scaler, error) { + client, err := getRedisClusterClient(ctx, meta.connectionInfo) if err != nil { return nil, fmt.Errorf("connection to redis cluster failed: %s", err) } @@ -93,8 +93,9 @@ func createClusteredRedisScaler(meta *redisMetadata, script string) (Scaler, err return nil } - listLengthFn := func() (int64, error) { - cmd := client.Eval(script, []string{meta.listName}) + listLengthFn := func(ctx context.Context) (int64, error) { + cl := client.WithContext(ctx) + cmd := cl.Eval(script, []string{meta.listName}) if cmd.Err() != nil { return -1, cmd.Err() } @@ -109,8 +110,8 @@ func createClusteredRedisScaler(meta *redisMetadata, script string) (Scaler, err }, nil } -func createRedisScaler(meta *redisMetadata, script string) (Scaler, error) { - client, err := getRedisClient(meta.connectionInfo, meta.databaseIndex) +func createRedisScaler(ctx context.Context, meta *redisMetadata, script string) (Scaler, error) { + client, err := getRedisClient(ctx, meta.connectionInfo, meta.databaseIndex) if err != nil { return nil, fmt.Errorf("connection to redis failed: %s", err) } @@ -123,8 +124,9 @@ func createRedisScaler(meta *redisMetadata, script string) (Scaler, error) { return nil } - listLengthFn := func() (int64, error) { - cmd := client.Eval(script, []string{meta.listName}) + listLengthFn := func(ctx context.Context) (int64, error) { + cl := client.WithContext(ctx) + cmd := cl.Eval(script, []string{meta.listName}) if cmd.Err() != nil { return -1, cmd.Err() } @@ -177,7 +179,7 @@ func parseRedisMetadata(config *ScalerConfig, parserFn redisAddressParser) (*red // IsActive checks if there is any element in the Redis list func (s *redisScaler) IsActive(ctx context.Context) (bool, error) { - length, err := s.getListLengthFn() + length, err := s.getListLengthFn(ctx) if err != nil { redisLog.Error(err, "error") @@ -187,12 +189,12 @@ func (s *redisScaler) IsActive(ctx context.Context) (bool, error) { return length > 0, nil } -func (s *redisScaler) Close() error { +func (s *redisScaler) Close(context.Context) error { return s.closeFn() } // GetMetricSpecForScaling returns the metric spec for the HPA -func (s *redisScaler) GetMetricSpecForScaling() []v2beta2.MetricSpec { +func (s *redisScaler) GetMetricSpecForScaling(context.Context) []v2beta2.MetricSpec { targetListLengthQty := resource.NewQuantity(int64(s.metadata.targetListLength), resource.DecimalSI) metricName := kedautil.NormalizeString(fmt.Sprintf("%s-%s", "redis", s.metadata.listName)) externalMetric := &v2beta2.ExternalMetricSource{ @@ -212,7 +214,7 @@ func (s *redisScaler) GetMetricSpecForScaling() []v2beta2.MetricSpec { // GetMetrics connects to Redis and finds the length of the list func (s *redisScaler) GetMetrics(ctx context.Context, metricName string, metricSelector labels.Selector) ([]external_metrics.ExternalMetricValue, error) { - listLen, err := s.getListLengthFn() + listLen, err := s.getListLengthFn(ctx) if err != nil { redisLog.Error(err, "error getting list length") @@ -343,7 +345,7 @@ func parseRedisClusterAddress(metadata, resolvedEnv, authParams map[string]strin return info, nil } -func getRedisClusterClient(info redisConnectionInfo) (*redis.ClusterClient, error) { +func getRedisClusterClient(ctx context.Context, info redisConnectionInfo) (*redis.ClusterClient, error) { options := &redis.ClusterOptions{ Addrs: info.addresses, Password: info.password, @@ -356,14 +358,13 @@ func getRedisClusterClient(info redisConnectionInfo) (*redis.ClusterClient, erro // confirm if connected c := redis.NewClusterClient(options) - err := c.Ping().Err() - if err != nil { + if err := c.WithContext(ctx).Ping().Err(); err != nil { return nil, err } return c, nil } -func getRedisClient(info redisConnectionInfo, dbIndex int) (*redis.Client, error) { +func getRedisClient(ctx context.Context, info redisConnectionInfo, dbIndex int) (*redis.Client, error) { options := &redis.Options{ Addr: info.addresses[0], Password: info.password, @@ -377,8 +378,7 @@ func getRedisClient(info redisConnectionInfo, dbIndex int) (*redis.Client, error // confirm if connected c := redis.NewClient(options) - err := c.Ping().Err() - if err != nil { + if err := c.WithContext(ctx).Ping().Err(); err != nil { return nil, err } return c, nil diff --git a/pkg/scalers/redis_scaler_test.go b/pkg/scalers/redis_scaler_test.go index a3b1cd23420..235ed6fe9e6 100644 --- a/pkg/scalers/redis_scaler_test.go +++ b/pkg/scalers/redis_scaler_test.go @@ -1,6 +1,7 @@ package scalers import ( + "context" "errors" "testing" @@ -77,14 +78,14 @@ func TestRedisGetMetricSpecForScaling(t *testing.T) { t.Fatal("Could not parse metadata:", err) } closeFn := func() error { return nil } - lengthFn := func() (int64, error) { return -1, nil } + lengthFn := func(context.Context) (int64, error) { return -1, nil } mockRedisScaler := redisScaler{ meta, closeFn, lengthFn, } - metricSpec := mockRedisScaler.GetMetricSpecForScaling() + metricSpec := mockRedisScaler.GetMetricSpecForScaling(context.Background()) metricName := metricSpec[0].External.Metric.Name if metricName != testData.name { t.Error("Wrong External metric source name:", metricName) diff --git a/pkg/scalers/redis_streams_scaler.go b/pkg/scalers/redis_streams_scaler.go index 3183088b1a2..59b5954706d 100644 --- a/pkg/scalers/redis_streams_scaler.go +++ b/pkg/scalers/redis_streams_scaler.go @@ -47,23 +47,23 @@ type redisStreamsMetadata struct { var redisStreamsLog = logf.Log.WithName("redis_streams_scaler") // NewRedisStreamsScaler creates a new redisStreamsScaler -func NewRedisStreamsScaler(isClustered bool, config *ScalerConfig) (Scaler, error) { +func NewRedisStreamsScaler(ctx context.Context, isClustered bool, config *ScalerConfig) (Scaler, error) { if isClustered { meta, err := parseRedisStreamsMetadata(config, parseRedisClusterAddress) if err != nil { return nil, fmt.Errorf("error parsing redis streams metadata: %s", err) } - return createClusteredRedisStreamsScaler(meta) + return createClusteredRedisStreamsScaler(ctx, meta) } meta, err := parseRedisStreamsMetadata(config, parseRedisAddress) if err != nil { return nil, fmt.Errorf("error parsing redis streams metadata: %s", err) } - return createRedisStreamsScaler(meta) + return createRedisStreamsScaler(ctx, meta) } -func createClusteredRedisStreamsScaler(meta *redisStreamsMetadata) (Scaler, error) { - client, err := getRedisClusterClient(meta.connectionInfo) +func createClusteredRedisStreamsScaler(ctx context.Context, meta *redisStreamsMetadata) (Scaler, error) { + client, err := getRedisClusterClient(ctx, meta.connectionInfo) if err != nil { return nil, fmt.Errorf("connection to redis cluster failed: %s", err) } @@ -91,8 +91,8 @@ func createClusteredRedisStreamsScaler(meta *redisStreamsMetadata) (Scaler, erro }, nil } -func createRedisStreamsScaler(meta *redisStreamsMetadata) (Scaler, error) { - client, err := getRedisClient(meta.connectionInfo, meta.databaseIndex) +func createRedisStreamsScaler(ctx context.Context, meta *redisStreamsMetadata) (Scaler, error) { + client, err := getRedisClient(ctx, meta.connectionInfo, meta.databaseIndex) if err != nil { return nil, fmt.Errorf("connection to redis failed: %s", err) } @@ -176,12 +176,12 @@ func (s *redisStreamsScaler) IsActive(ctx context.Context) (bool, error) { return count > 0, nil } -func (s *redisStreamsScaler) Close() error { +func (s *redisStreamsScaler) Close(context.Context) error { return s.closeFn() } // GetMetricSpecForScaling returns the metric spec for the HPA -func (s *redisStreamsScaler) GetMetricSpecForScaling() []v2beta2.MetricSpec { +func (s *redisStreamsScaler) GetMetricSpecForScaling(context.Context) []v2beta2.MetricSpec { targetPendingEntriesCount := resource.NewQuantity(int64(s.metadata.targetPendingEntriesCount), resource.DecimalSI) metricName := kedautil.NormalizeString(fmt.Sprintf("%s-%s-%s", "redis-streams", s.metadata.streamName, s.metadata.consumerGroupName)) externalMetric := &v2beta2.ExternalMetricSource{ diff --git a/pkg/scalers/redis_streams_scaler_test.go b/pkg/scalers/redis_streams_scaler_test.go index d840494f996..a5890b699f4 100644 --- a/pkg/scalers/redis_streams_scaler_test.go +++ b/pkg/scalers/redis_streams_scaler_test.go @@ -1,6 +1,7 @@ package scalers import ( + "context" "errors" "strconv" "testing" @@ -141,7 +142,7 @@ func TestRedisStreamsGetMetricSpecForScaling(t *testing.T) { getPendingEntriesCountFn := func() (int64, error) { return -1, nil } mockRedisStreamsScaler := redisStreamsScaler{meta, closeFn, getPendingEntriesCountFn} - metricSpec := mockRedisStreamsScaler.GetMetricSpecForScaling() + metricSpec := mockRedisStreamsScaler.GetMetricSpecForScaling(context.Background()) metricName := metricSpec[0].External.Metric.Name if metricName != testData.name { t.Error("Wrong External metric source name:", metricName) diff --git a/pkg/scalers/scaler.go b/pkg/scalers/scaler.go index d5941be6f38..5a75110daf9 100644 --- a/pkg/scalers/scaler.go +++ b/pkg/scalers/scaler.go @@ -43,12 +43,12 @@ type Scaler interface { // Returns the metrics based on which this scaler determines that the ScaleTarget scales. This is used to construct the HPA spec that is created for // this scaled object. The labels used should match the selectors used in GetMetrics - GetMetricSpecForScaling() []v2beta2.MetricSpec + GetMetricSpecForScaling(ctx context.Context) []v2beta2.MetricSpec IsActive(ctx context.Context) (bool, error) // Close any resources that need disposing when scaler is no longer used or destroyed - Close() error + Close(ctx context.Context) error } // PushScaler interface diff --git a/pkg/scalers/selenium_grid_scaler.go b/pkg/scalers/selenium_grid_scaler.go index 2084ae8448d..eda6c234c8e 100644 --- a/pkg/scalers/selenium_grid_scaler.go +++ b/pkg/scalers/selenium_grid_scaler.go @@ -106,12 +106,12 @@ func parseSeleniumGridScalerMetadata(config *ScalerConfig) (*seleniumGridScalerM } // No cleanup required for selenium grid scaler -func (s *seleniumGridScaler) Close() error { +func (s *seleniumGridScaler) Close(context.Context) error { return nil } func (s *seleniumGridScaler) GetMetrics(ctx context.Context, metricName string, metricSelector labels.Selector) ([]external_metrics.ExternalMetricValue, error) { - v, err := s.getSessionsCount() + v, err := s.getSessionsCount(ctx) if err != nil { return []external_metrics.ExternalMetricValue{}, fmt.Errorf("error requesting selenium grid endpoint: %s", err) } @@ -125,7 +125,7 @@ func (s *seleniumGridScaler) GetMetrics(ctx context.Context, metricName string, return append([]external_metrics.ExternalMetricValue{}, metric), nil } -func (s *seleniumGridScaler) GetMetricSpecForScaling() []v2beta2.MetricSpec { +func (s *seleniumGridScaler) GetMetricSpecForScaling(context.Context) []v2beta2.MetricSpec { targetValue := resource.NewQuantity(s.metadata.targetValue, resource.DecimalSI) metricName := kedautil.NormalizeString(fmt.Sprintf("%s-%s-%s-%s", "seleniumgrid", s.metadata.url, s.metadata.browserName, s.metadata.browserVersion)) externalMetric := &v2beta2.ExternalMetricSource{ @@ -144,7 +144,7 @@ func (s *seleniumGridScaler) GetMetricSpecForScaling() []v2beta2.MetricSpec { } func (s *seleniumGridScaler) IsActive(ctx context.Context) (bool, error) { - v, err := s.getSessionsCount() + v, err := s.getSessionsCount(ctx) if err != nil { return false, err } @@ -152,7 +152,7 @@ func (s *seleniumGridScaler) IsActive(ctx context.Context) (bool, error) { return v.AsApproximateFloat64() > 0.0, nil } -func (s *seleniumGridScaler) getSessionsCount() (*resource.Quantity, error) { +func (s *seleniumGridScaler) getSessionsCount(ctx context.Context) (*resource.Quantity, error) { body, err := json.Marshal(map[string]string{ "query": "{ sessionsInfo { sessionQueueRequests, sessions { id, capabilities, nodeId } } }", }) @@ -161,7 +161,7 @@ func (s *seleniumGridScaler) getSessionsCount() (*resource.Quantity, error) { return nil, err } - req, err := http.NewRequest("POST", s.metadata.url, bytes.NewBuffer(body)) + req, err := http.NewRequestWithContext(ctx, "POST", s.metadata.url, bytes.NewBuffer(body)) if err != nil { return nil, err } diff --git a/pkg/scalers/solace_scaler.go b/pkg/scalers/solace_scaler.go index 393bff7df78..36d1076d09a 100644 --- a/pkg/scalers/solace_scaler.go +++ b/pkg/scalers/solace_scaler.go @@ -239,7 +239,7 @@ func getSolaceSempCredentials(config *ScalerConfig) (u string, p string, err err // METRIC IDENTIFIER HAS THE SIGNATURE: // - solace-[VPN_Name]-[Queue_Name]-[metric_type] // e.g. solace-myvpn-QUEUE1-msgCount -func (s *SolaceScaler) GetMetricSpecForScaling() []v2beta2.MetricSpec { +func (s *SolaceScaler) GetMetricSpecForScaling(context.Context) []v2beta2.MetricSpec { var metricSpecList []v2beta2.MetricSpec // Message Count Target Spec if s.metadata.msgCountTarget > 0 { @@ -277,7 +277,7 @@ func (s *SolaceScaler) GetMetricSpecForScaling() []v2beta2.MetricSpec { } // returns SolaceMetricValues struct populated from broker SEMP endpoint -func (s *SolaceScaler) getSolaceQueueMetricsFromSEMP() (SolaceMetricValues, error) { +func (s *SolaceScaler) getSolaceQueueMetricsFromSEMP(ctx context.Context) (SolaceMetricValues, error) { var scaledMetricEndpointURL = s.metadata.endpointURL var httpClient = s.httpClient var sempResponse solaceSEMPResponse @@ -285,7 +285,7 @@ func (s *SolaceScaler) getSolaceQueueMetricsFromSEMP() (SolaceMetricValues, erro // RETRIEVE METRICS FROM SOLACE SEMP API // Define HTTP Request - request, err := http.NewRequest("GET", scaledMetricEndpointURL, nil) + request, err := http.NewRequestWithContext(ctx, "GET", scaledMetricEndpointURL, nil) if err != nil { return SolaceMetricValues{}, fmt.Errorf("failed attempting request to solace semp api: %s", err) } @@ -327,7 +327,7 @@ func (s *SolaceScaler) getSolaceQueueMetricsFromSEMP() (SolaceMetricValues, erro func (s *SolaceScaler) GetMetrics(ctx context.Context, metricName string, metricSelector labels.Selector) ([]external_metrics.ExternalMetricValue, error) { var metricValues, mv SolaceMetricValues var mve error - if mv, mve = s.getSolaceQueueMetricsFromSEMP(); mve != nil { + if mv, mve = s.getSolaceQueueMetricsFromSEMP(ctx); mve != nil { solaceLog.Error(mve, "call to semp endpoint failed") return []external_metrics.ExternalMetricValue{}, mve } @@ -360,7 +360,7 @@ func (s *SolaceScaler) GetMetrics(ctx context.Context, metricName string, metric // Call SEMP API to retrieve metrics // IsActive returns true if queue messageCount > 0 || msgSpoolUsage > 0 func (s *SolaceScaler) IsActive(ctx context.Context) (bool, error) { - metricValues, err := s.getSolaceQueueMetricsFromSEMP() + metricValues, err := s.getSolaceQueueMetricsFromSEMP(ctx) if err != nil { solaceLog.Error(err, "call to semp endpoint failed") return false, err @@ -369,6 +369,6 @@ func (s *SolaceScaler) IsActive(ctx context.Context) (bool, error) { } // Do Nothing - Satisfies Interface -func (s *SolaceScaler) Close() error { +func (s *SolaceScaler) Close(context.Context) error { return nil } diff --git a/pkg/scalers/solace_scaler_test.go b/pkg/scalers/solace_scaler_test.go index 553c51b18eb..f6220e9ec16 100644 --- a/pkg/scalers/solace_scaler_test.go +++ b/pkg/scalers/solace_scaler_test.go @@ -1,6 +1,7 @@ package scalers import ( + "context" "fmt" "net/http" "testing" @@ -433,7 +434,7 @@ func TestSolaceGetMetricSpec(t *testing.T) { } var metric []v2beta2.MetricSpec - if metric = testSolaceScaler.GetMetricSpecForScaling(); len(metric) == 0 { + if metric = testSolaceScaler.GetMetricSpecForScaling(context.Background()); len(metric) == 0 { err = fmt.Errorf("metric value not found") } else { metricName := metric[0].External.Metric.Name diff --git a/pkg/scalers/stan_scaler.go b/pkg/scalers/stan_scaler.go index fce827343db..79a50a534af 100644 --- a/pkg/scalers/stan_scaler.go +++ b/pkg/scalers/stan_scaler.go @@ -115,7 +115,7 @@ func parseStanMetadata(config *ScalerConfig) (stanMetadata, error) { func (s *stanScaler) IsActive(ctx context.Context) (bool, error) { monitoringEndpoint := s.getMonitoringEndpoint() - req, err := http.NewRequest("GET", monitoringEndpoint, nil) + req, err := http.NewRequestWithContext(ctx, "GET", monitoringEndpoint, nil) if err != nil { return false, err } @@ -126,7 +126,7 @@ func (s *stanScaler) IsActive(ctx context.Context) (bool, error) { } if resp.StatusCode == 404 { - req, err := http.NewRequest("GET", s.getSTANChannelsEndpoint(), nil) + req, err := http.NewRequestWithContext(ctx, "GET", s.getSTANChannelsEndpoint(), nil) if err != nil { return false, err } @@ -196,7 +196,7 @@ func (s *stanScaler) hasPendingMessage() bool { return false } -func (s *stanScaler) GetMetricSpecForScaling() []v2beta2.MetricSpec { +func (s *stanScaler) GetMetricSpecForScaling(context.Context) []v2beta2.MetricSpec { targetMetricValue := resource.NewQuantity(s.metadata.lagThreshold, resource.DecimalSI) metricName := kedautil.NormalizeString(fmt.Sprintf("%s-%s-%s-%s", "stan", s.metadata.queueGroup, s.metadata.durableName, s.metadata.subject)) externalMetric := &v2beta2.ExternalMetricSource{ @@ -216,7 +216,7 @@ func (s *stanScaler) GetMetricSpecForScaling() []v2beta2.MetricSpec { // GetMetrics returns value for a supported metric and an error if there is a problem getting the metric func (s *stanScaler) GetMetrics(ctx context.Context, metricName string, metricSelector labels.Selector) ([]external_metrics.ExternalMetricValue, error) { - req, err := http.NewRequest("GET", s.getMonitoringEndpoint(), nil) + req, err := http.NewRequestWithContext(ctx, "GET", s.getMonitoringEndpoint(), nil) if err != nil { return nil, err } @@ -244,6 +244,6 @@ func (s *stanScaler) GetMetrics(ctx context.Context, metricName string, metricSe } // Nothing to close here. -func (s *stanScaler) Close() error { +func (s *stanScaler) Close(context.Context) error { return nil } diff --git a/pkg/scalers/stan_scaler_test.go b/pkg/scalers/stan_scaler_test.go index e5d8dab8ff2..96ff76891ed 100644 --- a/pkg/scalers/stan_scaler_test.go +++ b/pkg/scalers/stan_scaler_test.go @@ -1,6 +1,7 @@ package scalers import ( + "context" "net/http" "testing" ) @@ -52,6 +53,7 @@ func TestStanParseMetadata(t *testing.T) { func TestStanGetMetricSpecForScaling(t *testing.T) { for _, testData := range stanMetricIdentifiers { + ctx := context.Background() meta, err := parseStanMetadata(&ScalerConfig{TriggerMetadata: testData.metadataTestData.metadata, ScalerIndex: testData.scalerIndex}) if err != nil { t.Fatal("Could not parse metadata:", err) @@ -62,7 +64,7 @@ func TestStanGetMetricSpecForScaling(t *testing.T) { httpClient: http.DefaultClient, } - metricSpec := mockStanScaler.GetMetricSpecForScaling() + metricSpec := mockStanScaler.GetMetricSpecForScaling(ctx) metricName := metricSpec[0].External.Metric.Name if metricName != testData.name { t.Error("Wrong External metric source name:", metricName) diff --git a/pkg/scaling/scale_handler.go b/pkg/scaling/scale_handler.go index 859ba16b15e..8ef9780a545 100644 --- a/pkg/scaling/scale_handler.go +++ b/pkg/scaling/scale_handler.go @@ -44,7 +44,7 @@ import ( type ScaleHandler interface { HandleScalableObject(scalableObject interface{}) error DeleteScalableObject(scalableObject interface{}) error - GetScalers(scalableObject interface{}) ([]scalers.Scaler, error) + GetScalers(ctx context.Context, scalableObject interface{}) ([]scalers.Scaler, error) } type scaleHandler struct { @@ -68,7 +68,7 @@ func NewScaleHandler(client client.Client, scaleClient scale.ScalesGetter, recon } } -func (h *scaleHandler) GetScalers(scalableObject interface{}) ([]scalers.Scaler, error) { +func (h *scaleHandler) GetScalers(ctx context.Context, scalableObject interface{}) ([]scalers.Scaler, error) { withTriggers, err := asDuckWithTriggers(scalableObject) if err != nil { return nil, err @@ -79,7 +79,7 @@ func (h *scaleHandler) GetScalers(scalableObject interface{}) ([]scalers.Scaler, return nil, err } - return h.buildScalers(withTriggers, podTemplateSpec, containerName) + return h.buildScalers(ctx, withTriggers, podTemplateSpec, containerName) } func (h *scaleHandler) HandleScalableObject(scalableObject interface{}) error { @@ -169,7 +169,7 @@ func (h *scaleHandler) startScaleLoop(ctx context.Context, withTriggers *kedav1a func (h *scaleHandler) startPushScalers(ctx context.Context, withTriggers *kedav1alpha1.WithTriggers, scalableObject interface{}, scalingMutex sync.Locker) { logger := h.logger.WithValues("type", withTriggers.Kind, "namespace", withTriggers.Namespace, "name", withTriggers.Name) - ss, err := h.GetScalers(scalableObject) + ss, err := h.GetScalers(ctx, scalableObject) if err != nil { logger.Error(err, "Error getting scalers", "object", scalableObject) return @@ -178,14 +178,14 @@ func (h *scaleHandler) startPushScalers(ctx context.Context, withTriggers *kedav for _, s := range ss { scaler, ok := s.(scalers.PushScaler) if !ok { - s.Close() + s.Close(ctx) continue } go func() { activeCh := make(chan bool) go scaler.Run(ctx, activeCh) - defer scaler.Close() + defer scaler.Close(ctx) for { select { case <-ctx.Done(): @@ -208,7 +208,7 @@ func (h *scaleHandler) startPushScalers(ctx context.Context, withTriggers *kedav // checkScalers contains the main logic for the ScaleHandler scaling logic. // It'll check each trigger active status then call RequestScale func (h *scaleHandler) checkScalers(ctx context.Context, scalableObject interface{}, scalingMutex sync.Locker) { - scalers, err := h.GetScalers(scalableObject) + scalers, err := h.GetScalers(ctx, scalableObject) if err != nil { h.logger.Error(err, "Error getting scalers", "object", scalableObject) return @@ -241,7 +241,7 @@ func (h *scaleHandler) isScaledObjectActive(ctx context.Context, scalers []scale isError := false for i, scaler := range scalers { isTriggerActive, err := scaler.IsActive(ctx) - scaler.Close() + scaler.Close(ctx) if err != nil { h.logger.V(1).Info("Error getting scale decision", "Error", err) @@ -250,13 +250,13 @@ func (h *scaleHandler) isScaledObjectActive(ctx context.Context, scalers []scale continue } else if isTriggerActive { isActive = true - if externalMetricsSpec := scaler.GetMetricSpecForScaling()[0].External; externalMetricsSpec != nil { + if externalMetricsSpec := scaler.GetMetricSpecForScaling(ctx)[0].External; externalMetricsSpec != nil { h.logger.V(1).Info("Scaler for scaledObject is active", "Metrics Name", externalMetricsSpec.Metric.Name) } - if resourceMetricsSpec := scaler.GetMetricSpecForScaling()[0].Resource; resourceMetricsSpec != nil { + if resourceMetricsSpec := scaler.GetMetricSpecForScaling(ctx)[0].Resource; resourceMetricsSpec != nil { h.logger.V(1).Info("Scaler for scaledObject is active", "Metrics Name", resourceMetricsSpec.Name) } - closeScalers(scalers[i+1:]) + closeScalers(ctx, scalers[i+1:]) break } } @@ -268,7 +268,7 @@ func (h *scaleHandler) isScaledJobActive(ctx context.Context, scalers []scalers. } // buildScalers returns list of Scalers for the specified triggers -func (h *scaleHandler) buildScalers(withTriggers *kedav1alpha1.WithTriggers, podTemplateSpec *corev1.PodTemplateSpec, containerName string) ([]scalers.Scaler, error) { +func (h *scaleHandler) buildScalers(ctx context.Context, withTriggers *kedav1alpha1.WithTriggers, podTemplateSpec *corev1.PodTemplateSpec, containerName string) ([]scalers.Scaler, error) { logger := h.logger.WithValues("type", withTriggers.Kind, "namespace", withTriggers.Namespace, "name", withTriggers.Name) var scalersRes []scalers.Scaler var err error @@ -293,13 +293,13 @@ func (h *scaleHandler) buildScalers(withTriggers *kedav1alpha1.WithTriggers, pod config.AuthParams, config.PodIdentity, err = resolver.ResolveAuthRefAndPodIdentity(h.client, logger, trigger.AuthenticationRef, podTemplateSpec, withTriggers.Namespace) if err != nil { - closeScalers(scalersRes) + closeScalers(ctx, scalersRes) return []scalers.Scaler{}, err } - scaler, err := buildScaler(h.client, trigger.Type, config) + scaler, err := buildScaler(ctx, h.client, trigger.Type, config) if err != nil { - closeScalers(scalersRes) + closeScalers(ctx, scalersRes) h.recorder.Event(withTriggers, corev1.EventTypeWarning, eventreason.KEDAScalerFailed, err.Error()) return []scalers.Scaler{}, fmt.Errorf("error getting scaler for trigger #%d: %s", scalerIndex, err) } @@ -310,7 +310,7 @@ func (h *scaleHandler) buildScalers(withTriggers *kedav1alpha1.WithTriggers, pod return scalersRes, nil } -func buildScaler(client client.Client, triggerType string, config *scalers.ScalerConfig) (scalers.Scaler, error) { +func buildScaler(ctx context.Context, client client.Client, triggerType string, config *scalers.ScalerConfig) (scalers.Scaler, error) { // TRIGGERS-START switch triggerType { case "artemis-queue": @@ -364,15 +364,15 @@ func buildScaler(client client.Client, triggerType string, config *scalers.Scale case "metrics-api": return scalers.NewMetricsAPIScaler(config) case "mongodb": - return scalers.NewMongoDBScaler(config) + return scalers.NewMongoDBScaler(ctx, config) case "mssql": return scalers.NewMSSQLScaler(config) case "mysql": return scalers.NewMySQLScaler(config) case "openstack-metric": - return scalers.NewOpenstackMetricScaler(config) + return scalers.NewOpenstackMetricScaler(ctx, config) case "openstack-swift": - return scalers.NewOpenstackSwiftScaler(config) + return scalers.NewOpenstackSwiftScaler(ctx, config) case "postgresql": return scalers.NewPostgreSQLScaler(config) case "prometheus": @@ -380,13 +380,13 @@ func buildScaler(client client.Client, triggerType string, config *scalers.Scale case "rabbitmq": return scalers.NewRabbitMQScaler(config) case "redis": - return scalers.NewRedisScaler(false, config) + return scalers.NewRedisScaler(ctx, false, config) case "redis-cluster": - return scalers.NewRedisScaler(true, config) + return scalers.NewRedisScaler(ctx, true, config) case "redis-cluster-streams": - return scalers.NewRedisStreamsScaler(true, config) + return scalers.NewRedisStreamsScaler(ctx, true, config) case "redis-streams": - return scalers.NewRedisStreamsScaler(false, config) + return scalers.NewRedisStreamsScaler(ctx, false, config) case "selenium-grid": return scalers.NewSeleniumGridScaler(config) case "solace-event-queue": @@ -425,8 +425,8 @@ func asDuckWithTriggers(scalableObject interface{}) (*kedav1alpha1.WithTriggers, } } -func closeScalers(scalers []scalers.Scaler) { +func closeScalers(ctx context.Context, scalers []scalers.Scaler) { for _, scaler := range scalers { - defer scaler.Close() + defer scaler.Close(ctx) } } diff --git a/pkg/scaling/scale_handler_test.go b/pkg/scaling/scale_handler_test.go index 0eb291d9004..eaf1cbc1c8c 100644 --- a/pkg/scaling/scale_handler_test.go +++ b/pkg/scaling/scale_handler_test.go @@ -55,7 +55,7 @@ func TestCheckScaledObjectScalersWithError(t *testing.T) { scaledObject := &kedav1alpha1.ScaledObject{} scaler.EXPECT().IsActive(gomock.Any()).Return(false, errors.New("Some error")) - scaler.EXPECT().Close() + scaler.EXPECT().Close(gomock.Any()) isActive, isError := scaleHandler.isScaledObjectActive(context.TODO(), scalers, scaledObject) @@ -85,9 +85,9 @@ func TestCheckScaledObjectFindFirstActiveIgnoringOthers(t *testing.T) { metricsSpecs := []v2beta2.MetricSpec{createMetricSpec(1)} activeScaler.EXPECT().IsActive(gomock.Any()).Return(true, nil) - activeScaler.EXPECT().GetMetricSpecForScaling().Times(2).Return(metricsSpecs) - activeScaler.EXPECT().Close() - failingScaler.EXPECT().Close() + activeScaler.EXPECT().GetMetricSpecForScaling(gomock.Any()).Times(2).Return(metricsSpecs) + activeScaler.EXPECT().Close(gomock.Any()) + failingScaler.EXPECT().Close(gomock.Any()) isActive, isError := scaleHandler.isScaledObjectActive(context.TODO(), scalers, scaledObject) diff --git a/pkg/scaling/scaledjob/scale_metrics.go b/pkg/scaling/scaledjob/scale_metrics.go index bcd60f8f6a0..1c702bd5042 100644 --- a/pkg/scaling/scaledjob/scale_metrics.go +++ b/pkg/scaling/scaledjob/scale_metrics.go @@ -89,7 +89,7 @@ func getScalersMetrics(ctx context.Context, scalers []scalers.Scaler, scaledJob scalerLogger := logger.WithValues("ScaledJob", scaledJob.Name, "Scaler", scalerType) - metricSpecs := scaler.GetMetricSpecForScaling() + metricSpecs := scaler.GetMetricSpecForScaling(ctx) // skip scaler that doesn't return any metric specs (usually External scaler with incorrect metadata) // or skip cpu/memory resource scaler @@ -101,7 +101,7 @@ func getScalersMetrics(ctx context.Context, scalers []scalers.Scaler, scaledJob if err != nil { scalerLogger.V(1).Info("Error getting scaler.IsActive, but continue", "Error", err) recorder.Event(scaledJob, corev1.EventTypeWarning, eventreason.KEDAScalerFailed, err.Error()) - scaler.Close() + scaler.Close(ctx) continue } @@ -111,7 +111,7 @@ func getScalersMetrics(ctx context.Context, scalers []scalers.Scaler, scaledJob if err != nil { scalerLogger.V(1).Info("Error getting scaler metrics, but continue", "Error", err) recorder.Event(scaledJob, corev1.EventTypeWarning, eventreason.KEDAScalerFailed, err.Error()) - scaler.Close() + scaler.Close(ctx) continue } @@ -125,7 +125,7 @@ func getScalersMetrics(ctx context.Context, scalers []scalers.Scaler, scaledJob } scalerLogger.V(1).Info("Scaler Metric value", "isTriggerActive", isTriggerActive, "queueLength", queueLength, "targetAverageValue", targetAverageValue) - scaler.Close() + scaler.Close(ctx) if isTriggerActive { isActive = true diff --git a/pkg/scaling/scaledjob/scale_metrics_test.go b/pkg/scaling/scaledjob/scale_metrics_test.go index 3e4a410db3b..64bbecd556b 100644 --- a/pkg/scaling/scaledjob/scale_metrics_test.go +++ b/pkg/scaling/scaledjob/scale_metrics_test.go @@ -199,8 +199,8 @@ func createScaler(ctrl *gomock.Controller, queueLength int64, averageValue int32 }, } scaler.EXPECT().IsActive(gomock.Any()).Return(isActive, nil) - scaler.EXPECT().GetMetricSpecForScaling().Return(metricsSpecs) + scaler.EXPECT().GetMetricSpecForScaling(gomock.Any()).Return(metricsSpecs) scaler.EXPECT().GetMetrics(gomock.Any(), metricName, nil).Return(metrics, nil) - scaler.EXPECT().Close() + scaler.EXPECT().Close(gomock.Any()) return scaler }