Skip to content

Commit

Permalink
Restrict the maximum size of value set by default limit (#208)
Browse files Browse the repository at this point in the history
* Restrict the maximum size of collect set output

Signed-off-by: Chen Dai <daichen@amazon.com>

* Add IT and change default limit to 2 temporarily for test

Signed-off-by: Chen Dai <daichen@amazon.com>

* Fix broken test and change query rewriter

Signed-off-by: Chen Dai <daichen@amazon.com>

* Restore default limit to 100

Signed-off-by: Chen Dai <daichen@amazon.com>

* Prepare for review

Signed-off-by: Chen Dai <daichen@amazon.com>

* Refactor IT and add user manual

Signed-off-by: Chen Dai <daichen@amazon.com>

---------

Signed-off-by: Chen Dai <daichen@amazon.com>
  • Loading branch information
dai-chen committed Jan 5, 2024
1 parent 804b3aa commit c72d773
Show file tree
Hide file tree
Showing 10 changed files with 291 additions and 305 deletions.
2 changes: 2 additions & 0 deletions docs/index.md
Original file line number Diff line number Diff line change
Expand Up @@ -124,6 +124,8 @@ High level API is dependent on query engine implementation. Please see Query Eng

#### Skipping Index

The default maximum size for the value set is 100. In cases where a file contains columns with high cardinality values, the value set will become null. This is the trade-off that prevents excessive memory consumption at the cost of not skipping the file.

```sql
CREATE SKIPPING INDEX [IF NOT EXISTS]
ON <object>
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -74,7 +74,7 @@ case class FlintSparkSkippingIndex(
// Wrap aggregate function with output column name
val namedAggFuncs =
(outputNames, aggFuncs).zipped.map { case (name, aggFunc) =>
new Column(aggFunc.toAggregateExpression().as(name))
new Column(aggFunc.as(name))
}

df.getOrElse(spark.read.table(tableName))
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,6 @@ import org.json4s.JsonAST.JString
import org.opensearch.flint.spark.skipping.FlintSparkSkippingStrategy.SkippingKind.SkippingKind

import org.apache.spark.sql.catalyst.expressions.Expression
import org.apache.spark.sql.catalyst.expressions.aggregate.AggregateFunction

/**
* Skipping index strategy that defines skipping data structure building and reading logic.
Expand Down Expand Up @@ -42,7 +41,7 @@ trait FlintSparkSkippingStrategy {
* @return
* aggregators that generate skipping data structure
*/
def getAggregators: Seq[AggregateFunction]
def getAggregators: Seq[Expression]

/**
* Rewrite a filtering condition on source table into a new predicate on index data based on
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@ import org.opensearch.flint.spark.skipping.FlintSparkSkippingStrategy
import org.opensearch.flint.spark.skipping.FlintSparkSkippingStrategy.SkippingKind.{MIN_MAX, SkippingKind}

import org.apache.spark.sql.catalyst.expressions.{And, AttributeReference, EqualTo, Expression, GreaterThan, GreaterThanOrEqual, In, LessThan, LessThanOrEqual, Literal}
import org.apache.spark.sql.catalyst.expressions.aggregate.{AggregateFunction, Max, Min}
import org.apache.spark.sql.catalyst.expressions.aggregate.{Max, Min}
import org.apache.spark.sql.catalyst.util.TypeUtils
import org.apache.spark.sql.functions.col

Expand All @@ -29,8 +29,11 @@ case class MinMaxSkippingStrategy(
override def outputSchema(): Map[String, String] =
Map(minColName -> columnType, maxColName -> columnType)

override def getAggregators: Seq[AggregateFunction] =
Seq(Min(col(columnName).expr), Max(col(columnName).expr))
override def getAggregators: Seq[Expression] = {
Seq(
Min(col(columnName).expr).toAggregateExpression(),
Max(col(columnName).expr).toAggregateExpression())
}

override def rewritePredicate(predicate: Expression): Option[Expression] =
predicate match {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@ import org.opensearch.flint.spark.skipping.FlintSparkSkippingStrategy
import org.opensearch.flint.spark.skipping.FlintSparkSkippingStrategy.SkippingKind.{PARTITION, SkippingKind}

import org.apache.spark.sql.catalyst.expressions.{AttributeReference, EqualTo, Expression, Literal}
import org.apache.spark.sql.catalyst.expressions.aggregate.{AggregateFunction, First}
import org.apache.spark.sql.catalyst.expressions.aggregate.First
import org.apache.spark.sql.functions.col

/**
Expand All @@ -25,8 +25,8 @@ case class PartitionSkippingStrategy(
Map(columnName -> columnType)
}

override def getAggregators: Seq[AggregateFunction] = {
Seq(First(col(columnName).expr, ignoreNulls = true))
override def getAggregators: Seq[Expression] = {
Seq(First(col(columnName).expr, ignoreNulls = true).toAggregateExpression())
}

override def rewritePredicate(predicate: Expression): Option[Expression] =
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -7,10 +7,10 @@ package org.opensearch.flint.spark.skipping.valueset

import org.opensearch.flint.spark.skipping.FlintSparkSkippingStrategy
import org.opensearch.flint.spark.skipping.FlintSparkSkippingStrategy.SkippingKind.{SkippingKind, VALUE_SET}
import org.opensearch.flint.spark.skipping.valueset.ValueSetSkippingStrategy.DEFAULT_VALUE_SET_SIZE_LIMIT

import org.apache.spark.sql.catalyst.expressions.{AttributeReference, EqualTo, Expression, Literal}
import org.apache.spark.sql.catalyst.expressions.aggregate.{AggregateFunction, CollectSet}
import org.apache.spark.sql.functions.col
import org.apache.spark.sql.functions._

/**
* Skipping strategy based on unique column value set.
Expand All @@ -24,8 +24,14 @@ case class ValueSetSkippingStrategy(
override def outputSchema(): Map[String, String] =
Map(columnName -> columnType)

override def getAggregators: Seq[AggregateFunction] =
Seq(CollectSet(col(columnName).expr))
override def getAggregators: Seq[Expression] = {
val limit = DEFAULT_VALUE_SET_SIZE_LIMIT
val collectSet = collect_set(columnName)
val aggregator =
when(size(collectSet) > limit, lit(null))
.otherwise(collectSet)
Seq(aggregator.expr)
}

override def rewritePredicate(predicate: Expression): Option[Expression] =
/*
Expand All @@ -34,7 +40,16 @@ case class ValueSetSkippingStrategy(
*/
predicate match {
case EqualTo(AttributeReference(`columnName`, _, _, _), value: Literal) =>
Some((col(columnName) === value).expr)
// Value set maybe null due to maximum size limit restriction
Some((isnull(col(columnName)) || col(columnName) === value).expr)
case _ => None
}
}

object ValueSetSkippingStrategy {

/**
* Default limit for value set size collected. TODO: make this val once it's configurable
*/
var DEFAULT_VALUE_SET_SIZE_LIMIT = 100
}
Loading

0 comments on commit c72d773

Please sign in to comment.