Skip to content

Commit

Permalink
Incorporate host in request after authenticate() call (#282)
Browse files Browse the repository at this point in the history
## Changes

Notebook native auth was not working because the API client accessed the
host before the credential provider had a chance to set it. This change
postpones reading the host until after the `config.authenticate()` call.

Fixes #230.

## Tests

* The new API client test adds coverage for this case.
* I manually confirmed that notebook native auth now works out of the
box on DBR.
  • Loading branch information
pietern committed May 13, 2024
1 parent 77a625c commit 989ab77
Show file tree
Hide file tree
Showing 2 changed files with 57 additions and 14 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -231,11 +231,10 @@ private <T> T execute(Request in, Class<T> target) throws IOException {
}

private Response getResponse(Request in) {
in.withUrl(config.getHost() + in.getUrl());
return executeInner(in);
return executeInner(in, in.getUrl());
}

private Response executeInner(Request in) {
private Response executeInner(Request in, String path) {
RetryStrategy retryStrategy = retryStrategyPicker.getRetryStrategy(in);
int attemptNumber = 0;
while (true) {
Expand All @@ -247,6 +246,10 @@ private Response executeInner(Request in) {
// Authenticate the request. Failures should not be retried.
in.withHeaders(config.authenticate());

// Prepend host to URL only after config.authenticate().
// This call may configure the host (e.g. in case of notebook native auth).
in.withUrl(config.getHost() + path);

// Set User-Agent with auth type info, which is available only
// after the first invocation to config.authenticate()
String userAgent = String.format("%s auth/%s", UserAgent.asString(), config.getAuthType());
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -42,26 +42,24 @@ public boolean equals(Object o) {
}
}

private ApiClient getApiClient(Request request, List<ResponseProvider> responses) {
private ApiClient getApiClient(
DatabricksConfig config, Request request, List<ResponseProvider> responses) {
DummyHttpClient hc = new DummyHttpClient();
for (ResponseProvider response : responses) {
hc.with(request, response);
}
return new ApiClient(config.setHttpClient(hc), new FakeTimer());
}

private ApiClient getApiClient(Request request, List<ResponseProvider> responses) {
String host = request.getUri().getScheme() + "://" + request.getUri().getHost();
DatabricksConfig config =
new DatabricksConfig()
.setHttpClient(hc)
.setHost(host)
.setCredentialsProvider(new DummyCredentialsProvider());
return new ApiClient(config, new FakeTimer());
new DatabricksConfig().setHost(host).setCredentialsProvider(new DummyCredentialsProvider());
return getApiClient(config, request, responses);
}

private <T> void runApiClientTest(
Request request,
List<ResponseProvider> responses,
Class<? extends T> clazz,
T expectedResponse) {
ApiClient client = getApiClient(request, responses);
ApiClient client, Request request, Class<? extends T> clazz, T expectedResponse) {
T response;
if (request.getMethod().equals(Request.GET)) {
response = client.GET(request.getUri().getPath(), clazz, Collections.emptyMap());
Expand All @@ -73,6 +71,15 @@ private <T> void runApiClientTest(
assertEquals(response, expectedResponse);
}

private <T> void runApiClientTest(
Request request,
List<ResponseProvider> responses,
Class<? extends T> clazz,
T expectedResponse) {
ApiClient client = getApiClient(request, responses);
runApiClientTest(client, request, clazz, expectedResponse);
}

private void runFailingApiClientTest(
Request request, List<ResponseProvider> responses, Class<?> clazz, String expectedMessage) {
DatabricksException exception =
Expand Down Expand Up @@ -347,6 +354,39 @@ void retryUnknownHostException() {
new MyEndpointResponse().setKey("value"));
}

class HostPopulatingCredentialsProvider implements CredentialsProvider {
private final String host;
private final CredentialsProvider parent;

public HostPopulatingCredentialsProvider(String host) {
this.host = host;
this.parent = new DummyCredentialsProvider();
}

@Override
public String authType() {
return parent.authType();
}

@Override
public HeaderFactory configure(DatabricksConfig config) {
config.setHost(this.host);
return parent.configure(config);
}
}

@Test
void populateHostFromCredentialProvider() {
Request req = getBasicRequest();
DatabricksConfig config =
new DatabricksConfig()
.setCredentialsProvider(new HostPopulatingCredentialsProvider("http://my.host"));
ApiClient client =
getApiClient(config, req, Collections.singletonList(getSuccessResponse(req)));
runApiClientTest(
client, req, MyEndpointResponse.class, new MyEndpointResponse().setKey("value"));
}

@Test
void testGetBackoffFromRetryAfterHeader() {
Request req = getBasicRequest();
Expand Down

0 comments on commit 989ab77

Please sign in to comment.