Skip to content

Commit

Permalink
add AWS sigv4 support (#1663)
Browse files Browse the repository at this point in the history
* add sigv4

Signed-off-by: Peng Huo <penghuo@gmail.com>

* remove coverage instrumentation code

Signed-off-by: Peng Huo <penghuo@gmail.com>

* address comments

Signed-off-by: Peng Huo <penghuo@gmail.com>

---------

Signed-off-by: Peng Huo <penghuo@gmail.com>
  • Loading branch information
penghuo committed May 30, 2023
1 parent b97b91c commit e80cf9b
Show file tree
Hide file tree
Showing 7 changed files with 289 additions and 36 deletions.
22 changes: 7 additions & 15 deletions flint/build.sbt
Original file line number Diff line number Diff line change
Expand Up @@ -30,21 +30,7 @@ lazy val compileScalastyle = taskKey[Unit]("compileScalastyle")
// Run as part of test task.
lazy val testScalastyle = taskKey[Unit]("testScalastyle")

lazy val build = taskKey[Unit]("assemblyWithCoverage")

build := Def
.sequential(flintSparkIntegration / assembly, flintSparkIntegration / coverageReport)
.value

lazy val commonSettings = Seq(
// Coverage
// todo. for demo now, increase to 100.
coverageMinimumStmtTotal := 70,
// todo. for demo now, increase to 100.
coverageMinimumBranchTotal := 70,
coverageFailOnMinimum := true,
coverageEnabled := true,

// Scalastyle
scalastyleConfig := (ThisBuild / scalastyleConfig).value,
compileScalastyle := (Compile / scalastyle).toTask("").value,
Expand All @@ -64,7 +50,9 @@ lazy val flintCore = (project in file("flint-core"))
scalaVersion := scala212,
libraryDependencies ++= Seq(
"org.opensearch.client" % "opensearch-rest-client" % opensearchVersion,
"org.opensearch.client" % "opensearch-rest-high-level-client" % opensearchVersion))
"org.opensearch.client" % "opensearch-rest-high-level-client" % opensearchVersion,
"com.amazonaws" % "aws-java-sdk" % "1.12.397" % "provided"
exclude("com.fasterxml.jackson.core", "jackson-databind") ))

lazy val flintSparkIntegration = (project in file("flint-spark-integration"))
.dependsOn(flintCore)
Expand All @@ -74,6 +62,8 @@ lazy val flintSparkIntegration = (project in file("flint-spark-integration"))
name := "flint-spark-integration",
scalaVersion := scala212,
libraryDependencies ++= Seq(
"com.amazonaws" % "aws-java-sdk" % "1.12.397" % "provided"
exclude("com.fasterxml.jackson.core", "jackson-databind"),
"org.scalactic" %% "scalactic" % "3.2.15",
"org.scalatest" %% "scalatest" % "3.2.15" % "test",
"org.scalatest" %% "scalatest-flatspec" % "3.2.15" % "test",
Expand Down Expand Up @@ -104,6 +94,8 @@ lazy val integtest = (project in file("integ-test"))
name := "integ-test",
scalaVersion := scala212,
libraryDependencies ++= Seq(
"com.amazonaws" % "aws-java-sdk" % "1.12.397" % "provided"
exclude("com.fasterxml.jackson.core", "jackson-databind"),
"org.scalactic" %% "scalactic" % "3.2.15",
"org.scalatest" %% "scalatest" % "3.2.15" % "test",
"com.stephenn" %% "scalatest-json-jsonassert" % "0.2.5" % "test",
Expand Down
52 changes: 48 additions & 4 deletions flint/docs/index.md
Original file line number Diff line number Diff line change
Expand Up @@ -72,7 +72,24 @@ flintClient.getIndexMetadata("alb_logs_skipping_index")
Index data read and write example:

```java
TODO
FlintClient flintClient = new FlintOpenSearchClient("localhost", 9200);

// read example
FlintReader reader = flintClient.createReader("indexName", null)\
while(reader.hasNext) {
reader.next()
}
reader.close()

// write example
FlintWriter writer = flintClient.createWriter("indexName")
writer.write("{\"create\":{}}")
writer.write("\n")
writer.write("{\"aInt\":1}")
writer.write("\n")
writer.flush()
writer.close()

```

### API
Expand Down Expand Up @@ -171,8 +188,16 @@ In the index mapping, the `_meta` and `properties`field stores meta and schema i

#### Configurations

- `spark.flint.indexstore.location`: default is localhost
- `spark.flint.indexstore.port`: default is 9200
- `spark.datasource.flint.location`: default is localhost.
- `spark.datasource.flint.port`: default is 9200.
- `spark.datasource.flint.scheme`: default is http. valid values [http, https]
- `spark.datasource.flint.auth`: default is false. valid values [false, sigv4]
- `spark.datasource.flint.region`: default is us-west-2. only been used when auth=sigv4
- `spark.datasource.flint.write.id_name`: no default value.
- `spark.datasource.flint.write.batch_size`: default value is 1000.
- `spark.datasource.flint.write.refresh_policy`: default value is false. valid values [NONE(false),
IMMEDIATE(true), WAIT_UNTIL(wait_for)]
- `spark.datasource.flint.read.scroll_size`: default value is 100.

#### API

Expand Down Expand Up @@ -200,6 +225,25 @@ trait FlintSparkSkippingStrategy {
}
```

#### Flint DataSource Read/Write

Here is an example for read index data from AWS OpenSearch domain.

```scala
val aos = Map(
"host" -> "yourdomain.us-west-2.es.amazonaws.com",
"port" -> "-1",
"scheme" -> "https",
"auth" -> "sigv4",
"region" -> "us-west-2")

val df = new SQLContext(sc).read
.format("flint")
.options(aos)
.schema("aInt int")
.load("t001")
```

## Benchmarks

TODO
TODO
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,21 @@ public class FlintOptions implements Serializable {
private final Map<String, String> options;

public static final String HOST = "host";

public static final String PORT = "port";

public static final String REGION = "region";

public static final String DEFAULT_REGION = "us-west-2";

public static final String SCHEME = "scheme";

public static final String AUTH = "auth";

public static final String NONE_AUTH = "false";

public static final String SIGV4_AUTH = "sigv4";

/**
* Used by {@link org.opensearch.flint.core.storage.OpenSearchScrollReader}
*/
Expand Down Expand Up @@ -50,4 +64,16 @@ public int getScrollSize() {
}

public String getRefreshPolicy() {return options.getOrDefault(REFRESH_POLICY, DEFAULT_REFRESH_POLICY);}

public String getRegion() {
return options.getOrDefault(REGION, DEFAULT_REGION);
}

public String getScheme() {
return options.getOrDefault(SCHEME, "http");
}

public String getAuth() {
return options.getOrDefault(AUTH, NONE_AUTH);
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,175 @@
/*
* Copyright OpenSearch Contributors
* SPDX-License-Identifier: Apache-2.0
*/

package org.opensearch.flint.core.auth;

import static org.apache.http.protocol.HttpCoreContext.HTTP_TARGET_HOST;

import com.amazonaws.DefaultRequest;
import com.amazonaws.auth.AWSCredentialsProvider;
import com.amazonaws.auth.Signer;
import com.amazonaws.http.HttpMethodName;
import java.io.IOException;
import java.net.URI;
import java.net.URISyntaxException;
import java.util.ArrayList;
import java.util.List;
import java.util.Map;
import java.util.TreeMap;
import org.apache.http.Header;
import org.apache.http.HttpEntityEnclosingRequest;
import org.apache.http.HttpException;
import org.apache.http.HttpHost;
import org.apache.http.HttpRequest;
import org.apache.http.HttpRequestInterceptor;
import org.apache.http.NameValuePair;
import org.apache.http.client.utils.URIBuilder;
import org.apache.http.entity.BasicHttpEntity;
import org.apache.http.message.BasicHeader;
import org.apache.http.protocol.HttpContext;

/**
* From https://github.com/opensearch-project/sql-jdbc/blob/main/src/main/java/org/opensearch/jdbc/transport/http/auth/aws/AWSRequestSigningApacheInterceptor.java
* An {@link HttpRequestInterceptor} that signs requests using any AWS {@link Signer}
* and {@link AWSCredentialsProvider}.
*/
public class AWSRequestSigningApacheInterceptor implements HttpRequestInterceptor {
/**
* The service that we're connecting to. Technically not necessary.
* Could be used by a future Signer, though.
*/
private final String service;

/**
* The particular signer implementation.
*/
private final Signer signer;

/**
* The source of AWS credentials for signing.
*/
private final AWSCredentialsProvider awsCredentialsProvider;

/**
*
* @param service service that we're connecting to
* @param signer particular signer implementation
* @param awsCredentialsProvider source of AWS credentials for signing
*/
public AWSRequestSigningApacheInterceptor(final String service,
final Signer signer,
final AWSCredentialsProvider awsCredentialsProvider) {
this.service = service;
this.signer = signer;
this.awsCredentialsProvider = awsCredentialsProvider;
}

/**
* {@inheritDoc}
*/
@Override
public void process(final HttpRequest request, final HttpContext context)
throws HttpException, IOException {
URIBuilder uriBuilder;
try {
uriBuilder = new URIBuilder(request.getRequestLine().getUri());
} catch (URISyntaxException e) {
throw new IOException("Invalid URI" , e);
}

// Copy Apache HttpRequest to AWS DefaultRequest
DefaultRequest<?> signableRequest = new DefaultRequest<>(service);

HttpHost host = (HttpHost) context.getAttribute(HTTP_TARGET_HOST);
if (host != null) {
signableRequest.setEndpoint(URI.create(host.toURI()));
}
final HttpMethodName httpMethod =
HttpMethodName.fromValue(request.getRequestLine().getMethod());
signableRequest.setHttpMethod(httpMethod);
try {
signableRequest.setResourcePath(uriBuilder.build().getRawPath());
} catch (URISyntaxException e) {
throw new IOException("Invalid URI" , e);
}

if (request instanceof HttpEntityEnclosingRequest) {
HttpEntityEnclosingRequest httpEntityEnclosingRequest =
(HttpEntityEnclosingRequest) request;
if (httpEntityEnclosingRequest.getEntity() != null) {
signableRequest.setContent(httpEntityEnclosingRequest.getEntity().getContent());
}
}
signableRequest.setParameters(nvpToMapParams(uriBuilder.getQueryParams()));
signableRequest.setHeaders(headerArrayToMap(request.getAllHeaders()));

// Sign it
signer.sign(signableRequest, awsCredentialsProvider.getCredentials());

// Now copy everything back
request.setHeaders(mapToHeaderArray(signableRequest.getHeaders()));
if (request instanceof HttpEntityEnclosingRequest) {
HttpEntityEnclosingRequest httpEntityEnclosingRequest =
(HttpEntityEnclosingRequest) request;
if (httpEntityEnclosingRequest.getEntity() != null) {
BasicHttpEntity basicHttpEntity = new BasicHttpEntity();
basicHttpEntity.setContent(signableRequest.getContent());
httpEntityEnclosingRequest.setEntity(basicHttpEntity);
}
}
}

/**
*
* @param params list of HTTP query params as NameValuePairs
* @return a multimap of HTTP query params
*/
private static Map<String, List<String>> nvpToMapParams(final List<NameValuePair> params) {
Map<String, List<String>> parameterMap = new TreeMap<>(String.CASE_INSENSITIVE_ORDER);
for (NameValuePair nvp : params) {
List<String> argsList =
parameterMap.computeIfAbsent(nvp.getName(), k -> new ArrayList<>());
argsList.add(nvp.getValue());
}
return parameterMap;
}

/**
* @param headers modeled Header objects
* @return a Map of header entries
*/
private static Map<String, String> headerArrayToMap(final Header[] headers) {
Map<String, String> headersMap = new TreeMap<>(String.CASE_INSENSITIVE_ORDER);
for (Header header : headers) {
if (!skipHeader(header)) {
headersMap.put(header.getName(), header.getValue());
}
}
return headersMap;
}

/**
* @param header header line to check
* @return true if the given header should be excluded when signing
*/
private static boolean skipHeader(final Header header) {
return ("content-length".equalsIgnoreCase(header.getName())
&& "0".equals(header.getValue())) // Strip Content-Length: 0
|| "host".equalsIgnoreCase(header.getName()); // Host comes from endpoint
}

/**
* @param mapHeaders Map of header entries
* @return modeled Header objects
*/
private static Header[] mapToHeaderArray(final Map<String, String> mapHeaders) {
Header[] headers = new Header[mapHeaders.size()];
int i = 0;
for (Map.Entry<String, String> headerEntry : mapHeaders.entrySet()) {
headers[i++] = new BasicHeader(headerEntry.getKey(), headerEntry.getValue());
}
return headers;
}
}
Loading

0 comments on commit e80cf9b

Please sign in to comment.