Skip to content

Commit

Permalink
Use BytesRestResponse constructor with contentType in asRestResponse (#…
Browse files Browse the repository at this point in the history
…3717)

### Description

Modifies `SecurityResponse.asRestResponse` to use the corresponding
constructor for `BytesRestResponse` if the `SecurityResponse` contains
the content-type header. This adds a test that fails before the change
and demonstrates how using the constructor of BytesRestResponse without
content-type defaults to text/plain

---------

Signed-off-by: Craig Perkins <cwperx@amazon.com>
Signed-off-by: Stephen Crawford <steecraw@amazon.com>
Signed-off-by: Craig Perkins <craig5008@gmail.com>
Co-authored-by: Stephen Crawford <steecraw@amazon.com>
Co-authored-by: Stephen Crawford <65832608+scrawfor99@users.noreply.github.com>
  • Loading branch information
3 people committed Nov 16, 2023
1 parent 5b75319 commit b7f47b1
Show file tree
Hide file tree
Showing 10 changed files with 232 additions and 25 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,7 @@
import org.opensearch.common.settings.Settings;
import org.opensearch.common.util.concurrent.ThreadContext;
import org.opensearch.common.xcontent.XContentFactory;
import org.opensearch.common.xcontent.XContentType;
import org.opensearch.core.xcontent.XContentBuilder;
import org.opensearch.env.Environment;
import org.opensearch.security.auth.HTTPAuthenticator;
Expand Down Expand Up @@ -284,10 +285,12 @@ public GSSCredential run() throws GSSException {
public Optional<SecurityResponse> reRequestAuthentication(final SecurityRequest request, AuthCredentials creds) {
final Map<String, String> headers = new HashMap<>();
String responseBody = "";
String contentType = null;
SecurityResponse response;
final String negotiateResponseBody = getNegotiateResponseBody();
if (negotiateResponseBody != null) {
responseBody = negotiateResponseBody;
headers.putAll(SecurityResponse.CONTENT_TYPE_APP_JSON);
contentType = XContentType.JSON.mediaType();
}

if (creds == null || creds.getNativeCredentials() == null) {
Expand All @@ -296,7 +299,12 @@ public Optional<SecurityResponse> reRequestAuthentication(final SecurityRequest
headers.put("WWW-Authenticate", "Negotiate " + Base64.getEncoder().encodeToString((byte[]) creds.getNativeCredentials()));
}

return Optional.of(new SecurityResponse(SC_UNAUTHORIZED, headers, responseBody));
if (contentType != null) {
response = new SecurityResponse(SC_UNAUTHORIZED, headers, responseBody, contentType);
} else {
response = new SecurityResponse(SC_UNAUTHORIZED, headers, responseBody);
}
return Optional.of(response);
}

@Override
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -226,10 +226,10 @@ private Optional<SecurityResponse> handleLowLevel(RestRequest restRequest) throw

String responseBodyString = DefaultObjectMapper.objectMapper.writeValueAsString(responseBody);

return Optional.of(new SecurityResponse(HttpStatus.SC_OK, SecurityResponse.CONTENT_TYPE_APP_JSON, responseBodyString));
return Optional.of(new SecurityResponse(HttpStatus.SC_OK, null, responseBodyString, XContentType.JSON.mediaType()));
} catch (JsonProcessingException e) {
log.warn("Error while parsing JSON for /_opendistro/_security/api/authtoken", e);
return Optional.of(new SecurityResponse(HttpStatus.SC_BAD_REQUEST, new Exception("JSON could not be parsed")));
return Optional.of(new SecurityResponse(HttpStatus.SC_BAD_REQUEST, "JSON could not be parsed"));
}
}

Expand Down
12 changes: 4 additions & 8 deletions src/main/java/org/opensearch/security/auth/BackendRegistry.java
Original file line number Diff line number Diff line change
Expand Up @@ -202,7 +202,7 @@ public boolean authenticate(final SecurityRequestChannel request) {
log.debug("Rejecting REST request because of blocked address: {}", request.getRemoteAddress().orElse(null));
}

request.queueForSending(new SecurityResponse(SC_UNAUTHORIZED, new Exception("Authentication finally failed")));
request.queueForSending(new SecurityResponse(SC_UNAUTHORIZED, "Authentication finally failed"));
return false;
}

Expand All @@ -224,7 +224,7 @@ public boolean authenticate(final SecurityRequestChannel request) {

if (!isInitialized()) {
log.error("Not yet initialized (you may need to run securityadmin)");
request.queueForSending(new SecurityResponse(SC_SERVICE_UNAVAILABLE, new Exception("OpenSearch Security not initialized.")));
request.queueForSending(new SecurityResponse(SC_SERVICE_UNAVAILABLE, "OpenSearch Security not initialized."));
return false;
}

Expand Down Expand Up @@ -354,11 +354,7 @@ public boolean authenticate(final SecurityRequestChannel request) {
log.error("Cannot authenticate rest user because admin user is not permitted to login via HTTP");
auditLog.logFailedLogin(authenticatedUser.getName(), true, null, request);
request.queueForSending(
new SecurityResponse(
SC_FORBIDDEN,
null,
"Cannot authenticate user because admin user is not permitted to login via HTTP"
)
new SecurityResponse(SC_FORBIDDEN, "Cannot authenticate user because admin user is not permitted to login via HTTP")
);
return false;
}
Expand Down Expand Up @@ -429,7 +425,7 @@ public boolean authenticate(final SecurityRequestChannel request) {
notifyIpAuthFailureListeners(request, authCredentials);

request.queueForSending(
challengeResponse.orElseGet(() -> new SecurityResponse(SC_UNAUTHORIZED, null, "Authentication finally failed"))
challengeResponse.orElseGet(() -> new SecurityResponse(SC_UNAUTHORIZED, "Authentication finally failed"))
);
return false;
}
Expand Down
58 changes: 52 additions & 6 deletions src/main/java/org/opensearch/security/filter/SecurityResponse.java
Original file line number Diff line number Diff line change
Expand Up @@ -12,11 +12,15 @@
package org.opensearch.security.filter;

import java.io.IOException;
import java.util.ArrayList;
import java.util.HashMap;
import java.util.List;
import java.util.Map;

import org.apache.http.HttpHeaders;

import org.opensearch.common.xcontent.XContentFactory;
import org.opensearch.common.xcontent.XContentType;
import org.opensearch.core.rest.RestStatus;
import org.opensearch.rest.BytesRestResponse;
import org.opensearch.rest.RestResponse;
Expand All @@ -26,26 +30,63 @@ public class SecurityResponse {
public static final Map<String, String> CONTENT_TYPE_APP_JSON = Map.of(HttpHeaders.CONTENT_TYPE, "application/json");

private final int status;
private final Map<String, String> headers;
private Map<String, List<String>> headers;
private final String body;
private final String contentType;

public SecurityResponse(final int status, final Exception e) {
this.status = status;
this.headers = CONTENT_TYPE_APP_JSON;
populateHeaders(CONTENT_TYPE_APP_JSON);
this.body = generateFailureMessage(e);
this.contentType = XContentType.JSON.mediaType();
}

public SecurityResponse(final int status, String body) {
this.status = status;
this.body = body;
this.contentType = null;
}

public SecurityResponse(final int status, final Map<String, String> headers, final String body) {
this.status = status;
this.headers = headers;
populateHeaders(headers);
this.body = body;
this.contentType = null;
}

public SecurityResponse(final int status, final Map<String, String> headers, final String body, String contentType) {
this.status = status;
this.body = body;
this.contentType = contentType;
populateHeaders(headers);
}

private void populateHeaders(Map<String, String> headers) {
if (headers != null) {
headers.entrySet().forEach(entry -> addHeader(entry.getKey(), entry.getValue()));
}
}

/**
* Add a custom header.
*/
public void addHeader(String name, String value) {
if (headers == null) {
headers = new HashMap<>(2);
}
List<String> header = headers.get(name);
if (header == null) {
header = new ArrayList<>();
headers.put(name, header);
}
header.add(value);
}

public int getStatus() {
return status;
}

public Map<String, String> getHeaders() {
public Map<String, List<String>> getHeaders() {
return headers;
}

Expand All @@ -54,9 +95,14 @@ public String getBody() {
}

public RestResponse asRestResponse() {
final RestResponse restResponse = new BytesRestResponse(RestStatus.fromCode(getStatus()), getBody());
final RestResponse restResponse;
if (this.contentType != null) {
restResponse = new BytesRestResponse(RestStatus.fromCode(getStatus()), this.contentType, getBody());
} else {
restResponse = new BytesRestResponse(RestStatus.fromCode(getStatus()), getBody());
}
if (getHeaders() != null) {
getHeaders().forEach(restResponse::addHeader);
getHeaders().entrySet().forEach(entry -> { entry.getValue().forEach(value -> restResponse.addHeader(entry.getKey(), value)); });
}
return restResponse;
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -245,7 +245,7 @@ void authorizeRequest(RestHandler original, SecurityRequestChannel request, User
}
log.debug(err);

request.queueForSending(new SecurityResponse(HttpStatus.SC_UNAUTHORIZED, null, err));
request.queueForSending(new SecurityResponse(HttpStatus.SC_UNAUTHORIZED, err));
return;
}
}
Expand Down Expand Up @@ -288,7 +288,7 @@ public void checkAndAuthenticateRequest(SecurityRequestChannel requestChannel) t
} catch (SSLPeerUnverifiedException e) {
log.error("No ssl info", e);
auditLog.logSSLException(requestChannel, e);
requestChannel.queueForSending(new SecurityResponse(HttpStatus.SC_FORBIDDEN, new Exception("No ssl info")));
requestChannel.queueForSending(new SecurityResponse(HttpStatus.SC_FORBIDDEN, e));
return;
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@
import javax.net.ssl.SSLEngine;
import javax.net.ssl.SSLParameters;

import com.google.common.collect.Lists;
import org.apache.hc.client5.http.config.RequestConfig;
import org.apache.hc.client5.http.impl.async.HttpAsyncClientBuilder;
import org.apache.hc.client5.http.impl.nio.PoolingAsyncClientConnectionManagerBuilder;
Expand All @@ -52,6 +53,7 @@
import org.apache.hc.core5.ssl.SSLContexts;
import org.apache.logging.log4j.LogManager;
import org.apache.logging.log4j.Logger;

import org.opensearch.action.index.IndexRequest;
import org.opensearch.action.index.IndexResponse;
import org.opensearch.action.support.WriteRequest.RefreshPolicy;
Expand All @@ -62,8 +64,6 @@
import org.opensearch.client.RestHighLevelClient;
import org.opensearch.common.xcontent.XContentType;

import com.google.common.collect.Lists;

public class HttpClient implements Closeable {

public static class HttpClientBuilder {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
import org.apache.http.HttpStatus;

import org.opensearch.common.xcontent.XContentFactory;
import org.opensearch.common.xcontent.XContentType;
import org.opensearch.core.rest.RestStatus;
import org.opensearch.security.filter.SecurityRequest;
import org.opensearch.security.filter.SecurityResponse;
Expand Down Expand Up @@ -113,7 +114,7 @@ public Optional<SecurityResponse> checkRequestIsAllowed(final SecurityRequest re
// if allowlisting is enabled but the request is not allowlisted, then return false, otherwise true.
if (this.enabled && !requestIsAllowlisted(request)) {
return Optional.of(
new SecurityResponse(HttpStatus.SC_FORBIDDEN, SecurityResponse.CONTENT_TYPE_APP_JSON, generateFailureMessage(request))
new SecurityResponse(HttpStatus.SC_FORBIDDEN, null, generateFailureMessage(request), XContentType.JSON.mediaType())
);
}
return Optional.empty();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@

import org.apache.http.HttpStatus;

import org.opensearch.common.xcontent.XContentType;
import org.opensearch.security.filter.SecurityRequest;
import org.opensearch.security.filter.SecurityResponse;

Expand Down Expand Up @@ -111,7 +112,7 @@ public Optional<SecurityResponse> checkRequestIsAllowed(final SecurityRequest re
// if whitelisting is enabled but the request is not whitelisted, then return false, otherwise true.
if (this.enabled && !requestIsWhitelisted(request)) {
return Optional.of(
new SecurityResponse(HttpStatus.SC_FORBIDDEN, SecurityResponse.CONTENT_TYPE_APP_JSON, generateFailureMessage(request))
new SecurityResponse(HttpStatus.SC_FORBIDDEN, null, generateFailureMessage(request), XContentType.JSON.mediaType())
);
}
return Optional.empty();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -887,7 +887,7 @@ private AuthenticateHeaders getAutenticateHeaders(HTTPSamlAuthenticator samlAuth
RestRequest restRequest = new FakeRestRequest(ImmutableMap.of(), new HashMap<String, String>());
SecurityResponse response = sendToAuthenticator(samlAuthenticator, restRequest).orElseThrow();

String wwwAuthenticateHeader = response.getHeaders().get("WWW-Authenticate");
String wwwAuthenticateHeader = response.getHeaders().get("WWW-Authenticate").get(0);

Assert.assertNotNull(wwwAuthenticateHeader);

Expand Down
Loading

0 comments on commit b7f47b1

Please sign in to comment.