Skip to content
This repository has been archived by the owner on Oct 12, 2023. It is now read-only.

Commit

Permalink
feat: return http 503 when IMDS healthcheck fails (#1206)
Browse files Browse the repository at this point in the history
* feat: return http 503 when IMDS healthcheck fails

Signed-off-by: Anish Ramasekar <anish.ramasekar@gmail.com>

* chore: update token request calls in demo and identityvalidator

Signed-off-by: Anish Ramasekar <anish.ramasekar@gmail.com>

* Review feedback

Signed-off-by: Anish Ramasekar <anish.ramasekar@gmail.com>
  • Loading branch information
aramase committed Dec 15, 2021
1 parent ba1c76d commit 10b8b0e
Show file tree
Hide file tree
Showing 9 changed files with 137 additions and 109 deletions.
23 changes: 10 additions & 13 deletions cmd/demo/main.go
Original file line number Diff line number Diff line change
Expand Up @@ -32,27 +32,23 @@ func main() {
flag.StringVar(&identityClientID, "identity-client-id", "", "The user-assigned identity client ID")
flag.Parse()

imdsTokenEndpoint, err := adal.GetMSIVMEndpoint()
if err != nil {
klog.Fatalf("failed to get IMDS token endpoint, error: %+v", err)
}

ticker := time.NewTicker(period)
defer ticker.Stop()

for ; true; <-ticker.C {
curlIMDSMetadataInstanceEndpoint()
t1 := getTokenFromIMDS(imdsTokenEndpoint)
t2 := getTokenFromIMDSWithUserAssignedID(imdsTokenEndpoint)
t1 := getTokenFromIMDS()
t2 := getTokenFromIMDSWithUserAssignedID()
if t1 == nil || t2 == nil || !strings.EqualFold(t1.AccessToken, t2.AccessToken) {
klog.Error("Tokens acquired from IMDS with and without identity client ID do not match")
}
klog.Infof("Try decoding your token %s at https://jwt.io", t1.AccessToken)
}
}

func getTokenFromIMDS(imdsTokenEndpoint string) *adal.Token {
spt, err := adal.NewServicePrincipalTokenFromMSIWithUserAssignedID(imdsTokenEndpoint, resourceName, identityClientID)
func getTokenFromIMDS() *adal.Token {
managedIdentityOpts := &adal.ManagedIdentityOptions{ClientID: identityClientID}
spt, err := adal.NewServicePrincipalTokenFromManagedIdentity(resourceName, managedIdentityOpts)
if err != nil {
klog.Errorf("failed to acquire a token from IMDS using user-assigned identity, error: %+v", err)
return nil
Expand All @@ -72,12 +68,13 @@ func getTokenFromIMDS(imdsTokenEndpoint string) *adal.Token {
return nil
}

klog.Infof("successfully acquired a service principal token from %s", imdsTokenEndpoint)
klog.Infof("successfully acquired a service principal token from IMDS")
return &token
}

func getTokenFromIMDSWithUserAssignedID(imdsTokenEndpoint string) *adal.Token {
spt, err := adal.NewServicePrincipalTokenFromMSIWithUserAssignedID(imdsTokenEndpoint, resourceName, identityClientID)
func getTokenFromIMDSWithUserAssignedID() *adal.Token {
managedIdentityOpts := &adal.ManagedIdentityOptions{ClientID: identityClientID}
spt, err := adal.NewServicePrincipalTokenFromManagedIdentity(resourceName, managedIdentityOpts)
if err != nil {
klog.Errorf("failed to acquire a token from IMDS using user-assigned identity, error: %+v", err)
return nil
Expand All @@ -97,7 +94,7 @@ func getTokenFromIMDSWithUserAssignedID(imdsTokenEndpoint string) *adal.Token {
return nil
}

klog.Infof("successfully acquired a service principal token from %s using a user-assigned identity (%s)", imdsTokenEndpoint, identityClientID)
klog.Infof("successfully acquired a service principal token from IMDS using a user-assigned identity (%s)", identityClientID)
return &token
}

Expand Down
41 changes: 14 additions & 27 deletions pkg/auth/auth.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,14 +3,12 @@ package auth
import (
"context"
"crypto/rsa"
"fmt"
"time"

"github.com/Azure/aad-pod-identity/pkg/metrics"
"github.com/Azure/aad-pod-identity/version"

"github.com/Azure/go-autorest/autorest/adal"

"golang.org/x/crypto/pkcs12"
"k8s.io/klog/v2"
)
Expand Down Expand Up @@ -38,19 +36,14 @@ func GetServicePrincipalTokenFromMSI(resource string) (_ *adal.Token, err error)
}
}()

msiEndpoint, err := adal.GetMSIVMEndpoint()
if err != nil {
return nil, fmt.Errorf("failed to get the MSI endpoint, error: %+v", err)
}
// Set up the configuration of the service principal
spt, err := adal.NewServicePrincipalTokenFromMSI(msiEndpoint, resource)
spt, err := adal.NewServicePrincipalTokenFromManagedIdentity(resource, nil)
if err != nil {
return nil, fmt.Errorf("failed to acquire a token for MSI, error: %+v", err)
return nil, err
}
// obtain a fresh token
err = spt.Refresh()
if err != nil {
return nil, fmt.Errorf("failed to refresh token, error: %+v", err)
return nil, err
}
token := spt.Token()
return &token, nil
Expand All @@ -73,22 +66,16 @@ func GetServicePrincipalTokenFromMSIWithUserAssignedID(clientID, resource string
}
}()

msiEndpoint, err := adal.GetMSIVMEndpoint()
if err != nil {
return nil, fmt.Errorf("failed to get the MSI endpoint, error: %+v", err)
}
// The ID of the user for whom the token is requested
userAssignedID := clientID
// Set up the configuration of the service principal
spt, err := adal.NewServicePrincipalTokenFromMSIWithUserAssignedID(msiEndpoint, resource, userAssignedID)
managedIdentityOptions := &adal.ManagedIdentityOptions{ClientID: clientID}
spt, err := adal.NewServicePrincipalTokenFromManagedIdentity(resource, managedIdentityOptions)
if err != nil {
return nil, fmt.Errorf("failed to acquire a token using the MSI VM extension, error: %+v", err)
return nil, err
}

// obtain a fresh token
err = spt.Refresh()
if err != nil {
return nil, fmt.Errorf("failed to refresh token, error: %+v", err)
return nil, err
}
token := spt.Token()
return &token, nil
Expand Down Expand Up @@ -127,7 +114,7 @@ func GetServicePrincipalToken(adEndpointFromSpec, tenantID, clientID, secret, re
func newServicePrincipalToken(activeDirectoryEndpoint, tenantID, clientID, secret, resource string) ([]*adal.Token, error) {
oauthConfig, err := adal.NewOAuthConfig(activeDirectoryEndpoint, tenantID)
if err != nil {
return nil, fmt.Errorf("failed to create OAuth config, error: %+v", err)
return nil, err
}
spt, err := adal.NewServicePrincipalToken(*oauthConfig, clientID, secret, resource)
if err != nil {
Expand All @@ -136,7 +123,7 @@ func newServicePrincipalToken(activeDirectoryEndpoint, tenantID, clientID, secre
// obtain a fresh token
err = spt.Refresh()
if err != nil {
return nil, fmt.Errorf("failed to refresh token, error: %+v", err)
return nil, err
}
token := spt.Token()
return []*adal.Token{&token}, nil
Expand All @@ -148,15 +135,15 @@ func newServicePrincipalToken(activeDirectoryEndpoint, tenantID, clientID, secre
func newMultiTenantServicePrincipalToken(activeDirectoryEndpoint, primaryTenantID, clientID, secret, resource string, auxiliaryTenantIDs []string) ([]*adal.Token, error) {
oauthConfig, err := adal.NewMultiTenantOAuthConfig(activeDirectoryEndpoint, primaryTenantID, auxiliaryTenantIDs, adal.OAuthOptions{})
if err != nil {
return nil, fmt.Errorf("failed to create MultiTenantOAuth config, error: %+v", err)
return nil, err
}
spt, err := adal.NewMultiTenantServicePrincipalToken(oauthConfig, clientID, secret, resource)
if err != nil {
return nil, err
}
err = spt.RefreshWithContext(context.TODO())
if err != nil {
return nil, fmt.Errorf("failed to refresh token, error: %+v", err)
return nil, err
}

var tokens []*adal.Token
Expand Down Expand Up @@ -195,12 +182,12 @@ func GetServicePrincipalTokenWithCertificate(adEndpointFromSpec, tenantID, clien
}
oauthConfig, err := adal.NewOAuthConfig(activeDirectoryEndpoint, tenantID)
if err != nil {
return nil, fmt.Errorf("failed to create OAuth config, error: %+v", err)
return nil, err
}

privateKey, cert, err := pkcs12.Decode(certificate, password)
if err != nil {
return nil, fmt.Errorf("failed to decode certificate, error: %+v", err)
return nil, err
}

spt, err := adal.NewServicePrincipalTokenFromCertificate(*oauthConfig, clientID, cert, privateKey.(*rsa.PrivateKey), resource)
Expand All @@ -210,7 +197,7 @@ func GetServicePrincipalTokenWithCertificate(adEndpointFromSpec, tenantID, clien
// obtain a fresh token
err = spt.Refresh()
if err != nil {
return nil, fmt.Errorf("failed to refresh token, error: %+v", err)
return nil, err
}
token := spt.Token()
return &token, nil
Expand Down
15 changes: 15 additions & 0 deletions pkg/auth/errors.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,15 @@
package auth

import "github.com/Azure/go-autorest/autorest/adal"

// IsTokenRefreshError returns true if the error is a TokenRefreshError.
// This method can be used to distinguish health check errors from token refresh errors.
func IsTokenRefreshError(err error) bool {
_, ok := err.(adal.TokenRefreshError)
return ok
}

// IsHealthCheckError returns true if the error is not a token refresh error.
func IsHealthCheckError(err error) bool {
return !IsTokenRefreshError(err)
}
67 changes: 67 additions & 0 deletions pkg/auth/errors_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,67 @@
package auth

import (
"errors"
"testing"

"github.com/Azure/go-autorest/autorest/adal"
)

type testError struct {
adal.TokenRefreshError
}

func TestIsTokenRefreshError(t *testing.T) {
tests := []struct {
name string
err error
want bool
}{
{
name: "not a token refresh error",
err: errors.New("some error"),
},
{
name: "token refresh error",
err: testError{},
want: true,
},
}

for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
got := IsTokenRefreshError(tt.err)
if got != tt.want {
t.Errorf("IsTokenRefreshError() = %v, want %v", got, tt.want)
}
})
}
}

func TestIsHealthCheckError(t *testing.T) {
tests := []struct {
name string
err error
want bool
}{
{
name: "health check error",
err: errors.New("some error"),
want: true,
},
{
name: "not health check error",
err: testError{},
want: false,
},
}

for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
got := IsHealthCheckError(tt.err)
if got != tt.want {
t.Errorf("IsHealthCheckError() = %v, want %v", got, tt.want)
}
})
}
}
9 changes: 2 additions & 7 deletions pkg/cloudprovider/cloudprovider.go
Original file line number Diff line number Diff line change
Expand Up @@ -106,21 +106,16 @@ func (c *Client) Init() error {

var spt *adal.ServicePrincipalToken
if c.Config.UseManagedIdentityExtension {
// MSI endpoint is required for both types of MSI - system assigned and user assigned.
msiEndpoint, err := adal.GetMSIVMEndpoint()
if err != nil {
return fmt.Errorf("failed to get MSI endpoint, error: %+v", err)
}
// UserAssignedIdentityID is empty, so we are going to use system assigned MSI
if c.Config.UserAssignedIdentityID == "" {
klog.Infof("MIC using system assigned identity for authentication.")
spt, err = adal.NewServicePrincipalTokenFromMSI(msiEndpoint, azureEnv.ResourceManagerEndpoint)
spt, err = adal.NewServicePrincipalTokenFromMSI("", azureEnv.ResourceManagerEndpoint)
if err != nil {
return fmt.Errorf("failed to get token from system-assigned identity, error: %+v", err)
}
} else { // User assigned identity usage.
klog.Infof("MIC using user assigned identity: %s for authentication.", utils.RedactClientID(c.Config.UserAssignedIdentityID))
spt, err = adal.NewServicePrincipalTokenFromMSIWithUserAssignedID(msiEndpoint, azureEnv.ResourceManagerEndpoint, c.Config.UserAssignedIdentityID)
spt, err = adal.NewServicePrincipalTokenFromMSIWithUserAssignedID("", azureEnv.ResourceManagerEndpoint, c.Config.UserAssignedIdentityID)
if err != nil {
return fmt.Errorf("failed to get token from user-assigned identity, error: %+v", err)
}
Expand Down
18 changes: 16 additions & 2 deletions pkg/nmi/server/server.go
Original file line number Diff line number Diff line change
Expand Up @@ -245,7 +245,14 @@ func (s *Server) hostHandler(w http.ResponseWriter, r *http.Request) (ns string)
tokens, err := s.TokenClient.GetTokens(r.Context(), tokenRequest.ClientID, tokenRequest.Resource, *podID)
if err != nil {
klog.Errorf("failed to get service principal token for pod:%s/%s, error: %+v", podns, podname, err)
http.Error(w, err.Error(), http.StatusForbidden)
httpErrorCode := http.StatusForbidden
if auth.IsHealthCheckError(err) {
// the adal library performs a health check prior to making the token request
// if the health check fails, we want to return a 503 instead of 403
// for health check failures, the error is not a token refresh error
httpErrorCode = http.StatusServiceUnavailable
}
http.Error(w, err.Error(), httpErrorCode)
return
}
nmiResp := NMIResponse{
Expand Down Expand Up @@ -390,7 +397,14 @@ func (s *Server) msiHandler(w http.ResponseWriter, r *http.Request) (ns string)
tokens, err := s.TokenClient.GetTokens(r.Context(), tokenRequest.ClientID, tokenRequest.Resource, *podID)
if err != nil {
klog.Errorf("failed to get service principal token for pod: %s/%s, error: %+v", podns, podname, err)
http.Error(w, err.Error(), http.StatusForbidden)
httpErrorCode := http.StatusForbidden
if auth.IsHealthCheckError(err) {
// the adal library performs a health check prior to making the token request
// if the health check fails, we want to return a 503 instead of 403
// for health check failures, the error is not a token refresh error
httpErrorCode = http.StatusServiceUnavailable
}
http.Error(w, err.Error(), httpErrorCode)
return
}

Expand Down
8 changes: 1 addition & 7 deletions test/image/identityvalidator/identityvalidator.go
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,6 @@ import (
"time"

"github.com/Azure/azure-sdk-for-go/services/keyvault/2016-10-01/keyvault"
"github.com/Azure/go-autorest/autorest/adal"
"k8s.io/klog/v2"
)

Expand Down Expand Up @@ -56,7 +55,6 @@ func main() {

klog.Infof("starting identity validator pod %s/%s with pod IP %s", podnamespace, podname, podip)

imdsTokenEndpoint, _ := adal.GetMSIVMEndpoint()
kvt := &keyvaultTester{
client: keyvault.New(),
subscriptionID: subscriptionID,
Expand All @@ -66,10 +64,6 @@ func main() {
secretName: keyvaultSecretName,
secretVersion: keyvaultSecretVersion,
secretValue: keyvaultSecretValue,
imdsTokenEndpoint: imdsTokenEndpoint,
}
spt := &servicePrincipalTester{
imdsTokenEndpoint: imdsTokenEndpoint,
}

var wg sync.WaitGroup
Expand All @@ -78,7 +72,7 @@ func main() {
for _, assert := range []assertFunction{
kvt.assertWithIdentityClientID,
kvt.assertWithIdentityResourceID,
spt.assertWithSystemAssignedIdentity,
assertWithSystemAssignedIdentity,
} {
wg.Add(1)
go func(assert assertFunction) {
Expand Down
Loading

0 comments on commit 10b8b0e

Please sign in to comment.