diff --git a/aws_config_test.go b/aws_config_test.go index b3f9abc8..e1ec6724 100644 --- a/aws_config_test.go +++ b/aws_config_test.go @@ -2457,6 +2457,184 @@ ca_bundle = no-such-file } } +func TestAssumeRole(t *testing.T) { + testCases := map[string]struct { + Config *Config + SharedConfigurationFile string + ExpectedCredentialsValue aws.Credentials + ExpectedError func(err error) bool + MockStsEndpoints []*servicemocks.MockEndpoint + }{ + "config": { + Config: &Config{ + AssumeRole: &AssumeRole{ + RoleARN: servicemocks.MockStsAssumeRoleArn, + SessionName: servicemocks.MockStsAssumeRoleSessionName, + }, + AccessKey: servicemocks.MockStaticAccessKey, + SecretKey: servicemocks.MockStaticSecretKey, + }, + ExpectedCredentialsValue: mockdata.MockStsAssumeRoleCredentials, + MockStsEndpoints: []*servicemocks.MockEndpoint{ + servicemocks.MockStsAssumeRoleValidEndpoint, + }, + }, + + "shared configuration file": { + Config: &Config{}, + SharedConfigurationFile: fmt.Sprintf(` +[default] +role_arn = %[1]s +role_session_name = %[2]s +source_profile = SharedConfigurationSourceProfile + +[profile SharedConfigurationSourceProfile] +aws_access_key_id = SharedConfigurationSourceAccessKey +aws_secret_access_key = SharedConfigurationSourceSecretKey +`, servicemocks.MockStsAssumeRoleArn, servicemocks.MockStsAssumeRoleSessionName), + ExpectedCredentialsValue: mockdata.MockStsAssumeRoleCredentials, + MockStsEndpoints: []*servicemocks.MockEndpoint{ + servicemocks.MockStsAssumeRoleValidEndpoint, + }, + }, + + "config overrides shared configuration": { + Config: &Config{ + AssumeRole: &AssumeRole{ + RoleARN: servicemocks.MockStsAssumeRoleArn, + SessionName: servicemocks.MockStsAssumeRoleSessionName, + }, + AccessKey: servicemocks.MockStaticAccessKey, + SecretKey: servicemocks.MockStaticSecretKey, + }, + SharedConfigurationFile: fmt.Sprintf(` +[default] +role_arn = %[1]s +role_session_name = %[2]s +source_profile = SharedConfigurationSourceProfile + +[profile SharedConfigurationSourceProfile] +aws_access_key_id = SharedConfigurationSourceAccessKey +aws_secret_access_key = SharedConfigurationSourceSecretKey +`, servicemocks.MockStsAssumeRoleArn, servicemocks.MockStsAssumeRoleSessionName), + ExpectedCredentialsValue: mockdata.MockStsAssumeRoleCredentials, + MockStsEndpoints: []*servicemocks.MockEndpoint{ + servicemocks.MockStsAssumeRoleValidEndpoint, + }, + }, + + "with duration": { + Config: &Config{ + AssumeRole: &AssumeRole{ + RoleARN: servicemocks.MockStsAssumeRoleArn, + SessionName: servicemocks.MockStsAssumeRoleSessionName, + Duration: 1 * time.Hour, + }, + AccessKey: servicemocks.MockStaticAccessKey, + SecretKey: servicemocks.MockStaticSecretKey, + }, + ExpectedCredentialsValue: mockdata.MockStsAssumeRoleCredentials, + MockStsEndpoints: []*servicemocks.MockEndpoint{ + servicemocks.MockStsAssumeRoleValidEndpointWithOptions(map[string]string{"DurationSeconds": "3600"}), + }, + }, + + "with policy": { + Config: &Config{ + AssumeRole: &AssumeRole{ + RoleARN: servicemocks.MockStsAssumeRoleArn, + SessionName: servicemocks.MockStsAssumeRoleSessionName, + Policy: "{}", + }, + AccessKey: servicemocks.MockStaticAccessKey, + SecretKey: servicemocks.MockStaticSecretKey, + }, + ExpectedCredentialsValue: mockdata.MockStsAssumeRoleCredentials, + MockStsEndpoints: []*servicemocks.MockEndpoint{ + servicemocks.MockStsAssumeRoleValidEndpointWithOptions(map[string]string{"Policy": "{}"}), + }, + }, + + "invalid empty config": { + Config: &Config{ + AssumeRole: &AssumeRole{}, + AccessKey: servicemocks.MockStaticAccessKey, + SecretKey: servicemocks.MockStaticSecretKey, + }, + ExpectedCredentialsValue: mockdata.MockStsAssumeRoleCredentials, + ExpectedError: func(err error) bool { + return strings.Contains(err.Error(), "role ARN not set") + }, + }, + } + + for testName, testCase := range testCases { + testCase := testCase + + t.Run(testName, func(t *testing.T) { + oldEnv := servicemocks.InitSessionTestEnv() + defer servicemocks.PopEnv(oldEnv) + + closeSts, _, stsEndpoint := mockdata.GetMockedAwsApiSession("STS", testCase.MockStsEndpoints) + defer closeSts() + + testCase.Config.StsEndpoint = stsEndpoint + + tempdir, err := ioutil.TempDir("", "temp") + if err != nil { + t.Fatalf("error creating temp dir: %s", err) + } + defer os.Remove(tempdir) + os.Setenv("TMPDIR", tempdir) + + if testCase.SharedConfigurationFile != "" { + file, err := ioutil.TempFile("", "aws-sdk-go-base-shared-configuration-file") + + if err != nil { + t.Fatalf("unexpected error creating temporary shared configuration file: %s", err) + } + + defer os.Remove(file.Name()) + + err = ioutil.WriteFile(file.Name(), []byte(testCase.SharedConfigurationFile), 0600) + + if err != nil { + t.Fatalf("unexpected error writing shared configuration file: %s", err) + } + + testCase.Config.SharedConfigFiles = []string{file.Name()} + } + + testCase.Config.SkipCredsValidation = true + + awsConfig, err := GetAwsConfig(context.Background(), testCase.Config) + + if err != nil { + if testCase.ExpectedError == nil { + t.Fatalf("expected no error, got '%[1]T' error: %[1]s", err) + } + + if !testCase.ExpectedError(err) { + t.Fatalf("unexpected GetAwsConfig() '%[1]T' error: %[1]s", err) + } + + t.Logf("received expected '%[1]T' error: %[1]s", err) + return + } + + credentialsValue, err := awsConfig.Credentials.Retrieve(context.Background()) + + if err != nil { + t.Fatalf("unexpected credentials Retrieve() error: %s", err) + } + + if diff := cmp.Diff(credentialsValue, testCase.ExpectedCredentialsValue, cmpopts.IgnoreFields(aws.Credentials{}, "Expires")); diff != "" { + t.Fatalf("unexpected credentials: (- got, + expected)\n%s", diff) + } + }) + } +} + func TestAssumeRoleWithWebIdentity(t *testing.T) { testCases := map[string]struct { Config *Config @@ -2467,6 +2645,7 @@ func TestAssumeRoleWithWebIdentity(t *testing.T) { SharedConfigurationFile string SetSharedConfigurationFile bool ExpectedCredentialsValue aws.Credentials + ExpectedError func(err error) bool MockStsEndpoints []*servicemocks.MockEndpoint }{ "config with inline token": { @@ -2606,6 +2785,28 @@ web_identity_token_file = no-such-file servicemocks.MockStsAssumeRoleWithWebIdentityValidWithOptions(map[string]string{"Policy": "{}"}), }, }, + + "invalid empty config": { + Config: &Config{ + AssumeRoleWithWebIdentity: &AssumeRoleWithWebIdentity{}, + }, + ExpectedCredentialsValue: mockdata.MockStsAssumeRoleWithWebIdentityCredentials, + ExpectedError: func(err error) bool { + return strings.Contains(err.Error(), "role ARN not set") + }, + }, + + "invalid no token": { + Config: &Config{ + AssumeRoleWithWebIdentity: &AssumeRoleWithWebIdentity{ + RoleARN: servicemocks.MockStsAssumeRoleWithWebIdentityArn, + }, + }, + ExpectedCredentialsValue: mockdata.MockStsAssumeRoleWithWebIdentityCredentials, + ExpectedError: func(err error) bool { + return strings.Contains(err.Error(), "one of WebIdentityToken, WebIdentityTokenFile must be set") + }, + }, } for testName, testCase := range testCases { @@ -2689,8 +2890,18 @@ web_identity_token_file = no-such-file testCase.Config.SkipCredsValidation = true awsConfig, err := GetAwsConfig(context.Background(), testCase.Config) + if err != nil { - t.Fatalf("error in GetAwsConfig() '%[1]T': %[1]s", err) + if testCase.ExpectedError == nil { + t.Fatalf("expected no error, got '%[1]T' error: %[1]s", err) + } + + if !testCase.ExpectedError(err) { + t.Fatalf("unexpected GetAwsConfig() '%[1]T' error: %[1]s", err) + } + + t.Logf("received expected '%[1]T' error: %[1]s", err) + return } credentialsValue, err := awsConfig.Credentials.Retrieve(context.Background()) diff --git a/credentials.go b/credentials.go index 4b051a87..2b3bac21 100644 --- a/credentials.go +++ b/credentials.go @@ -2,6 +2,7 @@ package awsbase import ( "context" + "errors" "fmt" "log" "os" @@ -136,8 +137,11 @@ func getCredentialsProvider(ctx context.Context, c *Config) (aws.CredentialsProv // This can probably be configured directly in commonLoadOptions() once // https://github.com/aws/aws-sdk-go-v2/pull/1682 is merged if c.AssumeRoleWithWebIdentity != nil { + if c.AssumeRoleWithWebIdentity.RoleARN == "" { + return nil, "", errors.New("Assume Role With Web Identity: role ARN not set") + } if c.AssumeRoleWithWebIdentity.WebIdentityToken == "" && c.AssumeRoleWithWebIdentity.WebIdentityTokenFile == "" { - return nil, "", c.NewCannotAssumeRoleWithWebIdentityError(fmt.Errorf("one of: WebIdentityToken, WebIdentityTokenFile must be set")) + return nil, "", errors.New("Assume Role With Web Identity: one of WebIdentityToken, WebIdentityTokenFile must be set") } provider, err := webIdentityCredentialsProvider(ctx, cfg, c) if err != nil { @@ -156,7 +160,7 @@ Error: %w`, err) return nil, "", c.NewNoValidCredentialSourcesError(err) } - if c.AssumeRole == nil || c.AssumeRole.RoleARN == "" { + if c.AssumeRole == nil { return cfg.Credentials, creds.Source, nil } @@ -192,6 +196,11 @@ func webIdentityCredentialsProvider(ctx context.Context, awsConfig aws.Config, c func assumeRoleCredentialsProvider(ctx context.Context, awsConfig aws.Config, c *Config) (aws.CredentialsProvider, error) { ar := c.AssumeRole + + if ar.RoleARN == "" { + return nil, errors.New("Assume Role: role ARN not set") + } + // When assuming a role, we need to first authenticate the base credentials above, then assume the desired role log.Printf("[INFO] Assuming IAM Role %q (SessionName: %q, ExternalId: %q)", ar.RoleARN, ar.SessionName, ar.ExternalID) diff --git a/internal/endpoints/endpoints.go b/internal/endpoints/endpoints.go index fa1fc611..bfbece6a 100644 --- a/internal/endpoints/endpoints.go +++ b/internal/endpoints/endpoints.go @@ -19,9 +19,7 @@ type Partition struct { func (p Partition) Regions() []string { rs := make([]string, len(p.p.regions)) - for i, v := range p.p.regions { - rs[i] = v - } + copy(rs, p.p.regions) return rs } diff --git a/v2/awsv1shim/session_test.go b/v2/awsv1shim/session_test.go index c1ecdb46..0816d230 100644 --- a/v2/awsv1shim/session_test.go +++ b/v2/awsv1shim/session_test.go @@ -10,6 +10,7 @@ import ( "os" "path/filepath" "runtime" + "strings" "testing" "time" @@ -1849,6 +1850,202 @@ ca_bundle = no-such-file } } +func TestAssumeRole(t *testing.T) { + testCases := map[string]struct { + Config *awsbase.Config + SharedConfigurationFile string + ExpectedCredentialsValue credentials.Value + ExpectedError func(err error) bool + MockStsEndpoints []*servicemocks.MockEndpoint + }{ + "config": { + Config: &awsbase.Config{ + AssumeRole: &awsbase.AssumeRole{ + RoleARN: servicemocks.MockStsAssumeRoleArn, + SessionName: servicemocks.MockStsAssumeRoleSessionName, + }, + AccessKey: servicemocks.MockStaticAccessKey, + SecretKey: servicemocks.MockStaticSecretKey, + }, + ExpectedCredentialsValue: mockdata.MockStsAssumeRoleCredentials, + MockStsEndpoints: []*servicemocks.MockEndpoint{ + servicemocks.MockStsAssumeRoleValidEndpoint, + }, + }, + + "shared configuration file": { + Config: &awsbase.Config{}, + SharedConfigurationFile: fmt.Sprintf(` +[default] +role_arn = %[1]s +role_session_name = %[2]s +source_profile = SharedConfigurationSourceProfile + +[profile SharedConfigurationSourceProfile] +aws_access_key_id = SharedConfigurationSourceAccessKey +aws_secret_access_key = SharedConfigurationSourceSecretKey +`, servicemocks.MockStsAssumeRoleArn, servicemocks.MockStsAssumeRoleSessionName), + ExpectedCredentialsValue: mockdata.MockStsAssumeRoleCredentials, + MockStsEndpoints: []*servicemocks.MockEndpoint{ + servicemocks.MockStsAssumeRoleValidEndpoint, + }, + }, + + "config overrides shared configuration": { + Config: &awsbase.Config{ + AssumeRole: &awsbase.AssumeRole{ + RoleARN: servicemocks.MockStsAssumeRoleArn, + SessionName: servicemocks.MockStsAssumeRoleSessionName, + }, + AccessKey: servicemocks.MockStaticAccessKey, + SecretKey: servicemocks.MockStaticSecretKey, + }, + SharedConfigurationFile: fmt.Sprintf(` +[default] +role_arn = %[1]s +role_session_name = %[2]s +source_profile = SharedConfigurationSourceProfile + +[profile SharedConfigurationSourceProfile] +aws_access_key_id = SharedConfigurationSourceAccessKey +aws_secret_access_key = SharedConfigurationSourceSecretKey +`, servicemocks.MockStsAssumeRoleArn, servicemocks.MockStsAssumeRoleSessionName), + ExpectedCredentialsValue: mockdata.MockStsAssumeRoleCredentials, + MockStsEndpoints: []*servicemocks.MockEndpoint{ + servicemocks.MockStsAssumeRoleValidEndpoint, + }, + }, + + "with duration": { + Config: &awsbase.Config{ + AssumeRole: &awsbase.AssumeRole{ + RoleARN: servicemocks.MockStsAssumeRoleArn, + SessionName: servicemocks.MockStsAssumeRoleSessionName, + Duration: 1 * time.Hour, + }, + AccessKey: servicemocks.MockStaticAccessKey, + SecretKey: servicemocks.MockStaticSecretKey, + }, + ExpectedCredentialsValue: mockdata.MockStsAssumeRoleCredentials, + MockStsEndpoints: []*servicemocks.MockEndpoint{ + servicemocks.MockStsAssumeRoleValidEndpointWithOptions(map[string]string{"DurationSeconds": "3600"}), + }, + }, + + "with policy": { + Config: &awsbase.Config{ + AssumeRole: &awsbase.AssumeRole{ + RoleARN: servicemocks.MockStsAssumeRoleArn, + SessionName: servicemocks.MockStsAssumeRoleSessionName, + Policy: "{}", + }, + AccessKey: servicemocks.MockStaticAccessKey, + SecretKey: servicemocks.MockStaticSecretKey, + }, + ExpectedCredentialsValue: mockdata.MockStsAssumeRoleCredentials, + MockStsEndpoints: []*servicemocks.MockEndpoint{ + servicemocks.MockStsAssumeRoleValidEndpointWithOptions(map[string]string{"Policy": "{}"}), + }, + }, + + "invalid empty config": { + Config: &awsbase.Config{ + AssumeRole: &awsbase.AssumeRole{}, + AccessKey: servicemocks.MockStaticAccessKey, + SecretKey: servicemocks.MockStaticSecretKey, + }, + ExpectedCredentialsValue: mockdata.MockStsAssumeRoleCredentials, + ExpectedError: func(err error) bool { + return strings.Contains(err.Error(), "role ARN not set") + }, + }, + } + + for testName, testCase := range testCases { + testCase := testCase + + t.Run(testName, func(t *testing.T) { + oldEnv := servicemocks.InitSessionTestEnv() + defer servicemocks.PopEnv(oldEnv) + + closeSts, mockStsSession, err := mockdata.GetMockedAwsApiSession("STS", testCase.MockStsEndpoints) + defer closeSts() + + if err != nil { + t.Fatalf("unexpected error creating mock STS server: %s", err) + } + + if mockStsSession != nil && mockStsSession.Config != nil { + testCase.Config.StsEndpoint = aws.StringValue(mockStsSession.Config.Endpoint) + } + + tempdir, err := ioutil.TempDir("", "temp") + if err != nil { + t.Fatalf("error creating temp dir: %s", err) + } + defer os.Remove(tempdir) + os.Setenv("TMPDIR", tempdir) + + if testCase.SharedConfigurationFile != "" { + file, err := ioutil.TempFile("", "aws-sdk-go-base-shared-configuration-file") + + if err != nil { + t.Fatalf("unexpected error creating temporary shared configuration file: %s", err) + } + + defer os.Remove(file.Name()) + + err = ioutil.WriteFile(file.Name(), []byte(testCase.SharedConfigurationFile), 0600) + + if err != nil { + t.Fatalf("unexpected error writing shared configuration file: %s", err) + } + + testCase.Config.SharedConfigFiles = []string{file.Name()} + } + + testCase.Config.SkipCredsValidation = true + + awsConfig, err := awsbase.GetAwsConfig(context.Background(), testCase.Config) + if err != nil { + if testCase.ExpectedError == nil { + t.Fatalf("expected no error, got '%[1]T' error: %[1]s", err) + } + + if !testCase.ExpectedError(err) { + t.Fatalf("unexpected GetAwsConfig() '%[1]T' error: %[1]s", err) + } + + t.Logf("received expected '%[1]T' error: %[1]s", err) + return + } + actualSession, err := GetSession(&awsConfig, testCase.Config) + if err != nil { + if testCase.ExpectedError == nil { + t.Fatalf("expected no error, got '%[1]T' error: %[1]s", err) + } + + if !testCase.ExpectedError(err) { + t.Fatalf("unexpected GetSession() '%[1]T' error: %[1]s", err) + } + + t.Logf("received expected '%[1]T' error: %[1]s", err) + return + } + + credentialsValue, err := actualSession.Config.Credentials.Get() + + if err != nil { + t.Fatalf("unexpected credentials Get() error: %s", err) + } + + if diff := cmp.Diff(credentialsValue, testCase.ExpectedCredentialsValue, cmpopts.IgnoreFields(credentials.Value{}, "ProviderName")); diff != "" { + t.Fatalf("unexpected credentials: (- got, + expected)\n%s", diff) + } + }) + } +} + func TestAssumeRoleWithWebIdentity(t *testing.T) { testCases := map[string]struct { Config *awsbase.Config @@ -1859,6 +2056,7 @@ func TestAssumeRoleWithWebIdentity(t *testing.T) { SharedConfigurationFile string SetSharedConfigurationFile bool ExpectedCredentialsValue credentials.Value + ExpectedError func(err error) bool MockStsEndpoints []*servicemocks.MockEndpoint }{ "config with inline token": { @@ -1998,6 +2196,28 @@ web_identity_token_file = no-such-file servicemocks.MockStsAssumeRoleWithWebIdentityValidWithOptions(map[string]string{"Policy": "{}"}), }, }, + + "invalid empty config": { + Config: &awsbase.Config{ + AssumeRoleWithWebIdentity: &awsbase.AssumeRoleWithWebIdentity{}, + }, + ExpectedCredentialsValue: mockdata.MockStsAssumeRoleWithWebIdentityCredentials, + ExpectedError: func(err error) bool { + return strings.Contains(err.Error(), "role ARN not set") + }, + }, + + "invalid no token": { + Config: &awsbase.Config{ + AssumeRoleWithWebIdentity: &awsbase.AssumeRoleWithWebIdentity{ + RoleARN: servicemocks.MockStsAssumeRoleWithWebIdentityArn, + }, + }, + ExpectedCredentialsValue: mockdata.MockStsAssumeRoleWithWebIdentityCredentials, + ExpectedError: func(err error) bool { + return strings.Contains(err.Error(), "one of WebIdentityToken, WebIdentityTokenFile must be set") + }, + }, } for testName, testCase := range testCases { @@ -2088,11 +2308,29 @@ web_identity_token_file = no-such-file awsConfig, err := awsbase.GetAwsConfig(context.Background(), testCase.Config) if err != nil { - t.Fatalf("GetAwsConfig() returned error: %s", err) + if testCase.ExpectedError == nil { + t.Fatalf("expected no error, got '%[1]T' error: %[1]s", err) + } + + if !testCase.ExpectedError(err) { + t.Fatalf("unexpected GetAwsConfig() '%[1]T' error: %[1]s", err) + } + + t.Logf("received expected '%[1]T' error: %[1]s", err) + return } actualSession, err := GetSession(&awsConfig, testCase.Config) if err != nil { - t.Fatalf("error in GetSession() '%[1]T': %[1]s", err) + if testCase.ExpectedError == nil { + t.Fatalf("expected no error, got '%[1]T' error: %[1]s", err) + } + + if !testCase.ExpectedError(err) { + t.Fatalf("unexpected GetSession() '%[1]T' error: %[1]s", err) + } + + t.Logf("received expected '%[1]T' error: %[1]s", err) + return } credentialsValue, err := actualSession.Config.Credentials.Get()