Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add ACR challenge-based authentication and remove basic auth policy #19696

Merged
merged 10 commits into from
Mar 24, 2021
Merged
Original file line number Diff line number Diff line change
Expand Up @@ -3,8 +3,8 @@ namespace Azure.Containers.ContainerRegistry
public partial class ContainerRegistryClient
{
protected ContainerRegistryClient() { }
public ContainerRegistryClient(System.Uri endpoint, string username, string password) { }
public ContainerRegistryClient(System.Uri endpoint, string username, string password, Azure.Containers.ContainerRegistry.ContainerRegistryClientOptions options) { }
public ContainerRegistryClient(System.Uri endpoint, Azure.Core.TokenCredential credential) { }
public ContainerRegistryClient(System.Uri endpoint, Azure.Core.TokenCredential credential, Azure.Containers.ContainerRegistry.ContainerRegistryClientOptions options) { }
public virtual Azure.Pageable<string> GetRepositories(System.Threading.CancellationToken cancellationToken = default(System.Threading.CancellationToken)) { throw null; }
public virtual Azure.AsyncPageable<string> GetRepositoriesAsync(System.Threading.CancellationToken cancellationToken = default(System.Threading.CancellationToken)) { throw null; }
}
Expand All @@ -19,8 +19,8 @@ public enum ServiceVersion
public partial class ContainerRepositoryClient
{
protected ContainerRepositoryClient() { }
public ContainerRepositoryClient(System.Uri endpoint, string repository, string username, string password) { }
public ContainerRepositoryClient(System.Uri endpoint, string repository, string username, string password, Azure.Containers.ContainerRegistry.ContainerRegistryClientOptions options) { }
public ContainerRepositoryClient(System.Uri endpoint, string repository, Azure.Core.TokenCredential credential) { }
public ContainerRepositoryClient(System.Uri endpoint, string repository, Azure.Core.TokenCredential credential, Azure.Containers.ContainerRegistry.ContainerRegistryClientOptions options) { }
public virtual System.Uri Endpoint { get { throw null; } }
public virtual Azure.Response DeleteTag(string tag, System.Threading.CancellationToken cancellationToken = default(System.Threading.CancellationToken)) { throw null; }
public virtual System.Threading.Tasks.Task<Azure.Response> DeleteTagAsync(string tag, System.Threading.CancellationToken cancellationToken = default(System.Threading.CancellationToken)) { throw null; }
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,131 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.

using System;
using System.Collections.Generic;
using System.Threading.Tasks;
using Azure.Core;
using Azure.Core.Pipeline;

namespace Azure.Containers.ContainerRegistry
{
/// <summary>
/// Challenge-based authentication policy for Container Registry Service.
///
/// The challenge-based authorization flow for ACR is illustrated in the following steps.
/// For example, GET /api/v1/acr/repositories translates into the following calls.
///
/// Step 1: GET /api/v1/acr/repositories
/// Return Header: 401: www-authenticate header - Bearer realm="{url}",service="{serviceName}",scope="{scope}",error="invalid_token"
///
/// Step 2: Retrieve the serviceName, scope from the WWW-Authenticate header. (Parse the string.)
///
/// Step 3: POST /api/oauth2/exchange
/// Request Body : { service, scope, grant-type, aadToken with ARM scope }
/// Response Body: { acrRefreshToken }
///
/// Step 4: POST /api/oauth2/token
/// Request Body: { acrRefreshToken, scope, grant-type }
/// Response Body: { acrAccessToken }
///
/// Step 5: GET /api/v1/acr/repositories
/// Request Header: { Bearer acrTokenAccess }
/// </summary>
internal class ContainerRegistryCredentialsPolicy : BearerTokenChallengeAuthenticationPolicy
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do we have tests for this policy?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes! see: sdk/containerregistry/Azure.Containers.ContainerRegistry/tests/Authentication/ContainerRegistryChallengeAuthenticationPolicyTest.cs

{
private readonly RefreshTokensRestClient _exchangeRestClient;
private readonly AccessTokensRestClient _tokenRestClient;

public ContainerRegistryCredentialsPolicy(TokenCredential credential, string aadScope, RefreshTokensRestClient exchangeRestClient, AccessTokensRestClient tokenRestClient)
: base(credential, aadScope)
{
Argument.AssertNotNull(credential, nameof(credential));
Argument.AssertNotNull(aadScope, nameof(aadScope));

_exchangeRestClient = exchangeRestClient;
_tokenRestClient = tokenRestClient;
}

protected override async ValueTask<bool> AuthenticateRequestOnChallengeAsync(HttpMessage message, bool async)
{
// Once we're here, we've completed Step 1.

// Step 2: Parse challenge string to retrieve serviceName and scope, where scope is the ACR Scope
var service = AuthorizationChallengeParser.GetChallengeParameterFromResponse(message.Response, "Bearer", "service");
var scope = AuthorizationChallengeParser.GetChallengeParameterFromResponse(message.Response, "Bearer", "scope");

string acrAccessToken = string.Empty;
if (async)
{
// Step 3: Exchange AAD Access Token for ACR Refresh Token
string acrRefreshToken = await ExchangeAadAccessTokenForAcrRefreshTokenAsync(message, service, true).ConfigureAwait(false);

// Step 4: Send in acrRefreshToken and get back acrAccessToken
acrAccessToken = await ExchangeAcrRefreshTokenForAcrAccessTokenAsync(acrRefreshToken, service, scope, true).ConfigureAwait(false);
}
else
{
// Step 3: Exchange AAD Access Token for ACR Refresh Token
string acrRefreshToken = ExchangeAadAccessTokenForAcrRefreshTokenAsync(message, service, false).EnsureCompleted();

// Step 4: Send in acrRefreshToken and get back acrAccessToken
acrAccessToken = ExchangeAcrRefreshTokenForAcrAccessTokenAsync(acrRefreshToken, service, scope, false).EnsureCompleted();
}

// Step 5 - Authorize Request. Note, we don't use SetAuthorizationHeader here, because it
// sets an AAD access token header, and at this point we're done with AAD and using an ACR access token.
message.Request.Headers.SetValue(HttpHeader.Names.Authorization, $"Bearer {acrAccessToken}");

return true;
annelo-msft marked this conversation as resolved.
Show resolved Hide resolved
}

private async Task<string> ExchangeAadAccessTokenForAcrRefreshTokenAsync(HttpMessage message, string service, bool async)
{
string aadAccessToken = GetAuthorizationHeader(message);

Response<RefreshToken> acrRefreshToken = null;
if (async)
{
acrRefreshToken = await _exchangeRestClient.GetFromExchangeAsync(
PostContentSchemaGrantType.AccessToken,
service,
accessToken: aadAccessToken).ConfigureAwait(false);
}
else
{
acrRefreshToken = _exchangeRestClient.GetFromExchange(
PostContentSchemaGrantType.AccessToken,
service,
accessToken: aadAccessToken);
}

return acrRefreshToken.Value.RefreshTokenValue;
}

private async Task<string> ExchangeAcrRefreshTokenForAcrAccessTokenAsync(string acrRefreshToken, string service, string scope, bool async)
{
Response<AccessToken> acrAccessToken = null;
if (async)
{
acrAccessToken = await _tokenRestClient.GetAsync(service, scope, acrRefreshToken).ConfigureAwait(false);
}
else
{
acrAccessToken = _tokenRestClient.Get(service, scope, acrRefreshToken);
}

return acrAccessToken.Value.AccessTokenValue;
}

private static string GetAuthorizationHeader(HttpMessage message)
{
string aadAuthHeader;
if (!message.Request.Headers.TryGetValue(HttpHeader.Names.Authorization, out aadAuthHeader))
{
throw new InvalidOperationException("Failed to retrieve Authentication header from message request.");
}

return aadAuthHeader.Remove(0, "Bearer ".Length);
}
}
}
Original file line number Diff line number Diff line change
@@ -1,12 +1,12 @@
<Project Sdk="Microsoft.NET.Sdk">
<Project Sdk="Microsoft.NET.Sdk">
<PropertyGroup>
<Description></Description>
<AssemblyTitle>Microsoft Azure.Containers.ContainerRegistry client library</AssemblyTitle>
<Version>1.0.0-beta.1</Version>
<PackageTags>Azure Container Registry;$(PackageCommonTags)</PackageTags>
<TargetFrameworks>$(RequiredTargetFrameworks)</TargetFrameworks>
</PropertyGroup>

<ItemGroup>
<PackageReference Include="System.Text.Json" />
</ItemGroup>
Expand All @@ -15,7 +15,9 @@
<ItemGroup>
<Compile Include="$(AzureCoreSharedSources)Argument.cs" LinkBase="Shared" />
<Compile Include="$(AzureCoreSharedSources)ArrayBufferWriter.cs" LinkBase="Shared" />
<Compile Include="$(AzureCoreSharedSources)AuthorizationChallengeParser.cs" LinkBase="Shared" />
<Compile Include="$(AzureCoreSharedSources)AzureResourceProviderNamespaceAttribute.cs" LinkBase="Shared" />
<Compile Include="$(AzureCoreSharedSources)BearerTokenChallengeAuthenticationPolicy.cs" LinkBase="Shared" />
<Compile Include="$(AzureCoreSharedSources)ClientDiagnostics.cs" LinkBase="Shared" />
<Compile Include="$(AzureCoreSharedSources)ContentTypeUtilities.cs" LinkBase="Shared" />
<Compile Include="$(AzureCoreSharedSources)DiagnosticScope.cs" LinkBase="Shared" />
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,33 +16,36 @@ public partial class ContainerRegistryClient
{
private readonly Uri _endpoint;
private readonly HttpPipeline _pipeline;
private readonly HttpPipeline _acrAuthPipeline;
private readonly ClientDiagnostics _clientDiagnostics;
private readonly ContainerRegistryRestClient _restClient;

private readonly RefreshTokensRestClient _tokenExchangeClient;
private readonly AccessTokensRestClient _acrTokenClient;
private readonly string AcrAadScope = "https://management.core.windows.net/.default";

/// <summary>
/// <paramref name="endpoint"/>
/// </summary>
public ContainerRegistryClient(Uri endpoint, string username, string password) : this(endpoint, username, password, new ContainerRegistryClientOptions())
public ContainerRegistryClient(Uri endpoint, TokenCredential credential) : this(endpoint, credential, new ContainerRegistryClientOptions())
{
}

/// <summary>
/// </summary>
/// <param name="endpoint"></param>
/// <param name="username"></param>
/// <param name="password"></param>
/// <param name="options"></param>
public ContainerRegistryClient(Uri endpoint, string username, string password, ContainerRegistryClientOptions options)
public ContainerRegistryClient(Uri endpoint, TokenCredential credential, ContainerRegistryClientOptions options)
{
Argument.AssertNotNull(endpoint, nameof(endpoint));
Argument.AssertNotNull(credential, nameof(credential));
Argument.AssertNotNull(options, nameof(options));

_pipeline = HttpPipelineBuilder.Build(options, new BasicAuthenticationPolicy(username, password));

_endpoint = endpoint;
_clientDiagnostics = new ClientDiagnostics(options);

_endpoint = endpoint;
_acrAuthPipeline = HttpPipelineBuilder.Build(options);
_tokenExchangeClient = new RefreshTokensRestClient(_clientDiagnostics, _acrAuthPipeline, endpoint.AbsoluteUri);
_acrTokenClient = new AccessTokensRestClient(_clientDiagnostics, _acrAuthPipeline, endpoint.AbsoluteUri);

_pipeline = HttpPipelineBuilder.Build(options, new ContainerRegistryCredentialsPolicy(credential, AcrAadScope, _tokenExchangeClient, _acrTokenClient));
_restClient = new ContainerRegistryRestClient(_clientDiagnostics, _pipeline, _endpoint.AbsoluteUri);
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -13,47 +13,45 @@ namespace Azure.Containers.ContainerRegistry
public partial class ContainerRepositoryClient
{
private readonly HttpPipeline _pipeline;
private readonly HttpPipeline _acrAuthPipeline;
private readonly ClientDiagnostics _clientDiagnostics;
private readonly ContainerRegistryRepositoryRestClient _restClient;

private readonly RefreshTokensRestClient _tokenExchangeClient;
private readonly AccessTokensRestClient _acrTokenClient;
private readonly string AcrAadScope = "https://management.core.windows.net/.default";

private readonly string _repository;

/// <summary>
/// </summary>
public virtual Uri Endpoint { get; }

/// <summary>
/// <param name="endpoint"></param>
/// <param name="repository"> Name of the image (including the namespace). </param>
/// <param name="username"></param>
/// <param name="password"></param>
/// </summary>
public ContainerRepositoryClient(Uri endpoint, string repository, string username, string password) : this(endpoint, repository, username, password, new ContainerRegistryClientOptions())
public ContainerRepositoryClient(Uri endpoint, string repository, TokenCredential credential) : this(endpoint, repository, credential, new ContainerRegistryClientOptions())
{
}

/// <summary>
/// <param name="endpoint"></param>
/// <param name="repository"> Name of the image (including the namespace). </param>
/// <param name="username"></param>
/// <param name="password"></param>
/// <param name="options"></param>
/// </summary>
public ContainerRepositoryClient(Uri endpoint, string repository, string username, string password, ContainerRegistryClientOptions options)
public ContainerRepositoryClient(Uri endpoint, string repository, TokenCredential credential, ContainerRegistryClientOptions options)
{
Argument.AssertNotNull(endpoint, nameof(endpoint));
Argument.AssertNotNull(repository, nameof(repository));
Argument.AssertNotNull(username, nameof(username));
Argument.AssertNotNull(password, nameof(password));
Argument.AssertNotNull(credential, nameof(credential));
Argument.AssertNotNull(options, nameof(options));

_pipeline = HttpPipelineBuilder.Build(options, new BasicAuthenticationPolicy(username, password));
Endpoint = endpoint;
_repository = repository;

_clientDiagnostics = new ClientDiagnostics(options);

Endpoint = endpoint;
_repository = repository;
_acrAuthPipeline = HttpPipelineBuilder.Build(options);
_tokenExchangeClient = new RefreshTokensRestClient(_clientDiagnostics, _acrAuthPipeline, endpoint.AbsoluteUri);
_acrTokenClient = new AccessTokensRestClient(_clientDiagnostics, _acrAuthPipeline, endpoint.AbsoluteUri);

_pipeline = HttpPipelineBuilder.Build(options, new ContainerRegistryCredentialsPolicy(credential, AcrAadScope, _tokenExchangeClient, _acrTokenClient));
_restClient = new ContainerRegistryRepositoryRestClient(_clientDiagnostics, _pipeline, Endpoint.AbsoluteUri);
}

Expand Down

This file was deleted.

Original file line number Diff line number Diff line change
Expand Up @@ -16,4 +16,10 @@
<ProjectReference Include="$(AzureCoreTestFramework)" />
<ProjectReference Include="..\src\Azure.Containers.ContainerRegistry.csproj" />
</ItemGroup>

<ItemGroup>
<Reference Include="System.Web">
<HintPath>..\..\..\..\..\..\Program Files (x86)\Reference Assemblies\Microsoft\Framework\.NETFramework\v4.7\System.Web.dll</HintPath>
</Reference>
</ItemGroup>
</Project>
Original file line number Diff line number Diff line change
Expand Up @@ -12,14 +12,14 @@ public class ContainerRegistryClientLiveTests : RecordedTestBase<ContainerRegist
{
public ContainerRegistryClientLiveTests(bool isAsync) : base(isAsync)
{
Sanitizer = new ContainerRegistryRecordedTestSanitizer();
}

private ContainerRegistryClient CreateClient()
{
return InstrumentClient(new ContainerRegistryClient(
new Uri(TestEnvironment.Endpoint),
TestEnvironment.UserName,
TestEnvironment.Password,
TestEnvironment.Credential,
InstrumentClientOptions(new ContainerRegistryClientOptions())
));
}
Expand Down
Loading