Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add authentication support for ACR #19926

Merged
merged 1 commit into from
Mar 22, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -279,6 +279,7 @@ the main ServiceBusClientBuilder. -->
<suppress checks="[a-zA-Z0-9]*" files="[/\\](generated-test-sources|generatedtestsources)[/\\]"/>

<!-- Allows the HttpPipelinePolicy derived class in Implementation folder -->
<suppress checks="com.azure.tools.checkstyle.checks.HttpPipelinePolicy" files="com.azure.containers.containerregistry.implementation.authentication.ContainerRegistryCredentialsPolicy.java"/>
<suppress checks="com.azure.tools.checkstyle.checks.HttpPipelinePolicy" files="com.azure.messaging.servicebus.implementation.ServiceBusTokenCredentialHttpPolicy.java"/>
<suppress checks="com.azure.tools.checkstyle.checks.HttpPipelinePolicy" files="com.azure.messaging.eventgrid.implementation.CloudEventTracingPipelinePolicy.java"/>
<suppress checks="com.azure.tools.checkstyle.checks.HttpPipelinePolicy" files="com.azure.storage.common.implementation.policy.SasTokenCredentialPolicy.java"/>
Expand All @@ -287,6 +288,7 @@ the main ServiceBusClientBuilder. -->
<suppress checks="com.azure.tools.checkstyle.checks.HttpPipelinePolicy" files="com.azure.security.keyvault.keys.implementation.KeyVaultCredentialPolicy.java"/>
<suppress checks="com.azure.tools.checkstyle.checks.HttpPipelinePolicy" files="com.azure.security.keyvault.secrets.implementation.KeyVaultCredentialPolicy.java"/>
<suppress checks="com.azure.tools.checkstyle.checks.HttpPipelinePolicy" files="com.azure.storage.blob.implementation.util.BlobUserAgentModificationPolicy.java"/>


<!-- Fields tenantId, clientId and clientSecret are not set in all constructors. -->
<suppress checks="com.azure.tools.checkstyle.checks.EnforceFinalFieldsCheck" files="com.azure.security.keyvault.jca.KeyVaultClient"/>
Expand Down
20 changes: 19 additions & 1 deletion sdk/containerregistry/azure-containers-containerregistry/pom.xml
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@
<jacoco.min.branchcoverage>0.10</jacoco.min.branchcoverage>

<!-- If skipping test coverage is absolutely necessary, for example resources cannot be provisioned uncomment this -->
<!--<jacoco.skip.coverage.check>false</jacoco.skip.coverage.check>-->
<jacoco.skip.coverage.check>false</jacoco.skip.coverage.check>
</properties>

<dependencies>
Expand All @@ -61,5 +61,23 @@
<version>5.7.1</version> <!-- {x-version-update;org.junit.jupiter:junit-jupiter-params;external_dependency} -->
<scope>test</scope>
</dependency>
<dependency>
<groupId>org.mockito</groupId>
<artifactId>mockito-core</artifactId>
<version>3.6.28</version> <!-- {x-version-update;org.mockito:mockito-core;external_dependency} -->
<scope>test</scope>
</dependency>
<dependency>
<groupId>io.projectreactor</groupId>
<artifactId>reactor-test</artifactId>
<version>3.4.3</version> <!-- {x-version-update;io.projectreactor:reactor-test;external_dependency} -->
<scope>test</scope>
</dependency>
<dependency>
<groupId>com.azure</groupId>
<artifactId>azure-core-test</artifactId>
<version>1.6.0</version> <!-- {x-version-update;com.azure:azure-core-test;dependency} -->
<scope>test</scope>
</dependency>
</dependencies>
</project>
Original file line number Diff line number Diff line change
@@ -0,0 +1,167 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.

package com.azure.containers.containerregistry.implementation.authentication;

import com.azure.core.credential.AccessToken;
import com.azure.core.util.logging.ClientLogger;
import reactor.core.publisher.Flux;
import reactor.core.publisher.Mono;
import reactor.core.publisher.Signal;
import reactor.core.publisher.Sinks;

import java.time.Duration;
import java.time.OffsetDateTime;
import java.util.Objects;
import java.util.concurrent.atomic.AtomicReference;
import java.util.function.Function;
import java.util.function.Predicate;
import java.util.function.Supplier;

/**
* A token cache that supports caching a token and refreshing it.
*/
public class AccessTokenCacheImpl {
// The delay after a refresh to attempt another token refresh
private static final Duration REFRESH_DELAY = Duration.ofSeconds(30);
// the offset before token expiry to attempt proactive token refresh
private static final Duration REFRESH_OFFSET = Duration.ofMinutes(5);
private volatile AccessToken cache;
private volatile OffsetDateTime nextTokenRefresh = OffsetDateTime.now();
private final AtomicReference<Sinks.One<AccessToken>> wip;
private final ContainerRegistryTokenCredential tokenCredential;
private ContainerRegistryTokenRequestContext tokenRequestContext;
private final Predicate<AccessToken> shouldRefresh;
private final ClientLogger logger = new ClientLogger(AccessTokenCacheImpl.class);

/**
* Creates an instance of AccessTokenCacheImpl with default scheme "Bearer".
*
* @param tokenCredential the credential to be used to acquire token from.
*/
public AccessTokenCacheImpl(ContainerRegistryTokenCredential tokenCredential) {
Objects.requireNonNull(tokenCredential, "The token credential cannot be null");
this.wip = new AtomicReference<>();
this.tokenCredential = tokenCredential;
this.shouldRefresh = accessToken -> OffsetDateTime.now()
.isAfter(accessToken.getExpiresAt().minus(REFRESH_OFFSET));
}

/**
* Asynchronously get a token from either the cache or replenish the cache with a new token.
*
* @param tokenRequestContext The request context for token acquisition.
* @return The Publisher that emits an AccessToken
*/
public Mono<AccessToken> getToken(ContainerRegistryTokenRequestContext tokenRequestContext) {
return Mono.defer(retrieveToken(tokenRequestContext))
// Keep resubscribing as long as Mono.defer [token acquisition] emits empty().
.repeatWhenEmpty((Flux<Long> longFlux) -> longFlux.concatMap(ignored -> Flux.just(true)));
}

private Supplier<Mono<? extends AccessToken>> retrieveToken(ContainerRegistryTokenRequestContext tokenRequestContext) {
return () -> {
try {
Copy link
Member

@srnagar srnagar Mar 19, 2021

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This whole supplier should be wrapped with Mono.defer() to evaluate the need to refresh at subscription-time instead of at assembly-time. #Resolved

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The caller actually does that - retrieveToken call just above this method. Since this is a private method chose not to do it.


In reply to: 597950212 [](ancestors = 597950212)

if (wip.compareAndSet(null, Sinks.one())) {
final Sinks.One<AccessToken> sinksOne = wip.get();
OffsetDateTime now = OffsetDateTime.now();
Mono<AccessToken> tokenRefresh;
Mono<AccessToken> fallback;

Supplier<Mono<AccessToken>> tokenSupplier = () ->
tokenCredential.getToken(this.tokenRequestContext);

boolean forceRefresh = checkIfWeShouldForceRefresh(tokenRequestContext);

if (forceRefresh) {
this.tokenRequestContext = tokenRequestContext;
tokenRefresh = Mono.defer(() -> tokenCredential.getToken(this.tokenRequestContext));
fallback = Mono.empty();
} else if (cache != null && !shouldRefresh.test(cache)) {
// fresh cache & no need to refresh
tokenRefresh = Mono.empty();
fallback = Mono.just(cache);
} else if (cache == null || cache.isExpired()) {
// no token to use
if (now.isAfter(nextTokenRefresh)) {
// refresh immediately
tokenRefresh = Mono.defer(tokenSupplier);
} else {
// wait for timeout, then refresh
tokenRefresh = Mono.defer(tokenSupplier)
.delaySubscription(Duration.between(now, nextTokenRefresh));
}
// cache doesn't exist or expired, no fallback
fallback = Mono.empty();
} else {
// token available, but close to expiry
if (now.isAfter(nextTokenRefresh)) {
// refresh immediately
tokenRefresh = Mono.defer(tokenSupplier);
} else {
// still in timeout, do not refresh
tokenRefresh = Mono.empty();
}
// cache hasn't expired, ignore refresh error this time
fallback = Mono.just(cache);
}
return tokenRefresh
.materialize()
.flatMap(processTokenRefreshResult(sinksOne, now, fallback))
.doOnError(sinksOne::tryEmitError)
.doFinally(ignored -> wip.set(null));
} else {
return Mono.empty();
}
} catch (Throwable t) {
return Mono.error(t);
}
};
}

private boolean checkIfWeShouldForceRefresh(ContainerRegistryTokenRequestContext tokenRequestContext) {
return !(this.tokenRequestContext != null
&& (this.tokenRequestContext.getScope() == null ? tokenRequestContext.getScope() == null
: this.tokenRequestContext.getScope().equals(tokenRequestContext.getScope()))
&& (this.tokenRequestContext.getServiceName() == null ? tokenRequestContext.getServiceName() == null
: this.tokenRequestContext.getServiceName().equals(tokenRequestContext.getServiceName())));
}

private Function<Signal<AccessToken>, Mono<? extends AccessToken>> processTokenRefreshResult(
Sinks.One<AccessToken> sinksOne, OffsetDateTime now, Mono<AccessToken> fallback) {
return signal -> {
AccessToken accessToken = signal.get();
Throwable error = signal.getThrowable();
if (signal.isOnNext() && accessToken != null) { // SUCCESS
logger.info(refreshLog(cache, now, "Acquired a new access token"));
cache = accessToken;
sinksOne.tryEmitValue(accessToken);
nextTokenRefresh = OffsetDateTime.now().plus(REFRESH_DELAY);
return Mono.just(accessToken);
} else if (signal.isOnError() && error != null) { // ERROR
logger.error(refreshLog(cache, now, "Failed to acquire a new access token"));
nextTokenRefresh = OffsetDateTime.now().plus(REFRESH_DELAY);
return fallback.switchIfEmpty(Mono.error(error));
} else { // NO REFRESH
sinksOne.tryEmitEmpty();
return fallback;
}
};
}

private static String refreshLog(AccessToken cache, OffsetDateTime now, String log) {
StringBuilder info = new StringBuilder(log);
if (cache == null) {
info.append(".");
} else {
Duration tte = Duration.between(now, cache.getExpiresAt());
info.append(" at ").append(tte.abs().getSeconds()).append(" seconds ")
.append(tte.isNegative() ? "after" : "before").append(" expiry. ")
.append("Retry may be attempted after ").append(REFRESH_DELAY.getSeconds()).append(" seconds.");
if (!tte.isNegative()) {
info.append(" The token currently cached will be used.");
}
}
return info.toString();
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,122 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.

package com.azure.containers.containerregistry.implementation.authentication;

import com.azure.containers.containerregistry.implementation.models.AcrAccessToken;
import com.azure.containers.containerregistry.implementation.models.AcrErrorsException;
import com.azure.core.annotation.ExpectedResponses;
import com.azure.core.annotation.Host;
import com.azure.core.annotation.Post;
import com.azure.core.annotation.ServiceInterface;
import com.azure.core.annotation.FormParam;
import com.azure.core.annotation.HeaderParam;
import com.azure.core.annotation.HostParam;
import com.azure.core.annotation.UnexpectedResponseExceptionType;
import com.azure.core.annotation.ReturnType;
import com.azure.core.annotation.ServiceMethod;
import com.azure.core.http.HttpPipeline;
import com.azure.core.http.rest.Response;
import com.azure.core.http.rest.RestProxy;
import com.azure.core.util.Context;
import com.azure.core.util.FluxUtil;
import com.azure.core.util.serializer.SerializerAdapter;
import reactor.core.publisher.Mono;

/** An instance of this class provides access to all the operations defined in AccessTokensService. */
public final class AccessTokensImpl {
Copy link
Contributor Author

@pallavit pallavit Mar 17, 2021

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

public [](start = 0, length = 6)

Why do we generate these classes as public? #Resolved

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I had to hand author the generated files due to the discrepancy between www-encoded form not correctly supported. Also, I could not use the builder pattern since this class is a hierarchial pattern and has only one builder. And it needed the policy that had to be added to the builder to consume this class. Unfortunately ACR's STS goes through the same rest endpoints. Let me know if there is a better way to do it.


In reply to: 596415910 [](ancestors = 596415910)

Copy link
Member

@alzimmermsft alzimmermsft Mar 18, 2021

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

These classes are generated as public since Java classes cannot be accessible across package boundaries without either being protected or public. #Resolved

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I missed that they were in a different package.


In reply to: 597119525 [](ancestors = 597119525)

Copy link
Member

@srnagar srnagar Mar 19, 2021

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@pallavit could you please file a bug in autorest.java repo with steps to reproduce and the expected code that should be generated? Ideally, we should not modify any generated code manually as it would be overwritten when we regenerate this code. #Resolved

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

There already is a bug and @jianghaolu is aware of it.


In reply to: 597954987 [](ancestors = 597954987)


/** The proxy service used to perform REST calls. */
private final AccessTokensServiceImpl service;

/** Registry login URL. */
private final String url;

/**
* Gets Registry login URL.
*
* @return the url value.
*/
public String getUrl() {
return this.url;
}

/**
* Initializes an instance of AccessTokensImpl.
*
* @param url the service endpoint.
* @param httpPipeline the pipeline to use to make the call.
* @param serializerAdapter the serializer adapter for the rest client.
*
*/
public AccessTokensImpl(String url, HttpPipeline httpPipeline, SerializerAdapter serializerAdapter) {
this.service =
RestProxy.create(AccessTokensServiceImpl.class, httpPipeline, serializerAdapter);
this.url = url;
}

/**
* Exchange ACR Refresh token for an ACR Access Token.
*
* @param refreshToken The refreshToken parameter.
* @throws IllegalArgumentException thrown if parameters fail the validation.
* @throws AcrErrorsException thrown if the request is rejected by server.
* @throws RuntimeException all other wrapped checked exceptions if the request fails to be sent.
* @return the response.
*/
@ServiceMethod(returns = ReturnType.SINGLE)
public Mono<Response<AcrAccessToken>> getAccessTokenWithResponseAsync(
String grantType,
String serviceName,
String scope,
String refreshToken) {
final String accept = "application/json";
return FluxUtil.withContext(
context -> service.getAccessToken(getUrl(), grantType, serviceName, scope, refreshToken, accept, context));
}
/**
* Exchange ACR Refresh token for an ACR Access Token.
*
* @param refreshToken The refreshToken parameter.
* @throws IllegalArgumentException thrown if parameters fail the validation.
* @throws AcrErrorsException thrown if the request is rejected by server.
* @throws RuntimeException all other wrapped checked exceptions if the request fails to be sent.
* @return the response.
*/
@ServiceMethod(returns = ReturnType.SINGLE)
public Mono<AcrAccessToken> getAccessTokenAsync(
String grantType,
String serviceName,
String scope,
String refreshToken) {
return getAccessTokenWithResponseAsync(grantType, serviceName, scope, refreshToken)
.flatMap(
(Response<AcrAccessToken> res) -> {
if (res.getValue() != null) {
return Mono.just(res.getValue());
} else {
return Mono.empty();
}
});
}

/**
* The interface defining all the services for AccessTokens to be used by the proxy service to
* perform REST calls.
*/
@Host("{url}")
@ServiceInterface(name = "ContainerRegistryAcc")
interface AccessTokensServiceImpl {
@Post("/oauth2/token")
@ExpectedResponses({200})
@UnexpectedResponseExceptionType(AcrErrorsException.class)
Mono<Response<AcrAccessToken>> getAccessToken(
@HostParam("url") String url,
@FormParam(value = "grant_type") String grantType,
@FormParam(value = "service") String service,
@FormParam(value = "scope") String scope,
@FormParam(value = "refresh_token") String refreshToken,
@HeaderParam("Accept") String accept,
Context context);
}
}
Loading