Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Query: Allow terminating operator on GroupBy without aggregate when n… #21623

Merged
merged 1 commit into from
Jul 15, 2020
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -1059,6 +1059,7 @@ void IPrintableExpression.Print(ExpressionPrinter expressionPrinter)
expressionPrinter.Visit(ServerQueryExpression);
}

expressionPrinter.AppendLine();
expressionPrinter.AppendLine("ProjectionMapping:");
using (expressionPrinter.Indent())
{
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -120,18 +120,26 @@ protected override ShapedQueryExpression TranslateAll(ShapedQueryExpression sour
Check.NotNull(source, nameof(source));
Check.NotNull(predicate, nameof(predicate));

var inMemoryQueryExpression = (InMemoryQueryExpression)source.QueryExpression;
predicate = TranslateLambdaExpression(source, predicate, preserveType: true);
if (predicate == null)
predicate = Expression.Lambda(Expression.Not(predicate.Body), predicate.Parameters);
source = TranslateWhere(source, predicate);
if (source == null)
{
return null;
}

var inMemoryQueryExpression = (InMemoryQueryExpression)source.QueryExpression;

if (source.ShaperExpression is GroupByShaperExpression)
{
inMemoryQueryExpression.ReplaceProjectionMapping(new Dictionary<ProjectionMember, Expression>());
inMemoryQueryExpression.PushdownIntoSubquery();
}

inMemoryQueryExpression.UpdateServerQueryExpression(
Expression.Call(
EnumerableMethods.All.MakeGenericMethod(inMemoryQueryExpression.CurrentParameter.Type),
inMemoryQueryExpression.ServerQueryExpression,
predicate));
Expression.Not(
Expression.Call(
EnumerableMethods.AnyWithoutPredicate.MakeGenericMethod(inMemoryQueryExpression.CurrentParameter.Type),
inMemoryQueryExpression.ServerQueryExpression)));

return source.UpdateShaperExpression(inMemoryQueryExpression.GetSingleScalarProjection());
}
Expand All @@ -144,30 +152,28 @@ protected override ShapedQueryExpression TranslateAll(ShapedQueryExpression sour
/// </summary>
protected override ShapedQueryExpression TranslateAny(ShapedQueryExpression source, LambdaExpression predicate)
{
var inMemoryQueryExpression = (InMemoryQueryExpression)source.QueryExpression;

if (predicate == null)
{
inMemoryQueryExpression.UpdateServerQueryExpression(
Expression.Call(
EnumerableMethods.AnyWithoutPredicate.MakeGenericMethod(inMemoryQueryExpression.CurrentParameter.Type),
inMemoryQueryExpression.ServerQueryExpression));
}
else
if (predicate != null)
{
predicate = TranslateLambdaExpression(source, predicate, preserveType: true);
if (predicate == null)
source = TranslateWhere(source, predicate);
if (source == null)
{
return null;
}
}

inMemoryQueryExpression.UpdateServerQueryExpression(
Expression.Call(
EnumerableMethods.AnyWithPredicate.MakeGenericMethod(inMemoryQueryExpression.CurrentParameter.Type),
inMemoryQueryExpression.ServerQueryExpression,
predicate));
var inMemoryQueryExpression = (InMemoryQueryExpression)source.QueryExpression;

if (source.ShaperExpression is GroupByShaperExpression)
{
inMemoryQueryExpression.ReplaceProjectionMapping(new Dictionary<ProjectionMember, Expression>());
inMemoryQueryExpression.PushdownIntoSubquery();
}

inMemoryQueryExpression.UpdateServerQueryExpression(
Expression.Call(
EnumerableMethods.AnyWithoutPredicate.MakeGenericMethod(inMemoryQueryExpression.CurrentParameter.Type),
inMemoryQueryExpression.ServerQueryExpression));

return source.UpdateShaperExpression(inMemoryQueryExpression.GetSingleScalarProjection());
}

Expand Down Expand Up @@ -256,30 +262,28 @@ protected override ShapedQueryExpression TranslateCount(ShapedQueryExpression so
{
Check.NotNull(source, nameof(source));

var inMemoryQueryExpression = (InMemoryQueryExpression)source.QueryExpression;

if (predicate == null)
{
inMemoryQueryExpression.UpdateServerQueryExpression(
Expression.Call(
EnumerableMethods.CountWithoutPredicate.MakeGenericMethod(inMemoryQueryExpression.CurrentParameter.Type),
inMemoryQueryExpression.ServerQueryExpression));
}
else
if (predicate != null)
{
predicate = TranslateLambdaExpression(source, predicate, preserveType: true);
if (predicate == null)
source = TranslateWhere(source, predicate);
if (source == null)
{
return null;
}
}

inMemoryQueryExpression.UpdateServerQueryExpression(
Expression.Call(
EnumerableMethods.CountWithPredicate.MakeGenericMethod(inMemoryQueryExpression.CurrentParameter.Type),
inMemoryQueryExpression.ServerQueryExpression,
predicate));
var inMemoryQueryExpression = (InMemoryQueryExpression)source.QueryExpression;

if (source.ShaperExpression is GroupByShaperExpression)
{
inMemoryQueryExpression.ReplaceProjectionMapping(new Dictionary<ProjectionMember, Expression>());
inMemoryQueryExpression.PushdownIntoSubquery();
}

inMemoryQueryExpression.UpdateServerQueryExpression(
Expression.Call(
EnumerableMethods.CountWithoutPredicate.MakeGenericMethod(inMemoryQueryExpression.CurrentParameter.Type),
inMemoryQueryExpression.ServerQueryExpression));

return source.UpdateShaperExpression(inMemoryQueryExpression.GetSingleScalarProjection());
}

Expand Down Expand Up @@ -720,37 +724,33 @@ protected override ShapedQueryExpression TranslateLeftJoin(
/// any release. You should only use it directly in your code with extreme caution and knowing that
/// doing so can result in application failures when updating to a new Entity Framework Core release.
/// </summary>
protected override ShapedQueryExpression TranslateLongCount(
ShapedQueryExpression source, LambdaExpression predicate)
protected override ShapedQueryExpression TranslateLongCount(ShapedQueryExpression source, LambdaExpression predicate)
{
Check.NotNull(source, nameof(source));

var inMemoryQueryExpression = (InMemoryQueryExpression)source.QueryExpression;

if (predicate == null)
{
inMemoryQueryExpression.UpdateServerQueryExpression(
Expression.Call(
EnumerableMethods.LongCountWithoutPredicate.MakeGenericMethod(
inMemoryQueryExpression.CurrentParameter.Type),
inMemoryQueryExpression.ServerQueryExpression));
}
else
if (predicate != null)
{
predicate = TranslateLambdaExpression(source, predicate, preserveType: true);
if (predicate == null)
source = TranslateWhere(source, predicate);
if (source == null)
{
return null;
}
}

inMemoryQueryExpression.UpdateServerQueryExpression(
Expression.Call(
EnumerableMethods.LongCountWithPredicate.MakeGenericMethod(
inMemoryQueryExpression.CurrentParameter.Type),
inMemoryQueryExpression.ServerQueryExpression,
predicate));
var inMemoryQueryExpression = (InMemoryQueryExpression)source.QueryExpression;

if (source.ShaperExpression is GroupByShaperExpression)
{
inMemoryQueryExpression.ReplaceProjectionMapping(new Dictionary<ProjectionMember, Expression>());
inMemoryQueryExpression.PushdownIntoSubquery();
}

inMemoryQueryExpression.UpdateServerQueryExpression(
Expression.Call(
EnumerableMethods.LongCountWithoutPredicate.MakeGenericMethod(
inMemoryQueryExpression.CurrentParameter.Type),
inMemoryQueryExpression.ServerQueryExpression));

return source.UpdateShaperExpression(inMemoryQueryExpression.GetSingleScalarProjection());
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -235,6 +235,7 @@ protected override ShapedQueryExpression TranslateAverage(ShapedQueryExpression

var selectExpression = (SelectExpression)source.QueryExpression;
selectExpression.PrepareForAggregate();
HandleGroupByForAggregate(selectExpression);

var newSelector = selector == null
|| selector.Body == selector.Parameters[0]
Expand Down Expand Up @@ -318,6 +319,8 @@ protected override ShapedQueryExpression TranslateCount(ShapedQueryExpression so
}
}

HandleGroupByForAggregate(selectExpression, eraseProjection: true);

var translation = _sqlTranslator.TranslateCount();
if (translation == null)
{
Expand Down Expand Up @@ -668,6 +671,8 @@ protected override ShapedQueryExpression TranslateLongCount(ShapedQueryExpressio
}
}

HandleGroupByForAggregate(selectExpression, eraseProjection: true);

var translation = _sqlTranslator.TranslateLongCount();
if (translation == null)
{
Expand All @@ -688,6 +693,7 @@ protected override ShapedQueryExpression TranslateMax(ShapedQueryExpression sour

var selectExpression = (SelectExpression)source.QueryExpression;
selectExpression.PrepareForAggregate();
HandleGroupByForAggregate(selectExpression);

var newSelector = selector == null
|| selector.Body == selector.Parameters[0]
Expand All @@ -713,6 +719,7 @@ protected override ShapedQueryExpression TranslateMin(ShapedQueryExpression sour

var selectExpression = (SelectExpression)source.QueryExpression;
selectExpression.PrepareForAggregate();
HandleGroupByForAggregate(selectExpression);

var newSelector = selector == null
|| selector.Body == selector.Parameters[0]
Expand Down Expand Up @@ -1048,6 +1055,7 @@ protected override ShapedQueryExpression TranslateSum(ShapedQueryExpression sour

var selectExpression = (SelectExpression)source.QueryExpression;
selectExpression.PrepareForAggregate();
HandleGroupByForAggregate(selectExpression);

var newSelector = selector == null
|| selector.Body == selector.Parameters[0]
Expand Down Expand Up @@ -1479,6 +1487,24 @@ private static Expression AccessField(
string fieldName)
=> Expression.Field(targetExpression, transparentIdentifierType.GetTypeInfo().GetDeclaredField(fieldName));

private static void HandleGroupByForAggregate(SelectExpression selectExpression, bool eraseProjection = false)
{
if (selectExpression.GroupBy.Count > 0)
{
if (eraseProjection)
{
selectExpression.ReplaceProjectionMapping(new Dictionary<ProjectionMember, Expression>());
selectExpression.AddToProjection(selectExpression.GroupBy[0]);
selectExpression.PushdownIntoSubquery();
selectExpression.ClearProjection();
}
else
{
selectExpression.PushdownIntoSubquery();
}
}
}

private ShapedQueryExpression AggregateResultShaper(
ShapedQueryExpression source, Expression projection, bool throwWhenEmpty, Type resultType)
{
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -462,8 +462,7 @@ public void PrepareForAggregate()
{
if (IsDistinct
|| Limit != null
|| Offset != null
|| GroupBy.Count > 0)
|| Offset != null)
{
PushdownIntoSubquery();
}
Expand Down Expand Up @@ -1652,7 +1651,7 @@ private SqlBinaryExpression ValidateKeyComparison(
bool allowNonEquality)
{
if (sqlBinaryExpression.OperatorType == ExpressionType.Equal
|| (allowNonEquality &&
|| (allowNonEquality &&
(sqlBinaryExpression.OperatorType == ExpressionType.NotEqual
|| sqlBinaryExpression.OperatorType == ExpressionType.GreaterThan
|| sqlBinaryExpression.OperatorType == ExpressionType.GreaterThanOrEqual
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2518,7 +2518,7 @@ public virtual Task Any_after_GroupBy_aggregate(bool async)
ss => ss.Set<Order>().GroupBy(o => o.CustomerID).Select(g => g.Sum(gg => gg.OrderID)));
}

[ConditionalTheory(Skip = "Issue#15097")]
[ConditionalTheory]
[MemberData(nameof(IsAsyncData))]
public virtual Task Count_after_GroupBy_without_aggregate(bool async)
{
Expand All @@ -2527,7 +2527,17 @@ public virtual Task Count_after_GroupBy_without_aggregate(bool async)
ss => ss.Set<Order>().GroupBy(o => o.CustomerID));
}

[ConditionalTheory(Skip = "Issue#15097")]
[ConditionalTheory]
[MemberData(nameof(IsAsyncData))]
public virtual Task Count_with_predicate_after_GroupBy_without_aggregate(bool async)
{
return AssertCount(
async,
ss => ss.Set<Order>().GroupBy(o => o.CustomerID),
g => g.Count() > 1);
}

[ConditionalTheory]
[MemberData(nameof(IsAsyncData))]
public virtual Task LongCount_after_GroupBy_without_aggregate(bool async)
{
Expand All @@ -2536,6 +2546,45 @@ public virtual Task LongCount_after_GroupBy_without_aggregate(bool async)
ss => ss.Set<Order>().GroupBy(o => o.CustomerID));
}

[ConditionalTheory]
[MemberData(nameof(IsAsyncData))]
public virtual Task LongCount_with_predicate_after_GroupBy_without_aggregate(bool async)
{
return AssertLongCount(
async,
ss => ss.Set<Order>().GroupBy(o => o.CustomerID),
g => g.Count() > 1);
}

[ConditionalTheory]
[MemberData(nameof(IsAsyncData))]
public virtual Task Any_after_GroupBy_without_aggregate(bool async)
{
return AssertAny(
async,
ss => ss.Set<Order>().GroupBy(o => o.CustomerID));
}

[ConditionalTheory]
[MemberData(nameof(IsAsyncData))]
public virtual Task Any_with_predicate_after_GroupBy_without_aggregate(bool async)
{
return AssertAny(
async,
ss => ss.Set<Order>().GroupBy(o => o.CustomerID),
g => g.Count() > 1);
}

[ConditionalTheory]
[MemberData(nameof(IsAsyncData))]
public virtual Task All_with_predicate_after_GroupBy_without_aggregate(bool async)
{
return AssertAll(
async,
ss => ss.Set<Order>().GroupBy(o => o.CustomerID),
g => g.Count() > 1);
}

#endregion

# region GroupByInSubquery
Expand Down Expand Up @@ -2657,8 +2706,6 @@ public virtual Task GroupBy_scalar_subquery(bool async)
elementSorter: e => e.Key);
}



[ConditionalTheory]
[MemberData(nameof(IsAsyncData))]
public virtual Task GroupBy_scalar_aggregate_in_set_operation(bool async)
Expand Down
15 changes: 15 additions & 0 deletions test/EFCore.Specification.Tests/Query/QueryTestBase.cs
Original file line number Diff line number Diff line change
Expand Up @@ -398,12 +398,27 @@ protected Task AssertLongCount<TResult>(
Func<ISetSource, IQueryable<TResult>> query)
=> AssertLongCount(async, query, query);

protected Task AssertLongCount<TResult>(
bool async,
Func<ISetSource, IQueryable<TResult>> query,
Expression<Func<TResult, bool>> predicate)
=> AssertLongCount(async, query, query, predicate, predicate);

protected Task AssertLongCount<TResult>(
bool async,
Func<ISetSource, IQueryable<TResult>> actualQuery,
Func<ISetSource, IQueryable<TResult>> expectedQuery)
=> QueryAsserter.AssertLongCount(actualQuery, expectedQuery, async);

protected Task AssertLongCount<TResult>(
bool async,
Func<ISetSource, IQueryable<TResult>> actualQuery,
Func<ISetSource, IQueryable<TResult>> expectedQuery,
Expression<Func<TResult, bool>> actualPredicate,
Expression<Func<TResult, bool>> expectedPredicate)
=> QueryAsserter.AssertLongCount(
actualQuery, expectedQuery, actualPredicate, expectedPredicate, async);

protected Task AssertMin<TResult>(
bool async,
Func<ISetSource, IQueryable<TResult>> query,
Expand Down
Loading