Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Use local aws config in cli to get account and regions #7758

38 changes: 23 additions & 15 deletions pkg/cli/aws/client.go
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@ package aws
import (
"context"

"github.com/aws/aws-sdk-go-v2/credentials"
"github.com/aws/aws-sdk-go-v2/config"
"github.com/aws/aws-sdk-go-v2/service/ec2"
"github.com/aws/aws-sdk-go-v2/service/sts"
)
Expand All @@ -28,9 +28,9 @@ import (
// Client is an interface that abstracts `rad init`'s interactions with AWS. This is for testing purposes. This is only exported because mockgen requires it.
type Client interface {
// GetCallerIdentity gets information about the provided credentials.
GetCallerIdentity(ctx context.Context, region string, accessKeyID string, secretAccessKey string) (*sts.GetCallerIdentityOutput, error)
GetCallerIdentity(ctx context.Context) (*sts.GetCallerIdentityOutput, error)
// ListRegions lists the AWS regions available (fetched from EC2.DescribeRegions API).
ListRegions(ctx context.Context, region string, accessKeyID string, secretAccessKey string) (*ec2.DescribeRegionsOutput, error)
ListRegions(ctx context.Context) (*ec2.DescribeRegionsOutput, error)
}

// NewClient returns a new Client.
Expand All @@ -43,12 +43,16 @@ type client struct{}
var _ Client = &client{}

// GetCallerIdentity gets information about the provided credentials.
func (c *client) GetCallerIdentity(ctx context.Context, region string, accessKeyID string, secretAccessKey string) (*sts.GetCallerIdentityOutput, error) {
credentialsProvider := credentials.NewStaticCredentialsProvider(accessKeyID, secretAccessKey, "")
stsClient := sts.New(sts.Options{
Region: region,
Credentials: credentialsProvider,
})
func (c *client) GetCallerIdentity(ctx context.Context) (*sts.GetCallerIdentityOutput, error) {
// Load the AWS SDK config and credentials
cfg, err := config.LoadDefaultConfig(ctx)
if err != nil {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

should "default" be a const?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

removed the config.WithSharedConfigProfile

return nil, err
}

// Create an STS client
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What is an STS Client?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Security Token Service Client and it is generally used when short-term access is needed to privileged AWS resources

stsClient := sts.NewFromConfig(cfg)

result, err := stsClient.GetCallerIdentity(ctx, &sts.GetCallerIdentityInput{})
if err != nil {
return nil, err
Expand All @@ -58,12 +62,16 @@ func (c *client) GetCallerIdentity(ctx context.Context, region string, accessKey
}

// ListRegions lists the AWS regions available (fetched from EC2.DescribeRegions API).
func (c *client) ListRegions(ctx context.Context, region string, accessKeyID string, secretAccessKey string) (*ec2.DescribeRegionsOutput, error) {
credentialsProvider := credentials.NewStaticCredentialsProvider(accessKeyID, secretAccessKey, "")
ec2Client := ec2.New(ec2.Options{
Region: region,
Credentials: credentialsProvider,
})
func (c *client) ListRegions(ctx context.Context) (*ec2.DescribeRegionsOutput, error) {
// Load the AWS SDK config and credentials
cfg, err := config.LoadDefaultConfig(ctx)
if err != nil {
return nil, err
}

// Create an EC2 client
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Not sure if we need these comments.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

removed it

ec2Client := ec2.NewFromConfig(cfg)

result, err := ec2Client.DescribeRegions(ctx, &ec2.DescribeRegionsInput{})
if err != nil {
return nil, err
Expand Down
24 changes: 12 additions & 12 deletions pkg/cli/aws/client_mock.go

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

50 changes: 39 additions & 11 deletions pkg/cli/cmd/radinit/aws.go
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@ package radinit

import (
"context"
"fmt"

"github.com/aws/aws-sdk-go-v2/service/ec2"
"github.com/charmbracelet/bubbles/textinput"
Expand All @@ -27,15 +28,15 @@ import (
)

const (
// QueryRegion is the region used for querying AWS before the user selects a region.
QueryRegion = "us-east-1"

selectAWSRegionPrompt = "Select the region you would like to deploy AWS resources to:"
enterAWSIAMAcessKeyIDPrompt = "Enter the IAM access key id:"
enterAWSIAMAcessKeyIDPlaceholder = "Enter IAM access key id..."
enterAWSIAMSecretAccessKeyPrompt = "Enter your IAM Secret Access Key:"
enterAWSIAMSecretAccessKeyPlaceholder = "Enter IAM secret access key..."
errNotEmptyTemplate = "%s cannot be empty"
confirmAWSAccountIDPromptFmt = "Use account id '%v'?"
enterAWSAccountIDPrompt = "Enter the account ID:"
enterAWSAccountIDPlaceholder = "Enter the account ID you want to use..."

awsAccessKeysCreateInstructionFmt = "\nAWS IAM Access keys (Access key ID and Secret access key) are required to access and create AWS resources.\n\nFor example, you can create one using the following command:\n\033[36maws iam create-access-key\033[0m\n\nFor more information refer to https://docs.aws.amazon.com/IAM/latest/UserGuide/id_credentials_access-keys.html.\n\n"
)
Expand All @@ -53,12 +54,24 @@ func (r *Runner) enterAWSCloudProvider(ctx context.Context) (*aws.Provider, erro
return nil, err
}

accountId, err := r.getAccountId(ctx, accessKeyID, secretAccessKey)
accountId, err := r.getAccountId(ctx)
if err != nil {
return nil, err
}

region, err := r.selectAWSRegion(ctx, QueryRegion, accessKeyID, secretAccessKey)
// addAccountID, err := prompt.YesOrNoPrompt(fmt.Sprintf(confirmAWSAccountIDPromptFmt, accountId), prompt.ConfirmYes, r.Prompter)
// if err != nil {
// return nil, err
// }

// if !addAccountID {
// accountId, err = r.Prompter.GetTextInput(enterAWSAccountIDPrompt, prompt.TextInputOptions{Placeholder: enterAWSAccountIDPlaceholder})
// if err != nil {
// return nil, err
// }
// }
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is this supposed to be uncommented?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

missed to removed thess, updated in the pr


region, err := r.selectAWSRegion(ctx)
if err != nil {
return nil, err
}
Expand All @@ -71,21 +84,36 @@ func (r *Runner) enterAWSCloudProvider(ctx context.Context) (*aws.Provider, erro
}, nil
}

func (r *Runner) getAccountId(ctx context.Context, accessKeyID, secretAccessKey string) (string, error) {
callerIdentityOutput, err := r.awsClient.GetCallerIdentity(ctx, QueryRegion, accessKeyID, secretAccessKey)
func (r *Runner) getAccountId(ctx context.Context) (string, error) {
callerIdentityOutput, err := r.awsClient.GetCallerIdentity(ctx)
if err != nil {
return "", clierrors.MessageWithCause(err, "AWS credential verification failed.")
return "", clierrors.MessageWithCause(err, "AWS Cloud Provider setup failed, please use aws configure to set up the configuration. More information :https://docs.aws.amazon.com/cli/latest/userguide/cli-chap-configure.html")
}
Copy link
Contributor

@nithyatsu nithyatsu Jul 23, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

here, would it help user if we add some detail, like "AWS credential verification failed. Please use aws configure to configure credentials and then try again " ? cc @Reshrahim


if callerIdentityOutput.Account == nil {
return "", clierrors.MessageWithCause(err, "AWS credential verification failed: Account ID is nil.")
}

return *callerIdentityOutput.Account, nil
accountID := *callerIdentityOutput.Account
addAccountID, err := prompt.YesOrNoPrompt(fmt.Sprintf(confirmAWSAccountIDPromptFmt, accountID), prompt.ConfirmYes, r.Prompter)
if err != nil {
return "", err
}

if !addAccountID {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Shouldn't this be true so that the user can enter account id? Like the logic here is that if addAccountID is false then ask the user to enter AWS account id? Am I missing something here?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

we retrieve the account id using the local config , and in the yes or no prompt ask the user if they want to use the account id from local config, if user selects no then we propmpt to provide the account id.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We can have a more meaningful name instead of addAccountID then I guess.

accountID, err = r.Prompter.GetTextInput(enterAWSAccountIDPrompt, prompt.TextInputOptions{Placeholder: enterAWSAccountIDPlaceholder})
if err != nil {
return "", err
}
}

return accountID, nil
}

func (r *Runner) selectAWSRegion(ctx context.Context, region, accessKeyID, secretAccessKey string) (string, error) {
listRegionsOutput, err := r.awsClient.ListRegions(ctx, region, accessKeyID, secretAccessKey)
// selectAWSRegion prompts the user to select an AWS region from a list of available regions.
// regions list is retrieved using the locally configured AWS account.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
// regions list is retrieved using the locally configured AWS account.
// Region list is retrieved using the locally configured AWS account.

func (r *Runner) selectAWSRegion(ctx context.Context) (string, error) {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

just for easier context, can you please add a comment here that it uses local aws config to list regions?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

added it

listRegionsOutput, err := r.awsClient.ListRegions(ctx)
if err != nil {
return "", clierrors.MessageWithCause(err, "Listing AWS regions failed.")
}
Expand Down
5 changes: 3 additions & 2 deletions pkg/cli/cmd/radinit/aws_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -45,8 +45,9 @@ func Test_enterAWSCloudProvider(t *testing.T) {

setAWSAccessKeyIDPrompt(prompter, "access-key-id")
setAWSSecretAccessKeyPrompt(prompter, "secret-access-key")
setAWSCallerIdentity(client, QueryRegion, "access-key-id", "secret-access-key", &sts.GetCallerIdentityOutput{Account: to.Ptr("account-id")})
setAWSListRegions(client, QueryRegion, "access-key-id", "secret-access-key", &ec2.DescribeRegionsOutput{Regions: ec2Regions})
setAWSCallerIdentity(client, &sts.GetCallerIdentityOutput{Account: to.Ptr("account-id")})
setAWSAccountIDConfirmPrompt(prompter, "account-id", prompt.ConfirmYes)
setAWSListRegions(client, &ec2.DescribeRegionsOutput{Regions: ec2Regions})
setAWSRegionPrompt(prompter, regions, "region")

provider, err := runner.enterAWSCloudProvider(context.Background())
Expand Down
20 changes: 14 additions & 6 deletions pkg/cli/cmd/radinit/init_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -1037,16 +1037,23 @@ func setAWSSecretAccessKeyPrompt(prompter *prompt.MockInterface, secretAccessKey
Return(secretAccessKey, nil).Times(1)
}

func setAWSCallerIdentity(client *aws.MockClient, region string, accessKeyID string, secretAccessKey string, callerIdentityOutput *sts.GetCallerIdentityOutput) {
func setAWSCallerIdentity(client *aws.MockClient, callerIdentityOutput *sts.GetCallerIdentityOutput) {
client.EXPECT().
GetCallerIdentity(gomock.Any(), region, accessKeyID, secretAccessKey).
GetCallerIdentity(gomock.Any()).
Return(callerIdentityOutput, nil).
Times(1)
}

func setAWSListRegions(client *aws.MockClient, region string, accessKeyID string, secretAccessKey string, ec2DescribeRegionsOutput *ec2.DescribeRegionsOutput) {
func setAWSAccountIDConfirmPrompt(prompter *prompt.MockInterface, accountName string, choice string) {
prompter.EXPECT().
GetListInput([]string{prompt.ConfirmYes, prompt.ConfirmNo}, fmt.Sprintf(confirmAWSAccountIDPromptFmt, accountName)).
Return(choice, nil).
Times(1)
}

func setAWSListRegions(client *aws.MockClient, ec2DescribeRegionsOutput *ec2.DescribeRegionsOutput) {
client.EXPECT().
ListRegions(gomock.Any(), region, accessKeyID, secretAccessKey).
ListRegions(gomock.Any()).
Return(ec2DescribeRegionsOutput, nil).
Times(1)
}
Expand All @@ -1055,8 +1062,9 @@ func setAWSListRegions(client *aws.MockClient, region string, accessKeyID string
func setAWSCloudProvider(prompter *prompt.MockInterface, client *aws.MockClient, provider aws.Provider) {
setAWSAccessKeyIDPrompt(prompter, provider.AccessKeyID)
setAWSSecretAccessKeyPrompt(prompter, provider.SecretAccessKey)
setAWSCallerIdentity(client, QueryRegion, provider.AccessKeyID, provider.SecretAccessKey, &sts.GetCallerIdentityOutput{Account: &provider.AccountID})
setAWSListRegions(client, QueryRegion, provider.AccessKeyID, provider.SecretAccessKey, &ec2.DescribeRegionsOutput{Regions: getMockAWSRegions()})
setAWSCallerIdentity(client, &sts.GetCallerIdentityOutput{Account: &provider.AccountID})
setAWSAccountIDConfirmPrompt(prompter, provider.AccountID, prompt.ConfirmYes)
setAWSListRegions(client, &ec2.DescribeRegionsOutput{Regions: getMockAWSRegions()})
setAWSRegionPrompt(prompter, getMockAWSRegionsString(), provider.Region)
}

Expand Down
Loading