Skip to content

Commit

Permalink
OCM-9613 | feat: Allow billing account update via the cluster edit co…
Browse files Browse the repository at this point in the history
…mmand
  • Loading branch information
cristianoveiga committed Jul 31, 2024
1 parent 4e4a45f commit 5281a3b
Show file tree
Hide file tree
Showing 6 changed files with 156 additions and 72 deletions.
61 changes: 7 additions & 54 deletions cmd/create/cluster/cmd.go
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,6 @@ import (
passwordValidator "github.com/openshift-online/ocm-common/pkg/idp/validations"
diskValidator "github.com/openshift-online/ocm-common/pkg/machinepool/validations"
kmsArnRegexpValidator "github.com/openshift-online/ocm-common/pkg/resource/validations"
accountsv1 "github.com/openshift-online/ocm-sdk-go/accountsmgmt/v1"
v1 "github.com/openshift-online/ocm-sdk-go/clustersmgmt/v1"
"github.com/spf13/cobra"
"github.com/spf13/pflag"
Expand Down Expand Up @@ -721,7 +720,7 @@ func initFlags(cmd *cobra.Command) {
&args.billingAccount,
"billing-account",
"",
"Account used for billing subscriptions purchased via the AWS marketplace",
"Account used for billing subscriptions purchased via the AWS console for ROSA",
)

flags.BoolVar(
Expand Down Expand Up @@ -1106,7 +1105,7 @@ func run(cmd *cobra.Command, _ []string) {
if !isHcpBillingTechPreview {

if billingAccount != "" && !ocm.IsValidAWSAccount(billingAccount) {
r.Reporter.Errorf("Billing account is invalid. Run the command again with a valid billing account. %s",
r.Reporter.Errorf("Billing account number is not valid. Rerun the command with a valid billing account number. %s",
listBillingAccountMessage)
os.Exit(1)
}
Expand Down Expand Up @@ -1150,20 +1149,20 @@ func run(cmd *cobra.Command, _ []string) {
billingAccount = aws.ParseOption(billingAccount)
}

err := validateBillingAccount(billingAccount)
err := ocm.ValidateBillingAccount(billingAccount)
if err != nil {
r.Reporter.Errorf("%v", err)
os.Exit(1)
}

// Get contract info
contracts, isContractEnabled := GetBillingAccountContracts(cloudAccounts, billingAccount)
contracts, isContractEnabled := ocm.GetBillingAccountContracts(cloudAccounts, billingAccount)

if billingAccount != awsCreator.AccountID {
r.Reporter.Infof(
"The selected AWS billing account is a different account than your AWS infrastructure account." +
"The selected AWS billing account is a different account than your AWS infrastructure account. " +
"The AWS billing account will be charged for subscription usage. " +
"The AWS infrastructure account will be used for managing the cluster.",
"The AWS infrastructure account contains the ROSA infrastructure.",
)
} else {
r.Reporter.Infof("Using '%s' as billing account.",
Expand All @@ -1172,7 +1171,7 @@ func run(cmd *cobra.Command, _ []string) {

if isContractEnabled && len(contracts) > 0 {
//currently, an AWS account will have only one ROSA HCP active contract at a time
contractDisplay := GenerateContractDisplay(contracts[0])
contractDisplay := ocm.GenerateContractDisplay(contracts[0])
r.Reporter.Infof(contractDisplay)
}
}
Expand Down Expand Up @@ -3341,14 +3340,6 @@ func clusterConfigFor(
return clusterConfig, nil
}

func validateBillingAccount(billingAccount string) error {
if billingAccount == "" || !ocm.IsValidAWSAccount(billingAccount) {
return fmt.Errorf("billing account is invalid. Run the command again with a valid billing account. %s",
listBillingAccountMessage)
}
return nil
}

// validateNetworkType ensure user passes a valid network type parameter at creation
func validateNetworkType(networkType string) error {
if networkType == "" {
Expand All @@ -3361,44 +3352,6 @@ func validateNetworkType(networkType string) error {
return nil
}

func GetBillingAccountContracts(cloudAccounts []*accountsv1.CloudAccount,
billingAccount string) ([]*accountsv1.Contract, bool) {
var contracts []*accountsv1.Contract
for _, account := range cloudAccounts {
if account.CloudAccountID() == billingAccount {
contracts = account.Contracts()
if ocm.HasValidContracts(account) {
return contracts, true
}
}
}
return contracts, false
}

func GenerateContractDisplay(contract *accountsv1.Contract) string {
format := "Jan 02, 2006"
dimensions := contract.Dimensions()

numberOfVCPUs, numberOfClusters := ocm.GetNumsOfVCPUsAndClusters(dimensions)

contractDisplay := fmt.Sprintf(`
+---------------------+----------------+
| Start Date |%s |
| End Date |%s |
| Number of vCPUs: |'%s' |
| Number of clusters: |'%s' |
+---------------------+----------------+
`,
contract.StartDate().Format(format),
contract.EndDate().Format(format),
strconv.Itoa(numberOfVCPUs),
strconv.Itoa(numberOfClusters),
)

return contractDisplay

}

func validateOperatorRolesAvailabilityUnderUserAwsAccount(awsClient aws.Client,
operatorIAMRoleList []ocm.OperatorIAMRole) error {
for _, role := range operatorIAMRoleList {
Expand Down
16 changes: 7 additions & 9 deletions cmd/create/cluster/cmd_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -266,7 +266,7 @@ var _ = Describe("Validate cloud accounts", func() {
Dimensions(v1.NewContractDimension().Name("control_plane").Value("4")))
cloudAccount, err := mockCloudAccount.Build()
Expect(err).NotTo(HaveOccurred())
_, isContractEnabled := GetBillingAccountContracts([]*v1.CloudAccount{cloudAccount}, "1234567")
_, isContractEnabled := ocm.GetBillingAccountContracts([]*v1.CloudAccount{cloudAccount}, "1234567")
Expect(isContractEnabled).To(Equal(true))
})

Expand All @@ -287,7 +287,7 @@ var _ = Describe("Validate cloud accounts", func() {
" | Number of clusters: |'4' | \n" +
" +---------------------+----------------+ \n"

contractDisplay := GenerateContractDisplay(mockContract)
contractDisplay := ocm.GenerateContractDisplay(mockContract)

Expect(contractDisplay).To(Equal(expected))
})
Expand Down Expand Up @@ -423,24 +423,22 @@ var _ = Describe("validateBillingAccount()", func() {

It("OK: valid billing account", func() {
validBillingAccount := "123456789012"
err := validateBillingAccount(validBillingAccount)
err := ocm.ValidateBillingAccount(validBillingAccount)
Expect(err).NotTo(HaveOccurred())
})

It("KO: fails to validate a wrong billing account", func() {
wrongBillingAccount := "123"
err := validateBillingAccount(wrongBillingAccount)
err := ocm.ValidateBillingAccount(wrongBillingAccount)
Expect(err).To(HaveOccurred())
Expect(err.Error()).To(Equal("billing account is invalid. Run the command again with a valid billing account." +
" To see the list of billing account options, you can use interactive mode by passing '-i'."))
Expect(err.Error()).To(Equal("Billing account number is not valid. Rerun the command with a valid billing account number"))
})

It("KO: fails to validate an empty billing account", func() {
wrongBillingAccount := ""
err := validateBillingAccount(wrongBillingAccount)
err := ocm.ValidateBillingAccount(wrongBillingAccount)
Expect(err).To(HaveOccurred())
Expect(err.Error()).To(Equal("billing account is invalid. Run the command again with a valid billing account." +
" To see the list of billing account options, you can use interactive mode by passing '-i'."))
Expect(err.Error()).To(Equal("Billing account number is not valid. Rerun the command with a valid billing account number"))
})

})
Expand Down
83 changes: 82 additions & 1 deletion cmd/edit/cluster/cmd.go
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,9 @@ var args struct {
// Audit log forwarding
auditLogRoleARN string

// HCP options:
billingAccount string

// Other options
additionalAllowedPrincipals []string
}
Expand Down Expand Up @@ -160,6 +163,13 @@ func init() {
"to be added to the Hosted Control Plane's VPC Endpoint Service to enable additional "+
"VPC Endpoint connection requests to be automatically accepted.",
)

flags.StringVar(
&args.billingAccount,
"billing-account",
"",
"Account used for billing subscriptions purchased via the AWS console for ROSA",
)
}

func run(cmd *cobra.Command, _ []string) {
Expand All @@ -173,7 +183,7 @@ func run(cmd *cobra.Command, _ []string) {
changedFlags := false
for _, flag := range []string{"expiration-time", "expiration", "private",
"disable-workload-monitoring", "http-proxy", "https-proxy", "no-proxy",
"additional-trust-bundle-file", "additional-allowed-principals", "audit-log-arn"} {
"additional-trust-bundle-file", "additional-allowed-principals", "audit-log-arn", "billing-account"} {
if cmd.Flags().Changed(flag) {
changedFlags = true
}
Expand Down Expand Up @@ -651,6 +661,77 @@ func run(cmd *cobra.Command, _ []string) {
}
}

var billingAccount string
if cmd.Flags().Changed("billing-account") {
billingAccount = args.billingAccount

if billingAccount != "" && !aws.IsHostedCP(cluster) {
r.Reporter.Errorf("Billing accounts are only supported for Hosted Control Plane clusters")
os.Exit(1)
}
if billingAccount != "" && !ocm.IsValidAWSAccount(billingAccount) {
r.Reporter.Errorf("Billing account number is not valid. Rerun the command with a valid billing account number")
os.Exit(1)
}
} else {
billingAccount = cluster.AWS().BillingAccountID()
}

if interactive.Enabled() && aws.IsHostedCP(cluster) {
cloudAccounts, err := r.OCMClient.GetBillingAccounts()
if err != nil {
r.Reporter.Errorf("%s", err)
os.Exit(1)
}

billingAccounts := ocm.GenerateBillingAccountsList(cloudAccounts)
if len(billingAccounts) > 0 {
billingAccount, err = interactive.GetOption(interactive.Input{
Question: "Update billing account",
Help: cmd.Flags().Lookup("billing-account").Usage,
Default: billingAccount,
DefaultMessage: fmt.Sprintf("current = '%s'", billingAccount),
Required: true,
Options: billingAccounts,
})

if err != nil {
r.Reporter.Errorf("Expected a valid billing account: '%s'", err)
os.Exit(1)
}

billingAccount = aws.ParseOption(billingAccount)
}

err = ocm.ValidateBillingAccount(billingAccount)
if err != nil {
r.Reporter.Errorf("%v", err)
os.Exit(1)
}

// Get contract info
contracts, isContractEnabled := ocm.GetBillingAccountContracts(cloudAccounts, billingAccount)

if billingAccount != r.Creator.AccountID {
r.Reporter.Infof(
"The selected AWS billing account is a different account than your AWS infrastructure account. " +
"The AWS billing account will be charged for subscription usage. " +
"The AWS infrastructure account contains the ROSA infrastructure.",
)
}

if isContractEnabled && len(contracts) > 0 {
//currently, an AWS account will have only one ROSA HCP active contract at a time
contractDisplay := ocm.GenerateContractDisplay(contracts[0])
r.Reporter.Infof(contractDisplay)
}
}

// sets the billing account only if it has changed
if billingAccount != "" && billingAccount != cluster.AWS().BillingAccountID() {
clusterConfig.BillingAccount = billingAccount
}

r.Reporter.Debugf("Updating cluster '%s'", clusterKey)
err = r.OCMClient.UpdateCluster(cluster.ID(), r.Creator, clusterConfig)
if err != nil {
Expand Down
19 changes: 12 additions & 7 deletions pkg/interactive/interactive.go
Original file line number Diff line number Diff line change
Expand Up @@ -33,12 +33,13 @@ import (
)

type Input struct {
Question string
Help string
Options []string
Default interface{}
Required bool
Validators []Validator
Question string
Help string
Options []string
Default interface{}
DefaultMessage string
Required bool
Validators []Validator
}

// Gets string input from the command line
Expand Down Expand Up @@ -179,7 +180,11 @@ func GetOption(input Input) (a string, err error) {
}
defaultMessage := ""
if dflt != "" {
defaultMessage = fmt.Sprintf("default = '%s'", dflt)
if input.DefaultMessage != "" {
defaultMessage = input.DefaultMessage
} else {
defaultMessage = fmt.Sprintf("default = '%s'", dflt)
}
}
question := input.Question
optionalMessage := ""
Expand Down
44 changes: 44 additions & 0 deletions pkg/ocm/billing.go
Original file line number Diff line number Diff line change
Expand Up @@ -98,3 +98,47 @@ func HasValidContracts(cloudAccount *v1.CloudAccount) bool {
func IsValidAWSAccount(account string) bool {
return awsAccountRegexp.MatchString(account)
}

func ValidateBillingAccount(billingAccount string) error {
if billingAccount == "" || !IsValidAWSAccount(billingAccount) {
return fmt.Errorf("Billing account number is not valid. Rerun the command with a valid billing account number")
}
return nil
}

func GenerateContractDisplay(contract *v1.Contract) string {
format := "Jan 02, 2006"
dimensions := contract.Dimensions()

numberOfVCPUs, numberOfClusters := GetNumsOfVCPUsAndClusters(dimensions)

contractDisplay := fmt.Sprintf(`
+---------------------+----------------+
| Start Date |%s |
| End Date |%s |
| Number of vCPUs: |'%s' |
| Number of clusters: |'%s' |
+---------------------+----------------+
`,
contract.StartDate().Format(format),
contract.EndDate().Format(format),
strconv.Itoa(numberOfVCPUs),
strconv.Itoa(numberOfClusters),
)

return contractDisplay
}

func GetBillingAccountContracts(cloudAccounts []*v1.CloudAccount,
billingAccount string) ([]*v1.Contract, bool) {
var contracts []*v1.Contract
for _, account := range cloudAccounts {
if account.CloudAccountID() == billingAccount {
contracts = account.Contracts()
if HasValidContracts(account) {
return contracts, true
}
}
}
return contracts, false
}
5 changes: 4 additions & 1 deletion pkg/ocm/clusters.go
Original file line number Diff line number Diff line change
Expand Up @@ -629,7 +629,7 @@ func (c *Client) UpdateCluster(clusterKey string, creator *aws.Creator, config S
clusterBuilder.Hypershift(hyperShiftBuilder)
}

if config.AuditLogRoleARN != nil || config.AdditionalAllowedPrincipals != nil {
if config.AuditLogRoleARN != nil || config.AdditionalAllowedPrincipals != nil || config.BillingAccount != "" {
awsBuilder := cmv1.NewAWS()
if config.AdditionalAllowedPrincipals != nil {
awsBuilder = awsBuilder.AdditionalAllowedPrincipals(config.AdditionalAllowedPrincipals...)
Expand All @@ -639,6 +639,9 @@ func (c *Client) UpdateCluster(clusterKey string, creator *aws.Creator, config S
auditLogBuiler := cmv1.NewAuditLog().RoleArn(*config.AuditLogRoleARN)
awsBuilder = awsBuilder.AuditLog(auditLogBuiler)
}
if config.BillingAccount != "" {
awsBuilder.BillingAccountID(config.BillingAccount)
}
clusterBuilder.AWS(awsBuilder)
}

Expand Down

0 comments on commit 5281a3b

Please sign in to comment.