Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Query rewrite for partition skipping index #1690

1 change: 1 addition & 0 deletions flint/docs/index.md
Original file line number Diff line number Diff line change
Expand Up @@ -198,6 +198,7 @@ In the index mapping, the `_meta` and `properties`field stores meta and schema i
- `spark.datasource.flint.write.refresh_policy`: default value is false. valid values [NONE(false),
IMMEDIATE(true), WAIT_UNTIL(wait_for)]
- `spark.datasource.flint.read.scroll_size`: default value is 100.
- `spark.flint.optimizer.enabled`: default is true.

#### API

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -82,6 +82,10 @@ object FlintSparkConf {
val SCROLL_SIZE = FlintConfig("read.scroll_size")
.doc("scroll read size")
.createWithDefault("100")

val OPTIMIZER_RULE_ENABLED = FlintConfig("spark.flint.optimizer.enabled")
.doc("Enable Flint optimizer rule for query rewrite with Flint index")
.createWithDefault("true")
}

class FlintSparkConf(properties: JMap[String, String]) extends Serializable {
Expand All @@ -97,6 +101,8 @@ class FlintSparkConf(properties: JMap[String, String]) extends Serializable {
else throw new NoSuchElementException("index or path not found")
}

def isOptimizerEnabled: Boolean = OPTIMIZER_RULE_ENABLED.readFrom(reader).toBoolean

/**
* Helper class, create {@link FlintOptions}.
*/
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -5,9 +5,18 @@

package org.opensearch.flint.spark

import org.opensearch.flint.spark.skipping.ApplyFlintSparkSkippingIndex

import org.apache.spark.sql.SparkSessionExtensions

/**
* Flint Spark extension entrypoint.
*/
class FlintSparkExtensions extends (SparkSessionExtensions => Unit) {

override def apply(v1: SparkSessionExtensions): Unit = {}
override def apply(extensions: SparkSessionExtensions): Unit = {
extensions.injectOptimizerRule { spark =>
new FlintSparkOptimizer(spark)
}
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,42 @@
/*
* Copyright OpenSearch Contributors
* SPDX-License-Identifier: Apache-2.0
*/

package org.opensearch.flint.spark

import scala.collection.JavaConverters._

import org.opensearch.flint.spark.skipping.ApplyFlintSparkSkippingIndex

import org.apache.spark.sql.SparkSession
import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan
import org.apache.spark.sql.catalyst.rules.Rule
import org.apache.spark.sql.flint.config.FlintSparkConf

/**
* Flint Spark optimizer that manages all Flint related optimizer rule.
* @param spark
* Spark session
*/
class FlintSparkOptimizer(spark: SparkSession) extends Rule[LogicalPlan] {

/** Flint Spark API */
private val flint: FlintSpark = new FlintSpark(spark)

/** Only one Flint optimizer rule for now. Need to estimate cost if more than one in future. */
private val rule = new ApplyFlintSparkSkippingIndex(flint)

override def apply(plan: LogicalPlan): LogicalPlan = {
if (isOptimizerEnabled) {
rule.apply(plan)
} else {
plan
}
}

private def isOptimizerEnabled: Boolean = {
val flintConf = new FlintSparkConf(spark.conf.getAll.asJava)
flintConf.isOptimizerEnabled
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,85 @@
/*
* Copyright OpenSearch Contributors
* SPDX-License-Identifier: Apache-2.0
*/

package org.opensearch.flint.spark.skipping

import org.opensearch.flint.spark.FlintSpark
import org.opensearch.flint.spark.skipping.FlintSparkSkippingIndex.{getSkippingIndexName, FILE_PATH_COLUMN, SKIPPING_INDEX_TYPE}

import org.apache.spark.sql.{Column, DataFrame}
import org.apache.spark.sql.catalyst.expressions.{And, Predicate}
import org.apache.spark.sql.catalyst.plans.logical.{Filter, LogicalPlan}
import org.apache.spark.sql.catalyst.rules.Rule
import org.apache.spark.sql.execution.datasources.{HadoopFsRelation, LogicalRelation}
import org.apache.spark.sql.flint.FlintDataSourceV2.FLINT_DATASOURCE

/**
* Flint Spark skipping index apply rule that rewrites applicable query's filtering condition and
* table scan operator to leverage additional skipping data structure and accelerate query by
* reducing data scanned significantly.
*
* @param flint
* Flint Spark API
*/
class ApplyFlintSparkSkippingIndex(flint: FlintSpark) extends Rule[LogicalPlan] {

override def apply(plan: LogicalPlan): LogicalPlan = plan transform {
case filter @ Filter( // TODO: abstract pattern match logic for different table support
condition: Predicate,
relation @ LogicalRelation(
baseRelation @ HadoopFsRelation(location, _, _, _, _, _),
dai-chen marked this conversation as resolved.
Show resolved Hide resolved
_,
Some(table),
false)) if !location.isInstanceOf[FlintSparkSkippingFileIndex] =>

val indexName = getSkippingIndexName(table.identifier.table) // TODO: database name
val index = flint.describeIndex(indexName)
if (index.exists(_.kind == SKIPPING_INDEX_TYPE)) {
val skippingIndex = index.get.asInstanceOf[FlintSparkSkippingIndex]
val indexPred = rewriteToIndexPredicate(skippingIndex, condition)

/*
* Replace original file index with Flint skipping file index:
* Filter(a=b)
* |- LogicalRelation(A)
* |- HadoopFsRelation
* |- FileIndex <== replaced with FlintSkippingFileIndex
*/
if (indexPred.isDefined) {
val filterByIndex = buildFilterIndexQuery(skippingIndex, indexPred.get)
val fileIndex = new FlintSparkSkippingFileIndex(location, filterByIndex)
val indexRelation = baseRelation.copy(location = fileIndex)(baseRelation.sparkSession)
filter.copy(child = relation.copy(relation = indexRelation))
} else {
filter
dai-chen marked this conversation as resolved.
Show resolved Hide resolved
}
} else {
filter
}
}

private def rewriteToIndexPredicate(
index: FlintSparkSkippingIndex,
condition: Predicate): Option[Predicate] = {

// TODO: currently only handle conjunction, namely the given condition is consist of
// one or more expression concatenated by AND only.
index.indexedColumns
dai-chen marked this conversation as resolved.
Show resolved Hide resolved
.flatMap(index => index.rewritePredicate(condition))
.reduceOption(And(_, _))
}

private def buildFilterIndexQuery(
index: FlintSparkSkippingIndex,
rewrittenPredicate: Predicate): DataFrame = {

// Get file list based on the rewritten predicates on index data
flint.spark.read
.format(FLINT_DATASOURCE)
.load(index.name())
.filter(new Column(rewrittenPredicate))
.select(FILE_PATH_COLUMN)
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,55 @@
/*
* Copyright OpenSearch Contributors
* SPDX-License-Identifier: Apache-2.0
*/

package org.opensearch.flint.spark.skipping

import org.apache.hadoop.fs.{FileStatus, Path}

import org.apache.spark.sql.DataFrame
import org.apache.spark.sql.catalyst.expressions.Expression
import org.apache.spark.sql.execution.datasources.{FileIndex, PartitionDirectory}
import org.apache.spark.sql.types.StructType

/**
* File index that skips source files based on the selected files by Flint skipping index.
*
* @param baseFileIndex
* original file index
* @param filterByIndex
* pushed down filtering on index data
*/
class FlintSparkSkippingFileIndex(baseFileIndex: FileIndex, filterByIndex: DataFrame)
extends FileIndex {

override def listFiles(
partitionFilters: Seq[Expression],
dataFilters: Seq[Expression]): Seq[PartitionDirectory] = {

val selectedFiles =
filterByIndex.collect
.map(_.getString(0))
.toSet

// TODO: figure out if list file call can be avoided
val partitions = baseFileIndex.listFiles(partitionFilters, dataFilters)
partitions
.map(p => p.copy(files = p.files.filter(f => isFileNotSkipped(selectedFiles, f))))
.filter(p => p.files.nonEmpty)
}

override def rootPaths: Seq[Path] = baseFileIndex.rootPaths

override def inputFiles: Array[String] = baseFileIndex.inputFiles

override def refresh(): Unit = baseFileIndex.refresh()

override def sizeInBytes: Long = baseFileIndex.sizeInBytes

override def partitionSchema: StructType = baseFileIndex.partitionSchema

private def isFileNotSkipped(selectedFiles: Set[String], f: FileStatus) = {
selectedFiles.contains(f.getPath.toUri.toString)
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -6,22 +6,27 @@
package org.opensearch.flint.spark.skipping

import org.json4s._
import org.json4s.native.JsonMethods._
import org.json4s.native.Serialization
import org.opensearch.flint.core.metadata.FlintMetadata
import org.opensearch.flint.spark.FlintSparkIndex
import org.opensearch.flint.spark.skipping.FlintSparkSkippingIndex.{getSkippingIndexName, FILE_PATH_COLUMN, SKIPPING_INDEX_TYPE}

import org.apache.spark.sql.{Column, DataFrame}
import org.apache.spark.sql.catalyst.dsl.expressions.DslExpression
import org.apache.spark.sql.flint.datatype.FlintDataType
import org.apache.spark.sql.functions.input_file_name
import org.apache.spark.sql.types.{DataType, StringType, StructField, StructType}

/**
* Flint skipping index in Spark.
*
* @param tableName
* source table name
*/
class FlintSparkSkippingIndex(tableName: String, indexedColumns: Seq[FlintSparkSkippingStrategy])
class FlintSparkSkippingIndex(
tableName: String,
val indexedColumns: Seq[FlintSparkSkippingStrategy])
extends FlintSparkIndex {

/** Required by json4s write function */
Expand All @@ -30,15 +35,6 @@ class FlintSparkSkippingIndex(tableName: String, indexedColumns: Seq[FlintSparkS
/** Skipping index type */
override val kind: String = SKIPPING_INDEX_TYPE

/** Output schema of the skipping index */
private val outputSchema: Map[String, String] = {
val schema = indexedColumns
.flatMap(_.outputSchema().toList)
.toMap

schema + (FILE_PATH_COLUMN -> "keyword")
}

override def name(): String = {
getSkippingIndexName(tableName)
}
Expand Down Expand Up @@ -74,9 +70,20 @@ class FlintSparkSkippingIndex(tableName: String, indexedColumns: Seq[FlintSparkS
}

private def getSchema: String = {
Serialization.write(outputSchema.map { case (colName, colType) =>
colName -> ("type" -> colType)
})
val indexFieldTypes = indexedColumns.map { indexCol =>
val columnName = indexCol.columnName
// Data type INT from catalog is not recognized by Spark DataType.fromJson()
val columnType = if (indexCol.columnType == "int") "integer" else indexCol.columnType
val sparkType = DataType.fromJson("\"" + columnType + "\"")
StructField(columnName, sparkType, nullable = false)
}

val allFieldTypes =
indexFieldTypes :+ StructField(FILE_PATH_COLUMN, StringType, nullable = false)

// Convert StructType to {"properties": ...} and only need the properties value
val properties = FlintDataType.serialize(StructType(allFieldTypes))
compact(render(parse(properties) \ "properties"))
}
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@

package org.opensearch.flint.spark.skipping

import org.apache.spark.sql.catalyst.expressions.Predicate
import org.apache.spark.sql.catalyst.expressions.aggregate.AggregateFunction

/**
Expand All @@ -18,9 +19,14 @@ trait FlintSparkSkippingStrategy {
val kind: String

/**
* Indexed column name and its Spark SQL type.
* Indexed column name.
*/
val columnName: String

/**
* Indexed column Spark SQL type.
*/
@transient
val columnType: String

/**
Expand All @@ -34,4 +40,15 @@ trait FlintSparkSkippingStrategy {
* aggregators that generate skipping data structure
*/
def getAggregators: Seq[AggregateFunction]

/**
* Rewrite a filtering condition on source table into a new predicate on index data based on
* current skipping strategy.
*
* @param predicate
* filtering condition on source table
* @return
* new filtering condition on index data or empty if index not applicable
*/
def rewritePredicate(predicate: Predicate): Option[Predicate]
}
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,8 @@ package org.opensearch.flint.spark.skipping.partition
import org.opensearch.flint.spark.skipping.FlintSparkSkippingStrategy

import org.apache.spark.sql.Column
import org.apache.spark.sql.catalyst.analysis.UnresolvedAttribute
import org.apache.spark.sql.catalyst.expressions.{AttributeReference, EqualTo, Literal, Predicate}
import org.apache.spark.sql.catalyst.expressions.aggregate.{AggregateFunction, First}

/**
Expand All @@ -20,18 +22,18 @@ class PartitionSkippingStrategy(
extends FlintSparkSkippingStrategy {

override def outputSchema(): Map[String, String] = {
Map(columnName -> convertToFlintType(columnType))
Map(columnName -> columnType)
}

override def getAggregators: Seq[AggregateFunction] = {
Seq(First(new Column(columnName).expr, ignoreNulls = true))
}

// TODO: move this mapping info to single place
private def convertToFlintType(colType: String): String = {
colType match {
case "string" => "keyword"
case "int" => "integer"
}
override def rewritePredicate(predicate: Predicate): Option[Predicate] = {
// Column has same name in index data, so just rewrite to the same equation
predicate.collect {
case EqualTo(AttributeReference(`columnName`, _, _, _), value: Literal) =>
EqualTo(UnresolvedAttribute(columnName), value)
}.headOption
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,8 @@

package org.apache.spark

import org.opensearch.flint.spark.FlintSparkExtensions

import org.apache.spark.sql.catalyst.expressions.CodegenObjectFactoryMode
import org.apache.spark.sql.catalyst.optimizer.ConvertToLocalRelation
import org.apache.spark.sql.flint.config.{FlintConfigEntry, FlintSparkConf}
Expand All @@ -22,6 +24,7 @@ trait FlintSuite extends SharedSparkSession {
// this rule may potentially block testing of other optimization rules such as
// ConstantPropagation etc.
.set(SQLConf.OPTIMIZER_EXCLUDED_RULES.key, ConvertToLocalRelation.ruleName)
.set("spark.sql.extensions", classOf[FlintSparkExtensions].getName)
conf
}

Expand Down
Loading