Skip to content

Commit

Permalink
Merge pull request #2448 from aws/feat-identity-resolve-aid
Browse files Browse the repository at this point in the history
Add accountID resolving logic for identity resolvers
  • Loading branch information
wty-Bryant committed Jan 12, 2024
2 parents 0c7b7e4 + f2a0627 commit bd213b5
Show file tree
Hide file tree
Showing 18 changed files with 151 additions and 1 deletion.
3 changes: 3 additions & 0 deletions aws/credentials.go
Original file line number Diff line number Diff line change
Expand Up @@ -90,6 +90,9 @@ type Credentials struct {
// The time the credentials will expire at. Should be ignored if CanExpire
// is false.
Expires time.Time

// The ID of the account for the credentials.
AccountID string
}

// Expired returns if the credentials have expired.
Expand Down
3 changes: 3 additions & 0 deletions config/env_config.go
Original file line number Diff line number Diff line change
Expand Up @@ -80,6 +80,8 @@ const (
awsRequestMinCompressionSizeBytes = "AWS_REQUEST_MIN_COMPRESSION_SIZE_BYTES"

awsS3DisableExpressSessionAuthEnv = "AWS_S3_DISABLE_EXPRESS_SESSION_AUTH"

awsAccountIDEnv = "AWS_ACCOUNT_ID"
)

var (
Expand Down Expand Up @@ -309,6 +311,7 @@ func NewEnvConfig() (EnvConfig, error) {
setStringFromEnvVal(&creds.AccessKeyID, credAccessEnvKeys)
setStringFromEnvVal(&creds.SecretAccessKey, credSecretEnvKeys)
if creds.HasKeys() {
creds.AccountID = os.Getenv(awsAccountIDEnv)
creds.SessionToken = os.Getenv(awsSessionTokenEnvVar)
cfg.Credentials = creds
}
Expand Down
11 changes: 11 additions & 0 deletions config/env_config_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -81,6 +81,17 @@ func TestNewEnvConfig_Creds(t *testing.T) {
Source: CredentialsSourceName,
},
},
{
Env: map[string]string{
"AWS_ACCESS_KEY_ID": "AKID",
"AWS_SECRET_ACCESS_KEY": "SECRET",
"AWS_ACCOUNT_ID": "012345678901",
},
Val: aws.Credentials{
AccessKeyID: "AKID", SecretAccessKey: "SECRET", AccountID: "012345678901",
Source: CredentialsSourceName,
},
},
}

for i, c := range cases {
Expand Down
3 changes: 3 additions & 0 deletions config/shared_config.go
Original file line number Diff line number Diff line change
Expand Up @@ -115,6 +115,8 @@ const (
requestMinCompressionSizeBytes = "request_min_compression_size_bytes"

s3DisableExpressSessionAuthKey = "s3_disable_express_session_auth"

accountIDKey = "aws_account_id"
)

// defaultSharedConfigProfile allows for swapping the default profile for testing
Expand Down Expand Up @@ -1130,6 +1132,7 @@ func (c *SharedConfig) setFromIniSection(profile string, section ini.Section) er
SecretAccessKey: section.String(secretAccessKey),
SessionToken: section.String(sessionTokenKey),
Source: fmt.Sprintf("SharedConfigCredentials: %s", section.SourceFile[accessKeyIDKey]),
AccountID: section.String(accountIDKey),
}

if creds.HasKeys() {
Expand Down
13 changes: 13 additions & 0 deletions config/shared_config_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -730,6 +730,19 @@ func TestNewSharedConfig(t *testing.T) {
},
},
},
"profile with aws account ID": {
ConfigFilenames: []string{testConfigFilename},
Profile: "account_id",
Expected: SharedConfig{
Profile: "account_id",
Credentials: aws.Credentials{
AccessKeyID: "account_id_akid",
SecretAccessKey: "account_id_secret",
Source: fmt.Sprintf("SharedConfigCredentials: %s", testConfigFilename),
AccountID: "012345678901",
},
},
},
}

for name, c := range cases {
Expand Down
5 changes: 5 additions & 0 deletions config/testdata/shared_config
Original file line number Diff line number Diff line change
Expand Up @@ -317,3 +317,8 @@ s3 =
other = foo
ec2 =
endpoint_url = http://127.0.0.1:81

[profile account_id]
aws_access_key_id = account_id_akid
aws_secret_access_key = account_id_secret
aws_account_id = 012345678901
1 change: 1 addition & 0 deletions credentials/endpointcreds/internal/client/client.go
Original file line number Diff line number Diff line change
Expand Up @@ -128,6 +128,7 @@ type GetCredentialsOutput struct {
AccessKeyID string
SecretAccessKey string
Token string
AccountID string
}

// EndpointError is an error returned from the endpoint service
Expand Down
1 change: 1 addition & 0 deletions credentials/endpointcreds/provider.go
Original file line number Diff line number Diff line change
Expand Up @@ -152,6 +152,7 @@ func (p *Provider) Retrieve(ctx context.Context) (aws.Credentials, error) {
SecretAccessKey: resp.SecretAccessKey,
SessionToken: resp.Token,
Source: ProviderName,
AccountID: resp.AccountID,
}

if resp.Expiration != nil {
Expand Down
6 changes: 5 additions & 1 deletion credentials/endpointcreds/provider_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -79,7 +79,8 @@ func TestRetrieveStaticCredentials(t *testing.T) {
StatusCode: 200,
Body: ioutil.NopCloser(bytes.NewReader([]byte(`{
"AccessKeyID": "AKID",
"SecretAccessKey": "SECRET"
"SecretAccessKey": "SECRET",
"AccountID": "012345678901"
}`))),
}, nil
})
Expand All @@ -96,6 +97,9 @@ func TestRetrieveStaticCredentials(t *testing.T) {
if e, a := "SECRET", creds.SecretAccessKey; e != a {
t.Errorf("expect %v, got %v", e, a)
}
if e, a := "012345678901", creds.AccountID; e != a {
t.Errorf("expect account ID to be %v, got %v", e, a)
}
if v := creds.SessionToken; len(v) != 0 {
t.Errorf("expect empty, got %v", v)
}
Expand Down
4 changes: 4 additions & 0 deletions credentials/processcreds/provider.go
Original file line number Diff line number Diff line change
Expand Up @@ -167,6 +167,9 @@ type CredentialProcessResponse struct {

// The date on which the current credentials expire.
Expiration *time.Time

// The ID of the account for credentials
AccountID string `json:"AccountId"`
}

// Retrieve executes the credential process command and returns the
Expand Down Expand Up @@ -208,6 +211,7 @@ func (p *Provider) Retrieve(ctx context.Context) (aws.Credentials, error) {
AccessKeyID: resp.AccessKeyID,
SecretAccessKey: resp.SecretAccessKey,
SessionToken: resp.SessionToken,
AccountID: resp.AccountID,
}

// Handle expiration
Expand Down
18 changes: 18 additions & 0 deletions credentials/processcreds/provider_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -142,6 +142,7 @@ type credentialTest struct {
AccessKeyID string `json:"AccessKeyId"`
SecretAccessKey string
Expiration string
AccountID string `json:"AccountId"`
}

func TestProviderStatic(t *testing.T) {
Expand Down Expand Up @@ -330,6 +331,23 @@ func BenchmarkProcessProvider(b *testing.B) {
}
}

func TestProviderWithAccountID(t *testing.T) {
provider := NewProvider(
fmt.Sprintf(
"%s %s",
getOSCat(),
filepath.Join("testdata", "accountid.json"),
))
v, err := provider.Retrieve(context.Background())
if err != nil {
t.Errorf("expected %v, got %v", "no error", err)
}

if e, a := "012345678901", v.AccountID; e != a {
t.Errorf("expect retrieved accountID to be %v, got %v", e, a)
}
}

func getOSCat() string {
if runtime.GOOS == "windows" {
return "type"
Expand Down
6 changes: 6 additions & 0 deletions credentials/processcreds/testdata/accountid.json
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
{
"Version":1,
"AccessKeyId":"accesskey",
"SecretAccessKey":"secretkey",
"AccountId": "012345678901"
}
1 change: 1 addition & 0 deletions credentials/ssocreds/sso_credentials_provider.go
Original file line number Diff line number Diff line change
Expand Up @@ -129,6 +129,7 @@ func (p *Provider) Retrieve(ctx context.Context) (aws.Credentials, error) {
CanExpire: true,
Expires: time.Unix(0, output.RoleCredentials.Expiration*int64(time.Millisecond)).UTC(),
Source: ProviderName,
AccountID: p.options.AccountID,
}, nil
}

Expand Down
2 changes: 2 additions & 0 deletions credentials/ssocreds/sso_credentials_provider_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -110,6 +110,7 @@ func TestProvider(t *testing.T) {
CanExpire: true,
Expires: time.Date(2021, 01, 20, 21, 22, 23, 0.123e9, time.UTC),
Source: ProviderName,
AccountID: "012345678901",
},
},
"custom cached token file": {
Expand Down Expand Up @@ -144,6 +145,7 @@ func TestProvider(t *testing.T) {
CanExpire: true,
Expires: time.Date(2021, 01, 20, 21, 22, 23, 0.123e9, time.UTC),
Source: ProviderName,
AccountID: "012345678901",
},
},
"expired access token": {
Expand Down
6 changes: 6 additions & 0 deletions credentials/stscreds/assume_role_provider.go
Original file line number Diff line number Diff line change
Expand Up @@ -308,6 +308,11 @@ func (p *AssumeRoleProvider) Retrieve(ctx context.Context) (aws.Credentials, err
return aws.Credentials{Source: ProviderName}, err
}

var accountID string
if resp.AssumedRoleUser != nil {
accountID = getAccountID(resp.AssumedRoleUser)
}

return aws.Credentials{
AccessKeyID: *resp.Credentials.AccessKeyId,
SecretAccessKey: *resp.Credentials.SecretAccessKey,
Expand All @@ -316,5 +321,6 @@ func (p *AssumeRoleProvider) Retrieve(ctx context.Context) (aws.Credentials, err

CanExpire: true,
Expires: *resp.Credentials.Expiration,
AccountID: accountID,
}, nil
}
6 changes: 6 additions & 0 deletions credentials/stscreds/assume_role_provider_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,9 @@ func (s *mockAssumeRole) AssumeRole(ctx context.Context, params *sts.AssumeRoleI
expiry := time.Now().Add(60 * time.Minute)

return &sts.AssumeRoleOutput{
AssumedRoleUser: &types.AssumedRoleUser{
Arn: aws.String("arn:aws:sts::131990247566:assumed-role/assume-role-integration-test-role/Name"),
},
Credentials: &types.Credentials{
// Just reflect the role arn to the provider.
AccessKeyId: params.RoleArn,
Expand Down Expand Up @@ -54,6 +57,9 @@ func TestAssumeRoleProvider(t *testing.T) {
if e, a := "assumedSessionToken", creds.SessionToken; e != a {
t.Errorf("Expect session token to match")
}
if e, a := "131990247566", creds.AccountID; e != a {
t.Error("Expect account id to match")
}
}

func TestAssumeRoleProvider_WithTokenProvider(t *testing.T) {
Expand Down
19 changes: 19 additions & 0 deletions credentials/stscreds/web_identity_provider.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ import (
"fmt"
"io/ioutil"
"strconv"
"strings"
"time"

"github.com/aws/aws-sdk-go-v2/aws"
Expand Down Expand Up @@ -135,6 +136,11 @@ func (p *WebIdentityRoleProvider) Retrieve(ctx context.Context) (aws.Credentials
return aws.Credentials{}, fmt.Errorf("failed to retrieve credentials, %w", err)
}

var accountID string
if resp.AssumedRoleUser != nil {
accountID = getAccountID(resp.AssumedRoleUser)
}

// InvalidIdentityToken error is a temporary error that can occur
// when assuming an Role with a JWT web identity token.

Expand All @@ -145,6 +151,19 @@ func (p *WebIdentityRoleProvider) Retrieve(ctx context.Context) (aws.Credentials
Source: WebIdentityProviderName,
CanExpire: true,
Expires: *resp.Credentials.Expiration,
AccountID: accountID,
}
return value, nil
}

// extract accountID from arn with format "arn:partition:service:region:account-id:[resource-section]"
func getAccountID(u *types.AssumedRoleUser) string {
if u.Arn == nil {
return ""
}
parts := strings.Split(*u.Arn, ":")
if len(parts) < 5 {
return ""
}
return parts[4]
}
44 changes: 44 additions & 0 deletions credentials/stscreds/web_identity_provider_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -161,6 +161,50 @@ func TestWebIdentityProviderRetrieve(t *testing.T) {
Expires: sdk.NowTime(),
},
},
"success with accountID": {
roleARN: "arn01234567890123456789",
tokenFilepath: "testdata/token.jwt",
options: func(o *stscreds.WebIdentityRoleOptions) {
o.RoleSessionName = "foo"
},
mockClient: func(
ctx context.Context, params *sts.AssumeRoleWithWebIdentityInput, optFns ...func(*sts.Options),
) (
*sts.AssumeRoleWithWebIdentityOutput, error,
) {
if e, a := "foo", *params.RoleSessionName; e != a {
return nil, fmt.Errorf("expected %v, but received %v", e, a)
}
if params.DurationSeconds != nil {
return nil, fmt.Errorf("expect no duration seconds, got %v",
*params.DurationSeconds)
}
if params.Policy != nil {
return nil, fmt.Errorf("expect no policy, got %v",
*params.Policy)
}
return &sts.AssumeRoleWithWebIdentityOutput{
AssumedRoleUser: &types.AssumedRoleUser{
Arn: aws.String("arn:aws:sts::131990247566:assumed-role/assume-role-integration-test-role/Name"),
},
Credentials: &types.Credentials{
Expiration: aws.Time(sdk.NowTime()),
AccessKeyId: aws.String("access-key-id"),
SecretAccessKey: aws.String("secret-access-key"),
SessionToken: aws.String("session-token"),
},
}, nil
},
expectedCredValue: aws.Credentials{
AccessKeyID: "access-key-id",
SecretAccessKey: "secret-access-key",
SessionToken: "session-token",
Source: stscreds.WebIdentityProviderName,
CanExpire: true,
Expires: sdk.NowTime(),
AccountID: "131990247566",
},
},
}

for name, c := range cases {
Expand Down

0 comments on commit bd213b5

Please sign in to comment.