Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add parameter validation for AssumeRole and AssumeRoleWithWebIdentity #277

Merged
merged 3 commits into from
Jun 6, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
213 changes: 212 additions & 1 deletion aws_config_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -2457,6 +2457,184 @@ ca_bundle = no-such-file
}
}

func TestAssumeRole(t *testing.T) {
testCases := map[string]struct {
Config *Config
SharedConfigurationFile string
ExpectedCredentialsValue aws.Credentials
ExpectedError func(err error) bool
MockStsEndpoints []*servicemocks.MockEndpoint
}{
"config": {
Config: &Config{
AssumeRole: &AssumeRole{
RoleARN: servicemocks.MockStsAssumeRoleArn,
SessionName: servicemocks.MockStsAssumeRoleSessionName,
},
AccessKey: servicemocks.MockStaticAccessKey,
SecretKey: servicemocks.MockStaticSecretKey,
},
ExpectedCredentialsValue: mockdata.MockStsAssumeRoleCredentials,
MockStsEndpoints: []*servicemocks.MockEndpoint{
servicemocks.MockStsAssumeRoleValidEndpoint,
},
},

"shared configuration file": {
Config: &Config{},
SharedConfigurationFile: fmt.Sprintf(`
[default]
role_arn = %[1]s
role_session_name = %[2]s
source_profile = SharedConfigurationSourceProfile

[profile SharedConfigurationSourceProfile]
aws_access_key_id = SharedConfigurationSourceAccessKey
aws_secret_access_key = SharedConfigurationSourceSecretKey
`, servicemocks.MockStsAssumeRoleArn, servicemocks.MockStsAssumeRoleSessionName),
ExpectedCredentialsValue: mockdata.MockStsAssumeRoleCredentials,
MockStsEndpoints: []*servicemocks.MockEndpoint{
servicemocks.MockStsAssumeRoleValidEndpoint,
},
},

"config overrides shared configuration": {
Config: &Config{
AssumeRole: &AssumeRole{
RoleARN: servicemocks.MockStsAssumeRoleArn,
SessionName: servicemocks.MockStsAssumeRoleSessionName,
},
AccessKey: servicemocks.MockStaticAccessKey,
SecretKey: servicemocks.MockStaticSecretKey,
},
SharedConfigurationFile: fmt.Sprintf(`
[default]
role_arn = %[1]s
role_session_name = %[2]s
source_profile = SharedConfigurationSourceProfile

[profile SharedConfigurationSourceProfile]
aws_access_key_id = SharedConfigurationSourceAccessKey
aws_secret_access_key = SharedConfigurationSourceSecretKey
`, servicemocks.MockStsAssumeRoleArn, servicemocks.MockStsAssumeRoleSessionName),
ExpectedCredentialsValue: mockdata.MockStsAssumeRoleCredentials,
MockStsEndpoints: []*servicemocks.MockEndpoint{
servicemocks.MockStsAssumeRoleValidEndpoint,
},
},

"with duration": {
Config: &Config{
AssumeRole: &AssumeRole{
RoleARN: servicemocks.MockStsAssumeRoleArn,
SessionName: servicemocks.MockStsAssumeRoleSessionName,
Duration: 1 * time.Hour,
},
AccessKey: servicemocks.MockStaticAccessKey,
SecretKey: servicemocks.MockStaticSecretKey,
},
ExpectedCredentialsValue: mockdata.MockStsAssumeRoleCredentials,
MockStsEndpoints: []*servicemocks.MockEndpoint{
servicemocks.MockStsAssumeRoleValidEndpointWithOptions(map[string]string{"DurationSeconds": "3600"}),
},
},

"with policy": {
Config: &Config{
AssumeRole: &AssumeRole{
RoleARN: servicemocks.MockStsAssumeRoleArn,
SessionName: servicemocks.MockStsAssumeRoleSessionName,
Policy: "{}",
},
AccessKey: servicemocks.MockStaticAccessKey,
SecretKey: servicemocks.MockStaticSecretKey,
},
ExpectedCredentialsValue: mockdata.MockStsAssumeRoleCredentials,
MockStsEndpoints: []*servicemocks.MockEndpoint{
servicemocks.MockStsAssumeRoleValidEndpointWithOptions(map[string]string{"Policy": "{}"}),
},
},

"invalid empty config": {
Config: &Config{
AssumeRole: &AssumeRole{},
AccessKey: servicemocks.MockStaticAccessKey,
SecretKey: servicemocks.MockStaticSecretKey,
},
ExpectedCredentialsValue: mockdata.MockStsAssumeRoleCredentials,
ExpectedError: func(err error) bool {
return strings.Contains(err.Error(), "role ARN not set")
},
},
}

for testName, testCase := range testCases {
testCase := testCase

t.Run(testName, func(t *testing.T) {
oldEnv := servicemocks.InitSessionTestEnv()
defer servicemocks.PopEnv(oldEnv)

closeSts, _, stsEndpoint := mockdata.GetMockedAwsApiSession("STS", testCase.MockStsEndpoints)
defer closeSts()

testCase.Config.StsEndpoint = stsEndpoint

tempdir, err := ioutil.TempDir("", "temp")
if err != nil {
t.Fatalf("error creating temp dir: %s", err)
}
defer os.Remove(tempdir)
os.Setenv("TMPDIR", tempdir)

if testCase.SharedConfigurationFile != "" {
file, err := ioutil.TempFile("", "aws-sdk-go-base-shared-configuration-file")

if err != nil {
t.Fatalf("unexpected error creating temporary shared configuration file: %s", err)
}

defer os.Remove(file.Name())

err = ioutil.WriteFile(file.Name(), []byte(testCase.SharedConfigurationFile), 0600)

if err != nil {
t.Fatalf("unexpected error writing shared configuration file: %s", err)
}

testCase.Config.SharedConfigFiles = []string{file.Name()}
}

testCase.Config.SkipCredsValidation = true

awsConfig, err := GetAwsConfig(context.Background(), testCase.Config)

if err != nil {
if testCase.ExpectedError == nil {
t.Fatalf("expected no error, got '%[1]T' error: %[1]s", err)
}

if !testCase.ExpectedError(err) {
t.Fatalf("unexpected GetAwsConfig() '%[1]T' error: %[1]s", err)
}

t.Logf("received expected '%[1]T' error: %[1]s", err)
return
}

credentialsValue, err := awsConfig.Credentials.Retrieve(context.Background())

if err != nil {
t.Fatalf("unexpected credentials Retrieve() error: %s", err)
}

if diff := cmp.Diff(credentialsValue, testCase.ExpectedCredentialsValue, cmpopts.IgnoreFields(aws.Credentials{}, "Expires")); diff != "" {
t.Fatalf("unexpected credentials: (- got, + expected)\n%s", diff)
}
})
}
}

func TestAssumeRoleWithWebIdentity(t *testing.T) {
testCases := map[string]struct {
Config *Config
Expand All @@ -2467,6 +2645,7 @@ func TestAssumeRoleWithWebIdentity(t *testing.T) {
SharedConfigurationFile string
SetSharedConfigurationFile bool
ExpectedCredentialsValue aws.Credentials
ExpectedError func(err error) bool
MockStsEndpoints []*servicemocks.MockEndpoint
}{
"config with inline token": {
Expand Down Expand Up @@ -2606,6 +2785,28 @@ web_identity_token_file = no-such-file
servicemocks.MockStsAssumeRoleWithWebIdentityValidWithOptions(map[string]string{"Policy": "{}"}),
},
},

"invalid empty config": {
Config: &Config{
AssumeRoleWithWebIdentity: &AssumeRoleWithWebIdentity{},
},
ExpectedCredentialsValue: mockdata.MockStsAssumeRoleWithWebIdentityCredentials,
ExpectedError: func(err error) bool {
return strings.Contains(err.Error(), "role ARN not set")
},
},

"invalid no token": {
Config: &Config{
AssumeRoleWithWebIdentity: &AssumeRoleWithWebIdentity{
RoleARN: servicemocks.MockStsAssumeRoleWithWebIdentityArn,
},
},
ExpectedCredentialsValue: mockdata.MockStsAssumeRoleWithWebIdentityCredentials,
ExpectedError: func(err error) bool {
return strings.Contains(err.Error(), "one of WebIdentityToken, WebIdentityTokenFile must be set")
},
},
}

for testName, testCase := range testCases {
Expand Down Expand Up @@ -2689,8 +2890,18 @@ web_identity_token_file = no-such-file
testCase.Config.SkipCredsValidation = true

awsConfig, err := GetAwsConfig(context.Background(), testCase.Config)

if err != nil {
t.Fatalf("error in GetAwsConfig() '%[1]T': %[1]s", err)
if testCase.ExpectedError == nil {
t.Fatalf("expected no error, got '%[1]T' error: %[1]s", err)
}

if !testCase.ExpectedError(err) {
t.Fatalf("unexpected GetAwsConfig() '%[1]T' error: %[1]s", err)
}

t.Logf("received expected '%[1]T' error: %[1]s", err)
return
}

credentialsValue, err := awsConfig.Credentials.Retrieve(context.Background())
Expand Down
13 changes: 11 additions & 2 deletions credentials.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@ package awsbase

import (
"context"
"errors"
"fmt"
"log"
"os"
Expand Down Expand Up @@ -136,8 +137,11 @@ func getCredentialsProvider(ctx context.Context, c *Config) (aws.CredentialsProv
// This can probably be configured directly in commonLoadOptions() once
// https://github.com/aws/aws-sdk-go-v2/pull/1682 is merged
if c.AssumeRoleWithWebIdentity != nil {
if c.AssumeRoleWithWebIdentity.RoleARN == "" {
return nil, "", errors.New("Assume Role With Web Identity: role ARN not set")
}
if c.AssumeRoleWithWebIdentity.WebIdentityToken == "" && c.AssumeRoleWithWebIdentity.WebIdentityTokenFile == "" {
return nil, "", c.NewCannotAssumeRoleWithWebIdentityError(fmt.Errorf("one of: WebIdentityToken, WebIdentityTokenFile must be set"))
return nil, "", errors.New("Assume Role With Web Identity: one of WebIdentityToken, WebIdentityTokenFile must be set")
}
provider, err := webIdentityCredentialsProvider(ctx, cfg, c)
if err != nil {
Expand All @@ -156,7 +160,7 @@ Error: %w`, err)
return nil, "", c.NewNoValidCredentialSourcesError(err)
}

if c.AssumeRole == nil || c.AssumeRole.RoleARN == "" {
if c.AssumeRole == nil {
return cfg.Credentials, creds.Source, nil
}

Expand Down Expand Up @@ -192,6 +196,11 @@ func webIdentityCredentialsProvider(ctx context.Context, awsConfig aws.Config, c

func assumeRoleCredentialsProvider(ctx context.Context, awsConfig aws.Config, c *Config) (aws.CredentialsProvider, error) {
ar := c.AssumeRole

if ar.RoleARN == "" {
return nil, errors.New("Assume Role: role ARN not set")
}

// When assuming a role, we need to first authenticate the base credentials above, then assume the desired role
log.Printf("[INFO] Assuming IAM Role %q (SessionName: %q, ExternalId: %q)", ar.RoleARN, ar.SessionName, ar.ExternalID)

Expand Down
4 changes: 1 addition & 3 deletions internal/endpoints/endpoints.go
Original file line number Diff line number Diff line change
Expand Up @@ -19,9 +19,7 @@ type Partition struct {

func (p Partition) Regions() []string {
rs := make([]string, len(p.p.regions))
for i, v := range p.p.regions {
rs[i] = v
}
copy(rs, p.p.regions)
return rs
}

Expand Down
Loading