Skip to content

Commit

Permalink
[Event Hubs] Support SAS token in connnection string (#14912)
Browse files Browse the repository at this point in the history
* Move load balancing options changes to main branch

* Support SAS token in connnection string
  • Loading branch information
srnagar committed Sep 9, 2020
1 parent 3f42cdb commit 1ea7f17
Show file tree
Hide file tree
Showing 8 changed files with 266 additions and 27 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -25,25 +25,31 @@ public class ConnectionStringProperties {
private static final String ENDPOINT = "Endpoint";
private static final String SHARED_ACCESS_KEY_NAME = "SharedAccessKeyName";
private static final String SHARED_ACCESS_KEY = "SharedAccessKey";
private static final String SHARED_ACCESS_SIGNATURE = "SharedAccessSignature";
private static final String SAS_VALUE_PREFIX = "sharedaccesssignature ";
private static final String ENTITY_PATH = "EntityPath";
private static final String CONNECTION_STRING_WITH_ACCESS_KEY = "Endpoint={endpoint};"
+ "SharedAccessKeyName={sharedAccessKeyName};SharedAccessKey={sharedAccessKey};EntityPath={entityPath}";
private static final String CONNECTION_STRING_WITH_SAS = "Endpoint={endpoint};SharedAccessSignature="
+ "SharedAccessSignature {sharedAccessSignature};EntityPath={entityPath}";
private static final String ERROR_MESSAGE_FORMAT = "Could not parse 'connectionString'. Expected format: "
+ "'Endpoint={endpoint};SharedAccessKeyName={sharedAccessKeyName};"
+ "SharedAccessKey={sharedAccessKey};EntityPath={entityPath}'. Actual: %s";
+ CONNECTION_STRING_WITH_ACCESS_KEY + " or " + CONNECTION_STRING_WITH_SAS + ". Actual: %s";
private static final String ERROR_MESSAGE_ENDPOINT_FORMAT = "'Endpoint' must be provided in 'connectionString'."
+ " Actual: %s";

private final URI endpoint;
private final String entityPath;
private final String sharedAccessKeyName;
private final String sharedAccessKey;
private final String sharedAccessSignature;

/**
* Creates a new instance by parsing the {@code connectionString} into its components.
*
* @param connectionString The connection string to the Event Hub instance.
*
* @throws NullPointerException if {@code connectionString} is null.
* @throws IllegalArgumentException if {@code connectionString} is an empty string or the connection string has
* an invalid format.
* an invalid format.
*/
public ConnectionStringProperties(String connectionString) {
Objects.requireNonNull(connectionString, "'connectionString' cannot be null.");
Expand All @@ -56,6 +62,7 @@ public ConnectionStringProperties(String connectionString) {
String entityPath = null;
String sharedAccessKeyName = null;
String sharedAccessKeyValue = null;
String sharedAccessSignature = null;

for (String tokenValuePair : tokenValuePairs) {
final String[] pair = tokenValuePair.split(TOKEN_VALUE_SEPARATOR, 2);
Expand Down Expand Up @@ -83,25 +90,34 @@ public ConnectionStringProperties(String connectionString) {
sharedAccessKeyValue = value;
} else if (key.equalsIgnoreCase(ENTITY_PATH)) {
entityPath = value;
} else if (key.equalsIgnoreCase(SHARED_ACCESS_SIGNATURE)
&& value.toLowerCase(Locale.ROOT).startsWith(SAS_VALUE_PREFIX)) {
sharedAccessSignature = value;
} else {
throw new IllegalArgumentException(
String.format(Locale.US, "Illegal connection string parameter name: %s", key));
}
}

if (endpoint == null || sharedAccessKeyName == null || sharedAccessKeyValue == null) {
// connection string should have an endpoint and either shared access signature or shared access key and value
boolean includesSharedKey = sharedAccessKeyName != null || sharedAccessKeyValue != null;
boolean hasSharedKeyAndValue = sharedAccessKeyName != null && sharedAccessKeyValue != null;
boolean includesSharedAccessSignature = sharedAccessSignature != null;
if (endpoint == null
|| (includesSharedKey && includesSharedAccessSignature) // includes both SAS and key or value
|| (!hasSharedKeyAndValue && !includesSharedAccessSignature)) { // invalid key, value and SAS
throw new IllegalArgumentException(String.format(Locale.US, ERROR_MESSAGE_FORMAT, connectionString));
}

this.endpoint = endpoint;
this.entityPath = entityPath;
this.sharedAccessKeyName = sharedAccessKeyName;
this.sharedAccessKey = sharedAccessKeyValue;
this.sharedAccessSignature = sharedAccessSignature;
}

/**
* Gets the endpoint to be used for connecting to the AMQP message broker.
*
* @return The endpoint address, including protocol, from the connection string.
*/
public URI getEndpoint() {
Expand All @@ -110,7 +126,6 @@ public URI getEndpoint() {

/**
* Gets the entity path to connect to in the message broker.
*
* @return The entity path to connect to in the message broker.
*/
public String getEntityPath() {
Expand All @@ -119,7 +134,6 @@ public String getEntityPath() {

/**
* Gets the name of the shared access key, either for the Event Hubs namespace or the Event Hub instance.
*
* @return The name of the shared access key.
*/
public String getSharedAccessKeyName() {
Expand All @@ -128,13 +142,21 @@ public String getSharedAccessKeyName() {

/**
* The value of the shared access key, either for the Event Hubs namespace or the Event Hub.
*
* @return The value of the shared access key.
*/
public String getSharedAccessKey() {
return sharedAccessKey;
}

/**
* The value of the shared access signature, if the connection string used to create this instance included the
* shared access signature component.
* @return The shared access signature value, if included in the connection string.
*/
public String getSharedAccessSignature() {
return sharedAccessSignature;
}

/*
* The function checks for pre existing scheme of "sb://" , "http://" or "https://". If the scheme is not provided
* in endpoint, it will set the default scheme to "sb://".
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -5,8 +5,11 @@

import org.junit.jupiter.api.Assertions;
import org.junit.jupiter.api.Test;
import org.junit.jupiter.params.ParameterizedTest;
import org.junit.jupiter.params.provider.MethodSource;

import java.util.Locale;
import java.util.stream.Stream;

import static org.junit.jupiter.api.Assertions.assertThrows;

Expand All @@ -16,6 +19,11 @@ public class ConnectionStringPropertiesTest {
private static final String EVENT_HUB = "event-hub-instance";
private static final String SAS_KEY = "test-sas-key";
private static final String SAS_VALUE = "some-secret-value";
private static final String SHARED_ACCESS_SIGNATURE = "SharedAccessSignature "
+ "sr=https%3A%2F%2Fentity-name.servicebus.windows.net%2F"
+ "&sig=encodedsignature%3D"
+ "&se=100000"
+ "&skn=test-sas-key";

@Test
public void nullConnectionString() {
Expand Down Expand Up @@ -130,7 +138,54 @@ public void parseConnectionString() {
Assertions.assertEquals(EVENT_HUB, properties.getEntityPath());
}

private static String getConnectionString(String hostname, String eventHubName, String sasKeyName, String sasKeyValue) {
@ParameterizedTest
@MethodSource("getInvalidConnectionString")
public void testConnectionStringWithSas(String invalidConnectionString) {
assertThrows(IllegalArgumentException.class, () -> new ConnectionStringProperties(invalidConnectionString));
}

@ParameterizedTest
@MethodSource("getSharedAccessSignature")
public void testInvalidSharedAccessSignature(String sas) {
assertThrows(IllegalArgumentException.class, () ->
new ConnectionStringProperties(getConnectionString(HOSTNAME_URI, null, null, null, sas)));
}

private static Stream<String> getInvalidConnectionString() {
String keyNameWithSas = getConnectionString(HOSTNAME_URI, EVENT_HUB, SAS_KEY, null, SHARED_ACCESS_SIGNATURE);
String keyValueWithSas = getConnectionString(HOSTNAME_URI, EVENT_HUB, null, SAS_VALUE, SHARED_ACCESS_SIGNATURE);
String keyNameAndValueWithSas = getConnectionString(HOSTNAME_URI, EVENT_HUB, SAS_KEY, SAS_VALUE,
SHARED_ACCESS_SIGNATURE);
String nullHostName = getConnectionString(null, EVENT_HUB, SAS_KEY, SAS_VALUE, SHARED_ACCESS_SIGNATURE);
String nullHostNameValidSas = getConnectionString(null, EVENT_HUB, null, null, SHARED_ACCESS_SIGNATURE);
String nullHostNameValidKey = getConnectionString(null, EVENT_HUB, SAS_KEY, SAS_VALUE, null);
return Stream.of(keyNameWithSas, keyValueWithSas, keyNameAndValueWithSas, nullHostName, nullHostNameValidSas,
nullHostNameValidKey);
}

private static Stream<String> getSharedAccessSignature() {
String nullSas = null;
String sasInvalidPrefix = "AccessSignature " // invalid prefix
+ "sr=https%3A%2F%2Fentity-name.servicebus.windows.net%2F"
+ "&sig=encodedsignature%3D"
+ "&se=100000"
+ "&skn=test-sas-key";
String sasWithoutSpace = "SharedAccessSignature" // no space after prefix
+ "sr=https%3A%2F%2Fentity-name.servicebus.windows.net%2F"
+ "&sig=encodedsignature%3D"
+ "&se=100000"
+ "&skn=test-sas-key";

return Stream.of(nullSas, sasInvalidPrefix, sasWithoutSpace);
}

private static String getConnectionString(String hostname, String eventHubName, String sasKeyName,
String sasKeyValue) {
return getConnectionString(hostname, eventHubName, sasKeyName, sasKeyValue, null);
}

private static String getConnectionString(String hostname, String eventHubName, String sasKeyName,
String sasKeyValue, String sharedAccessSignature) {
final StringBuilder builder = new StringBuilder();
if (hostname != null) {
builder.append(String.format(Locale.US, "Endpoint=%s;", hostname));
Expand All @@ -144,6 +199,9 @@ private static String getConnectionString(String hostname, String eventHubName,
if (sasKeyValue != null) {
builder.append(String.format(Locale.US, "SharedAccessKey=%s;", sasKeyValue));
}
if (sharedAccessSignature != null) {
builder.append(String.format(Locale.US, "SharedAccessSignature=%s;", sharedAccessSignature));
}

return builder.toString();
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -178,13 +178,22 @@ public EventHubClientBuilder() {
* connection string.
*/
public EventHubClientBuilder connectionString(String connectionString) {
final ConnectionStringProperties properties = new ConnectionStringProperties(connectionString);
final TokenCredential tokenCredential = new EventHubSharedKeyCredential(properties.getSharedAccessKeyName(),
properties.getSharedAccessKey(), ClientConstants.TOKEN_VALIDITY);

ConnectionStringProperties properties = new ConnectionStringProperties(connectionString);
TokenCredential tokenCredential = getTokenCredential(properties);
return credential(properties.getEndpoint().getHost(), properties.getEntityPath(), tokenCredential);
}

private TokenCredential getTokenCredential(ConnectionStringProperties properties) {
TokenCredential tokenCredential;
if (properties.getSharedAccessSignature() == null) {
tokenCredential = new EventHubSharedKeyCredential(properties.getSharedAccessKeyName(),
properties.getSharedAccessKey(), ClientConstants.TOKEN_VALIDITY);
} else {
tokenCredential = new EventHubSharedKeyCredential(properties.getSharedAccessSignature());
}
return tokenCredential;
}

/**
* Sets the credential information given a connection string to the Event Hubs namespace and name to a specific
* Event Hub instance.
Expand Down Expand Up @@ -213,8 +222,7 @@ public EventHubClientBuilder connectionString(String connectionString, String ev
}

final ConnectionStringProperties properties = new ConnectionStringProperties(connectionString);
final TokenCredential tokenCredential = new EventHubSharedKeyCredential(properties.getSharedAccessKeyName(),
properties.getSharedAccessKey(), ClientConstants.TOKEN_VALIDITY);
TokenCredential tokenCredential = getTokenCredential(properties);

if (!CoreUtils.isNullOrEmpty(properties.getEntityPath())
&& !eventHubName.equals(properties.getEntityPath())) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,8 +18,10 @@
import java.security.InvalidKeyException;
import java.security.NoSuchAlgorithmException;
import java.time.Duration;
import java.time.Instant;
import java.time.OffsetDateTime;
import java.time.ZoneOffset;
import java.util.Arrays;
import java.util.Base64;
import java.util.Locale;
import java.util.Objects;
Expand Down Expand Up @@ -51,6 +53,7 @@ public class EventHubSharedKeyCredential implements TokenCredential {
private final String policyName;
private final Duration tokenValidity;
private final SecretKeySpec secretKeySpec;
private final String sharedAccessSignature;

/**
* Creates an instance that authorizes using the {@code policyName} and {@code sharedAccessKey}.
Expand Down Expand Up @@ -98,6 +101,26 @@ public EventHubSharedKeyCredential(String policyName, String sharedAccessKey, Du

final byte[] sasKeyBytes = sharedAccessKey.getBytes(UTF_8);
secretKeySpec = new SecretKeySpec(sasKeyBytes, HASH_ALGORITHM);
sharedAccessSignature = null;
}

/**
* Creates an instance using the provided Shared Access Signature (SAS) string. The credential created using this
* constructor will not be refreshed. The expiration time is set to the time defined in "se={
* tokenValidationSeconds}`. If the SAS string does not contain this or is in invalid format, then the token
* expiration will be set to {@link OffsetDateTime#MAX max duration}.
* <p><a href="https://docs.microsoft.com/rest/api/eventhub/generate-sas-token">See how to generate SAS
* programmatically.</a></p>
*
* @param sharedAccessSignature The base64 encoded shared access signature string.
* @throws NullPointerException if {@code sharedAccessSignature} is null.
*/
public EventHubSharedKeyCredential(String sharedAccessSignature) {
this.sharedAccessSignature = Objects.requireNonNull(sharedAccessSignature,
"'sharedAccessSignature' cannot be null");
this.policyName = null;
this.secretKeySpec = null;
this.tokenValidity = null;
}

/**
Expand All @@ -124,6 +147,10 @@ private AccessToken generateSharedAccessSignature(final String resource) throws
throw logger.logExceptionAsError(new IllegalArgumentException("resource cannot be empty"));
}

if (sharedAccessSignature != null) {
return new AccessToken(sharedAccessSignature, getExpirationTime(sharedAccessSignature));
}

final Mac hmac;
try {
hmac = Mac.getInstance(HASH_ALGORITHM);
Expand Down Expand Up @@ -153,4 +180,24 @@ private AccessToken generateSharedAccessSignature(final String resource) throws

return new AccessToken(token, expiresOn);
}

private OffsetDateTime getExpirationTime(String sharedAccessSignature) {
String[] parts = sharedAccessSignature.split("&");
return Arrays.stream(parts)
.map(part -> part.split("="))
.filter(pair -> pair.length == 2 && pair[0].equalsIgnoreCase("se"))
.findFirst()
.map(pair -> pair[1])
.map(expirationTimeStr -> {
try {
long epochSeconds = Long.parseLong(expirationTimeStr);
return Instant.ofEpochSecond(epochSeconds).atOffset(ZoneOffset.UTC);
} catch (NumberFormatException exception) {
logger.verbose("Invalid expiration time format in the SAS token: {}. Falling back to max "
+ "expiration time.", expirationTimeStr);
return OffsetDateTime.MAX;
}
})
.orElse(OffsetDateTime.MAX);
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -90,6 +90,23 @@ public void throwsWithProxyWhenTransportTypeNotChanged() {
assertNotNull(builder.buildAsyncClient());
});
}
@Test
public void testConnectionStringWithSas() {

String connectionStringWithNoEntityPath = "Endpoint=sb://eh-name.servicebus.windows.net/;"
+ "SharedAccessSignature=SharedAccessSignature test-value";
String connectionStringWithEntityPath = "Endpoint=sb://eh-name.servicebus.windows.net/;"
+ "SharedAccessSignature=SharedAccessSignature test-value;EntityPath=eh-name";

assertNotNull(new EventHubClientBuilder()
.connectionString(connectionStringWithNoEntityPath, "eh-name"));
assertNotNull(new EventHubClientBuilder()
.connectionString(connectionStringWithEntityPath));
assertThrows(NullPointerException.class, () -> new EventHubClientBuilder()
.connectionString(connectionStringWithNoEntityPath));
assertThrows(IllegalArgumentException.class, () -> new EventHubClientBuilder()
.connectionString(connectionStringWithEntityPath, "eh-name-mismatch"));
}

@MethodSource("getProxyConfigurations")
@ParameterizedTest
Expand Down
Loading

0 comments on commit 1ea7f17

Please sign in to comment.