Skip to content

Commit

Permalink
Add covering index in Flint Spark API (#22)
Browse files Browse the repository at this point in the history
* Refactor index builder

Signed-off-by: Chen Dai <daichen@amazon.com>

* Add covering index builder and empty impl

Signed-off-by: Chen Dai <daichen@amazon.com>

* Implement covering index metadata

Signed-off-by: Chen Dai <daichen@amazon.com>

* Implement covering index build

Signed-off-by: Chen Dai <daichen@amazon.com>

* Refactor IT class

Signed-off-by: Chen Dai <daichen@amazon.com>

* Rename build method

Signed-off-by: Chen Dai <daichen@amazon.com>

* Add more IT

Signed-off-by: Chen Dai <daichen@amazon.com>

* Add UT for new covering index class

Signed-off-by: Chen Dai <daichen@amazon.com>

* Refactor flint index name prefix and suffix

Signed-off-by: Chen Dai <daichen@amazon.com>

* Add comment

Signed-off-by: Chen Dai <daichen@amazon.com>

* Remove parse flint index name logic

Signed-off-by: Chen Dai <daichen@amazon.com>

---------

Signed-off-by: Chen Dai <daichen@amazon.com>
  • Loading branch information
dai-chen committed Sep 18, 2023
1 parent 0faa95a commit 7434e5a
Show file tree
Hide file tree
Showing 10 changed files with 558 additions and 158 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -12,10 +12,11 @@ import org.json4s.native.JsonMethods.parse
import org.json4s.native.Serialization
import org.opensearch.flint.core.{FlintClient, FlintClientBuilder}
import org.opensearch.flint.core.metadata.FlintMetadata
import org.opensearch.flint.spark.FlintSpark._
import org.opensearch.flint.spark.FlintSpark.RefreshMode.{FULL, INCREMENTAL, RefreshMode}
import org.opensearch.flint.spark.FlintSparkIndex.ID_COLUMN
import org.opensearch.flint.spark.skipping.{FlintSparkSkippingIndex, FlintSparkSkippingStrategy}
import org.opensearch.flint.spark.covering.FlintSparkCoveringIndex
import org.opensearch.flint.spark.covering.FlintSparkCoveringIndex.COVERING_INDEX_TYPE
import org.opensearch.flint.spark.skipping.FlintSparkSkippingIndex
import org.opensearch.flint.spark.skipping.FlintSparkSkippingIndex.SKIPPING_INDEX_TYPE
import org.opensearch.flint.spark.skipping.FlintSparkSkippingStrategy.{SkippingKind, SkippingKindSerializer}
import org.opensearch.flint.spark.skipping.FlintSparkSkippingStrategy.SkippingKind.{MIN_MAX, PARTITION, VALUE_SET}
Expand All @@ -25,12 +26,10 @@ import org.opensearch.flint.spark.skipping.valueset.ValueSetSkippingStrategy

import org.apache.spark.sql.{DataFrame, SparkSession}
import org.apache.spark.sql.SaveMode._
import org.apache.spark.sql.catalog.Column
import org.apache.spark.sql.flint.FlintDataSourceV2.FLINT_DATASOURCE
import org.apache.spark.sql.flint.config.FlintSparkConf
import org.apache.spark.sql.flint.config.FlintSparkConf.{DOC_ID_COLUMN_NAME, IGNORE_DOC_ID_COLUMN}
import org.apache.spark.sql.streaming.OutputMode.Append
import org.apache.spark.sql.streaming.StreamingQuery

/**
* Flint Spark integration API entrypoint.
Expand All @@ -42,8 +41,7 @@ class FlintSpark(val spark: SparkSession) {
FlintSparkConf(
Map(
DOC_ID_COLUMN_NAME.optionKey -> ID_COLUMN,
IGNORE_DOC_ID_COLUMN.optionKey -> "true"
).asJava)
IGNORE_DOC_ID_COLUMN.optionKey -> "true").asJava)

/** Flint client for low-level index operation */
private val flintClient: FlintClient = FlintClientBuilder.build(flintSparkConf.flintOptions())
Expand All @@ -57,8 +55,18 @@ class FlintSpark(val spark: SparkSession) {
* @return
* index builder
*/
def skippingIndex(): IndexBuilder = {
new IndexBuilder(this)
def skippingIndex(): FlintSparkSkippingIndex.Builder = {
new FlintSparkSkippingIndex.Builder(this)
}

/**
* Create index builder for creating index with fluent API.
*
* @return
* index builder
*/
def coveringIndex(): FlintSparkCoveringIndex.Builder = {
new FlintSparkCoveringIndex.Builder(this)
}

/**
Expand Down Expand Up @@ -199,6 +207,7 @@ class FlintSpark(val spark: SparkSession) {
*/
private def deserialize(metadata: FlintMetadata): FlintSparkIndex = {
val meta = parse(metadata.getContent) \ "_meta"
val indexName = (meta \ "name").extract[String]
val tableName = (meta \ "source").extract[String]
val indexType = (meta \ "kind").extract[String]
val indexedColumns = (meta \ "indexedColumns").asInstanceOf[JArray]
Expand All @@ -222,6 +231,13 @@ class FlintSpark(val spark: SparkSession) {
}
}
new FlintSparkSkippingIndex(tableName, strategies)
case COVERING_INDEX_TYPE =>
new FlintSparkCoveringIndex(
indexName,
tableName,
indexedColumns.arr.map { obj =>
((obj \ "columnName").extract[String], (obj \ "columnType").extract[String])
}.toMap)
}
}
}
Expand All @@ -236,102 +252,4 @@ object FlintSpark {
type RefreshMode = Value
val FULL, INCREMENTAL = Value
}

/**
* Helper class for index class construct. For now only skipping index supported.
*/
class IndexBuilder(flint: FlintSpark) {
var tableName: String = ""
var indexedColumns: Seq[FlintSparkSkippingStrategy] = Seq()

lazy val allColumns: Map[String, Column] = {
flint.spark.catalog
.listColumns(tableName)
.collect()
.map(col => (col.name, col))
.toMap
}

/**
* Configure which source table the index is based on.
*
* @param tableName
* full table name
* @return
* index builder
*/
def onTable(tableName: String): IndexBuilder = {
this.tableName = tableName
this
}

/**
* Add partition skipping indexed columns.
*
* @param colNames
* indexed column names
* @return
* index builder
*/
def addPartitions(colNames: String*): IndexBuilder = {
require(tableName.nonEmpty, "table name cannot be empty")

colNames
.map(findColumn)
.map(col => PartitionSkippingStrategy(columnName = col.name, columnType = col.dataType))
.foreach(addIndexedColumn)
this
}

/**
* Add value set skipping indexed column.
*
* @param colName
* indexed column name
* @return
* index builder
*/
def addValueSet(colName: String): IndexBuilder = {
require(tableName.nonEmpty, "table name cannot be empty")

val col = findColumn(colName)
addIndexedColumn(ValueSetSkippingStrategy(columnName = col.name, columnType = col.dataType))
this
}

/**
* Add min max skipping indexed column.
*
* @param colName
* indexed column name
* @return
* index builder
*/
def addMinMax(colName: String): IndexBuilder = {
val col = findColumn(colName)
indexedColumns =
indexedColumns :+ MinMaxSkippingStrategy(columnName = col.name, columnType = col.dataType)
this
}

/**
* Create index.
*/
def create(): Unit = {
flint.createIndex(new FlintSparkSkippingIndex(tableName, indexedColumns))
}

private def findColumn(colName: String): Column =
allColumns.getOrElse(
colName,
throw new IllegalArgumentException(s"Column $colName does not exist"))

private def addIndexedColumn(indexedCol: FlintSparkSkippingStrategy): Unit = {
require(
indexedColumns.forall(_.columnName != indexedCol.columnName),
s"${indexedCol.columnName} is already indexed")

indexedColumns = indexedColumns :+ indexedCol
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -49,4 +49,15 @@ object FlintSparkIndex {
* ID column name.
*/
val ID_COLUMN: String = "__id__"

/**
* Common prefix of Flint index name which is "flint_database_table_"
*
* @param fullTableName
* source full table name
* @return
* Flint index name
*/
def flintIndexNamePrefix(fullTableName: String): String =
s"flint_${fullTableName.replace(".", "_")}_"
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,46 @@
/*
* Copyright OpenSearch Contributors
* SPDX-License-Identifier: Apache-2.0
*/

package org.opensearch.flint.spark

import org.apache.spark.sql.catalog.Column

/**
* Flint Spark index builder base class.
*
* @param flint
* Flint Spark API entrypoint
*/
abstract class FlintSparkIndexBuilder(flint: FlintSpark) {

/** Source table name */
protected var tableName: String = ""

/** All columns of the given source table */
lazy protected val allColumns: Map[String, Column] = {
require(tableName.nonEmpty, "Source table name is not provided")

flint.spark.catalog
.listColumns(tableName)
.collect()
.map(col => (col.name, col))
.toMap
}

/**
* Create Flint index.
*/
def create(): Unit = flint.createIndex(buildIndex())

/**
* Build method for concrete builder class to implement
*/
protected def buildIndex(): FlintSparkIndex

protected def findColumn(colName: String): Column =
allColumns.getOrElse(
colName,
throw new IllegalArgumentException(s"Column $colName does not exist"))
}
Loading

0 comments on commit 7434e5a

Please sign in to comment.