Skip to content

Commit

Permalink
optimize some accountID resolving code
Browse files Browse the repository at this point in the history
  • Loading branch information
Tianyi Wang committed Jan 11, 2024
1 parent a45a311 commit f2a0627
Show file tree
Hide file tree
Showing 8 changed files with 16 additions and 24 deletions.
11 changes: 0 additions & 11 deletions .changelog/9166aec5123a472ab3d80afbcae6de9b.json

This file was deleted.

2 changes: 1 addition & 1 deletion aws/credentials.go
Original file line number Diff line number Diff line change
Expand Up @@ -91,7 +91,7 @@ type Credentials struct {
// is false.
Expires time.Time

// AWS Account ID resolved from identity and used for optional endpoint2.0 routing
// The ID of the account for the credentials.
AccountID string
}

Expand Down
4 changes: 2 additions & 2 deletions config/env_config.go
Original file line number Diff line number Diff line change
Expand Up @@ -81,7 +81,7 @@ const (

awsS3DisableExpressSessionAuthEnv = "AWS_S3_DISABLE_EXPRESS_SESSION_AUTH"

awsAccountID = "AWS_ACCOUNT_ID"
awsAccountIDEnv = "AWS_ACCOUNT_ID"
)

var (
Expand Down Expand Up @@ -311,7 +311,7 @@ func NewEnvConfig() (EnvConfig, error) {
setStringFromEnvVal(&creds.AccessKeyID, credAccessEnvKeys)
setStringFromEnvVal(&creds.SecretAccessKey, credSecretEnvKeys)
if creds.HasKeys() {
creds.AccountID = os.Getenv(awsAccountID)
creds.AccountID = os.Getenv(awsAccountIDEnv)
creds.SessionToken = os.Getenv(awsSessionTokenEnvVar)
cfg.Credentials = creds
}
Expand Down
4 changes: 2 additions & 2 deletions config/shared_config.go
Original file line number Diff line number Diff line change
Expand Up @@ -116,7 +116,7 @@ const (

s3DisableExpressSessionAuthKey = "s3_disable_express_session_auth"

accountID = "aws_account_id"
accountIDKey = "aws_account_id"
)

// defaultSharedConfigProfile allows for swapping the default profile for testing
Expand Down Expand Up @@ -1132,7 +1132,7 @@ func (c *SharedConfig) setFromIniSection(profile string, section ini.Section) er
SecretAccessKey: section.String(secretAccessKey),
SessionToken: section.String(sessionTokenKey),
Source: fmt.Sprintf("SharedConfigCredentials: %s", section.SourceFile[accessKeyIDKey]),
AccountID: section.String(accountID),
AccountID: section.String(accountIDKey),
}

if creds.HasKeys() {
Expand Down
2 changes: 1 addition & 1 deletion credentials/processcreds/provider.go
Original file line number Diff line number Diff line change
Expand Up @@ -168,7 +168,7 @@ type CredentialProcessResponse struct {
// The date on which the current credentials expire.
Expiration *time.Time

// The aws account ID for this op and could be used for endpoint2.0 routing
// The ID of the account for credentials
AccountID string `json:"AccountId"`
}

Expand Down
1 change: 0 additions & 1 deletion credentials/stscreds/assume_role_provider.go
Original file line number Diff line number Diff line change
Expand Up @@ -308,7 +308,6 @@ func (p *AssumeRoleProvider) Retrieve(ctx context.Context) (aws.Credentials, err
return aws.Credentials{Source: ProviderName}, err
}

// extract accountID from arn with format "arn:partition:service:region:account-id:[resource-section]"
var accountID string
if resp.AssumedRoleUser != nil {
accountID = getAccountID(resp.AssumedRoleUser)
Expand Down
2 changes: 1 addition & 1 deletion credentials/stscreds/assume_role_provider_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -58,7 +58,7 @@ func TestAssumeRoleProvider(t *testing.T) {
t.Errorf("Expect session token to match")
}
if e, a := "131990247566", creds.AccountID; e != a {
t.Errorf("Expect account id to match")
t.Error("Expect account id to match")
}
}

Expand Down
14 changes: 9 additions & 5 deletions credentials/stscreds/web_identity_provider.go
Original file line number Diff line number Diff line change
Expand Up @@ -136,7 +136,6 @@ func (p *WebIdentityRoleProvider) Retrieve(ctx context.Context) (aws.Credentials
return aws.Credentials{}, fmt.Errorf("failed to retrieve credentials, %w", err)
}

// extract accountID from arn with format "arn:partition:service:region:account-id:[resource-section]"
var accountID string
if resp.AssumedRoleUser != nil {
accountID = getAccountID(resp.AssumedRoleUser)
Expand All @@ -157,9 +156,14 @@ func (p *WebIdentityRoleProvider) Retrieve(ctx context.Context) (aws.Credentials
return value, nil
}

func getAccountID(assumedRoleUser *types.AssumedRoleUser) string {
if arn := assumedRoleUser.Arn; arn != nil && len(*arn) > 0 {
return strings.Split(*arn, ":")[4]
// extract accountID from arn with format "arn:partition:service:region:account-id:[resource-section]"
func getAccountID(u *types.AssumedRoleUser) string {
if u.Arn == nil {
return ""
}
return ""
parts := strings.Split(*u.Arn, ":")
if len(parts) < 5 {
return ""
}
return parts[4]
}

0 comments on commit f2a0627

Please sign in to comment.