Skip to content

Commit

Permalink
Unit tests and minor code cleanup for GetTenantID. (Azure#3467)
Browse files Browse the repository at this point in the history
  • Loading branch information
tariq1890 authored and kkmsft committed Jul 20, 2018
1 parent bdebc29 commit 4805b4a
Show file tree
Hide file tree
Showing 3 changed files with 106 additions and 5 deletions.
6 changes: 2 additions & 4 deletions pkg/acsengine/tenantid.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,18 +5,16 @@ import (
"regexp"

"github.com/Azure/azure-sdk-for-go/arm/resources/subscriptions"
"github.com/Azure/go-autorest/autorest/azure"
"github.com/pkg/errors"
log "github.com/sirupsen/logrus"
)

// GetTenantID figures out the AAD tenant ID of the subscription by making an
// unauthenticated request to the Get Subscription Details endpoint and parses
// the value from WWW-Authenticate header.
func GetTenantID(env azure.Environment, subscriptionID string) (string, error) {
func GetTenantID(resourceManagerEndpoint string, subscriptionID string) (string, error) {
const hdrKey = "WWW-Authenticate"
c := subscriptions.NewGroupClient()
c.BaseURI = env.ResourceManagerEndpoint
c := subscriptions.NewGroupClientWithBaseURI(resourceManagerEndpoint)

log.Debugf("Resolving tenantID for subscriptionID: %s", subscriptionID)

Expand Down
103 changes: 103 additions & 0 deletions pkg/acsengine/tenantid_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,103 @@
package acsengine

import (
"net/http"
"net/http/httptest"
"testing"
)

var (
mux *http.ServeMux
server *httptest.Server
)

func setup() func() {
mux = http.NewServeMux()
server = httptest.NewServer(mux)

return func() {
server.Close()
}
}

func TestGetTenantID(t *testing.T) {

tearDown := setup()
defer tearDown()

expectedTenantID := "96fe9d1-6171-40aa-945b-4c64b63bf655"
mux.HandleFunc("/subscriptions/foobarsubscription", func(w http.ResponseWriter, r *http.Request) {
w.Header().Set("WWW-Authenticate", `authorization_uri="https://login.windows.net/`+expectedTenantID+`"`)
w.WriteHeader(http.StatusUnauthorized)
w.Write([]byte("Unauthorized"))
})

tenantID, err := GetTenantID(server.URL, "foobarsubscription")

if err != nil {
t.Error("Did not expect error")
}

if tenantID != expectedTenantID {
t.Errorf("expected tenant Id : %s, but got %s", expectedTenantID, tenantID)
}
}

func TestGetTenantID_UnexpectedResponse(t *testing.T) {

tearDown := setup()
defer tearDown()

mux.HandleFunc("/subscriptions/foobarsubscription", func(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(http.StatusBadRequest)
return
})

_, err := GetTenantID(server.URL, "foobarsubscription")

expectedMsg := "Unexpected response from Get Subscription: 400"

if err == nil || err.Error() != expectedMsg {
t.Errorf("expected error with msg : %s to be thrown", expectedMsg)
}
}

func TestGetTenantID_InvalidHeader(t *testing.T) {

tearDown := setup()
defer tearDown()

mux.HandleFunc("/subscriptions/foobarsubscription", func(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(http.StatusUnauthorized)
w.Header().Set("fookey", "bazvalue")
return
})

_, err := GetTenantID(server.URL, "foobarsubscription")

expectedMsg := "Header WWW-Authenticate not found in Get Subscription response"

if err == nil || err.Error() != expectedMsg {
t.Errorf("expected error with msg : %s to be thrown", expectedMsg)
}
}

func TestGetTenantID_InvalidHeaderValue(t *testing.T) {

tearDown := setup()
defer tearDown()

mux.HandleFunc("/subscriptions/foobarsubscription", func(w http.ResponseWriter, r *http.Request) {
w.Header().Set("WWW-Authenticate", `sample_invalid_auth_uri`)
w.WriteHeader(http.StatusUnauthorized)
w.Write([]byte("Unauthorized"))
})

_, err := GetTenantID(server.URL, "foobarsubscription")

expectedMsg := "Could not find the tenant ID in header: WWW-Authenticate \"sample_invalid_auth_uri\""

if err == nil || err.Error() != expectedMsg {
t.Errorf("expected error with msg : %s to be thrown", expectedMsg)
}
}
2 changes: 1 addition & 1 deletion pkg/armhelpers/azureclient.go
Original file line number Diff line number Diff line change
Expand Up @@ -239,7 +239,7 @@ func tryLoadCachedToken(cachePath string) (*adal.Token, error) {
}

func getOAuthConfig(env azure.Environment, subscriptionID string) (*adal.OAuthConfig, string, error) {
tenantID, err := acsengine.GetTenantID(env, subscriptionID)
tenantID, err := acsengine.GetTenantID(env.ResourceManagerEndpoint, subscriptionID)
if err != nil {
return nil, "", err
}
Expand Down

0 comments on commit 4805b4a

Please sign in to comment.