Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add accountID resolving logic for identity resolvers #2448

Merged
merged 3 commits into from
Jan 12, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions aws/credentials.go
Original file line number Diff line number Diff line change
Expand Up @@ -90,6 +90,9 @@ type Credentials struct {
// The time the credentials will expire at. Should be ignored if CanExpire
// is false.
Expires time.Time

// The ID of the account for the credentials.
AccountID string
}

// Expired returns if the credentials have expired.
Expand Down
3 changes: 3 additions & 0 deletions config/env_config.go
Original file line number Diff line number Diff line change
Expand Up @@ -80,6 +80,8 @@ const (
awsRequestMinCompressionSizeBytes = "AWS_REQUEST_MIN_COMPRESSION_SIZE_BYTES"

awsS3DisableExpressSessionAuthEnv = "AWS_S3_DISABLE_EXPRESS_SESSION_AUTH"

awsAccountIDEnv = "AWS_ACCOUNT_ID"
)

var (
Expand Down Expand Up @@ -309,6 +311,7 @@ func NewEnvConfig() (EnvConfig, error) {
setStringFromEnvVal(&creds.AccessKeyID, credAccessEnvKeys)
setStringFromEnvVal(&creds.SecretAccessKey, credSecretEnvKeys)
if creds.HasKeys() {
creds.AccountID = os.Getenv(awsAccountIDEnv)
creds.SessionToken = os.Getenv(awsSessionTokenEnvVar)
cfg.Credentials = creds
}
Expand Down
11 changes: 11 additions & 0 deletions config/env_config_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -81,6 +81,17 @@ func TestNewEnvConfig_Creds(t *testing.T) {
Source: CredentialsSourceName,
},
},
{
Env: map[string]string{
"AWS_ACCESS_KEY_ID": "AKID",
"AWS_SECRET_ACCESS_KEY": "SECRET",
"AWS_ACCOUNT_ID": "012345678901",
},
Val: aws.Credentials{
AccessKeyID: "AKID", SecretAccessKey: "SECRET", AccountID: "012345678901",
Source: CredentialsSourceName,
},
},
}

for i, c := range cases {
Expand Down
3 changes: 3 additions & 0 deletions config/shared_config.go
Original file line number Diff line number Diff line change
Expand Up @@ -115,6 +115,8 @@ const (
requestMinCompressionSizeBytes = "request_min_compression_size_bytes"

s3DisableExpressSessionAuthKey = "s3_disable_express_session_auth"

accountIDKey = "aws_account_id"
)

// defaultSharedConfigProfile allows for swapping the default profile for testing
Expand Down Expand Up @@ -1130,6 +1132,7 @@ func (c *SharedConfig) setFromIniSection(profile string, section ini.Section) er
SecretAccessKey: section.String(secretAccessKey),
SessionToken: section.String(sessionTokenKey),
Source: fmt.Sprintf("SharedConfigCredentials: %s", section.SourceFile[accessKeyIDKey]),
AccountID: section.String(accountIDKey),
}

if creds.HasKeys() {
Expand Down
13 changes: 13 additions & 0 deletions config/shared_config_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -730,6 +730,19 @@ func TestNewSharedConfig(t *testing.T) {
},
},
},
"profile with aws account ID": {
ConfigFilenames: []string{testConfigFilename},
Profile: "account_id",
Expected: SharedConfig{
Profile: "account_id",
Credentials: aws.Credentials{
AccessKeyID: "account_id_akid",
SecretAccessKey: "account_id_secret",
Source: fmt.Sprintf("SharedConfigCredentials: %s", testConfigFilename),
AccountID: "012345678901",
},
},
},
}

for name, c := range cases {
Expand Down
5 changes: 5 additions & 0 deletions config/testdata/shared_config
Original file line number Diff line number Diff line change
Expand Up @@ -317,3 +317,8 @@ s3 =
other = foo
ec2 =
endpoint_url = http://127.0.0.1:81

[profile account_id]
aws_access_key_id = account_id_akid
aws_secret_access_key = account_id_secret
aws_account_id = 012345678901
1 change: 1 addition & 0 deletions credentials/endpointcreds/internal/client/client.go
Original file line number Diff line number Diff line change
Expand Up @@ -128,6 +128,7 @@ type GetCredentialsOutput struct {
AccessKeyID string
SecretAccessKey string
Token string
AccountID string
}

// EndpointError is an error returned from the endpoint service
Expand Down
1 change: 1 addition & 0 deletions credentials/endpointcreds/provider.go
Original file line number Diff line number Diff line change
Expand Up @@ -152,6 +152,7 @@ func (p *Provider) Retrieve(ctx context.Context) (aws.Credentials, error) {
SecretAccessKey: resp.SecretAccessKey,
SessionToken: resp.Token,
Source: ProviderName,
AccountID: resp.AccountID,
}

if resp.Expiration != nil {
Expand Down
6 changes: 5 additions & 1 deletion credentials/endpointcreds/provider_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -79,7 +79,8 @@ func TestRetrieveStaticCredentials(t *testing.T) {
StatusCode: 200,
Body: ioutil.NopCloser(bytes.NewReader([]byte(`{
"AccessKeyID": "AKID",
"SecretAccessKey": "SECRET"
"SecretAccessKey": "SECRET",
"AccountID": "012345678901"
}`))),
}, nil
})
Expand All @@ -96,6 +97,9 @@ func TestRetrieveStaticCredentials(t *testing.T) {
if e, a := "SECRET", creds.SecretAccessKey; e != a {
t.Errorf("expect %v, got %v", e, a)
}
if e, a := "012345678901", creds.AccountID; e != a {
t.Errorf("expect account ID to be %v, got %v", e, a)
}
if v := creds.SessionToken; len(v) != 0 {
t.Errorf("expect empty, got %v", v)
}
Expand Down
4 changes: 4 additions & 0 deletions credentials/processcreds/provider.go
Original file line number Diff line number Diff line change
Expand Up @@ -167,6 +167,9 @@ type CredentialProcessResponse struct {

// The date on which the current credentials expire.
Expiration *time.Time

// The ID of the account for credentials
AccountID string `json:"AccountId"`
}

// Retrieve executes the credential process command and returns the
Expand Down Expand Up @@ -208,6 +211,7 @@ func (p *Provider) Retrieve(ctx context.Context) (aws.Credentials, error) {
AccessKeyID: resp.AccessKeyID,
SecretAccessKey: resp.SecretAccessKey,
SessionToken: resp.SessionToken,
AccountID: resp.AccountID,
}

// Handle expiration
Expand Down
18 changes: 18 additions & 0 deletions credentials/processcreds/provider_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -142,6 +142,7 @@ type credentialTest struct {
AccessKeyID string `json:"AccessKeyId"`
SecretAccessKey string
Expiration string
AccountID string `json:"AccountId"`
}

func TestProviderStatic(t *testing.T) {
Expand Down Expand Up @@ -330,6 +331,23 @@ func BenchmarkProcessProvider(b *testing.B) {
}
}

func TestProviderWithAccountID(t *testing.T) {
provider := NewProvider(
fmt.Sprintf(
"%s %s",
getOSCat(),
filepath.Join("testdata", "accountid.json"),
))
v, err := provider.Retrieve(context.Background())
if err != nil {
t.Errorf("expected %v, got %v", "no error", err)
}

if e, a := "012345678901", v.AccountID; e != a {
t.Errorf("expect retrieved accountID to be %v, got %v", e, a)
}
}

func getOSCat() string {
if runtime.GOOS == "windows" {
return "type"
Expand Down
6 changes: 6 additions & 0 deletions credentials/processcreds/testdata/accountid.json
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
{
"Version":1,
"AccessKeyId":"accesskey",
"SecretAccessKey":"secretkey",
"AccountId": "012345678901"
}
1 change: 1 addition & 0 deletions credentials/ssocreds/sso_credentials_provider.go
Original file line number Diff line number Diff line change
Expand Up @@ -129,6 +129,7 @@ func (p *Provider) Retrieve(ctx context.Context) (aws.Credentials, error) {
CanExpire: true,
Expires: time.Unix(0, output.RoleCredentials.Expiration*int64(time.Millisecond)).UTC(),
Source: ProviderName,
AccountID: p.options.AccountID,
}, nil
}

Expand Down
2 changes: 2 additions & 0 deletions credentials/ssocreds/sso_credentials_provider_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -110,6 +110,7 @@ func TestProvider(t *testing.T) {
CanExpire: true,
Expires: time.Date(2021, 01, 20, 21, 22, 23, 0.123e9, time.UTC),
Source: ProviderName,
AccountID: "012345678901",
},
},
"custom cached token file": {
Expand Down Expand Up @@ -144,6 +145,7 @@ func TestProvider(t *testing.T) {
CanExpire: true,
Expires: time.Date(2021, 01, 20, 21, 22, 23, 0.123e9, time.UTC),
Source: ProviderName,
AccountID: "012345678901",
},
},
"expired access token": {
Expand Down
6 changes: 6 additions & 0 deletions credentials/stscreds/assume_role_provider.go
Original file line number Diff line number Diff line change
Expand Up @@ -308,6 +308,11 @@ func (p *AssumeRoleProvider) Retrieve(ctx context.Context) (aws.Credentials, err
return aws.Credentials{Source: ProviderName}, err
}

var accountID string
if resp.AssumedRoleUser != nil {
accountID = getAccountID(resp.AssumedRoleUser)
}

return aws.Credentials{
AccessKeyID: *resp.Credentials.AccessKeyId,
SecretAccessKey: *resp.Credentials.SecretAccessKey,
Expand All @@ -316,5 +321,6 @@ func (p *AssumeRoleProvider) Retrieve(ctx context.Context) (aws.Credentials, err

CanExpire: true,
Expires: *resp.Credentials.Expiration,
AccountID: accountID,
}, nil
}
6 changes: 6 additions & 0 deletions credentials/stscreds/assume_role_provider_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,9 @@ func (s *mockAssumeRole) AssumeRole(ctx context.Context, params *sts.AssumeRoleI
expiry := time.Now().Add(60 * time.Minute)

return &sts.AssumeRoleOutput{
AssumedRoleUser: &types.AssumedRoleUser{
Arn: aws.String("arn:aws:sts::131990247566:assumed-role/assume-role-integration-test-role/Name"),
},
Credentials: &types.Credentials{
// Just reflect the role arn to the provider.
AccessKeyId: params.RoleArn,
Expand Down Expand Up @@ -54,6 +57,9 @@ func TestAssumeRoleProvider(t *testing.T) {
if e, a := "assumedSessionToken", creds.SessionToken; e != a {
t.Errorf("Expect session token to match")
}
if e, a := "131990247566", creds.AccountID; e != a {
t.Error("Expect account id to match")
}
}

func TestAssumeRoleProvider_WithTokenProvider(t *testing.T) {
Expand Down
19 changes: 19 additions & 0 deletions credentials/stscreds/web_identity_provider.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ import (
"fmt"
"io/ioutil"
"strconv"
"strings"
"time"

"github.com/aws/aws-sdk-go-v2/aws"
Expand Down Expand Up @@ -135,6 +136,11 @@ func (p *WebIdentityRoleProvider) Retrieve(ctx context.Context) (aws.Credentials
return aws.Credentials{}, fmt.Errorf("failed to retrieve credentials, %w", err)
}

var accountID string
if resp.AssumedRoleUser != nil {
accountID = getAccountID(resp.AssumedRoleUser)
}

// InvalidIdentityToken error is a temporary error that can occur
// when assuming an Role with a JWT web identity token.

Expand All @@ -145,6 +151,19 @@ func (p *WebIdentityRoleProvider) Retrieve(ctx context.Context) (aws.Credentials
Source: WebIdentityProviderName,
CanExpire: true,
Expires: *resp.Credentials.Expiration,
AccountID: accountID,
}
return value, nil
}

// extract accountID from arn with format "arn:partition:service:region:account-id:[resource-section]"
func getAccountID(u *types.AssumedRoleUser) string {
if u.Arn == nil {
return ""
}
parts := strings.Split(*u.Arn, ":")
if len(parts) < 5 {
return ""
}
return parts[4]
}
44 changes: 44 additions & 0 deletions credentials/stscreds/web_identity_provider_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -161,6 +161,50 @@ func TestWebIdentityProviderRetrieve(t *testing.T) {
Expires: sdk.NowTime(),
},
},
"success with accountID": {
roleARN: "arn01234567890123456789",
tokenFilepath: "testdata/token.jwt",
options: func(o *stscreds.WebIdentityRoleOptions) {
o.RoleSessionName = "foo"
},
mockClient: func(
ctx context.Context, params *sts.AssumeRoleWithWebIdentityInput, optFns ...func(*sts.Options),
) (
*sts.AssumeRoleWithWebIdentityOutput, error,
) {
if e, a := "foo", *params.RoleSessionName; e != a {
return nil, fmt.Errorf("expected %v, but received %v", e, a)
}
if params.DurationSeconds != nil {
return nil, fmt.Errorf("expect no duration seconds, got %v",
*params.DurationSeconds)
}
if params.Policy != nil {
return nil, fmt.Errorf("expect no policy, got %v",
*params.Policy)
}
return &sts.AssumeRoleWithWebIdentityOutput{
AssumedRoleUser: &types.AssumedRoleUser{
Arn: aws.String("arn:aws:sts::131990247566:assumed-role/assume-role-integration-test-role/Name"),
},
Credentials: &types.Credentials{
Expiration: aws.Time(sdk.NowTime()),
AccessKeyId: aws.String("access-key-id"),
SecretAccessKey: aws.String("secret-access-key"),
SessionToken: aws.String("session-token"),
},
}, nil
},
expectedCredValue: aws.Credentials{
AccessKeyID: "access-key-id",
SecretAccessKey: "secret-access-key",
SessionToken: "session-token",
Source: stscreds.WebIdentityProviderName,
CanExpire: true,
Expires: sdk.NowTime(),
AccountID: "131990247566",
},
},
}

for name, c := range cases {
Expand Down
Loading