Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add ACR challenge-based authentication and remove basic auth policy #19696

Merged
merged 10 commits into from
Mar 24, 2021
Original file line number Diff line number Diff line change
Expand Up @@ -44,23 +44,21 @@ public ContentProperties() { }
public partial class RepositoryProperties
{
internal RepositoryProperties() { }
public System.DateTimeOffset? CreatedOn { get { throw null; } }
public System.DateTimeOffset CreatedOn { get { throw null; } }
public System.DateTimeOffset? LastUpdatedOn { get { throw null; } }
public string Name { get { throw null; } }
public string Registry { get { throw null; } }
public int? RegistryArtifactCount { get { throw null; } }
public int? TagCount { get { throw null; } }
public int RegistryArtifactCount { get { throw null; } }
public int TagCount { get { throw null; } }
public Azure.Containers.ContainerRegistry.ContentProperties WriteableProperties { get { throw null; } }
}
public partial class TagProperties
{
internal TagProperties() { }
public System.DateTimeOffset? CreatedOn { get { throw null; } }
public System.DateTimeOffset CreatedOn { get { throw null; } }
public string Digest { get { throw null; } }
public System.DateTimeOffset? LastUpdatedOn { get { throw null; } }
public Azure.Containers.ContainerRegistry.ContentProperties ModifiableProperties { get { throw null; } }
public System.DateTimeOffset LastUpdatedOn { get { throw null; } }
public string Name { get { throw null; } }
public string Registry { get { throw null; } }
public string Repository { get { throw null; } }
public Azure.Containers.ContainerRegistry.ContentProperties WriteableProperties { get { throw null; } }
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,9 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.

namespace Azure.Containers.ContainerRegistry
{
internal partial class AuthenticationRestClient : IContainerRegistryAuthenticationClient
{
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,6 @@
// Licensed under the MIT License.

using System;
using System.Collections.Generic;
using System.Threading.Tasks;
using Azure.Core;
using Azure.Core.Pipeline;
Expand Down Expand Up @@ -31,19 +30,17 @@ namespace Azure.Containers.ContainerRegistry
/// Step 5: GET /api/v1/acr/repositories
/// Request Header: { Bearer acrTokenAccess }
/// </summary>
internal class ContainerRegistryCredentialsPolicy : BearerTokenChallengeAuthenticationPolicy
internal class ContainerRegistryChallengeAuthenticationPolicy : BearerTokenChallengeAuthenticationPolicy
{
private readonly RefreshTokensRestClient _exchangeRestClient;
private readonly AccessTokensRestClient _tokenRestClient;
private readonly IContainerRegistryAuthenticationClient _authenticationClient;

public ContainerRegistryCredentialsPolicy(TokenCredential credential, string aadScope, RefreshTokensRestClient exchangeRestClient, AccessTokensRestClient tokenRestClient)
public ContainerRegistryChallengeAuthenticationPolicy(TokenCredential credential, string aadScope, IContainerRegistryAuthenticationClient authenticationClient)
: base(credential, aadScope)
{
Argument.AssertNotNull(credential, nameof(credential));
Argument.AssertNotNull(aadScope, nameof(aadScope));

_exchangeRestClient = exchangeRestClient;
_tokenRestClient = tokenRestClient;
_authenticationClient = authenticationClient;
}

protected override async ValueTask<bool> AuthenticateRequestOnChallengeAsync(HttpMessage message, bool async)
Expand All @@ -61,15 +58,15 @@ protected override async ValueTask<bool> AuthenticateRequestOnChallengeAsync(Htt
string acrRefreshToken = await ExchangeAadAccessTokenForAcrRefreshTokenAsync(message, service, true).ConfigureAwait(false);

// Step 4: Send in acrRefreshToken and get back acrAccessToken
acrAccessToken = await ExchangeAcrRefreshTokenForAcrAccessTokenAsync(acrRefreshToken, service, scope, true).ConfigureAwait(false);
acrAccessToken = await ExchangeAcrRefreshTokenForAcrAccessTokenAsync(message, acrRefreshToken, service, scope, true).ConfigureAwait(false);
}
else
{
// Step 3: Exchange AAD Access Token for ACR Refresh Token
string acrRefreshToken = ExchangeAadAccessTokenForAcrRefreshTokenAsync(message, service, false).EnsureCompleted();

// Step 4: Send in acrRefreshToken and get back acrAccessToken
acrAccessToken = ExchangeAcrRefreshTokenForAcrAccessTokenAsync(acrRefreshToken, service, scope, false).EnsureCompleted();
acrAccessToken = ExchangeAcrRefreshTokenForAcrAccessTokenAsync(message, acrRefreshToken, service, scope, false).EnsureCompleted();
}

// Step 5 - Authorize Request. Note, we don't use SetAuthorizationHeader here, because it
Expand All @@ -81,43 +78,37 @@ protected override async ValueTask<bool> AuthenticateRequestOnChallengeAsync(Htt

private async Task<string> ExchangeAadAccessTokenForAcrRefreshTokenAsync(HttpMessage message, string service, bool async)
{
string aadAccessToken = GetAuthorizationHeader(message);
string aadAccessToken = GetAuthorizationToken(message);

Response<RefreshToken> acrRefreshToken = null;
Response<AcrRefreshToken> acrRefreshToken = null;
if (async)
{
acrRefreshToken = await _exchangeRestClient.GetFromExchangeAsync(
PostContentSchemaGrantType.AccessToken,
service,
accessToken: aadAccessToken).ConfigureAwait(false);
acrRefreshToken = await _authenticationClient.ExchangeAadAccessTokenForAcrRefreshTokenAsync(service, aadAccessToken, message.CancellationToken).ConfigureAwait(false);
}
else
{
acrRefreshToken = _exchangeRestClient.GetFromExchange(
PostContentSchemaGrantType.AccessToken,
service,
accessToken: aadAccessToken);
acrRefreshToken = _authenticationClient.ExchangeAadAccessTokenForAcrRefreshToken(service, aadAccessToken, message.CancellationToken);
}

return acrRefreshToken.Value.RefreshTokenValue;
return acrRefreshToken.Value.RefreshToken;
}

private async Task<string> ExchangeAcrRefreshTokenForAcrAccessTokenAsync(string acrRefreshToken, string service, string scope, bool async)
private async Task<string> ExchangeAcrRefreshTokenForAcrAccessTokenAsync(HttpMessage message, string acrRefreshToken, string service, string scope, bool async)
annelo-msft marked this conversation as resolved.
Show resolved Hide resolved
{
Response<AccessToken> acrAccessToken = null;
Response<AcrAccessToken> acrAccessToken = null;
if (async)
{
acrAccessToken = await _tokenRestClient.GetAsync(service, scope, acrRefreshToken).ConfigureAwait(false);
acrAccessToken = await _authenticationClient.ExchangeAcrRefreshTokenForAcrAccessTokenAsync(service, scope, acrRefreshToken, message.CancellationToken).ConfigureAwait(false);
}
else
{
acrAccessToken = _tokenRestClient.Get(service, scope, acrRefreshToken);
acrAccessToken = _authenticationClient.ExchangeAcrRefreshTokenForAcrAccessToken(service, scope, acrRefreshToken, message.CancellationToken);
}

return acrAccessToken.Value.AccessTokenValue;
return acrAccessToken.Value.AccessToken;
}

private static string GetAuthorizationHeader(HttpMessage message)
private static string GetAuthorizationToken(HttpMessage message)
{
string aadAuthHeader;
if (!message.Request.Headers.TryGetValue(HttpHeader.Names.Authorization, out aadAuthHeader))
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,17 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.

using System.Threading;
using System.Threading.Tasks;

namespace Azure.Containers.ContainerRegistry
{
internal interface IContainerRegistryAuthenticationClient
{
Task<Response<AcrRefreshToken>> ExchangeAadAccessTokenForAcrRefreshTokenAsync(string service, string aadAccessToken, CancellationToken token = default);
Response<AcrRefreshToken> ExchangeAadAccessTokenForAcrRefreshToken(string service, string aadAccessToken, CancellationToken token = default);

Task<Response<AcrAccessToken>> ExchangeAcrRefreshTokenForAcrAccessTokenAsync(string service, string scope, string acrRefreshToken, CancellationToken token = default);
Response<AcrAccessToken> ExchangeAcrRefreshTokenForAcrAccessToken(string service, string scope, string acrRefreshToken, CancellationToken token = default);
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -20,8 +20,7 @@ public partial class ContainerRegistryClient
private readonly ClientDiagnostics _clientDiagnostics;
private readonly ContainerRegistryRestClient _restClient;

private readonly RefreshTokensRestClient _tokenExchangeClient;
private readonly AccessTokensRestClient _acrTokenClient;
private readonly AuthenticationRestClient _acrAuthClient;
private readonly string AcrAadScope = "https://management.core.windows.net/.default";

/// <summary>
Expand All @@ -42,10 +41,9 @@ public ContainerRegistryClient(Uri endpoint, TokenCredential credential, Contain
_clientDiagnostics = new ClientDiagnostics(options);

_acrAuthPipeline = HttpPipelineBuilder.Build(options);
_tokenExchangeClient = new RefreshTokensRestClient(_clientDiagnostics, _acrAuthPipeline, endpoint.AbsoluteUri);
_acrTokenClient = new AccessTokensRestClient(_clientDiagnostics, _acrAuthPipeline, endpoint.AbsoluteUri);
_acrAuthClient = new AuthenticationRestClient(_clientDiagnostics, _acrAuthPipeline, endpoint.AbsoluteUri);

_pipeline = HttpPipelineBuilder.Build(options, new ContainerRegistryCredentialsPolicy(credential, AcrAadScope, _tokenExchangeClient, _acrTokenClient));
_pipeline = HttpPipelineBuilder.Build(options, new ContainerRegistryChallengeAuthenticationPolicy(credential, AcrAadScope, _acrAuthClient));
_restClient = new ContainerRegistryRestClient(_clientDiagnostics, _pipeline, _endpoint.AbsoluteUri);
}

Expand Down
Loading