diff --git a/aws_config.go b/aws_config.go index 9092525d..cbb696cc 100644 --- a/aws_config.go +++ b/aws_config.go @@ -38,9 +38,15 @@ func configCommonLogging(ctx context.Context) context.Context { func GetAwsConfig(ctx context.Context, c *Config) (context.Context, aws.Config, diag.Diagnostics) { var diags diag.Diagnostics + + var logger logging.Logger = logging.NullLogger{} + if c.Logger != nil { + logger = c.Logger + } + ctx = logging.RegisterLogger(ctx, logger) ctx = configCommonLogging(ctx) - baseCtx, logger := logging.New(ctx, loggerName) + baseCtx, logger := logger.SubLogger(ctx, loggerName) baseCtx = logging.RegisterLogger(baseCtx, logger) logger.Trace(baseCtx, "Resolving AWS configuration") @@ -209,8 +215,13 @@ func (r *networkErrorShortcutter) RetryDelay(attempt int, err error) (time.Durat func GetAwsAccountIDAndPartition(ctx context.Context, awsConfig aws.Config, c *Config) (string, string, diag.Diagnostics) { var diags diag.Diagnostics + + var logger logging.Logger = logging.NullLogger{} + if c.Logger != nil { + logger = c.Logger + } ctx = configCommonLogging(ctx) - ctx, logger := logging.New(ctx, loggerName) + ctx, logger = logger.SubLogger(ctx, loggerName) ctx = logging.RegisterLogger(ctx, logger) if !c.SkipCredsValidation { diff --git a/aws_config_test.go b/aws_config_test.go index 9a89f6fd..1c564ee2 100644 --- a/aws_config_test.go +++ b/aws_config_test.go @@ -8,6 +8,7 @@ import ( "context" "errors" "fmt" + "io" "net" "net/http" "os" @@ -31,9 +32,12 @@ import ( "github.com/hashicorp/aws-sdk-go-base/v2/internal/awsconfig" "github.com/hashicorp/aws-sdk-go-base/v2/internal/constants" "github.com/hashicorp/aws-sdk-go-base/v2/internal/test" + "github.com/hashicorp/aws-sdk-go-base/v2/logging" "github.com/hashicorp/aws-sdk-go-base/v2/mockdata" "github.com/hashicorp/aws-sdk-go-base/v2/servicemocks" "github.com/hashicorp/aws-sdk-go-base/v2/useragent" + "github.com/hashicorp/go-hclog" + "github.com/hashicorp/terraform-plugin-log/tflog" "github.com/hashicorp/terraform-plugin-log/tflogtest" ) @@ -3466,15 +3470,19 @@ func (r *withNoDelay) RetryDelay(attempt int, err error) (time.Duration, error) return 0 * time.Second, nil } -func TestLogger(t *testing.T) { +func TestLogger_TfLog(t *testing.T) { + ctx := context.Background() var buf bytes.Buffer - ctx := tflogtest.RootLogger(context.Background(), &buf) + ctx = tflogtest.RootLogger(ctx, &buf) oldEnv := servicemocks.InitSessionTestEnv() defer servicemocks.PopEnv(oldEnv) + ctx, logger := logging.NewTfLogger(ctx) + config := &Config{ AccessKey: servicemocks.MockStaticAccessKey, + Logger: logger, Region: "us-east-1", SecretKey: servicemocks.MockStaticSecretKey, } @@ -3497,6 +3505,9 @@ func TestLogger(t *testing.T) { t.Fatalf("GetAwsConfig: decoding log lines: %s", err) } + if len(lines) == 0 { + t.Fatalf("expected log entries, had none") + } for i, line := range lines { if a, e := line["@module"], expectedName; a != e { t.Errorf("GetAwsConfig: line %d: expected module %q, got %q", i+1, e, a) @@ -3513,9 +3524,145 @@ func TestLogger(t *testing.T) { t.Fatalf("GetAwsAccountIDAndPartition: decoding log lines: %s", err) } + if len(lines) == 0 { + t.Fatalf("expected log entries, had none") + } for i, line := range lines { if a, e := line["@module"], expectedName; a != e { t.Errorf("GetAwsAccountIDAndPartition: line %d: expected module %q, got %q", i+1, e, a) } } } + +func TestLoggerDefaultMasking_TfLog(t *testing.T) { + ctx := context.Background() + var buf bytes.Buffer + ctx = tflogtest.RootLogger(ctx, &buf) + + oldEnv := servicemocks.InitSessionTestEnv() + defer servicemocks.PopEnv(oldEnv) + + config := &Config{ + AccessKey: servicemocks.MockStaticAccessKey, + Region: "us-east-1", + SecretKey: servicemocks.MockStaticSecretKey, + } + + ts := servicemocks.MockAwsApiServer("STS", []*servicemocks.MockEndpoint{ + servicemocks.MockStsGetCallerIdentityValidEndpoint, + }) + defer ts.Close() + config.StsEndpoint = ts.URL + + ctx, _, diags := GetAwsConfig(ctx, config) + if diags.HasError() { + t.Fatalf("error in GetAwsConfig(): %v", diags) + } + + buf.Reset() + + tflog.Info(ctx, "message", map[string]any{ + "id": "AKIAI44QH8DHBEXAMPLE", + }) + + lines, err := tflogtest.MultilineJSONDecode(&buf) + if err != nil { + t.Fatalf("decoding log lines: %s", err) + } + + if l := len(lines); l != 1 { + t.Fatalf("expected 1 log entry, got %d", l) + } + + line := lines[0] + if a, e := line["id"], "***"; a != e { + t.Errorf("expected %q, got %q", e, a) + } +} + +func TestLogger_HcLog(t *testing.T) { + ctx := context.Background() + + rootName := "hc-log-test" + expectedName := rootName + "." + loggerName + + var buf bytes.Buffer + hclogger := configureHcLogger(rootName, &buf) + + oldEnv := servicemocks.InitSessionTestEnv() + defer servicemocks.PopEnv(oldEnv) + + ctx, logger := logging.NewHcLogger(ctx, hclogger) + + config := &Config{ + AccessKey: servicemocks.MockStaticAccessKey, + Logger: logger, + Region: "us-east-1", + SecretKey: servicemocks.MockStaticSecretKey, + } + + ts := servicemocks.MockAwsApiServer("STS", []*servicemocks.MockEndpoint{ + servicemocks.MockStsGetCallerIdentityValidEndpoint, + }) + defer ts.Close() + config.StsEndpoint = ts.URL + + ctx, awsConfig, diags := GetAwsConfig(ctx, config) + if diags.HasError() { + t.Fatalf("error in GetAwsConfig(): %v", diags) + } + + lines, err := tflogtest.MultilineJSONDecode(&buf) + if err != nil { + t.Fatalf("GetAwsConfig: decoding log lines: %s", err) + } + + if len(lines) == 0 { + t.Fatalf("expected log entries, had none") + } + for i, line := range lines { + if a, e := line["@module"], expectedName; a != e { + t.Errorf("GetAwsConfig: line %d: expected module %q, got %q", i+1, e, a) + } + } + + _, _, diags = GetAwsAccountIDAndPartition(ctx, awsConfig, config) + if diags.HasError() { + t.Fatalf("GetAwsAccountIDAndPartition: unexpected '%[1]T': %[1]s", err) + } + + lines, err = tflogtest.MultilineJSONDecode(&buf) + if err != nil { + t.Fatalf("GetAwsAccountIDAndPartition: decoding log lines: %s", err) + } + + if len(lines) == 0 { + t.Fatalf("expected log entries, had none") + } + for i, line := range lines { + if a, e := line["@module"], expectedName; a != e { + t.Errorf("GetAwsAccountIDAndPartition: line %d: expected module %q, got %q", i+1, e, a) + } + } +} + +// configureHcLogger configures the default logger with settings suitable for testing: +// +// - Log level set to TRACE +// - Written to the io.Writer passed in, such as a bytes.Buffer +// - Log entries are in JSON format, and can be decoded using multilineJSONDecode +// - Caller information is not included +// - Timestamp is not included +func configureHcLogger(name string, output io.Writer) hclog.Logger { + logger := hclog.NewInterceptLogger(&hclog.LoggerOptions{ + Name: name, + Level: hclog.Trace, + Output: output, + IndependentLevels: true, + JSONFormat: true, + IncludeLocation: false, + DisableTime: true, + }) + + return logger +} diff --git a/go.mod b/go.mod index 45722679..015f5a2c 100644 --- a/go.mod +++ b/go.mod @@ -11,6 +11,7 @@ require ( github.com/aws/aws-sdk-go-v2/service/sts v1.21.2 github.com/aws/smithy-go v1.14.1 github.com/google/go-cmp v0.5.9 + github.com/hashicorp/go-hclog v1.5.0 github.com/hashicorp/go-multierror v1.1.1 github.com/hashicorp/terraform-plugin-log v0.9.0 github.com/mitchellh/go-homedir v1.1.0 @@ -27,7 +28,6 @@ require ( github.com/aws/aws-sdk-go-v2/service/ssooidc v1.15.2 // indirect github.com/fatih/color v1.15.0 // indirect github.com/hashicorp/errwrap v1.1.0 // indirect - github.com/hashicorp/go-hclog v1.5.0 // indirect github.com/mattn/go-colorable v0.1.13 // indirect github.com/mattn/go-isatty v0.0.17 // indirect github.com/mitchellh/go-testing-interface v1.14.1 // indirect diff --git a/internal/config/config.go b/internal/config/config.go index fc2accea..d613acc0 100644 --- a/internal/config/config.go +++ b/internal/config/config.go @@ -16,6 +16,7 @@ import ( awshttp "github.com/aws/aws-sdk-go-v2/aws/transport/http" "github.com/aws/aws-sdk-go-v2/feature/ec2/imds" "github.com/hashicorp/aws-sdk-go-base/v2/internal/expand" + "github.com/hashicorp/aws-sdk-go-base/v2/logging" ) type Config struct { @@ -33,6 +34,7 @@ type Config struct { HTTPProxy string IamEndpoint string Insecure bool + Logger logging.Logger MaxRetries int Profile string Region string diff --git a/logging/context.go b/logging/context.go index 8ff753a2..0d18f3f8 100644 --- a/logging/context.go +++ b/logging/context.go @@ -11,14 +11,14 @@ type loggerKeyT string const loggerKey loggerKeyT = "logger-key" -func RegisterLogger(ctx context.Context, logger TfLogger) context.Context { +func RegisterLogger(ctx context.Context, logger Logger) context.Context { return context.WithValue(ctx, loggerKey, logger) } -func RetrieveLogger(ctx context.Context) TfLogger { - logger, ok := ctx.Value(loggerKey).(TfLogger) +func RetrieveLogger(ctx context.Context) Logger { + logger, ok := ctx.Value(loggerKey).(Logger) if !ok { - return TfLogger("") + return NullLogger{} } return logger } diff --git a/logging/hc_logger.go b/logging/hc_logger.go new file mode 100644 index 00000000..d8ce74d3 --- /dev/null +++ b/logging/hc_logger.go @@ -0,0 +1,71 @@ +// Copyright (c) HashiCorp, Inc. +// SPDX-License-Identifier: MPL-2.0 + +package logging + +import ( + "context" + + "github.com/hashicorp/go-hclog" +) + +type HcLogger struct{} + +var _ Logger = HcLogger{} + +func NewHcLogger(ctx context.Context, logger hclog.Logger) (context.Context, HcLogger) { + ctx = hclog.WithContext(ctx, logger) + + return ctx, HcLogger{} +} + +func (l HcLogger) SubLogger(ctx context.Context, name string) (context.Context, Logger) { + logger := hclog.FromContext(ctx) + logger = logger.Named(name) + ctx = hclog.WithContext(ctx, logger) + + return ctx, HcLogger{} +} + +func (l HcLogger) Warn(ctx context.Context, msg string, fields ...map[string]any) { + logger := hclog.FromContext(ctx) + logger.Warn(msg, flattenFields(fields...)...) +} + +func (l HcLogger) Info(ctx context.Context, msg string, fields ...map[string]any) { + logger := hclog.FromContext(ctx) + logger.Info(msg, flattenFields(fields...)...) +} + +func (l HcLogger) Debug(ctx context.Context, msg string, fields ...map[string]any) { + logger := hclog.FromContext(ctx) + logger.Debug(msg, flattenFields(fields...)...) +} + +func (l HcLogger) Trace(ctx context.Context, msg string, fields ...map[string]any) { + logger := hclog.FromContext(ctx) + logger.Trace(msg, flattenFields(fields...)...) +} + +// TODO: how to handle duplicates +func flattenFields(fields ...map[string]any) []any { + var totalLen int + for _, m := range fields { + totalLen = len(m) + } + f := make([]any, 0, totalLen*2) //nolint:gomnd + + for _, m := range fields { + for k, v := range m { + f = append(f, k, v) + } + } + return f +} + +func (l HcLogger) SetField(ctx context.Context, key string, value any) context.Context { + logger := hclog.FromContext(ctx) + logger = logger.With(key, value) + ctx = hclog.WithContext(ctx, logger) + return ctx +} diff --git a/logging/hc_logger_test.go b/logging/hc_logger_test.go new file mode 100644 index 00000000..4602256a --- /dev/null +++ b/logging/hc_logger_test.go @@ -0,0 +1,52 @@ +// Copyright (c) HashiCorp, Inc. +// SPDX-License-Identifier: MPL-2.0 + +package logging + +import ( + "context" + "io" + "testing" + + "github.com/hashicorp/go-hclog" +) + +const hclogRootName = "hc-log-test" + +func TestHcLoggerWarn(t *testing.T) { + testLoggerWarn(t, hclogRootName, hcLoggerFactory) +} + +func TestHcLoggerSetField(t *testing.T) { + testLoggerSetField(t, hclogRootName, hcLoggerFactory) +} + +func hcLoggerFactory(ctx context.Context, name string, output io.Writer) (context.Context, Logger) { + hclogger := configureHcLogger(output) + + ctx, rootLogger := NewHcLogger(ctx, hclogger) + ctx, logger := rootLogger.SubLogger(ctx, name) + + return ctx, logger +} + +// configureHcLogger configures the default logger with settings suitable for testing: +// +// - Log level set to TRACE +// - Written to the io.Writer passed in, such as a bytes.Buffer +// - Log entries are in JSON format, and can be decoded using multilineJSONDecode +// - Caller information is not included +// - Timestamp is not included +func configureHcLogger(output io.Writer) hclog.Logger { + logger := hclog.NewInterceptLogger(&hclog.LoggerOptions{ + Name: hclogRootName, + Level: hclog.Trace, + Output: output, + IndependentLevels: true, + JSONFormat: true, + IncludeLocation: false, + DisableTime: true, + }) + + return logger +} diff --git a/logging/logger.go b/logging/logger.go index f3c7a883..c6f21519 100644 --- a/logging/logger.go +++ b/logging/logger.go @@ -5,55 +5,15 @@ package logging import ( "context" - - "github.com/hashicorp/terraform-plugin-log/tflog" ) -func New(ctx context.Context, name string) (context.Context, TfLogger) { - ctx = tflog.NewSubsystem(ctx, name, tflog.WithRootFields()) - logger := TfLogger(name) - - return ctx, logger -} - -type TfLogger string - -func (l TfLogger) Warn(ctx context.Context, msg string, fields ...map[string]any) { - if l == "" { - tflog.Warn(ctx, msg, fields...) - } else { - tflog.SubsystemWarn(ctx, string(l), msg, fields...) - } -} +type Logger interface { + Warn(ctx context.Context, msg string, fields ...map[string]any) + Info(ctx context.Context, msg string, fields ...map[string]any) + Debug(ctx context.Context, msg string, fields ...map[string]any) + Trace(ctx context.Context, msg string, fields ...map[string]any) -func (l TfLogger) Info(ctx context.Context, msg string, fields ...map[string]any) { - if l == "" { - tflog.Info(ctx, msg, fields...) - } else { - tflog.SubsystemInfo(ctx, string(l), msg, fields...) - } -} - -func (l TfLogger) Debug(ctx context.Context, msg string, fields ...map[string]any) { - if l == "" { - tflog.Debug(ctx, msg, fields...) - } else { - tflog.SubsystemDebug(ctx, string(l), msg, fields...) - } -} - -func (l TfLogger) Trace(ctx context.Context, msg string, fields ...map[string]any) { - if l == "" { - tflog.Trace(ctx, msg, fields...) - } else { - tflog.SubsystemTrace(ctx, string(l), msg, fields...) - } -} + SetField(ctx context.Context, key string, value any) context.Context -func (l TfLogger) SetField(ctx context.Context, key string, value any) context.Context { - if l == "" { - return tflog.SetField(ctx, key, value) - } else { - return tflog.SubsystemSetField(ctx, string(l), key, value) - } + SubLogger(ctx context.Context, name string) (context.Context, Logger) } diff --git a/logging/logger_test.go b/logging/logger_test.go new file mode 100644 index 00000000..dab7a4b9 --- /dev/null +++ b/logging/logger_test.go @@ -0,0 +1,97 @@ +// Copyright (c) HashiCorp, Inc. +// SPDX-License-Identifier: MPL-2.0 + +package logging + +import ( + "bytes" + "context" + "io" + "testing" + + "github.com/google/go-cmp/cmp" + "github.com/hashicorp/terraform-plugin-log/tflogtest" +) + +func testLoggerWarn(t *testing.T, rootName string, factory func(ctx context.Context, name string, output io.Writer) (context.Context, Logger)) { + t.Helper() + + loggerName := "test" + expectedModule := rootName + "." + loggerName + + var buf bytes.Buffer + ctx := context.Background() + ctx, logger := factory(ctx, loggerName, &buf) + + logger.Warn(ctx, "message", map[string]any{ + "one": int(1), + "two": "two", + }) + + lines, err := tflogtest.MultilineJSONDecode(&buf) + if err != nil { + t.Fatalf("decoding log lines: %s", err) + } + + expected := []map[string]any{ + { + "@level": "warn", + "@module": expectedModule, + "@message": "message", + "one": float64(1), + "two": "two", + }, + } + + if diff := cmp.Diff(expected, lines); diff != "" { + t.Errorf("unexpected logger output difference: %s", diff) + } +} + +func testLoggerSetField(t *testing.T, rootName string, factory func(ctx context.Context, name string, output io.Writer) (context.Context, Logger)) { + t.Helper() + + loggerName := "test" + expectedModule := rootName + "." + loggerName + + var buf bytes.Buffer + originalCtx := context.Background() + originalCtx, logger := factory(originalCtx, loggerName, &buf) + + newCtx := logger.SetField(originalCtx, "key", "value") + + logger.Warn(newCtx, "new logger") + logger.Warn(newCtx, "new logger", map[string]any{ + "key": "other value", + }) + logger.Warn(originalCtx, "original logger") + + lines, err := tflogtest.MultilineJSONDecode(&buf) + if err != nil { + t.Fatalf("ctxWithField: decoding log lines: %s", err) + } + + expected := []map[string]any{ + { + "@level": "warn", + "@module": expectedModule, + "@message": "new logger", + "key": "value", + }, + { + "@level": "warn", + "@module": expectedModule, + "@message": "new logger", + "key": "other value", + }, + { + "@level": "warn", + "@module": expectedModule, + "@message": "original logger", + }, + } + + if diff := cmp.Diff(expected, lines); diff != "" { + t.Errorf("unexpected logger output difference: %s", diff) + } +} diff --git a/logging/null_logger.go b/logging/null_logger.go new file mode 100644 index 00000000..ee703801 --- /dev/null +++ b/logging/null_logger.go @@ -0,0 +1,33 @@ +// Copyright (c) HashiCorp, Inc. +// SPDX-License-Identifier: MPL-2.0 + +package logging + +import ( + "context" +) + +type NullLogger struct { +} + +var _ Logger = NullLogger{} + +func (l NullLogger) SubLogger(ctx context.Context, name string) (context.Context, Logger) { + return ctx, l +} + +func (l NullLogger) Warn(ctx context.Context, msg string, fields ...map[string]any) { +} + +func (l NullLogger) Info(ctx context.Context, msg string, fields ...map[string]any) { +} + +func (l NullLogger) Debug(ctx context.Context, msg string, fields ...map[string]any) { +} + +func (l NullLogger) Trace(ctx context.Context, msg string, fields ...map[string]any) { +} + +func (l NullLogger) SetField(ctx context.Context, key string, value any) context.Context { + return ctx +} diff --git a/logging/tf_logger.go b/logging/tf_logger.go new file mode 100644 index 00000000..57a55e68 --- /dev/null +++ b/logging/tf_logger.go @@ -0,0 +1,73 @@ +// Copyright (c) HashiCorp, Inc. +// SPDX-License-Identifier: MPL-2.0 + +package logging + +import ( + "context" + + "github.com/hashicorp/terraform-plugin-log/tflog" +) + +type TfLogger string + +var _ Logger = TfLogger("") + +func NewTfLogger(ctx context.Context) (context.Context, TfLogger) { + return ctx, TfLogger("") +} + +func (l TfLogger) SubLogger(ctx context.Context, name string) (context.Context, Logger) { + ctx = tflog.NewSubsystem(ctx, name, tflog.WithRootFields()) + logger := TfLogger(name) + + return ctx, logger +} + +func (l TfLogger) Warn(ctx context.Context, msg string, fields ...map[string]any) { + if l == "" { + tflog.Warn(ctx, msg, fields...) + } else { + tflog.SubsystemWarn(ctx, string(l), msg, fields...) + } +} + +func (l TfLogger) Info(ctx context.Context, msg string, fields ...map[string]any) { + if l == "" { + tflog.Info(ctx, msg, fields...) + } else { + tflog.SubsystemInfo(ctx, string(l), msg, fields...) + } +} + +func (l TfLogger) Debug(ctx context.Context, msg string, fields ...map[string]any) { + if l == "" { + tflog.Debug(ctx, msg, fields...) + } else { + tflog.SubsystemDebug(ctx, string(l), msg, fields...) + } +} + +func (l TfLogger) Trace(ctx context.Context, msg string, fields ...map[string]any) { + if l == "" { + tflog.Trace(ctx, msg, fields...) + } else { + tflog.SubsystemTrace(ctx, string(l), msg, fields...) + } +} + +func (l TfLogger) SetField(ctx context.Context, key string, value any) context.Context { + if l == "" { + return tflog.SetField(ctx, key, value) + } else { + return tflog.SubsystemSetField(ctx, string(l), key, value) + } +} + +// func (l TfLogger) MaskAllFieldValuesRegexes(ctx context.Context, expressions ...*regexp.Regexp) context.Context { +// if l == "" { +// return tflog.MaskAllFieldValuesRegexes(ctx, expressions...) +// } else { +// return tflog.SubsystemMaskAllFieldValuesRegexes(ctx, string(l), expressions...) +// } +// } diff --git a/logging/tf_logger_test.go b/logging/tf_logger_test.go new file mode 100644 index 00000000..344e89d9 --- /dev/null +++ b/logging/tf_logger_test.go @@ -0,0 +1,31 @@ +// Copyright (c) HashiCorp, Inc. +// SPDX-License-Identifier: MPL-2.0 + +package logging + +import ( + "context" + "io" + "testing" + + "github.com/hashicorp/terraform-plugin-log/tflogtest" +) + +const tflogRootName = "provider" + +func TestTfLoggerWarn(t *testing.T) { + testLoggerWarn(t, tflogRootName, tfLoggerFactory) +} + +func TestTfLoggerSetField(t *testing.T) { + testLoggerSetField(t, tflogRootName, tfLoggerFactory) +} + +func tfLoggerFactory(ctx context.Context, name string, output io.Writer) (context.Context, Logger) { + ctx = tflogtest.RootLogger(ctx, output) + + ctx, rootLogger := NewTfLogger(ctx) + ctx, logger := rootLogger.SubLogger(ctx, name) + + return ctx, logger +} diff --git a/v2/awsv1shim/session.go b/v2/awsv1shim/session.go index 18cc6072..56e239ea 100644 --- a/v2/awsv1shim/session.go +++ b/v2/awsv1shim/session.go @@ -82,8 +82,11 @@ const loggerName string = "aws-base-v1" func GetSession(ctx context.Context, awsC *awsv2.Config, c *awsbase.Config) (*session.Session, diag.Diagnostics) { var diags diag.Diagnostics - // var loggerFactory tfLoggerFactory - ctx, logger := logging.New(ctx, loggerName) + var logger logging.Logger = logging.NullLogger{} + if c.Logger != nil { + logger = c.Logger + } + ctx, logger = logger.SubLogger(ctx, loggerName) ctx = logging.RegisterLogger(ctx, logger) options, err := getSessionOptions(ctx, awsC, c) diff --git a/v2/awsv1shim/session_test.go b/v2/awsv1shim/session_test.go index af9cb26c..b3b0b683 100644 --- a/v2/awsv1shim/session_test.go +++ b/v2/awsv1shim/session_test.go @@ -36,6 +36,7 @@ import ( "github.com/hashicorp/aws-sdk-go-base/v2/diag" "github.com/hashicorp/aws-sdk-go-base/v2/internal/constants" "github.com/hashicorp/aws-sdk-go-base/v2/internal/test" + "github.com/hashicorp/aws-sdk-go-base/v2/logging" "github.com/hashicorp/aws-sdk-go-base/v2/servicemocks" "github.com/hashicorp/aws-sdk-go-base/v2/useragent" "github.com/hashicorp/terraform-plugin-log/tflogtest" @@ -2564,6 +2565,7 @@ func TestSessionRetryHandlers(t *testing.T) { request, _ := iamconn.GetUserRequest(&iam.GetUserInput{}) request.RetryCount = testcase.RetryCount request.Error = testcase.Error + request.SetContext(ctx) // Prevent the retryer from using the default retry delay retryer := request.Retryer.(client.DefaultRetryer) @@ -2599,8 +2601,11 @@ func TestLogger(t *testing.T) { oldEnv := servicemocks.InitSessionTestEnv() defer servicemocks.PopEnv(oldEnv) + ctx, logger := logging.NewTfLogger(ctx) + config := &awsbase.Config{ AccessKey: servicemocks.MockStaticAccessKey, + Logger: logger, Region: "us-east-1", SecretKey: servicemocks.MockStaticSecretKey, }