Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

AWS auth login with multi region STS support #21960

Merged
merged 7 commits into from
Jul 28, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
17 changes: 17 additions & 0 deletions builtin/credential/aws/path_config_client.go
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@ import (

"github.com/aws/aws-sdk-go/aws"
"github.com/hashicorp/go-secure-stdlib/strutil"

"github.com/hashicorp/vault/sdk/framework"
"github.com/hashicorp/vault/sdk/logical"
)
Expand Down Expand Up @@ -61,6 +62,12 @@ func (b *backend) pathConfigClient() *framework.Path {
Description: "The region ID for the sts_endpoint, if set.",
},

"use_sts_region_from_client": {
Type: framework.TypeBool,
Default: false,
Description: "Uses the STS region from client requests for making AWS STS API calls.",
},

"iam_server_id_header_value": {
Type: framework.TypeString,
Default: "",
Expand Down Expand Up @@ -168,6 +175,7 @@ func (b *backend) pathConfigClientRead(ctx context.Context, req *logical.Request
"iam_endpoint": clientConfig.IAMEndpoint,
"sts_endpoint": clientConfig.STSEndpoint,
"sts_region": clientConfig.STSRegion,
"use_sts_region_from_client": clientConfig.UseSTSRegionFromClient,
"iam_server_id_header_value": clientConfig.IAMServerIdHeaderValue,
"max_retries": clientConfig.MaxRetries,
"allowed_sts_header_values": clientConfig.AllowedSTSHeaderValues,
Expand Down Expand Up @@ -281,6 +289,14 @@ func (b *backend) pathConfigClientCreateUpdate(ctx context.Context, req *logical
}
}

useSTSRegionFromClientRaw, ok := data.GetOk("use_sts_region_from_client")
if ok {
if configEntry.UseSTSRegionFromClient != useSTSRegionFromClientRaw.(bool) {
changedCreds = true
configEntry.UseSTSRegionFromClient = useSTSRegionFromClientRaw.(bool)
}
}

headerValStr, ok := data.GetOk("iam_server_id_header_value")
if ok {
if configEntry.IAMServerIdHeaderValue != headerValStr.(string) {
Expand Down Expand Up @@ -363,6 +379,7 @@ type clientConfig struct {
IAMEndpoint string `json:"iam_endpoint"`
STSEndpoint string `json:"sts_endpoint"`
STSRegion string `json:"sts_region"`
UseSTSRegionFromClient bool `json:"use_sts_region_from_client"`
IAMServerIdHeaderValue string `json:"iam_server_id_header_value"`
AllowedSTSHeaderValues []string `json:"allowed_sts_header_values"`
MaxRetries int `json:"max_retries"`
Expand Down
58 changes: 58 additions & 0 deletions builtin/credential/aws/path_login.go
Original file line number Diff line number Diff line change
Expand Up @@ -21,15 +21,18 @@ import (

"github.com/aws/aws-sdk-go/aws"
awsClient "github.com/aws/aws-sdk-go/aws/client"
"github.com/aws/aws-sdk-go/aws/endpoints"
"github.com/aws/aws-sdk-go/service/ec2"
"github.com/aws/aws-sdk-go/service/iam"
"github.com/aws/aws-sdk-go/service/sts"
"github.com/hashicorp/errwrap"
cleanhttp "github.com/hashicorp/go-cleanhttp"
"github.com/hashicorp/go-retryablehttp"
"github.com/hashicorp/go-secure-stdlib/awsutil"
"github.com/hashicorp/go-secure-stdlib/parseutil"
"github.com/hashicorp/go-secure-stdlib/strutil"
uuid "github.com/hashicorp/go-uuid"

"github.com/hashicorp/vault/builtin/credential/aws/pkcs7"
"github.com/hashicorp/vault/sdk/framework"
"github.com/hashicorp/vault/sdk/helper/cidrutil"
Expand Down Expand Up @@ -318,6 +321,24 @@ func (b *backend) pathLoginIamGetRoleNameCallerIdAndEntity(ctx context.Context,
}
}

// Extract and use a regional STS endpoint
// based on the region set in the Authorization header.
if config.UseSTSRegionFromClient {
clientSpecifiedRegion, err := awsRegionFromHeader(headers.Get("Authorization"))
if err != nil {
return "", nil, nil, logical.ErrorResponse("region missing from Authorization header"), nil
}

url, err := stsRegionalEndpoint(clientSpecifiedRegion)
if err != nil {
return "", nil, nil, logical.ErrorResponse(err.Error()), nil
}

b.Logger().Debug("use_sts_region_from_client set; using region specified from header", "region", clientSpecifiedRegion)
endpoint = url
}

b.Logger().Debug("submitting caller identity request", "endpoint", endpoint)
callerID, err := submitCallerIdentityRequest(ctx, maxRetries, method, endpoint, parsedUrl, body, headers)
if err != nil {
return "", nil, nil, logical.ErrorResponse(fmt.Sprintf("error making upstream request: %v", err)), nil
Expand Down Expand Up @@ -1884,6 +1905,43 @@ func getMetadataValue(fromAuth *logical.Auth, forKey string) (string, error) {
return "", fmt.Errorf("%q not found in auth metadata", forKey)
}

func awsRegionFromHeader(authorizationHeader string) (string, error) {
// https://docs.aws.amazon.com/AmazonS3/latest/API/sigv4-auth-using-authorization-header.html
// The Authorization header takes the following form.
// Authorization: AWS4-HMAC-SHA256
// Credential=AKIAIOSFODNN7EXAMPLE/20230719/us-east-1/sts/aws4_request,
// SignedHeaders=content-length;content-type;host;x-amz-date,
// Signature=fe5f80f77d5fa3beca038a248ff027d0445342fe2855ddc963176630326f1024
//
// The credential is in the form of "<your-access-key-id>/<date>/<aws-region>/<aws-service>/aws4_request"
Comment on lines +1909 to +1916
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks, this comment is helpful!

fields := strings.Split(authorizationHeader, " ")
for _, field := range fields {
if strings.HasPrefix(field, "Credential=") {
fields := strings.Split(field, "/")
if len(fields) < 3 {
return "", fmt.Errorf("invalid header format")
}

region := fields[2]
return region, nil
}
}

return "", fmt.Errorf("invalid header format")
}

func stsRegionalEndpoint(region string) (string, error) {
stsService := sts.EndpointsID
resolver := endpoints.DefaultResolver()
resolvedEndpoint, err := resolver.EndpointFor(stsService, region,
endpoints.STSRegionalEndpointOption,
endpoints.StrictMatchingOption)
if err != nil {
return "", fmt.Errorf("unable to get regional STS endpoint for region: %v", region)
}
return resolvedEndpoint.URL, nil
}

const iamServerIdHeader = "X-Vault-AWS-IAM-Server-ID"

const pathLoginSyn = `
Expand Down
54 changes: 54 additions & 0 deletions builtin/credential/aws/path_login_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,8 @@ import (

"github.com/aws/aws-sdk-go/aws/session"
"github.com/aws/aws-sdk-go/service/sts"
"github.com/stretchr/testify/assert"

"github.com/hashicorp/vault/sdk/logical"
)

Expand Down Expand Up @@ -625,6 +627,58 @@ func TestBackend_defaultAliasMetadata(t *testing.T) {
}
}

func TestRegionFromHeader(t *testing.T) {
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nice 👍

tcs := map[string]struct {
header string
expectedRegion string
expectedSTSEndpoint string
}{
"us-east-1": {
header: "AWS4-HMAC-SHA256 Credential=AAAAAAAAAAAAAAAAAAAA/20230719/us-east-1/sts/aws4_request, SignedHeaders=content-length;content-type;host;x-amz-date, Signature=aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa",
expectedRegion: "us-east-1",
expectedSTSEndpoint: "https://sts.us-east-1.amazonaws.com",
},
"us-west-2": {
header: "AWS4-HMAC-SHA256 Credential=AAAAAAAAAAAAAAAAAAAA/20230719/us-west-2/sts/aws4_request, SignedHeaders=content-length;content-type;host;x-amz-date, Signature=aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa",
expectedRegion: "us-west-2",
expectedSTSEndpoint: "https://sts.us-west-2.amazonaws.com",
},
"ap-northeast-3": {
header: "AWS4-HMAC-SHA256 Credential=AAAAAAAAAAAAAAAAAAAA/20230719/ap-northeast-3/sts/aws4_request, SignedHeaders=content-length;content-type;host;x-amz-date, Signature=aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa",
expectedRegion: "ap-northeast-3",
expectedSTSEndpoint: "https://sts.ap-northeast-3.amazonaws.com",
},
"us-gov-east-1": {
header: "AWS4-HMAC-SHA256 Credential=AAAAAAAAAAAAAAAAAAAA/20230719/us-gov-east-1/sts/aws4_request, SignedHeaders=content-length;content-type;host;x-amz-date, Signature=aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa",
expectedRegion: "us-gov-east-1",
expectedSTSEndpoint: "https://sts.us-gov-east-1.amazonaws.com",
},
}
for name, tc := range tcs {
t.Run(name, func(t *testing.T) {
region, err := awsRegionFromHeader(tc.header)
assert.NoError(t, err)
assert.Equal(t, tc.expectedRegion, region)

stsEndpoint, err := stsRegionalEndpoint(region)
assert.NoError(t, err)
assert.Equal(t, tc.expectedSTSEndpoint, stsEndpoint)
})
}

t.Run("invalid-header", func(t *testing.T) {
region, err := awsRegionFromHeader("this-is-an-invalid-header/foobar")
assert.EqualError(t, err, "invalid header format")
assert.Empty(t, region)
})

t.Run("invalid-region", func(t *testing.T) {
endpoint, err := stsRegionalEndpoint("fake-region-1")
assert.EqualError(t, err, "unable to get regional STS endpoint for region: fake-region-1")
assert.Empty(t, endpoint)
})
}

func defaultLoginData() (map[string]interface{}, error) {
awsSession, err := session.NewSession()
if err != nil {
Expand Down
3 changes: 3 additions & 0 deletions changelog/21960.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
```release-note:improvement
aws/auth: Adds a new config field `use_sts_region_from_client` which allows for using dynamic regional sts endpoints based on Authorization header when using IAM-based authentication.
```
5 changes: 5 additions & 0 deletions website/content/api-docs/auth/aws.mdx
Original file line number Diff line number Diff line change
Expand Up @@ -65,6 +65,10 @@ capabilities, the credentials are fetched automatically.
- `sts_region` `(string: "")` - Region to override the default region for making
AWS STS API calls. Should only be set if `sts_endpoint` is set. If so, should
be set to the region in which the custom `sts_endpoint` resides.
- `use_sts_region_from_client` `(boolean: false)` - If set, overrides both `sts_endpoint`
and `sts_region` to instead use the region specified in the client request headers for
IAM-based authentication . This can be useful when you have client requests coming from
different regions and want flexibility in which regional STS API is used.
maxcoulombe marked this conversation as resolved.
Show resolved Hide resolved
- `iam_server_id_header_value` `(string: "")` - The value to require in the
`X-Vault-AWS-IAM-Server-ID` header as part of GetCallerIdentity requests that
are used in the iam auth method. If not set, then no value is required or
Expand Down Expand Up @@ -123,6 +127,7 @@ $ curl \
"iam_endpoint": "",
"sts_endpoint": "",
"sts_region": "",
"use_sts_region_from_client": false,
"iam_server_id_header_value": ""
}
}
Expand Down
Loading