diff --git a/pkg/acsengine/tenantid.go b/pkg/acsengine/tenantid.go index 15e0e26939..aa807ab0ee 100644 --- a/pkg/acsengine/tenantid.go +++ b/pkg/acsengine/tenantid.go @@ -5,7 +5,6 @@ import ( "regexp" "github.com/Azure/azure-sdk-for-go/arm/resources/subscriptions" - "github.com/Azure/go-autorest/autorest/azure" "github.com/pkg/errors" log "github.com/sirupsen/logrus" ) @@ -13,10 +12,9 @@ import ( // GetTenantID figures out the AAD tenant ID of the subscription by making an // unauthenticated request to the Get Subscription Details endpoint and parses // the value from WWW-Authenticate header. -func GetTenantID(env azure.Environment, subscriptionID string) (string, error) { +func GetTenantID(resourceManagerEndpoint string, subscriptionID string) (string, error) { const hdrKey = "WWW-Authenticate" - c := subscriptions.NewGroupClient() - c.BaseURI = env.ResourceManagerEndpoint + c := subscriptions.NewGroupClientWithBaseURI(resourceManagerEndpoint) log.Debugf("Resolving tenantID for subscriptionID: %s", subscriptionID) diff --git a/pkg/acsengine/tenantid_test.go b/pkg/acsengine/tenantid_test.go new file mode 100644 index 0000000000..45af91d348 --- /dev/null +++ b/pkg/acsengine/tenantid_test.go @@ -0,0 +1,103 @@ +package acsengine + +import ( + "net/http" + "net/http/httptest" + "testing" +) + +var ( + mux *http.ServeMux + server *httptest.Server +) + +func setup() func() { + mux = http.NewServeMux() + server = httptest.NewServer(mux) + + return func() { + server.Close() + } +} + +func TestGetTenantID(t *testing.T) { + + tearDown := setup() + defer tearDown() + + expectedTenantID := "96fe9d1-6171-40aa-945b-4c64b63bf655" + mux.HandleFunc("/subscriptions/foobarsubscription", func(w http.ResponseWriter, r *http.Request) { + w.Header().Set("WWW-Authenticate", `authorization_uri="https://login.windows.net/`+expectedTenantID+`"`) + w.WriteHeader(http.StatusUnauthorized) + w.Write([]byte("Unauthorized")) + }) + + tenantID, err := GetTenantID(server.URL, "foobarsubscription") + + if err != nil { + t.Error("Did not expect error") + } + + if tenantID != expectedTenantID { + t.Errorf("expected tenant Id : %s, but got %s", expectedTenantID, tenantID) + } +} + +func TestGetTenantID_UnexpectedResponse(t *testing.T) { + + tearDown := setup() + defer tearDown() + + mux.HandleFunc("/subscriptions/foobarsubscription", func(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(http.StatusBadRequest) + return + }) + + _, err := GetTenantID(server.URL, "foobarsubscription") + + expectedMsg := "Unexpected response from Get Subscription: 400" + + if err == nil || err.Error() != expectedMsg { + t.Errorf("expected error with msg : %s to be thrown", expectedMsg) + } +} + +func TestGetTenantID_InvalidHeader(t *testing.T) { + + tearDown := setup() + defer tearDown() + + mux.HandleFunc("/subscriptions/foobarsubscription", func(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(http.StatusUnauthorized) + w.Header().Set("fookey", "bazvalue") + return + }) + + _, err := GetTenantID(server.URL, "foobarsubscription") + + expectedMsg := "Header WWW-Authenticate not found in Get Subscription response" + + if err == nil || err.Error() != expectedMsg { + t.Errorf("expected error with msg : %s to be thrown", expectedMsg) + } +} + +func TestGetTenantID_InvalidHeaderValue(t *testing.T) { + + tearDown := setup() + defer tearDown() + + mux.HandleFunc("/subscriptions/foobarsubscription", func(w http.ResponseWriter, r *http.Request) { + w.Header().Set("WWW-Authenticate", `sample_invalid_auth_uri`) + w.WriteHeader(http.StatusUnauthorized) + w.Write([]byte("Unauthorized")) + }) + + _, err := GetTenantID(server.URL, "foobarsubscription") + + expectedMsg := "Could not find the tenant ID in header: WWW-Authenticate \"sample_invalid_auth_uri\"" + + if err == nil || err.Error() != expectedMsg { + t.Errorf("expected error with msg : %s to be thrown", expectedMsg) + } +} diff --git a/pkg/armhelpers/azureclient.go b/pkg/armhelpers/azureclient.go index 275ea1b222..d4d88cf040 100644 --- a/pkg/armhelpers/azureclient.go +++ b/pkg/armhelpers/azureclient.go @@ -239,7 +239,7 @@ func tryLoadCachedToken(cachePath string) (*adal.Token, error) { } func getOAuthConfig(env azure.Environment, subscriptionID string) (*adal.OAuthConfig, string, error) { - tenantID, err := acsengine.GetTenantID(env, subscriptionID) + tenantID, err := acsengine.GetTenantID(env.ResourceManagerEndpoint, subscriptionID) if err != nil { return nil, "", err }