Skip to content

Commit

Permalink
Query: Translate aggregate over grouping element in separate pass
Browse files Browse the repository at this point in the history
Design:
- Introduce `SqlEnumerableExpression` - a holder class which indicates the `SqlExpression` is in form of a enumerable (or group) coming as a result of whole table selection or a grouping element. It also stores details about if `Distinct` is applied over grouping or if there are any orderings.
- Due to above `DistinctExpression` has been removed. The token while used to denote `Distinct` over grouping element were not valid in other parts of SQL tree hence it makes more sense to combine it with `SqlEnumerableExpression`.
- To support dual pass, `GroupByShaperExpression` contains 2 forms of grouping element. One element selector form which correlates directly with the parent grouped query, second subquery form which correlates to parent grouped query through a correlation predicate. Element selector is first used to translate aggregation. If that fails we use subquery form to translate as a subquery. Due to 2 forms of same component, GroupByShaperExpression disallows calling into VisitChildren method, any visitor which is visiting a tree containing GroupByShaperExpression (which appears only in `QueryExpression.ShaperExpression` or LINQ expression after remapping but before translation) must intercept the tree and either ignore or process it appropriately.
- An internal visitor (`GroupByAggregateChainProcessor`) inside SqlTranslator visits and process chain of queryable operations on a grouping element before aggregate is called and condense it into `SqlEnumerableExpression` which is then passed to method which translates aggregate. This visitor only processes Where/Distinct/Select for now. Future PR will add processing for OrderBy/ThenBy(Descending) operations to generate orderings.
- Side-effect above is that joins expanded over the grouping element (due to navigations used on aggregate chain), doesn't translate to aggregate anymore since we need to translate the join on parent query, remove the translated join if the chain didn't end in aggregate and also de-dupe same joins. Filing issue to improve this in future. Due to fragile nature of matching to lift the join, we shouldn't try to lift joins.
- To support custom aggregate operations, we will either reused `IMethodCallTranslator` or create a parallel structure for aggregate methods and call into it from SqlTranslator by passing translated SqlEnumerableExpression as appropriate.
- For complex grouping key, we cause a pushdown so that we can reference the grouping key through columns only. This allows us to reference the grouping key in correlation predicate for subquery without generating invalid SQL in many cases.
- With complex grouping key converting to columns, now we are able to correctly generate identifiers for grouping queries which makes more queries with correlated collections (where either parent or inner both queries can be groupby query) translatable.
- Erase client projection when applying aggregate operation over GroupBy result.
- When processing result selector in GroupBy use the updated key selector if the select expression was pushed down.

Resolves #27132
Resolves #27266
Resolves #27433
Resolves #23601
Resolves #27721
Resolves #27796
Resolves #27801
Resolves #19683

Relates to #22957
  • Loading branch information
smitpatel committed May 4, 2022
1 parent 7ad294c commit 9b181c9
Show file tree
Hide file tree
Showing 40 changed files with 1,659 additions and 1,011 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -600,6 +600,7 @@ public virtual GroupByShaperExpression ApplyGrouping(

return new GroupByShaperExpression(
groupingKey,
shaperExpression,
new ShapedQueryExpression(
clonedInMemoryQueryExpression,
new QueryExpressionReplacingExpressionVisitor(this, clonedInMemoryQueryExpression).Visit(shaperExpression)));
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1192,6 +1192,7 @@ protected override Expression VisitMethodCall(MethodCallExpression methodCallExp
protected override Expression VisitExtension(Expression extensionExpression)
=> extensionExpression is EntityShaperExpression
|| extensionExpression is ShapedQueryExpression
|| extensionExpression is GroupByShaperExpression
? extensionExpression
: base.VisitExtension(extensionExpression);

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -146,63 +146,58 @@ public virtual Expression Translate(SelectExpression selectExpression, Expressio
new ProjectionBindingExpression(_selectExpression, _clientProjections.Count - 1, expression.Type),
materializeCollectionNavigationExpression.Navigation,
materializeCollectionNavigationExpression.Navigation.ClrType.GetSequenceType());
}

var translation = _sqlTranslator.Translate(expression);
if (translation != null)
{
return AddClientProjection(translation, expression.Type.MakeNullable());
}

case MethodCallExpression methodCallExpression:
if (methodCallExpression.Method.IsGenericMethod
if (expression is MethodCallExpression methodCallExpression)
{
if (methodCallExpression.Method.IsGenericMethod
&& methodCallExpression.Method.DeclaringType == typeof(Enumerable)
&& methodCallExpression.Method.Name == nameof(Enumerable.ToList)
&& methodCallExpression.Arguments.Count == 1
&& methodCallExpression.Arguments[0].Type.TryGetElementType(typeof(IQueryable<>)) != null)
{
var subquery = _queryableMethodTranslatingExpressionVisitor.TranslateSubquery(
methodCallExpression.Arguments[0]);
if (subquery != null)
{
var subquery = _queryableMethodTranslatingExpressionVisitor.TranslateSubquery(
methodCallExpression.Arguments[0]);
if (subquery != null)
{
_clientProjections!.Add(subquery);
// expression.Type here will be List<T>
return new CollectionResultExpression(
new ProjectionBindingExpression(_selectExpression, _clientProjections.Count - 1, expression.Type),
navigation: null,
methodCallExpression.Method.GetGenericArguments()[0]);
}
_clientProjections!.Add(subquery);
// expression.Type here will be List<T>
return new CollectionResultExpression(
new ProjectionBindingExpression(_selectExpression, _clientProjections.Count - 1, expression.Type),
navigation: null,
methodCallExpression.Method.GetGenericArguments()[0]);
}
else
}
else
{
var subquery = _queryableMethodTranslatingExpressionVisitor.TranslateSubquery(methodCallExpression);
if (subquery != null)
{
var subquery = _queryableMethodTranslatingExpressionVisitor.TranslateSubquery(methodCallExpression);
if (subquery != null)
_clientProjections!.Add(subquery);
var type = expression.Type;
if (type.IsGenericType
&& type.GetGenericTypeDefinition() == typeof(IQueryable<>))
{
// This simplifies the check when subquery is translated and can be lifted as scalar.
var scalarTranslation = _sqlTranslator.Translate(subquery);
if (scalarTranslation != null)
{
return AddClientProjection(scalarTranslation, expression.Type.MakeNullable());
}

_clientProjections!.Add(subquery);
var type = expression.Type;

if (type.IsGenericType
&& type.GetGenericTypeDefinition() == typeof(IQueryable<>))
{
type = typeof(List<>).MakeGenericType(type.GetSequenceType());
}

var projectionBindingExpression = new ProjectionBindingExpression(
_selectExpression, _clientProjections.Count - 1, type);
return subquery.ResultCardinality == ResultCardinality.Enumerable
? new CollectionResultExpression(
projectionBindingExpression, navigation: null, subquery.ShaperExpression.Type)
: projectionBindingExpression;
type = typeof(List<>).MakeGenericType(type.GetSequenceType());
}
}

break;
var projectionBindingExpression = new ProjectionBindingExpression(
_selectExpression, _clientProjections.Count - 1, type);
return subquery.ResultCardinality == ResultCardinality.Enumerable
? new CollectionResultExpression(
projectionBindingExpression, navigation: null, subquery.ShaperExpression.Type)
: projectionBindingExpression;
}
}
}

var translation = _sqlTranslator.Translate(expression);
return translation != null
? AddClientProjection(translation, expression.Type.MakeNullable())
: base.Visit(expression);
return base.Visit(expression);
}
else
{
Expand Down
35 changes: 25 additions & 10 deletions src/EFCore.Relational/Query/QuerySqlGenerator.cs
Original file line number Diff line number Diff line change
Expand Up @@ -506,6 +506,31 @@ protected override Expression VisitSqlBinary(SqlBinaryExpression sqlBinaryExpres
return sqlBinaryExpression;
}

/// <inheritdoc />
protected override Expression VisitSqlEnumerable(SqlEnumerableExpression sqlEnumerableExpression)
{
if (sqlEnumerableExpression.Orderings.Count != 0)
{
// TODO: Throw error here because we don't know how to print orderings.
// Though providers can override this method and generate orderings if they have a way to print it.
throw new InvalidOperationException();
}

if (sqlEnumerableExpression.IsDistinct)
{
_relationalCommandBuilder.Append("DISTINCT (");
}

Visit(sqlEnumerableExpression.SqlExpression);

if (sqlEnumerableExpression.IsDistinct)
{
_relationalCommandBuilder.Append(")");
}

return sqlEnumerableExpression;
}

/// <inheritdoc />
protected override Expression VisitSqlConstant(SqlConstantExpression sqlConstantExpression)
{
Expand Down Expand Up @@ -609,16 +634,6 @@ protected override Expression VisitCollate(CollateExpression collateExpression)
return collateExpression;
}

/// <inheritdoc />
protected override Expression VisitDistinct(DistinctExpression distinctExpression)
{
_relationalCommandBuilder.Append("DISTINCT (");
Visit(distinctExpression.Operand);
_relationalCommandBuilder.Append(")");

return distinctExpression;
}

/// <inheritdoc />
protected override Expression VisitCase(CaseExpression caseExpression)
{
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,6 @@ public class RelationalQueryableMethodTranslatingExpressionVisitor : QueryableMe
private readonly QueryCompilationContext _queryCompilationContext;
private readonly ISqlExpressionFactory _sqlExpressionFactory;
private readonly bool _subquery;
private SqlExpression? _groupingElementCorrelationalPredicate;

/// <summary>
/// Creates a new instance of the <see cref="QueryableMethodTranslatingExpressionVisitor" /> class.
Expand Down Expand Up @@ -135,7 +134,6 @@ when queryRootExpression.GetType() == typeof(QueryRootExpression)
case GroupByShaperExpression groupByShaperExpression:
var groupShapedQueryExpression = groupByShaperExpression.GroupingEnumerable;
var groupClonedSelectExpression = ((SelectExpression)groupShapedQueryExpression.QueryExpression).Clone();
_groupingElementCorrelationalPredicate = groupClonedSelectExpression.Predicate;
return new ShapedQueryExpression(
groupClonedSelectExpression,
new QueryExpressionReplacingExpressionVisitor(
Expand Down Expand Up @@ -418,7 +416,7 @@ private static ShapedQueryExpression CreateShapedQueryExpression(IEntityType ent

var newResultSelectorBody = new ReplacingExpressionVisitor(
new Expression[] { original1, original2 },
new[] { translatedKey, groupByShaper })
new[] { groupByShaper.KeySelector, groupByShaper })
.Visit(resultSelector.Body);

newResultSelectorBody = ExpandSharedTypeEntities(selectExpression, newResultSelectorBody);
Expand Down Expand Up @@ -1032,6 +1030,7 @@ protected override Expression VisitMethodCall(MethodCallExpression methodCallExp
protected override Expression VisitExtension(Expression extensionExpression)
=> extensionExpression is EntityShaperExpression
|| extensionExpression is ShapedQueryExpression
|| extensionExpression is GroupByShaperExpression
? extensionExpression
: base.VisitExtension(extensionExpression);

Expand Down Expand Up @@ -1356,7 +1355,7 @@ public DeferredOwnedExpansionRemovingVisitor(SelectExpression selectExpression)
{
DeferredOwnedExpansionExpression doee => UnwrapDeferredEntityProjectionExpression(doee),
// For the source entity shaper or owned collection expansion
EntityShaperExpression or ShapedQueryExpression => expression,
EntityShaperExpression or ShapedQueryExpression or GroupByShaperExpression => expression,
_ => base.Visit(expression)
};

Expand Down Expand Up @@ -1406,7 +1405,8 @@ private static void HandleGroupByForAggregate(SelectExpression selectExpression,
{
if (eraseProjection)
{
selectExpression.ReplaceProjection(new Dictionary<ProjectionMember, Expression>());
// Erasing client projections erase projectionMapping projections too
selectExpression.ReplaceProjection(new List<Expression>());
}

selectExpression.PushdownIntoSubquery();
Expand Down Expand Up @@ -1461,14 +1461,11 @@ private static Expression MatchShaperNullabilityForSetOperation(Expression shape
private ShapedQueryExpression? TranslateAggregateWithPredicate(
ShapedQueryExpression source,
LambdaExpression? predicate,
Func<SqlExpression, SqlExpression?> aggregateTranslator,
Func<SqlEnumerableExpression, SqlExpression?> aggregateTranslator,
Type resultType)
{
var selectExpression = (SelectExpression)source.QueryExpression;
if (_groupingElementCorrelationalPredicate == null)
{
selectExpression.PrepareForAggregate();
}
selectExpression.PrepareForAggregate();

if (predicate != null)
{
Expand All @@ -1481,37 +1478,9 @@ private static Expression MatchShaperNullabilityForSetOperation(Expression shape
source = translatedSource;
}

SqlExpression sqlExpression = _sqlExpressionFactory.Fragment("*");

if (_groupingElementCorrelationalPredicate != null)
{
if (selectExpression.IsDistinct)
{
var shaperExpression = source.ShaperExpression;
if (shaperExpression is UnaryExpression unaryExpression
&& unaryExpression.NodeType == ExpressionType.Convert)
{
shaperExpression = unaryExpression.Operand;
}

if (shaperExpression is ProjectionBindingExpression projectionBindingExpression)
{
sqlExpression = (SqlExpression)selectExpression.GetProjection(projectionBindingExpression);
}
else
{
return null;
}
}

sqlExpression = CombineGroupByAggregateTerms(selectExpression, sqlExpression);
}
else
{
HandleGroupByForAggregate(selectExpression, eraseProjection: true);
}
HandleGroupByForAggregate(selectExpression, eraseProjection: true);

var translation = aggregateTranslator(sqlExpression);
var translation = aggregateTranslator(new SqlEnumerableExpression(_sqlExpressionFactory.Fragment("*"), distinct: false, null));
if (translation == null)
{
return null;
Expand All @@ -1531,16 +1500,13 @@ private static Expression MatchShaperNullabilityForSetOperation(Expression shape
private ShapedQueryExpression? TranslateAggregateWithSelector(
ShapedQueryExpression source,
LambdaExpression? selector,
Func<SqlExpression, SqlExpression?> aggregateTranslator,
Func<SqlEnumerableExpression, SqlExpression?> aggregateTranslator,
bool throwWhenEmpty,
Type resultType)
{
var selectExpression = (SelectExpression)source.QueryExpression;
if (_groupingElementCorrelationalPredicate == null)
{
selectExpression.PrepareForAggregate();
HandleGroupByForAggregate(selectExpression);
}
selectExpression.PrepareForAggregate();
HandleGroupByForAggregate(selectExpression);

SqlExpression translatedSelector;
if (selector == null
Expand Down Expand Up @@ -1575,12 +1541,7 @@ private static Expression MatchShaperNullabilityForSetOperation(Expression shape
}
}

if (_groupingElementCorrelationalPredicate != null)
{
translatedSelector = CombineGroupByAggregateTerms(selectExpression, translatedSelector);
}

var projection = aggregateTranslator(translatedSelector);
var projection = aggregateTranslator(new SqlEnumerableExpression(translatedSelector, distinct: false, null));
if (projection == null)
{
return null;
Expand Down Expand Up @@ -1636,52 +1597,4 @@ private static Expression MatchShaperNullabilityForSetOperation(Expression shape

return source.UpdateShaperExpression(shaper);
}

private SqlExpression CombineGroupByAggregateTerms(SelectExpression selectExpression, SqlExpression selector)
{
if (selectExpression.Predicate != null
&& !selectExpression.Predicate.Equals(_groupingElementCorrelationalPredicate))
{
if (selector is SqlFragmentExpression { Sql: "*" })
{
selector = _sqlExpressionFactory.Constant(1);
}

var correlationTerms = new List<SqlExpression>();
var predicateTerms = new List<SqlExpression>();
PopulatePredicateTerms(_groupingElementCorrelationalPredicate!, correlationTerms);
PopulatePredicateTerms(selectExpression.Predicate, predicateTerms);
var predicate = predicateTerms.Skip(correlationTerms.Count)
.Aggregate((l, r) => _sqlExpressionFactory.AndAlso(l, r));
selector = _sqlExpressionFactory.Case(
new List<CaseWhenClause> { new(predicate, selector) },
elseResult: null);
selectExpression.UpdatePredicate(_groupingElementCorrelationalPredicate!);
}

if (selectExpression.IsDistinct)
{
if (selector is SqlFragmentExpression { Sql: "*" })
{
selector = _sqlExpressionFactory.Constant(1);
}

selector = new DistinctExpression(selector);
}

return selector;

static void PopulatePredicateTerms(SqlExpression predicate, List<SqlExpression> terms)
{
if (predicate is SqlBinaryExpression { OperatorType: ExpressionType.AndAlso } sqlBinaryExpression)
{
PopulatePredicateTerms(sqlBinaryExpression.Left, terms);
PopulatePredicateTerms(sqlBinaryExpression.Right, terms);
}
else
{
terms.Add(predicate);
}
}
}
}
Loading

0 comments on commit 9b181c9

Please sign in to comment.