Skip to content

Commit

Permalink
Add pushdown optimization by painless script
Browse files Browse the repository at this point in the history
Signed-off-by: Chen Dai <daichen@amazon.com>
  • Loading branch information
dai-chen committed Mar 7, 2024
1 parent d3cdb0e commit 72a4f32
Show file tree
Hide file tree
Showing 5 changed files with 173 additions and 3 deletions.
112 changes: 112 additions & 0 deletions flint-spark-integration/src/main/resources/bloom_filter_query.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1,112 @@
int hashLong(long input, int seed) {
int low = (int) input;
int high = (int) (input >>> 32);

int k1 = mixK1(low);
int h1 = mixH1(seed, k1);

k1 = mixK1(high);
h1 = mixH1(h1, k1);

return fmix(h1, 8);
}

int mixK1(int k1) {
k1 *= 0xcc9e2d51L;
k1 = Integer.rotateLeft(k1, 15);
k1 *= 0x1b873593L;
return k1;
}

int mixH1(int h1, int k1) {
h1 ^= k1;
h1 = Integer.rotateLeft(h1, 13);
h1 = h1 * 5 + (int) 0xe6546b64L;
return h1;
}

int fmix(int h1, int length) {
h1 ^= length;
h1 ^= h1 >>> 16;
h1 *= 0x85ebca6bL;
h1 ^= h1 >>> 13;
h1 *= 0xc2b2ae35L;
h1 ^= h1 >>> 16;
return h1;
}

BytesRef bfBytes = doc[params.fieldName].value;
byte[] buf = bfBytes.bytes;
int pos = 0;
int count = buf.length;
int ch1 = (pos < count) ? (buf[pos++] & (int) 0xffL) : -1;
int ch2 = (pos < count) ? (buf[pos++] & (int) 0xffL) : -1;
int ch3 = (pos < count) ? (buf[pos++] & (int) 0xffL) : -1;
int ch4 = (pos < count) ? (buf[pos++] & (int) 0xffL) : -1;
int version = ((ch1 << 24) + (ch2 << 16) + (ch3 << 8) + (ch4 << 0));
ch1 = (pos < count) ? (buf[pos++] & (int) 0xffL) : -1;
ch2 = (pos < count) ? (buf[pos++] & (int) 0xffL) : -1;
ch3 = (pos < count) ? (buf[pos++] & (int) 0xffL) : -1;
ch4 = (pos < count) ? (buf[pos++] & (int) 0xffL) : -1;
int numHashFunctions = ((ch1 << 24) + (ch2 << 16) + (ch3 << 8) + (ch4 << 0));
ch1 = (pos < count) ? (buf[pos++] & (int) 0xffL) : -1;
ch2 = (pos < count) ? (buf[pos++] & (int) 0xffL) : -1;
ch3 = (pos < count) ? (buf[pos++] & (int) 0xffL) : -1;
ch4 = (pos < count) ? (buf[pos++] & (int) 0xffL) : -1;
int numWords = ((ch1 << 24) + (ch2 << 16) + (ch3 << 8) + (ch4 << 0));

long[] data = new long[numWords];
byte[] readBuffer = new byte[8];
for (int i = 0; i < numWords; i++) {

int n = 0;
while (n < 8) {
int count2;
int off = n;
int len = 8 - n;
if (pos >= count) {
count2 = -1;
} else {
int avail = count - pos;
if (len > avail) {
len = avail;
}
if (len <= 0) {
count2 = 0;
} else {
System.arraycopy(buf, pos, readBuffer, off, len);
pos += len;
count2 = len;
}
}
n += count2;
}
data[i] = (((long) readBuffer[0] << 56) +
((long) (readBuffer[1] & 255) << 48) +
((long) (readBuffer[2] & 255) << 40) +
((long) (readBuffer[3] & 255) << 32) +
((long) (readBuffer[4] & 255) << 24) +
((readBuffer[5] & 255) << 16) +
((readBuffer[6] & 255) << 8) +
((readBuffer[7] & 255) << 0));
}
long bitCount = 0;
for (long word : data) {
bitCount += Long.bitCount(word);
}

long item = params.value;
int h1 = hashLong(item, 0);
int h2 = hashLong(item, h1);

long bitSize = (long) data.length * Long.SIZE;
for (int i = 1; i <= numHashFunctions; i++) {
int combinedHash = h1 + (i * h2);
if (combinedHash < 0) {
combinedHash = ~combinedHash;
}
if ((data[(int) (combinedHash % bitSize >>> 6)] & (1L << combinedHash % bitSize)) == 0) {
return false;
}
}
return true;
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@ package org.apache.spark.sql.flint.datatype

import org.json4s.{Formats, JField, JValue, NoTypeHints}
import org.json4s.JsonAST.{JNothing, JObject, JString}
import org.json4s.JsonAST.JBool.True
import org.json4s.jackson.JsonMethods
import org.json4s.native.Serialization

Expand Down Expand Up @@ -156,7 +157,11 @@ object FlintDataType {
case ArrayType(elementType, _) => serializeField(elementType, Metadata.empty)

// binary
case BinaryType => JObject("type" -> JString("binary"))
case BinaryType =>
JObject(
"type" -> JString("binary"),
"doc_values" -> True // enable doc value required by painless script filtering
)

case unknown => throw new IllegalStateException(s"unsupported data type: ${unknown.sql}")
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,8 @@ import org.apache.spark.sql.flint.datatype.FlintDataType.STRICT_DATE_OPTIONAL_TI
import org.apache.spark.sql.internal.SQLConf
import org.apache.spark.sql.types._

import scala.io.Source

/**
* Todo. find the right package.
*/
Expand Down Expand Up @@ -112,6 +114,26 @@ case class FlintQueryCompiler(schema: StructType) {
s"""{"wildcard":{"${compile(p.children()(0))}":{"value":"*${compile(
p.children()(1),
false)}"}}}"""
case "BLOOM_FILTER_MIGHT_CONTAIN" =>
val code = Source.fromResource("bloom_filter_query.txt").getLines().mkString(" ")
s"""
|{
| "bool": {
| "filter": {
| "script": {
| "script": {
| "lang": "painless",
| "source": "$code",
| "params": {
| "fieldName": "${compile(p.children()(0))}",
| "value": ${compile(p.children()(1))}
| }
| }
| }
| }
| }
|}
|""".stripMargin
case _ => ""
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,8 @@ class FlintDataTypeSuite extends FlintSuite with Matchers {
| "type": "text"
| },
| "binaryField": {
| "type": "binary"
| "type": "binary",
| "doc_values": true
| }
| }
|}""".stripMargin
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -5,8 +5,10 @@

package org.apache.spark.sql.flint.storage

import scala.io.Source

import org.apache.spark.FlintSuite
import org.apache.spark.sql.connector.expressions.{FieldReference, GeneralScalarExpression}
import org.apache.spark.sql.connector.expressions.{FieldReference, GeneralScalarExpression, LiteralValue}
import org.apache.spark.sql.connector.expressions.filter.Predicate
import org.apache.spark.sql.sources._
import org.apache.spark.sql.types._
Expand Down Expand Up @@ -137,6 +139,34 @@ class FlintQueryCompilerSuite extends FlintSuite {
assertResult("""{"exists":{"field":"aString"}}""")(query)
}

test("compile bloom_filter_might_contain(aInt, 1) successfully") {
val query =
FlintQueryCompiler(schema()).compile(
new Predicate(
"BLOOM_FILTER_MIGHT_CONTAIN",
Array(FieldReference("aInt"), LiteralValue(1, IntegerType))))

val code = Source.fromResource("bloom_filter_query.txt").getLines().mkString(" ")
assertResult(s"""
|{
| "bool": {
| "filter": {
| "script": {
| "script": {
| "lang": "painless",
| "source": "$code",
| "params": {
| "fieldName": "aInt",
| "value": 1
| }
| }
| }
| }
| }
|}
|""".stripMargin)(query)
}

protected def schema(): StructType = {
StructType(
Seq(
Expand Down

0 comments on commit 72a4f32

Please sign in to comment.