Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Query: Translate Distinct operator over group element before aggregate #21921

Merged
merged 1 commit into from
Aug 4, 2020
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,6 @@
using JetBrains.Annotations;
using Microsoft.EntityFrameworkCore.Diagnostics;
using Microsoft.EntityFrameworkCore.Infrastructure;
using Microsoft.EntityFrameworkCore.InMemory.Internal;
using Microsoft.EntityFrameworkCore.Metadata;
using Microsoft.EntityFrameworkCore.Query;
using Microsoft.EntityFrameworkCore.Storage;
Expand Down Expand Up @@ -414,14 +413,11 @@ protected override Expression VisitMethodCall(MethodCallExpression methodCallExp
groupingElementExpression, methodCallExpression.Arguments[1].UnwrapLambdaFromQuote());
}

var selector = GetSelector(groupingElementExpression);
var expression = ApplySelect(groupingElementExpression);

result = selector == null
result = expression == null
? null
: Expression.Call(
EnumerableMethods.GetAverageWithSelector(selector.ReturnType).MakeGenericMethod(typeof(ValueBuffer)),
groupingElementExpression.Source,
selector);
: Expression.Call(EnumerableMethods.GetAverageWithoutSelector(expression.Type.TryGetSequenceType()), expression);
break;
}

Expand All @@ -439,12 +435,24 @@ protected override Expression VisitMethodCall(MethodCallExpression methodCallExp
}
}

result = Expression.Call(
EnumerableMethods.CountWithoutPredicate.MakeGenericMethod(typeof(ValueBuffer)),
groupingElementExpression.Source);
var expression = ApplySelect(groupingElementExpression);

result = expression == null
? null
: Expression.Call(
EnumerableMethods.CountWithoutPredicate.MakeGenericMethod(expression.Type.TryGetSequenceType()),
expression);
break;
}

case nameof(Enumerable.Distinct):
result = groupingElementExpression.Selector is EntityShaperExpression
? groupingElementExpression
: groupingElementExpression.IsDistinct
? null
: groupingElementExpression.ApplyDistinct();
break;

case nameof(Enumerable.LongCount):
{
if (methodCallExpression.Arguments.Count == 2)
Expand All @@ -459,9 +467,13 @@ protected override Expression VisitMethodCall(MethodCallExpression methodCallExp
}
}

result = Expression.Call(
EnumerableMethods.LongCountWithoutPredicate.MakeGenericMethod(typeof(ValueBuffer)),
groupingElementExpression.Source);
var expression = ApplySelect(groupingElementExpression);

result = expression == null
? null
: Expression.Call(
EnumerableMethods.LongCountWithoutPredicate.MakeGenericMethod(expression.Type.TryGetSequenceType()),
expression);
break;
}

Expand All @@ -473,20 +485,22 @@ protected override Expression VisitMethodCall(MethodCallExpression methodCallExp
groupingElementExpression, methodCallExpression.Arguments[1].UnwrapLambdaFromQuote());
}

var selector = GetSelector(groupingElementExpression);
if (selector != null)
var expression = ApplySelect(groupingElementExpression);
if (expression == null
|| expression is ParameterExpression)
{
var aggregateMethod = EnumerableMethods.GetMaxWithSelector(selector.ReturnType);
aggregateMethod = aggregateMethod.GetGenericArguments().Length == 2
? aggregateMethod.MakeGenericMethod(typeof(ValueBuffer), selector.ReturnType)
: aggregateMethod.MakeGenericMethod(typeof(ValueBuffer));


result = Expression.Call(aggregateMethod, groupingElementExpression.Source, selector);
result = null;
}
else
{
result = null;
var type = expression.Type.TryGetSequenceType();
var aggregateMethod = EnumerableMethods.GetMaxWithoutSelector(type);
if (aggregateMethod.IsGenericMethod)
{
aggregateMethod = aggregateMethod.MakeGenericMethod(type);
}

result = Expression.Call(aggregateMethod, expression);
}

break;
Expand All @@ -500,21 +514,22 @@ protected override Expression VisitMethodCall(MethodCallExpression methodCallExp
groupingElementExpression, methodCallExpression.Arguments[1].UnwrapLambdaFromQuote());
}

var selector = GetSelector(groupingElementExpression);

if (selector != null)
var expression = ApplySelect(groupingElementExpression);
if (expression == null
|| expression is ParameterExpression)
{
var aggregateMethod = EnumerableMethods.GetMinWithSelector(selector.ReturnType);
aggregateMethod = aggregateMethod.GetGenericArguments().Length == 2
? aggregateMethod.MakeGenericMethod(typeof(ValueBuffer), selector.ReturnType)
: aggregateMethod.MakeGenericMethod(typeof(ValueBuffer));


result = Expression.Call(aggregateMethod, groupingElementExpression.Source, selector);
result = null;
}
else
{
result = null;
var type = expression.Type.TryGetSequenceType();
var aggregateMethod = EnumerableMethods.GetMinWithoutSelector(type);
if (aggregateMethod.IsGenericMethod)
{
aggregateMethod = aggregateMethod.MakeGenericMethod(type);
}

result = Expression.Call(aggregateMethod, expression);
}

break;
Expand All @@ -532,14 +547,11 @@ protected override Expression VisitMethodCall(MethodCallExpression methodCallExp
groupingElementExpression, methodCallExpression.Arguments[1].UnwrapLambdaFromQuote());
}

var selector = GetSelector(groupingElementExpression);
var expression = ApplySelect(groupingElementExpression);

result = selector == null
result = expression == null
? null
: Expression.Call(
EnumerableMethods.GetSumWithSelector(selector.ReturnType).MakeGenericMethod(typeof(ValueBuffer)),
groupingElementExpression.Source,
selector);
: Expression.Call(EnumerableMethods.GetSumWithoutSelector(expression.Type.TryGetSequenceType()), expression);
break;
}

Expand Down Expand Up @@ -567,12 +579,30 @@ GroupingElementExpression ApplyPredicate(GroupingElementExpression groupingEleme
Expression.Lambda(predicate, groupingElement.ValueBufferParameter)));
}

LambdaExpression GetSelector(GroupingElementExpression groupingElement)
Expression ApplySelect(GroupingElementExpression groupingElement)
{
var selector = TranslateInternal(groupingElement.Selector);
return selector == null
? null
: Expression.Lambda(selector, groupingElement.ValueBufferParameter);

if (selector == null)
{
return groupingElement.Selector is EntityShaperExpression
? groupingElement.Source
: null;
}

var result = Expression.Call(
EnumerableMethods.Select.MakeGenericMethod(typeof(ValueBuffer), selector.Type),
groupingElement.Source,
Expression.Lambda(selector, groupingElement.ValueBufferParameter));

if (groupingElement.IsDistinct)
{
result = Expression.Call(
EnumerableMethods.Distinct.MakeGenericMethod(selector.Type),
result);
}

return result;
}

static GroupingElementExpression ApplySelector(GroupingElementExpression groupingElement, LambdaExpression lambdaExpression)
Expand Down Expand Up @@ -1571,9 +1601,15 @@ public GroupingElementExpression(Expression source, Expression selector, Paramet
Selector = selector;
}
public Expression Source { get; private set; }
public bool IsDistinct { get; private set; }
public Expression Selector { get; private set; }
public ParameterExpression ValueBufferParameter { get; }
public GroupingElementExpression ApplyDistinct()
{
IsDistinct = true;

return this;
}
public GroupingElementExpression ApplySelector(Expression expression)
{
Selector = expression;
Expand Down
12 changes: 12 additions & 0 deletions src/EFCore.Relational/Query/QuerySqlGenerator.cs
Original file line number Diff line number Diff line change
Expand Up @@ -636,6 +636,18 @@ protected override Expression VisitCollate(CollateExpression collateExpresion)
return collateExpresion;
}

/// <inheritdoc />
protected override Expression VisitDistinctSql(DistinctSqlExpression distinctSqlExpression)
{
Check.NotNull(distinctSqlExpression, nameof(distinctSqlExpression));

_relationalCommandBuilder.Append("DISTINCT (");
Visit(distinctSqlExpression.Operand);
_relationalCommandBuilder.Append(")");

return distinctSqlExpression;
}

/// <inheritdoc />
protected override Expression VisitCase(CaseExpression caseExpression)
{
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -156,8 +156,11 @@ public virtual SqlExpression TranslateAverage([NotNull] SqlExpression sqlExpress
if (inputType == typeof(int)
|| inputType == typeof(long))
{
sqlExpression = _sqlExpressionFactory.ApplyDefaultTypeMapping(
_sqlExpressionFactory.Convert(sqlExpression, typeof(double)));
sqlExpression = sqlExpression is DistinctSqlExpression distinctSqlExpression
? new DistinctSqlExpression(_sqlExpressionFactory.ApplyDefaultTypeMapping(
_sqlExpressionFactory.Convert(distinctSqlExpression.Operand, typeof(double))))
: _sqlExpressionFactory.ApplyDefaultTypeMapping(
_sqlExpressionFactory.Convert(sqlExpression, typeof(double)));
}

return inputType == typeof(float)
Expand Down Expand Up @@ -492,6 +495,14 @@ protected override Expression VisitMethodCall(MethodCallExpression methodCallExp
result = TranslateCount(GetExpressionForAggregation(groupingElementExpression, starProjection: true));
break;

case nameof(Enumerable.Distinct):
result = groupingElementExpression.Element is EntityShaperExpression
? groupingElementExpression
: groupingElementExpression.IsDistinct
? null
: groupingElementExpression.ApplyDistinct();
break;

case nameof(Enumerable.LongCount):
if (methodCallExpression.Arguments.Count == 2)
{
Expand Down Expand Up @@ -600,14 +611,20 @@ SqlExpression GetExpressionForAggregation(GroupingElementExpression groupingElem
selector = _sqlExpressionFactory.Constant(1);
}

return _sqlExpressionFactory.Case(
selector = _sqlExpressionFactory.Case(
new List<CaseWhenClause>
{
new CaseWhenClause(groupingElement.Predicate, selector)
},
elseResult: null);
}

if (groupingElement.IsDistinct
&& !(selector is SqlFragmentExpression))
{
selector = new DistinctSqlExpression(selector);
}

return selector;
}
}
Expand Down
9 changes: 6 additions & 3 deletions src/EFCore.Relational/Query/SqlExpressionFactory.cs
Original file line number Diff line number Diff line change
Expand Up @@ -62,6 +62,7 @@ public virtual SqlExpression ApplyTypeMapping(SqlExpression sqlExpression, Relat
{
CaseExpression e => ApplyTypeMappingOnCase(e, typeMapping),
CollateExpression e => ApplyTypeMappingOnCollate(e, typeMapping),
DistinctSqlExpression e => ApplyTypeMappingOnDistinctSql(e, typeMapping),
LikeExpression e => ApplyTypeMappingOnLike(e),
SqlBinaryExpression e => ApplyTypeMappingOnSqlBinary(e, typeMapping),
SqlUnaryExpression e => ApplyTypeMappingOnSqlUnary(e, typeMapping),
Expand Down Expand Up @@ -108,9 +109,11 @@ private SqlExpression ApplyTypeMappingOnCase(

private SqlExpression ApplyTypeMappingOnCollate(
CollateExpression collateExpression, RelationalTypeMapping typeMapping)
=> new CollateExpression(
ApplyTypeMapping(collateExpression.Operand, typeMapping),
collateExpression.Collation);
=> collateExpression.Update(ApplyTypeMapping(collateExpression.Operand, typeMapping));

private SqlExpression ApplyTypeMappingOnDistinctSql(
DistinctSqlExpression distinctSqlExpression, RelationalTypeMapping typeMapping)
=> distinctSqlExpression.Update(ApplyTypeMapping(distinctSqlExpression.Operand, typeMapping));

private SqlExpression ApplyTypeMappingOnSqlUnary(
SqlUnaryExpression sqlUnaryExpression, RelationalTypeMapping typeMapping)
Expand Down
9 changes: 9 additions & 0 deletions src/EFCore.Relational/Query/SqlExpressionVisitor.cs
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,9 @@ protected override Expression VisitExtension(Expression extensionExpression)
case CrossJoinExpression crossJoinExpression:
return VisitCrossJoin(crossJoinExpression);

case DistinctSqlExpression distinctSqlExpression:
return VisitDistinctSql(distinctSqlExpression);

case ExceptExpression exceptExpression:
return VisitExcept(exceptExpression);

Expand Down Expand Up @@ -149,6 +152,12 @@ protected override Expression VisitExtension(Expression extensionExpression)
/// <returns> The modified expression, if it or any subexpression was modified; otherwise, returns the original expression. </returns>
protected abstract Expression VisitCrossJoin([NotNull] CrossJoinExpression crossJoinExpression);
/// <summary>
/// Visits the children of the distinct SQL expression.
/// </summary>
/// <param name="distinctSqlExpression"> The expression to visit. </param>
/// <returns> The modified expression, if it or any subexpression was modified; otherwise, returns the original expression. </returns>
protected abstract Expression VisitDistinctSql([NotNull] DistinctSqlExpression distinctSqlExpression);
/// <summary>
/// Visits the children of the except expression.
/// </summary>
/// <param name="exceptExpression"> The expression to visit. </param>
Expand Down
Loading