Skip to content

Commit

Permalink
[FLINK-35816][table-planner] Non-mergeable proctime tvf window aggreg…
Browse files Browse the repository at this point in the history
…ate needs to fallback to group aggregate

This closes #25082
  • Loading branch information
lincoln-lil committed Jul 13, 2024
1 parent 9081036 commit 41a1409
Show file tree
Hide file tree
Showing 4 changed files with 808 additions and 29 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -17,8 +17,9 @@
*/
package org.apache.flink.table.planner.plan.nodes.physical.stream

import org.apache.flink.table.api.TableException
import org.apache.flink.table.planner.calcite.FlinkTypeFactory
import org.apache.flink.table.planner.plan.logical.WindowingStrategy
import org.apache.flink.table.planner.plan.logical.{WindowAttachedWindowingStrategy, WindowingStrategy}
import org.apache.flink.table.planner.plan.nodes.exec.{ExecNode, InputProperty}
import org.apache.flink.table.planner.plan.nodes.exec.stream.StreamExecWindowAggregate
import org.apache.flink.table.planner.plan.utils._
Expand Down Expand Up @@ -112,6 +113,12 @@ class StreamPhysicalWindowAggregate(

override def translateToExecNode(): ExecNode[_] = {
checkEmitConfiguration(unwrapTableConfig(this))

if (windowing.isInstanceOf[WindowAttachedWindowingStrategy] && windowing.isProctime) {
throw new TableException(
"Non-mergeable processing time window tvf aggregation is invalid, should fallback to group " +
"aggregation instead. This is a bug and should not happen. Please file an issue.")
}
new StreamExecWindowAggregate(
unwrapTableConfig(this),
grouping,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ import org.apache.flink.table.planner.functions.sql.{FlinkSqlOperatorTable, SqlW
import org.apache.flink.table.planner.plan.`trait`.RelWindowProperties
import org.apache.flink.table.planner.plan.logical._
import org.apache.flink.table.planner.plan.metadata.FlinkRelMetadataQuery
import org.apache.flink.table.planner.plan.nodes.logical.{FlinkLogicalAggregate, FlinkLogicalJoin, FlinkLogicalRank, FlinkLogicalTableFunctionScan}
import org.apache.flink.table.planner.plan.nodes.logical.{FlinkLogicalAggregate, FlinkLogicalMatch, FlinkLogicalOverAggregate, FlinkLogicalRank, FlinkLogicalTableFunctionScan}
import org.apache.flink.table.planner.plan.utils.AggregateUtil.inferAggAccumulatorNames
import org.apache.flink.table.planner.plan.utils.WindowEmitStrategy.{TABLE_EXEC_EMIT_EARLY_FIRE_ENABLED, TABLE_EXEC_EMIT_LATE_FIRE_ENABLED}
import org.apache.flink.table.planner.typeutils.RowTypeUtils
Expand All @@ -35,7 +35,7 @@ import org.apache.flink.table.types.logical.utils.LogicalTypeChecks.canBeTimeAtt

import org.apache.calcite.plan.volcano.RelSubset
import org.apache.calcite.rel.`type`.RelDataType
import org.apache.calcite.rel.{RelNode, SingleRel}
import org.apache.calcite.rel.{BiRel, RelNode, RelVisitor}
import org.apache.calcite.rel.core._
import org.apache.calcite.rex._
import org.apache.calcite.sql.`type`.SqlTypeFamily
Expand All @@ -45,10 +45,9 @@ import org.apache.calcite.util.{ImmutableBitSet, Util}
import java.time.Duration
import java.util.Collections

import scala.annotation.tailrec
import scala.collection.JavaConversions._
import scala.collection.mutable
import scala.collection.mutable.ArrayBuffer
import scala.collection.mutable.{ArrayBuffer, ListBuffer}

/** Utilities for window table-valued functions. */
object WindowUtil {
Expand Down Expand Up @@ -333,7 +332,8 @@ object WindowUtil {

/**
* For rowtime window, return true if the given aggregate grouping contains window start and end.
* For proctime window, we should also check if it exists a neighbour windowTableFunctionCall.
* For proctime window, we should also check if it exists a neighbour windowTableFunctionCall and
* doesn't exist any [[RexCall]] on window time columns.
*
* If the window is a session window, we should also check if the partition keys are the same as
* the group keys. See more at [[WindowUtil.validGroupKeyPartitionKey()]].
Expand All @@ -346,7 +346,7 @@ object WindowUtil {
return false
}
if (WindowUtil.groupingContainsWindowStartEnd(grouping, windowProperties)) {
windowProperties.isRowtime || existNeighbourWindowTableFunc(agg.getInput)
isValidRowtimeWindow(windowProperties) || isValidProcTimeWindow(windowProperties, fmq, agg)
} else {
false
}
Expand Down Expand Up @@ -385,35 +385,94 @@ object WindowUtil {
}
}

private def existNeighbourWindowTableFunc(rel: RelNode): Boolean = {
private def isValidRowtimeWindow(windowProperties: RelWindowProperties): Boolean = {
// rowtime tvf window can support calculation on window columns even before aggregation
windowProperties.isRowtime
}

@tailrec
def find(rel: RelNode): Unit = {
rel match {
case rss: RelSubset =>
val innerRel = Option.apply(rss.getBest).getOrElse(rss.getOriginal)
find(innerRel)
/**
* If the middle Calc(s) contains call(s) on window columns, we should not convert the Aggregate
* into WindowAggregate but GroupAggregate instead.
*
* The valid plan structure is like:
*
* {{{
* Aggregate
* |
* Calc (should not contain call on window columns)
* |
* WindowTableFunctionScan
* }}}
*
* and unlike:
*
* {{{
* Aggregate
* |
* Calc
* |
* Aggregate
* |
* Calc
* |
* WindowTableFunctionScan
* }}}
*/
private def isValidProcTimeWindow(
windowProperties: RelWindowProperties,
fmq: FlinkRelMetadataQuery,
agg: FlinkLogicalAggregate): Boolean = {
val calcMatcher = new CalcWindowFunctionScanMatcher
try {
calcMatcher.go(agg.getInput(0))
} catch {
case _: Throwable => // do nothing
}
if (!calcMatcher.existNeighbourWindowTableFunc) {
return false
}
var existCallOnWindowColumns = calcMatcher.calcNodes.nonEmpty &&
calcMatcher.calcNodes.exists(calc => calcContainsCallsOnWindowColumns(calc, fmq))

// aggregate call shouldn't be on window columns
val aggInputWindowProps = windowProperties.getWindowColumns
existCallOnWindowColumns = existCallOnWindowColumns || !agg.getAggCallList.forall {
call => aggInputWindowProps.intersect(ImmutableBitSet.of(call.getArgList)).isEmpty
}
// proctime tvf window can't support calculation on window columns before aggregation
!existCallOnWindowColumns
}

private class CalcWindowFunctionScanMatcher extends RelVisitor {
val calcNodes: ListBuffer[Calc] = ListBuffer[Calc]()
var existNeighbourWindowTableFunc = false

override def visit(node: RelNode, ordinal: Int, parent: RelNode): Unit = {
node match {
case calc: Calc =>
calcNodes += calc
// continue to visit children
super.visit(calc, 0, parent)
case scan: FlinkLogicalTableFunctionScan =>
if (WindowUtil.isWindowTableFunctionCall(scan.getCall)) {
existNeighbourWindowTableFunc = true
// stop visiting
throw new Util.FoundOne
}
find(scan.getInput(0))

// proctime attribute comes from these operators can not be used directly for proctime
// window aggregate, so further traversal of child nodes is unnecessary
case _: FlinkLogicalAggregate | _: FlinkLogicalRank | _: FlinkLogicalJoin =>

case sr: SingleRel => find(sr.getInput)
case rss: RelSubset =>
val innerRel = Option.apply(rss.getBest).getOrElse(rss.getOriginal)
// special case doesn't call super.visit for RelSubSet because it has no children
visit(innerRel, 0, rss)
case _: FlinkLogicalAggregate | _: FlinkLogicalMatch | _: FlinkLogicalOverAggregate |
_: FlinkLogicalRank | _: BiRel | _: SetOp =>
// proctime attribute comes from these operators can't be used directly for proctime
// window aggregate, so further tree walk is unnecessary
throw new Util.FoundOne
case _ =>
// continue to visit children
super.visit(node, ordinal, parent)
}
}

try {
find(rel)
} catch {
case _: Util.FoundOne => return true
}
false
}

/**
Expand Down
Loading

0 comments on commit 41a1409

Please sign in to comment.