diff --git a/src/EFCore.InMemory/Query/Internal/InMemoryExpressionTranslatingExpressionVisitor.cs b/src/EFCore.InMemory/Query/Internal/InMemoryExpressionTranslatingExpressionVisitor.cs
index 7df03ed425a..41d096e459e 100644
--- a/src/EFCore.InMemory/Query/Internal/InMemoryExpressionTranslatingExpressionVisitor.cs
+++ b/src/EFCore.InMemory/Query/Internal/InMemoryExpressionTranslatingExpressionVisitor.cs
@@ -268,6 +268,12 @@ protected override Expression VisitExtension(Expression extensionExpression)
.GetMappedProjection(projectionBindingExpression.ProjectionMember)
: null;
+ case InMemoryGroupByShaperExpression inMemoryGroupByShaperExpression:
+ return new GroupingElementExpression(
+ inMemoryGroupByShaperExpression.GroupingParameter,
+ inMemoryGroupByShaperExpression.ElementSelector,
+ inMemoryGroupByShaperExpression.ValueBufferParameter);
+
default:
return null;
}
@@ -392,79 +398,193 @@ protected override Expression VisitMethodCall(MethodCallExpression methodCallExp
// GroupBy Aggregate case
if (methodCallExpression.Object == null
&& methodCallExpression.Method.DeclaringType == typeof(Enumerable)
- && methodCallExpression.Arguments.Count > 0
- && methodCallExpression.Arguments[0] is InMemoryGroupByShaperExpression groupByShaperExpression)
+ && methodCallExpression.Arguments.Count > 0)
{
- var methodName = methodCallExpression.Method.Name;
- switch (methodName)
+ if (methodCallExpression.Arguments[0].Type.TryGetElementType(typeof(IQueryable<>)) == null
+ && Visit(methodCallExpression.Arguments[0]) is GroupingElementExpression groupingElementExpression)
{
- case nameof(Enumerable.Average):
- case nameof(Enumerable.Max):
- case nameof(Enumerable.Min):
- case nameof(Enumerable.Sum):
+ Expression result = null;
+ switch (methodCallExpression.Method.Name)
{
- var translation = TranslateInternal(GetSelectorOnGrouping(methodCallExpression, groupByShaperExpression));
- if (translation == null)
+ case nameof(Enumerable.Average):
{
- return null;
+ if (methodCallExpression.Arguments.Count == 2)
+ {
+ groupingElementExpression = ApplySelector(
+ groupingElementExpression, methodCallExpression.Arguments[1].UnwrapLambdaFromQuote());
+ }
+
+ var selector = GetSelector(groupingElementExpression);
+
+ result = selector == null
+ ? null
+ : Expression.Call(
+ EnumerableMethods.GetAverageWithSelector(selector.ReturnType).MakeGenericMethod(typeof(ValueBuffer)),
+ groupingElementExpression.Source,
+ selector);
+ break;
}
- var selector = Expression.Lambda(translation, groupByShaperExpression.ValueBufferParameter);
- var method2 = GetMethod();
- method2 = method2.GetGenericArguments().Length == 2
- ? method2.MakeGenericMethod(typeof(ValueBuffer), selector.ReturnType)
- : method2.MakeGenericMethod(typeof(ValueBuffer));
+ case nameof(Enumerable.Count):
+ {
+ if (methodCallExpression.Arguments.Count == 2)
+ {
+ groupingElementExpression = ApplyPredicate(
+ groupingElementExpression, methodCallExpression.Arguments[1].UnwrapLambdaFromQuote());
+
+ if (groupingElementExpression == null)
+ {
+ result = null;
+ break;
+ }
+ }
+
+ result = Expression.Call(
+ EnumerableMethods.CountWithoutPredicate.MakeGenericMethod(typeof(ValueBuffer)),
+ groupingElementExpression.Source);
+ break;
+ }
- return Expression.Call(
- method2,
- groupByShaperExpression.GroupingParameter,
- selector);
+ case nameof(Enumerable.LongCount):
+ {
+ if (methodCallExpression.Arguments.Count == 2)
+ {
+ groupingElementExpression = ApplyPredicate(
+ groupingElementExpression, methodCallExpression.Arguments[1].UnwrapLambdaFromQuote());
+
+ if (groupingElementExpression == null)
+ {
+ result = null;
+ break;
+ }
+ }
+
+ result = Expression.Call(
+ EnumerableMethods.LongCountWithoutPredicate.MakeGenericMethod(typeof(ValueBuffer)),
+ groupingElementExpression.Source);
+ break;
+ }
- MethodInfo GetMethod()
- => methodName switch
+ case nameof(Enumerable.Max):
+ {
+ if (methodCallExpression.Arguments.Count == 2)
{
- nameof(Enumerable.Average) => EnumerableMethods.GetAverageWithSelector(selector.ReturnType),
- nameof(Enumerable.Max) => EnumerableMethods.GetMaxWithSelector(selector.ReturnType),
- nameof(Enumerable.Min) => EnumerableMethods.GetMinWithSelector(selector.ReturnType),
- nameof(Enumerable.Sum) => EnumerableMethods.GetSumWithSelector(selector.ReturnType),
- _ => throw new InvalidOperationException(InMemoryStrings.InvalidStateEncountered("Aggregate Operator")),
- };
- }
+ groupingElementExpression = ApplySelector(
+ groupingElementExpression, methodCallExpression.Arguments[1].UnwrapLambdaFromQuote());
+ }
- case nameof(Enumerable.Count):
- case nameof(Enumerable.LongCount):
- {
- var countMethod = string.Equals(methodName, nameof(Enumerable.Count));
- var predicate = GetPredicateOnGrouping(methodCallExpression, groupByShaperExpression);
- if (predicate == null)
+ var selector = GetSelector(groupingElementExpression);
+ if (selector != null)
+ {
+ var aggregateMethod = EnumerableMethods.GetMaxWithSelector(selector.ReturnType);
+ aggregateMethod = aggregateMethod.GetGenericArguments().Length == 2
+ ? aggregateMethod.MakeGenericMethod(typeof(ValueBuffer), selector.ReturnType)
+ : aggregateMethod.MakeGenericMethod(typeof(ValueBuffer));
+
+
+ result = Expression.Call(aggregateMethod, groupingElementExpression.Source, selector);
+ }
+ else
+ {
+ result = null;
+ }
+
+ break;
+ }
+
+ case nameof(Enumerable.Min):
{
- return Expression.Call(
- (countMethod
- ? EnumerableMethods.CountWithoutPredicate
- : EnumerableMethods.LongCountWithoutPredicate)
- .MakeGenericMethod(typeof(ValueBuffer)),
- groupByShaperExpression.GroupingParameter);
+ if (methodCallExpression.Arguments.Count == 2)
+ {
+ groupingElementExpression = ApplySelector(
+ groupingElementExpression, methodCallExpression.Arguments[1].UnwrapLambdaFromQuote());
+ }
+
+ var selector = GetSelector(groupingElementExpression);
+
+ if (selector != null)
+ {
+ var aggregateMethod = EnumerableMethods.GetMinWithSelector(selector.ReturnType);
+ aggregateMethod = aggregateMethod.GetGenericArguments().Length == 2
+ ? aggregateMethod.MakeGenericMethod(typeof(ValueBuffer), selector.ReturnType)
+ : aggregateMethod.MakeGenericMethod(typeof(ValueBuffer));
+
+
+ result = Expression.Call(aggregateMethod, groupingElementExpression.Source, selector);
+ }
+ else
+ {
+ result = null;
+ }
+
+ break;
}
- var translation = TranslateInternal(predicate);
- if (translation == null)
+ case nameof(Enumerable.Select):
+ result = ApplySelector(groupingElementExpression, methodCallExpression.Arguments[1].UnwrapLambdaFromQuote());
+ break;
+
+ case nameof(Enumerable.Sum):
{
- return null;
+ if (methodCallExpression.Arguments.Count == 2)
+ {
+ groupingElementExpression = ApplySelector(
+ groupingElementExpression, methodCallExpression.Arguments[1].UnwrapLambdaFromQuote());
+ }
+
+ var selector = GetSelector(groupingElementExpression);
+
+ result = selector == null
+ ? null
+ : Expression.Call(
+ EnumerableMethods.GetSumWithSelector(selector.ReturnType).MakeGenericMethod(typeof(ValueBuffer)),
+ groupingElementExpression.Source,
+ selector);
+ break;
}
- predicate = Expression.Lambda(translation, groupByShaperExpression.ValueBufferParameter);
+ case nameof(Enumerable.Where):
+ result = ApplyPredicate(groupingElementExpression, methodCallExpression.Arguments[1].UnwrapLambdaFromQuote());
+ break;
- return Expression.Call(
- (countMethod
- ? EnumerableMethods.CountWithPredicate
- : EnumerableMethods.LongCountWithPredicate)
- .MakeGenericMethod(typeof(ValueBuffer)),
- groupByShaperExpression.GroupingParameter,
- predicate);
+ default:
+ result = null;
+ break;
}
- default:
- throw new InvalidOperationException(CoreStrings.TranslationFailed(methodCallExpression.Print()));
+ return result ?? throw new InvalidOperationException(CoreStrings.TranslationFailed(methodCallExpression.Print()));
+
+ GroupingElementExpression ApplyPredicate(GroupingElementExpression groupingElement, LambdaExpression lambdaExpression)
+ {
+ var predicate = TranslateInternal(RemapLambda(groupingElement, lambdaExpression));
+
+ return predicate == null
+ ? null
+ : groupingElement.UpdateSource(
+ Expression.Call(
+ EnumerableMethods.Where.MakeGenericMethod(typeof(ValueBuffer)),
+ groupingElement.Source,
+ Expression.Lambda(predicate, groupingElement.ValueBufferParameter)));
+ }
+
+ LambdaExpression GetSelector(GroupingElementExpression groupingElement)
+ {
+ var selector = TranslateInternal(groupingElement.Selector);
+ return selector == null
+ ? null
+ : Expression.Lambda(selector, groupingElement.ValueBufferParameter);
+ }
+
+ static GroupingElementExpression ApplySelector(GroupingElementExpression groupingElement, LambdaExpression lambdaExpression)
+ {
+ var selector = RemapLambda(groupingElement, lambdaExpression);
+
+ return groupingElement.ApplySelector(selector);
+ }
+
+ static Expression RemapLambda(GroupingElementExpression groupingElement, LambdaExpression lambdaExpression)
+ => ReplacingExpressionVisitor.Replace(
+ lambdaExpression.Parameters[0], groupingElement.Selector, lambdaExpression.Body);
}
}
@@ -997,46 +1117,6 @@ private static Expression ConvertToNonNullable(Expression expression)
? Expression.Convert(expression, expression.Type.UnwrapNullableType())
: expression;
- private static Expression GetSelectorOnGrouping(
- MethodCallExpression methodCallExpression, GroupByShaperExpression groupByShaperExpression)
- {
- if (methodCallExpression.Arguments.Count == 1)
- {
- return groupByShaperExpression.ElementSelector;
- }
-
- if (methodCallExpression.Arguments.Count == 2)
- {
- var selectorLambda = methodCallExpression.Arguments[1].UnwrapLambdaFromQuote();
- return ReplacingExpressionVisitor.Replace(
- selectorLambda.Parameters[0],
- groupByShaperExpression.ElementSelector,
- selectorLambda.Body);
- }
-
- throw new InvalidOperationException(CoreStrings.TranslationFailed(methodCallExpression.Print()));
- }
-
- private Expression GetPredicateOnGrouping(
- MethodCallExpression methodCallExpression, GroupByShaperExpression groupByShaperExpression)
- {
- if (methodCallExpression.Arguments.Count == 1)
- {
- return null;
- }
-
- if (methodCallExpression.Arguments.Count == 2)
- {
- var selectorLambda = methodCallExpression.Arguments[1].UnwrapLambdaFromQuote();
- return ReplacingExpressionVisitor.Replace(
- selectorLambda.Parameters[0],
- groupByShaperExpression.ElementSelector,
- selectorLambda.Body);
- }
-
- throw new InvalidOperationException(CoreStrings.TranslationFailed(methodCallExpression.Print()));
- }
-
private IProperty FindProperty(Expression expression)
{
if (expression.NodeType == ExpressionType.Convert
@@ -1481,5 +1561,35 @@ public Expression Convert(Type type)
return derivedEntityType == null ? null : new EntityReferenceExpression(this, derivedEntityType);
}
}
+
+ private sealed class GroupingElementExpression : Expression
+ {
+ public GroupingElementExpression(Expression source, Expression selector, ParameterExpression valueBufferParameter)
+ {
+ Source = source;
+ ValueBufferParameter = valueBufferParameter;
+ Selector = selector;
+ }
+ public Expression Source { get; private set; }
+ public Expression Selector { get; private set; }
+ public ParameterExpression ValueBufferParameter { get; }
+
+ public GroupingElementExpression ApplySelector(Expression expression)
+ {
+ Selector = expression;
+
+ return this;
+ }
+
+ public GroupingElementExpression UpdateSource(Expression source)
+ {
+ Source = source;
+
+ return this;
+ }
+
+ public override Type Type => typeof(IEnumerable<>).MakeGenericType(Selector.Type);
+ public override ExpressionType NodeType => ExpressionType.Extension;
+ }
}
}
diff --git a/src/EFCore.Relational/Query/RelationalQueryableMethodTranslatingExpressionVisitor.cs b/src/EFCore.Relational/Query/RelationalQueryableMethodTranslatingExpressionVisitor.cs
index 05330794b8e..8f60facc6cd 100644
--- a/src/EFCore.Relational/Query/RelationalQueryableMethodTranslatingExpressionVisitor.cs
+++ b/src/EFCore.Relational/Query/RelationalQueryableMethodTranslatingExpressionVisitor.cs
@@ -260,7 +260,13 @@ protected override ShapedQueryExpression TranslateAverage(ShapedQueryExpression
return null;
}
- var projection = _sqlTranslator.TranslateAverage(newSelector);
+ var translatedSelector = TranslateExpression(newSelector);
+ if (translatedSelector == null)
+ {
+ return null;
+ }
+
+ var projection = _sqlTranslator.TranslateAverage(translatedSelector);
return projection != null
? AggregateResultShaper(source, projection, throwWhenEmpty: true, resultType)
: null;
@@ -336,7 +342,7 @@ protected override ShapedQueryExpression TranslateCount(ShapedQueryExpression so
HandleGroupByForAggregate(selectExpression, eraseProjection: true);
- var translation = _sqlTranslator.TranslateCount();
+ var translation = _sqlTranslator.TranslateCount(_sqlExpressionFactory.Fragment("*"));
if (translation == null)
{
return null;
@@ -692,7 +698,7 @@ protected override ShapedQueryExpression TranslateLongCount(ShapedQueryExpressio
HandleGroupByForAggregate(selectExpression, eraseProjection: true);
- var translation = _sqlTranslator.TranslateLongCount();
+ var translation = _sqlTranslator.TranslateLongCount(_sqlExpressionFactory.Fragment("*"));
if (translation == null)
{
return null;
@@ -730,7 +736,13 @@ protected override ShapedQueryExpression TranslateMax(ShapedQueryExpression sour
return null;
}
- var projection = _sqlTranslator.TranslateMax(newSelector);
+ var translatedSelector = TranslateExpression(newSelector);
+ if (translatedSelector == null)
+ {
+ return null;
+ }
+
+ var projection = _sqlTranslator.TranslateMax(translatedSelector);
return AggregateResultShaper(source, projection, throwWhenEmpty: true, resultType);
}
@@ -756,7 +768,13 @@ protected override ShapedQueryExpression TranslateMin(ShapedQueryExpression sour
return null;
}
- var projection = _sqlTranslator.TranslateMin(newSelector);
+ var translatedSelector = TranslateExpression(newSelector);
+ if (translatedSelector == null)
+ {
+ return null;
+ }
+
+ var projection = _sqlTranslator.TranslateMin(translatedSelector);
return AggregateResultShaper(source, projection, throwWhenEmpty: true, resultType);
}
@@ -1053,7 +1071,13 @@ protected override ShapedQueryExpression TranslateSum(ShapedQueryExpression sour
return null;
}
- var projection = _sqlTranslator.TranslateSum(newSelector);
+ var translatedSelector = TranslateExpression(newSelector);
+ if (translatedSelector == null)
+ {
+ return null;
+ }
+
+ var projection = _sqlTranslator.TranslateSum(translatedSelector);
return projection != null
? AggregateResultShaper(source, projection, throwWhenEmpty: false, resultType)
: null;
diff --git a/src/EFCore.Relational/Query/RelationalSqlTranslatingExpressionVisitor.cs b/src/EFCore.Relational/Query/RelationalSqlTranslatingExpressionVisitor.cs
index d9c8a5daa9f..81a47af521f 100644
--- a/src/EFCore.Relational/Query/RelationalSqlTranslatingExpressionVisitor.cs
+++ b/src/EFCore.Relational/Query/RelationalSqlTranslatingExpressionVisitor.cs
@@ -146,24 +146,11 @@ private SqlExpression TranslateInternal(Expression expression)
///
/// Translates Average over an expression to an equivalent SQL representation.
///
- /// An expression to translate Average over.
+ /// An expression to translate Average over.
/// A SQL translation of Average over the given expression.
- public virtual SqlExpression TranslateAverage([NotNull] Expression expression)
+ public virtual SqlExpression TranslateAverage([NotNull] SqlExpression sqlExpression)
{
- Check.NotNull(expression, nameof(expression));
-
- if (!(expression is SqlExpression sqlExpression))
- {
- sqlExpression = TranslateInternal(expression);
- }
-
- if (sqlExpression == null)
- {
- throw new InvalidOperationException(
- TranslationErrorDetails == null
- ? CoreStrings.TranslationFailed(expression.Print())
- : CoreStrings.TranslationFailedWithDetails(expression.Print(), TranslationErrorDetails));
- }
+ Check.NotNull(sqlExpression, nameof(sqlExpression));
var inputType = sqlExpression.Type;
if (inputType == typeof(int)
@@ -195,20 +182,16 @@ public virtual SqlExpression TranslateAverage([NotNull] Expression expression)
///
/// Translates Count over an expression to an equivalent SQL representation.
///
- /// An expression to translate Count over.
+ /// An expression to translate Count over.
/// A SQL translation of Count over the given expression.
- public virtual SqlExpression TranslateCount([CanBeNull] Expression expression = null)
+ public virtual SqlExpression TranslateCount([NotNull] SqlExpression sqlExpression)
{
- if (expression != null)
- {
- // TODO: Translate Count with predicate for GroupBy
- return null;
- }
+ Check.NotNull(sqlExpression, nameof(sqlExpression));
return _sqlExpressionFactory.ApplyDefaultTypeMapping(
_sqlExpressionFactory.Function(
"COUNT",
- new[] { _sqlExpressionFactory.Fragment("*") },
+ new[] { sqlExpression },
nullable: false,
argumentsPropagateNullability: new[] { false },
typeof(int)));
@@ -217,20 +200,16 @@ public virtual SqlExpression TranslateCount([CanBeNull] Expression expression =
///
/// Translates LongCount over an expression to an equivalent SQL representation.
///
- /// An expression to translate LongCount over.
+ /// An expression to translate LongCount over.
/// A SQL translation of LongCount over the given expression.
- public virtual SqlExpression TranslateLongCount([CanBeNull] Expression expression = null)
+ public virtual SqlExpression TranslateLongCount([NotNull] SqlExpression sqlExpression)
{
- if (expression != null)
- {
- // TODO: Translate Count with predicate for GroupBy
- return null;
- }
+ Check.NotNull(sqlExpression, nameof(sqlExpression));
return _sqlExpressionFactory.ApplyDefaultTypeMapping(
_sqlExpressionFactory.Function(
"COUNT",
- new[] { _sqlExpressionFactory.Fragment("*") },
+ new[] { sqlExpression },
nullable: false,
argumentsPropagateNullability: new[] { false },
typeof(long)));
@@ -239,16 +218,11 @@ public virtual SqlExpression TranslateLongCount([CanBeNull] Expression expressio
///
/// Translates Max over an expression to an equivalent SQL representation.
///
- /// An expression to translate Max over.
+ /// An expression to translate Max over.
/// A SQL translation of Max over the given expression.
- public virtual SqlExpression TranslateMax([NotNull] Expression expression)
+ public virtual SqlExpression TranslateMax([NotNull] SqlExpression sqlExpression)
{
- Check.NotNull(expression, nameof(expression));
-
- if (!(expression is SqlExpression sqlExpression))
- {
- sqlExpression = TranslateInternal(expression);
- }
+ Check.NotNull(sqlExpression, nameof(sqlExpression));
return sqlExpression != null
? _sqlExpressionFactory.Function(
@@ -264,16 +238,11 @@ public virtual SqlExpression TranslateMax([NotNull] Expression expression)
///
/// Translates Min over an expression to an equivalent SQL representation.
///
- /// An expression to translate Min over.
+ /// An expression to translate Min over.
/// A SQL translation of Min over the given expression.
- public virtual SqlExpression TranslateMin([NotNull] Expression expression)
+ public virtual SqlExpression TranslateMin([NotNull] SqlExpression sqlExpression)
{
- Check.NotNull(expression, nameof(expression));
-
- if (!(expression is SqlExpression sqlExpression))
- {
- sqlExpression = TranslateInternal(expression);
- }
+ Check.NotNull(sqlExpression, nameof(sqlExpression));
return sqlExpression != null
? _sqlExpressionFactory.Function(
@@ -289,24 +258,11 @@ public virtual SqlExpression TranslateMin([NotNull] Expression expression)
///
/// Translates Sum over an expression to an equivalent SQL representation.
///
- /// An expression to translate Sum over.
+ /// An expression to translate Sum over.
/// A SQL translation of Sum over the given expression.
- public virtual SqlExpression TranslateSum([NotNull] Expression expression)
+ public virtual SqlExpression TranslateSum([NotNull] SqlExpression sqlExpression)
{
- Check.NotNull(expression, nameof(expression));
-
- if (!(expression is SqlExpression sqlExpression))
- {
- sqlExpression = TranslateInternal(expression);
- }
-
- if (sqlExpression == null)
- {
- throw new InvalidOperationException(
- TranslationErrorDetails == null
- ? CoreStrings.TranslationFailed(expression.Print())
- : CoreStrings.TranslationFailedWithDetails(expression.Print(), TranslationErrorDetails));
- }
+ Check.NotNull(sqlExpression, nameof(sqlExpression));
var inputType = sqlExpression.Type;
@@ -448,6 +404,9 @@ protected override Expression VisitExtension(Expression extensionExpression)
.GetMappedProjection(projectionBindingExpression.ProjectionMember)
: null;
+ case GroupByShaperExpression groupByShaperExpression:
+ return new GroupingElementExpression(groupByShaperExpression.ElementSelector);
+
default:
return null;
}
@@ -498,29 +457,160 @@ protected override Expression VisitMethodCall(MethodCallExpression methodCallExp
// GroupBy Aggregate case
if (methodCallExpression.Object == null
&& methodCallExpression.Method.DeclaringType == typeof(Enumerable)
- && methodCallExpression.Arguments.Count > 0
- && methodCallExpression.Arguments[0] is GroupByShaperExpression groupByShaperExpression)
+ && methodCallExpression.Arguments.Count > 0)
{
- var translatedAggregate = methodCallExpression.Method.Name switch
- {
- nameof(Enumerable.Average) => TranslateAverage(GetSelectorOnGrouping(methodCallExpression, groupByShaperExpression)),
- nameof(Enumerable.Count) => TranslateCount(GetPredicateOnGrouping(methodCallExpression, groupByShaperExpression)),
- nameof(Enumerable.LongCount) => TranslateLongCount(GetPredicateOnGrouping(methodCallExpression, groupByShaperExpression)),
- nameof(Enumerable.Max) => TranslateMax(GetSelectorOnGrouping(methodCallExpression, groupByShaperExpression)),
- nameof(Enumerable.Min) => TranslateMin(GetSelectorOnGrouping(methodCallExpression, groupByShaperExpression)),
- nameof(Enumerable.Sum) => TranslateSum(GetSelectorOnGrouping(methodCallExpression, groupByShaperExpression)),
- _ => null
- };
-
- if (translatedAggregate == null)
+ if (methodCallExpression.Arguments[0].Type.TryGetElementType(typeof(IQueryable<>)) == null
+ && Visit(methodCallExpression.Arguments[0]) is GroupingElementExpression groupingElementExpression)
{
- throw new InvalidOperationException(
- TranslationErrorDetails == null
- ? CoreStrings.TranslationFailed(methodCallExpression.Print())
- : CoreStrings.TranslationFailedWithDetails(methodCallExpression.Print(), TranslationErrorDetails));
- }
+ Expression result;
+ switch (methodCallExpression.Method.Name)
+ {
+ case nameof(Enumerable.Average):
+ if (methodCallExpression.Arguments.Count == 2)
+ {
+ groupingElementExpression = ApplySelector(
+ groupingElementExpression, methodCallExpression.Arguments[1].UnwrapLambdaFromQuote());
+ }
+
+ result = GetExpressionForAggregation(groupingElementExpression) is SqlExpression averageExpression
+ ? TranslateAverage(averageExpression)
+ : null;
+ break;
+
+ case nameof(Enumerable.Count):
+ if (methodCallExpression.Arguments.Count == 2)
+ {
+ groupingElementExpression = ApplyPredicate(
+ groupingElementExpression, methodCallExpression.Arguments[1].UnwrapLambdaFromQuote());
+ if (groupingElementExpression == null)
+ {
+ result = null;
+ break;
+ }
+ }
+
+ result = TranslateCount(GetExpressionForAggregation(groupingElementExpression, starProjection: true));
+ break;
+
+ case nameof(Enumerable.LongCount):
+ if (methodCallExpression.Arguments.Count == 2)
+ {
+ groupingElementExpression = ApplyPredicate(
+ groupingElementExpression, methodCallExpression.Arguments[1].UnwrapLambdaFromQuote());
+ if (groupingElementExpression == null)
+ {
+ result = null;
+ break;
+ }
+ }
+
+ result = TranslateLongCount(GetExpressionForAggregation(groupingElementExpression, starProjection: true));
+ break;
+
+ case nameof(Enumerable.Max):
+ if (methodCallExpression.Arguments.Count == 2)
+ {
+ groupingElementExpression = ApplySelector(
+ groupingElementExpression, methodCallExpression.Arguments[1].UnwrapLambdaFromQuote());
+ }
+
+ result = GetExpressionForAggregation(groupingElementExpression) is SqlExpression maxExpression
+ ? TranslateMax(maxExpression)
+ : null;
+ break;
+
+ case nameof(Enumerable.Min):
+ if (methodCallExpression.Arguments.Count == 2)
+ {
+ groupingElementExpression = ApplySelector(
+ groupingElementExpression, methodCallExpression.Arguments[1].UnwrapLambdaFromQuote());
+ }
+
+ result = GetExpressionForAggregation(groupingElementExpression) is SqlExpression minExpression
+ ? TranslateMin(minExpression)
+ : null;
+ break;
+
+ case nameof(Enumerable.Select):
+ result = ApplySelector(groupingElementExpression, methodCallExpression.Arguments[1].UnwrapLambdaFromQuote());
+ break;
+
+ case nameof(Enumerable.Sum):
+ if (methodCallExpression.Arguments.Count == 2)
+ {
+ groupingElementExpression = ApplySelector(
+ groupingElementExpression, methodCallExpression.Arguments[1].UnwrapLambdaFromQuote());
+ }
+
+ result = GetExpressionForAggregation(groupingElementExpression) is SqlExpression sumExpression
+ ? TranslateSum(sumExpression)
+ : null;
+ break;
+
+ case nameof(Enumerable.Where):
+ result = ApplyPredicate(groupingElementExpression, methodCallExpression.Arguments[1].UnwrapLambdaFromQuote());
+ break;
+
+ default:
+ result = null;
+ break;
+ }
+
+ return result ?? throw new InvalidOperationException(CoreStrings.TranslationFailed(methodCallExpression.Print()));
+
+ GroupingElementExpression ApplyPredicate(GroupingElementExpression groupingElement, LambdaExpression lambdaExpression)
+ {
+ var predicate = TranslateInternal(RemapLambda(groupingElement, lambdaExpression));
+
+ return predicate == null
+ ? null
+ : groupingElement.ApplyPredicate(predicate);
+ }
+
+ static GroupingElementExpression ApplySelector(GroupingElementExpression groupingElement, LambdaExpression lambdaExpression)
+ {
+ var selector = RemapLambda(groupingElement, lambdaExpression);
+
+ return groupingElement.ApplySelector(selector);
+ }
+
+ static Expression RemapLambda(GroupingElementExpression groupingElement, LambdaExpression lambdaExpression)
+ => ReplacingExpressionVisitor.Replace(
+ lambdaExpression.Parameters[0], groupingElement.Element, lambdaExpression.Body);
+
+ SqlExpression GetExpressionForAggregation(GroupingElementExpression groupingElement, bool starProjection = false)
+ {
+ var selector = TranslateInternal(groupingElement.Element);
+ if (selector == null)
+ {
+ if (starProjection)
+ {
+ selector = _sqlExpressionFactory.Fragment("*");
+ }
+ else
+ {
+ return null;
+ }
+ }
+
+ if (groupingElement.Predicate != null)
+ {
+ if (selector is SqlFragmentExpression)
+ {
+ selector = _sqlExpressionFactory.Constant(1);
+ }
- return translatedAggregate;
+ return _sqlExpressionFactory.Case(
+ new List
+ {
+ new CaseWhenClause(groupingElement.Predicate, selector)
+ },
+ elseResult: null);
+ }
+
+ return selector;
+ }
+ }
}
// Subquery case
@@ -990,46 +1080,6 @@ private SqlExpression BindProperty(EntityReferenceExpression entityReferenceExpr
return null;
}
- private static Expression GetSelectorOnGrouping(
- MethodCallExpression methodCallExpression, GroupByShaperExpression groupByShaperExpression)
- {
- if (methodCallExpression.Arguments.Count == 1)
- {
- return groupByShaperExpression.ElementSelector;
- }
-
- if (methodCallExpression.Arguments.Count == 2)
- {
- var selectorLambda = methodCallExpression.Arguments[1].UnwrapLambdaFromQuote();
- return ReplacingExpressionVisitor.Replace(
- selectorLambda.Parameters[0],
- groupByShaperExpression.ElementSelector,
- selectorLambda.Body);
- }
-
- throw new InvalidOperationException(CoreStrings.TranslationFailed(methodCallExpression.Print()));
- }
-
- private static Expression GetPredicateOnGrouping(
- MethodCallExpression methodCallExpression, GroupByShaperExpression groupByShaperExpression)
- {
- if (methodCallExpression.Arguments.Count == 1)
- {
- return null;
- }
-
- if (methodCallExpression.Arguments.Count == 2)
- {
- var selectorLambda = methodCallExpression.Arguments[1].UnwrapLambdaFromQuote();
- return ReplacingExpressionVisitor.Replace(
- selectorLambda.Parameters[0],
- groupByShaperExpression.ElementSelector,
- selectorLambda.Body);
- }
-
- throw new InvalidOperationException(CoreStrings.TranslationFailed(methodCallExpression.Print()));
- }
-
private static Expression TryRemoveImplicitConvert(Expression expression)
{
if (expression is UnaryExpression unaryExpression
@@ -1412,6 +1462,56 @@ public Expression Convert(Type type)
}
}
+ private sealed class GroupingElementExpression : Expression
+ {
+ public GroupingElementExpression(Expression element)
+ {
+ Element = element;
+ }
+ public Expression Element { get; private set; }
+ public bool IsDistinct { get; private set; }
+ public SqlExpression Predicate { get; private set; }
+
+ public GroupingElementExpression ApplyDistinct()
+ {
+ IsDistinct = true;
+
+ return this;
+ }
+
+ public GroupingElementExpression ApplySelector(Expression expression)
+ {
+ Element = expression;
+
+ return this;
+ }
+
+ public GroupingElementExpression ApplyPredicate(SqlExpression expression)
+ {
+ Check.NotNull(expression, nameof(expression));
+
+ if (expression is SqlConstantExpression sqlConstant
+ && sqlConstant.Value is bool boolValue
+ && boolValue)
+ {
+ return this;
+ }
+
+ Predicate = Predicate == null
+ ? expression
+ : new SqlBinaryExpression(
+ ExpressionType.AndAlso,
+ Predicate,
+ expression,
+ typeof(bool),
+ expression.TypeMapping);
+
+ return this;
+ }
+ public override Type Type => typeof(IEnumerable<>).MakeGenericType(Element.Type);
+ public override ExpressionType NodeType => ExpressionType.Extension;
+ }
+
private sealed class SqlTypeMappingVerifyingExpressionVisitor : ExpressionVisitor
{
protected override Expression VisitExtension(Expression extensionExpression)
diff --git a/src/EFCore.SqlServer/Query/Internal/SqlServerSqlTranslatingExpressionVisitor.cs b/src/EFCore.SqlServer/Query/Internal/SqlServerSqlTranslatingExpressionVisitor.cs
index 643eadc10ac..fdd5f952b00 100644
--- a/src/EFCore.SqlServer/Query/Internal/SqlServerSqlTranslatingExpressionVisitor.cs
+++ b/src/EFCore.SqlServer/Query/Internal/SqlServerSqlTranslatingExpressionVisitor.cs
@@ -110,18 +110,14 @@ protected override Expression VisitUnary(UnaryExpression unaryExpression)
/// any release. You should only use it directly in your code with extreme caution and knowing that
/// doing so can result in application failures when updating to a new Entity Framework Core release.
///
- public override SqlExpression TranslateLongCount(Expression expression = null)
+ public override SqlExpression TranslateLongCount(SqlExpression sqlExpression)
{
- if (expression != null)
- {
- // TODO: Translate Count with predicate for GroupBy
- return null;
- }
+ Check.NotNull(sqlExpression, nameof(sqlExpression));
return Dependencies.SqlExpressionFactory.ApplyDefaultTypeMapping(
Dependencies.SqlExpressionFactory.Function(
"COUNT_BIG",
- new[] { Dependencies.SqlExpressionFactory.Fragment("*") },
+ new[] { sqlExpression },
nullable: false,
argumentsPropagateNullability: new[] { false },
typeof(long)));
diff --git a/src/EFCore.Sqlite.Core/Query/Internal/SqliteSqlTranslatingExpressionVisitor.cs b/src/EFCore.Sqlite.Core/Query/Internal/SqliteSqlTranslatingExpressionVisitor.cs
index 662b3d158c8..4b72a4f0f74 100644
--- a/src/EFCore.Sqlite.Core/Query/Internal/SqliteSqlTranslatingExpressionVisitor.cs
+++ b/src/EFCore.Sqlite.Core/Query/Internal/SqliteSqlTranslatingExpressionVisitor.cs
@@ -196,11 +196,11 @@ protected override Expression VisitBinary(BinaryExpression binaryExpression)
/// any release. You should only use it directly in your code with extreme caution and knowing that
/// doing so can result in application failures when updating to a new Entity Framework Core release.
///
- public override SqlExpression TranslateAverage(Expression expression)
+ public override SqlExpression TranslateAverage(SqlExpression sqlExpression)
{
- Check.NotNull(expression, nameof(expression));
+ Check.NotNull(sqlExpression, nameof(sqlExpression));
- var visitedExpression = base.TranslateAverage(expression);
+ var visitedExpression = base.TranslateAverage(sqlExpression);
var argumentType = GetProviderType(visitedExpression);
if (argumentType == typeof(decimal))
{
@@ -217,11 +217,11 @@ public override SqlExpression TranslateAverage(Expression expression)
/// any release. You should only use it directly in your code with extreme caution and knowing that
/// doing so can result in application failures when updating to a new Entity Framework Core release.
///
- public override SqlExpression TranslateMax(Expression expression)
+ public override SqlExpression TranslateMax(SqlExpression sqlExpression)
{
- Check.NotNull(expression, nameof(expression));
+ Check.NotNull(sqlExpression, nameof(sqlExpression));
- var visitedExpression = base.TranslateMax(expression);
+ var visitedExpression = base.TranslateMax(sqlExpression);
var argumentType = GetProviderType(visitedExpression);
if (argumentType == typeof(DateTimeOffset)
|| argumentType == typeof(decimal)
@@ -247,11 +247,11 @@ public override SqlExpression TranslateMax(Expression expression)
/// any release. You should only use it directly in your code with extreme caution and knowing that
/// doing so can result in application failures when updating to a new Entity Framework Core release.
///
- public override SqlExpression TranslateMin(Expression expression)
+ public override SqlExpression TranslateMin(SqlExpression sqlExpression)
{
- Check.NotNull(expression, nameof(expression));
+ Check.NotNull(sqlExpression, nameof(sqlExpression));
- var visitedExpression = base.TranslateMin(expression);
+ var visitedExpression = base.TranslateMin(sqlExpression);
var argumentType = GetProviderType(visitedExpression);
if (argumentType == typeof(DateTimeOffset)
|| argumentType == typeof(decimal)
@@ -271,11 +271,11 @@ public override SqlExpression TranslateMin(Expression expression)
/// any release. You should only use it directly in your code with extreme caution and knowing that
/// doing so can result in application failures when updating to a new Entity Framework Core release.
///
- public override SqlExpression TranslateSum(Expression expression)
+ public override SqlExpression TranslateSum(SqlExpression sqlExpression)
{
- Check.NotNull(expression, nameof(expression));
+ Check.NotNull(sqlExpression, nameof(sqlExpression));
- var visitedExpression = base.TranslateSum(expression);
+ var visitedExpression = base.TranslateSum(sqlExpression);
var argumentType = GetProviderType(visitedExpression);
if (argumentType == typeof(decimal))
{
diff --git a/src/EFCore/Query/ShapedQueryCompilingExpressionVisitor.cs b/src/EFCore/Query/ShapedQueryCompilingExpressionVisitor.cs
index 386b256e7bc..4501fcfa7ca 100644
--- a/src/EFCore/Query/ShapedQueryCompilingExpressionVisitor.cs
+++ b/src/EFCore/Query/ShapedQueryCompilingExpressionVisitor.cs
@@ -221,7 +221,9 @@ public ConstantVerifyingExpressionVisitor(ITypeMappingSource typeMappingSource)
private bool ValidConstant(ConstantExpression constantExpression)
{
return constantExpression.Value == null
- || _typeMappingSource.FindMapping(constantExpression.Type) != null;
+ || _typeMappingSource.FindMapping(constantExpression.Type) != null
+ || constantExpression.Value is Array array
+ && array.Length == 0;
}
protected override Expression VisitConstant(ConstantExpression constantExpression)
diff --git a/test/EFCore.Specification.Tests/Query/NorthwindGroupByQueryTestBase.cs b/test/EFCore.Specification.Tests/Query/NorthwindGroupByQueryTestBase.cs
index 10ede25a272..04b7085109e 100644
--- a/test/EFCore.Specification.Tests/Query/NorthwindGroupByQueryTestBase.cs
+++ b/test/EFCore.Specification.Tests/Query/NorthwindGroupByQueryTestBase.cs
@@ -4,6 +4,7 @@
using System;
using System.Linq;
using System.Threading.Tasks;
+using Castle.Components.DictionaryAdapter;
using Microsoft.EntityFrameworkCore.TestModels.Northwind;
using Microsoft.EntityFrameworkCore.TestUtilities;
using Xunit;
@@ -37,7 +38,7 @@ public virtual Task GroupBy_Property_Select_Average(bool async)
[ConditionalTheory(Skip = "issue #18923")]
[MemberData(nameof(IsAsyncData))]
- public virtual Task GroupBy_Property_Select_Average_with_navigation_expansion(bool async)
+ public virtual Task GroupBy_Property_Select_Average_with_group_enumerable_projected(bool async)
{
return AssertQueryScalar(
async,
@@ -2007,7 +2008,7 @@ public virtual Task Distinct_GroupBy_OrderBy_key(bool async)
assertOrder: true);
}
- [ConditionalTheory(Skip = "Issue #18923")]
+ [ConditionalTheory(Skip = "Issue #15873")]
[MemberData(nameof(IsAsyncData))]
public virtual Task Select_nested_collection_with_groupby(bool async)
{
@@ -2020,7 +2021,7 @@ public virtual Task Select_nested_collection_with_groupby(bool async)
: Array.Empty()));
}
- [ConditionalTheory(Skip = "Issue #18923")]
+ [ConditionalTheory]
[MemberData(nameof(IsAsyncData))]
public virtual Task Select_GroupBy_All(bool async)
{
@@ -2050,9 +2051,21 @@ public override bool Equals(object obj)
public override int GetHashCode() => Order.GetHashCode();
}
- [ConditionalTheory(Skip = "Issue #18836")]
+ [ConditionalTheory]
[MemberData(nameof(IsAsyncData))]
- public virtual Task GroupBy_Where_in_aggregate(bool async)
+ public virtual Task GroupBy_Where_Average(bool async)
+ {
+ return AssertQueryScalar(
+ async,
+ ss => from o in ss.Set()
+ group o by new { o.CustomerID }
+ into g
+ select g.Where(e => e.OrderID < 10300).Select(e => (int?)e.OrderID).Average());
+ }
+
+ [ConditionalTheory]
+ [MemberData(nameof(IsAsyncData))]
+ public virtual Task GroupBy_Where_Count(bool async)
{
return AssertQueryScalar(
async,
@@ -2062,6 +2075,160 @@ into g
select g.Where(e => e.OrderID < 10300).Count());
}
+ [ConditionalTheory]
+ [MemberData(nameof(IsAsyncData))]
+ public virtual Task GroupBy_Where_LongCount(bool async)
+ {
+ return AssertQueryScalar(
+ async,
+ ss => from o in ss.Set()
+ group o by new { o.CustomerID }
+ into g
+ select g.Where(e => e.OrderID < 10300).LongCount());
+ }
+
+ [ConditionalTheory]
+ [MemberData(nameof(IsAsyncData))]
+ public virtual Task GroupBy_Where_Max(bool async)
+ {
+ return AssertQueryScalar(
+ async,
+ ss => from o in ss.Set()
+ group o by new { o.CustomerID }
+ into g
+ select g.Where(e => e.OrderID < 10300).Select(e => (int?)e.OrderID).Max());
+ }
+
+ [ConditionalTheory]
+ [MemberData(nameof(IsAsyncData))]
+ public virtual Task GroupBy_Where_Min(bool async)
+ {
+ return AssertQueryScalar(
+ async,
+ ss => from o in ss.Set()
+ group o by new { o.CustomerID }
+ into g
+ select g.Where(e => e.OrderID < 10300).Select(e => (int?)e.OrderID).Min());
+ }
+
+ [ConditionalTheory]
+ [MemberData(nameof(IsAsyncData))]
+ public virtual Task GroupBy_Where_Sum(bool async)
+ {
+ return AssertQueryScalar(
+ async,
+ ss => from o in ss.Set()
+ group o by new { o.CustomerID }
+ into g
+ select g.Where(e => e.OrderID < 10300).Select(e => e.OrderID).Sum());
+ }
+
+ [ConditionalTheory]
+ [MemberData(nameof(IsAsyncData))]
+ public virtual Task GroupBy_Where_Count_with_predicate(bool async)
+ {
+ return AssertQueryScalar(
+ async,
+ ss => from o in ss.Set()
+ group o by new { o.CustomerID }
+ into g
+ select g.Where(e => e.OrderID < 10300).Count(e => e.OrderDate.HasValue && e.OrderDate.Value.Year == 1997));
+ }
+
+ [ConditionalTheory]
+ [MemberData(nameof(IsAsyncData))]
+ public virtual Task GroupBy_Where_Where_Count(bool async)
+ {
+ return AssertQueryScalar(
+ async,
+ ss => from o in ss.Set()
+ group o by new { o.CustomerID }
+ into g
+ select g.Where(e => e.OrderID < 10300).Where(e => e.OrderDate.HasValue && e.OrderDate.Value.Year == 1997).Count());
+ }
+
+ [ConditionalTheory]
+ [MemberData(nameof(IsAsyncData))]
+ public virtual Task GroupBy_Where_Select_Where_Count(bool async)
+ {
+ return AssertQueryScalar(
+ async,
+ ss => from o in ss.Set()
+ group o by new { o.CustomerID }
+ into g
+ select g.Where(e => e.OrderID < 10300).Select(e => e.OrderDate).Where(e => e.HasValue && e.Value.Year == 1997).Count());
+ }
+
+ [ConditionalTheory]
+ [MemberData(nameof(IsAsyncData))]
+ public virtual Task GroupBy_Where_Select_Where_Select_Min(bool async)
+ {
+ return AssertQueryScalar(
+ async,
+ ss => from o in ss.Set()
+ group o by new { o.CustomerID }
+ into g
+ select g.Where(e => e.OrderID < 10300)
+ .Select(e => new { e.OrderID, e.OrderDate })
+ .Where(e => e.OrderDate.HasValue && e.OrderDate.Value.Year == 1997)
+ .Select(e => (int?)e.OrderID).Min());
+ }
+
+ [ConditionalTheory]
+ [MemberData(nameof(IsAsyncData))]
+ public virtual Task GroupBy_multiple_Count_with_predicate(bool async)
+ {
+ return AssertQuery(
+ async,
+ ss => from o in ss.Set()
+ group o by new { o.CustomerID }
+ into g
+ select new
+ {
+ g.Key,
+ All = g.Count(),
+ TenK = g.Count(e => e.OrderID < 11000),
+ EleventK = g.Count(e => e.OrderID < 12000)
+ },
+ elementSorter: e => e.Key.CustomerID);
+ }
+
+ [ConditionalTheory]
+ [MemberData(nameof(IsAsyncData))]
+ public virtual Task GroupBy_multiple_Sum_with_conditional_projection(bool async)
+ {
+ return AssertQuery(
+ async,
+ ss => from o in ss.Set()
+ group o by new { o.CustomerID }
+ into g
+ select new
+ {
+ g.Key,
+ TenK = g.Sum(e => e.OrderID < 11000 ? e.OrderID : 0),
+ EleventK = g.Sum(e => e.OrderID >= 11000 ? e.OrderID : 0)
+ },
+ elementSorter: e => e.Key.CustomerID);
+ }
+
+ [ConditionalTheory]
+ [MemberData(nameof(IsAsyncData))]
+ public virtual Task GroupBy_multiple_Sum_with_Select_conditional_projection(bool async)
+ {
+ return AssertQuery(
+ async,
+ ss => from o in ss.Set()
+ group o by new { o.CustomerID }
+ into g
+ select new
+ {
+ g.Key,
+ TenK = g.Select(e => e.OrderID < 11000 ? e.OrderID : 0).Sum(),
+ EleventK = g.Select(e => e.OrderID >= 11000 ? e.OrderID : 0).Sum()
+ },
+ elementSorter: e => e.Key.CustomerID);
+ }
+
[ConditionalTheory]
[MemberData(nameof(IsAsyncData))]
public virtual Task GroupBy_Key_as_part_of_element_selector(bool async)
@@ -2460,9 +2627,9 @@ public virtual Task Count_after_GroupBy_aggregate(bool async)
ss => ss.Set().GroupBy(o => o.CustomerID).Select(g => g.Sum(gg => gg.OrderID)).CountAsync(default));
}
- [ConditionalTheory(Skip = "Issue #18836")]
+ [ConditionalTheory]
[MemberData(nameof(IsAsyncData))]
- public virtual Task LongCount_after_client_GroupBy(bool async)
+ public virtual Task LongCount_after_GroupBy_aggregate(bool async)
{
return AssertSingleResult(
async,
@@ -2662,7 +2829,7 @@ public virtual Task Complex_query_with_groupBy_in_subquery3(bool async)
}
// also 15279
- [ConditionalTheory(Skip = "issue #11711")]
+ [ConditionalTheory(Skip = "issue #15873")]
[MemberData(nameof(IsAsyncData))]
public virtual Task Complex_query_with_groupBy_in_subquery4(bool async)
{
diff --git a/test/EFCore.SqlServer.FunctionalTests/Query/NorthwindGroupByQuerySqlServerTest.cs b/test/EFCore.SqlServer.FunctionalTests/Query/NorthwindGroupByQuerySqlServerTest.cs
index 5fb1ad6505b..cfabc6d61a2 100644
--- a/test/EFCore.SqlServer.FunctionalTests/Query/NorthwindGroupByQuerySqlServerTest.cs
+++ b/test/EFCore.SqlServer.FunctionalTests/Query/NorthwindGroupByQuerySqlServerTest.cs
@@ -37,9 +37,9 @@ FROM [Orders] AS [o]
Fixture.TestSqlLoggerFactory.Log.Select(l => l.Message));
}
- public override async Task GroupBy_Property_Select_Average_with_navigation_expansion(bool async)
+ public override async Task GroupBy_Property_Select_Average_with_group_enumerable_projected(bool async)
{
- await base.GroupBy_Property_Select_Average_with_navigation_expansion(async);
+ await base.GroupBy_Property_Select_Average_with_group_enumerable_projected(async);
AssertSql(
@"");
@@ -644,7 +644,7 @@ public override async Task GroupBy_Property_scalar_element_selector_Count(bool a
await base.GroupBy_Property_scalar_element_selector_Count(async);
AssertSql(
- @"SELECT COUNT(*)
+ @"SELECT COUNT([o].[OrderID])
FROM [Orders] AS [o]
GROUP BY [o].[CustomerID]");
}
@@ -654,7 +654,7 @@ public override async Task GroupBy_Property_scalar_element_selector_LongCount(bo
await base.GroupBy_Property_scalar_element_selector_LongCount(async);
AssertSql(
- @"SELECT COUNT_BIG(*)
+ @"SELECT COUNT_BIG([o].[OrderID])
FROM [Orders] AS [o]
GROUP BY [o].[CustomerID]");
}
@@ -1513,19 +1513,181 @@ public override async Task Select_GroupBy_All(bool async)
await base.Select_GroupBy_All(async);
AssertSql(
- @"SELECT [o].[OrderID] AS [Order], [o].[CustomerID] AS [Customer]
+ @"SELECT CASE
+ WHEN NOT EXISTS (
+ SELECT 1
+ FROM [Orders] AS [o]
+ GROUP BY [o].[CustomerID]
+ HAVING ([o].[CustomerID] <> N'ALFKI') OR [o].[CustomerID] IS NULL) THEN CAST(1 AS bit)
+ ELSE CAST(0 AS bit)
+END");
+ }
+
+ public override async Task GroupBy_Where_Average(bool async)
+ {
+ await base.GroupBy_Where_Average(async);
+
+ AssertSql(
+ @"SELECT AVG(CAST(CASE
+ WHEN [o].[OrderID] < 10300 THEN [o].[OrderID]
+END AS float))
FROM [Orders] AS [o]
-ORDER BY [o].[CustomerID]");
+GROUP BY [o].[CustomerID]");
}
- public override async Task GroupBy_Where_in_aggregate(bool async)
+
+ public override async Task GroupBy_Where_Count(bool async)
{
- await base.GroupBy_Where_in_aggregate(async);
+ await base.GroupBy_Where_Count(async);
AssertSql(
- @"SELECT [o].[OrderID], [o].[CustomerID], [o].[EmployeeID], [o].[OrderDate]
+ @"SELECT COUNT(CASE
+ WHEN [o].[OrderID] < 10300 THEN 1
+END)
FROM [Orders] AS [o]
-ORDER BY [o].[CustomerID]");
+GROUP BY [o].[CustomerID]");
+ }
+
+ public override async Task GroupBy_Where_LongCount(bool async)
+ {
+ await base.GroupBy_Where_LongCount(async);
+
+ AssertSql(
+ @"SELECT COUNT_BIG(CASE
+ WHEN [o].[OrderID] < 10300 THEN 1
+END)
+FROM [Orders] AS [o]
+GROUP BY [o].[CustomerID]");
+ }
+
+ public override async Task GroupBy_Where_Max(bool async)
+ {
+ await base.GroupBy_Where_Max(async);
+
+ AssertSql(
+ @"SELECT MAX(CASE
+ WHEN [o].[OrderID] < 10300 THEN [o].[OrderID]
+END)
+FROM [Orders] AS [o]
+GROUP BY [o].[CustomerID]");
+ }
+
+ public override async Task GroupBy_Where_Min(bool async)
+ {
+ await base.GroupBy_Where_Min(async);
+
+ AssertSql(
+ @"SELECT MIN(CASE
+ WHEN [o].[OrderID] < 10300 THEN [o].[OrderID]
+END)
+FROM [Orders] AS [o]
+GROUP BY [o].[CustomerID]");
+ }
+
+ public override async Task GroupBy_Where_Sum(bool async)
+ {
+ await base.GroupBy_Where_Sum(async);
+
+ AssertSql(
+ @"SELECT COALESCE(SUM(CASE
+ WHEN [o].[OrderID] < 10300 THEN [o].[OrderID]
+END), 0)
+FROM [Orders] AS [o]
+GROUP BY [o].[CustomerID]");
+ }
+
+ public override async Task GroupBy_Where_Count_with_predicate(bool async)
+ {
+ await base.GroupBy_Where_Count_with_predicate(async);
+
+ AssertSql(
+ @"SELECT COUNT(CASE
+ WHEN ([o].[OrderID] < 10300) AND ([o].[OrderDate] IS NOT NULL AND (DATEPART(year, [o].[OrderDate]) = 1997)) THEN 1
+END)
+FROM [Orders] AS [o]
+GROUP BY [o].[CustomerID]");
+ }
+
+ public override async Task GroupBy_Where_Where_Count(bool async)
+ {
+ await base.GroupBy_Where_Where_Count(async);
+
+ AssertSql(
+ @"SELECT COUNT(CASE
+ WHEN ([o].[OrderID] < 10300) AND ([o].[OrderDate] IS NOT NULL AND (DATEPART(year, [o].[OrderDate]) = 1997)) THEN 1
+END)
+FROM [Orders] AS [o]
+GROUP BY [o].[CustomerID]");
+ }
+
+ public override async Task GroupBy_Where_Select_Where_Count(bool async)
+ {
+ await base.GroupBy_Where_Select_Where_Count(async);
+
+ AssertSql(
+ @"SELECT COUNT(CASE
+ WHEN ([o].[OrderID] < 10300) AND ([o].[OrderDate] IS NOT NULL AND (DATEPART(year, [o].[OrderDate]) = 1997)) THEN [o].[OrderDate]
+END)
+FROM [Orders] AS [o]
+GROUP BY [o].[CustomerID]");
+ }
+
+ public override async Task GroupBy_Where_Select_Where_Select_Min(bool async)
+ {
+ await base.GroupBy_Where_Select_Where_Select_Min(async);
+
+ AssertSql(
+ @"SELECT MIN(CASE
+ WHEN ([o].[OrderID] < 10300) AND ([o].[OrderDate] IS NOT NULL AND (DATEPART(year, [o].[OrderDate]) = 1997)) THEN [o].[OrderID]
+END)
+FROM [Orders] AS [o]
+GROUP BY [o].[CustomerID]");
+ }
+
+ public override async Task GroupBy_multiple_Count_with_predicate(bool async)
+ {
+ await base.GroupBy_multiple_Count_with_predicate(async);
+
+ AssertSql(
+ @"SELECT [o].[CustomerID], COUNT(*) AS [All], COUNT(CASE
+ WHEN [o].[OrderID] < 11000 THEN 1
+END) AS [TenK], COUNT(CASE
+ WHEN [o].[OrderID] < 12000 THEN 1
+END) AS [EleventK]
+FROM [Orders] AS [o]
+GROUP BY [o].[CustomerID]");
+ }
+
+ public override async Task GroupBy_multiple_Sum_with_conditional_projection(bool async)
+ {
+ await base.GroupBy_multiple_Sum_with_conditional_projection(async);
+
+ AssertSql(
+ @"SELECT [o].[CustomerID], COALESCE(SUM(CASE
+ WHEN [o].[OrderID] < 11000 THEN [o].[OrderID]
+ ELSE 0
+END), 0) AS [TenK], COALESCE(SUM(CASE
+ WHEN [o].[OrderID] >= 11000 THEN [o].[OrderID]
+ ELSE 0
+END), 0) AS [EleventK]
+FROM [Orders] AS [o]
+GROUP BY [o].[CustomerID]");
+ }
+
+ public override async Task GroupBy_multiple_Sum_with_Select_conditional_projection(bool async)
+ {
+ await base.GroupBy_multiple_Sum_with_Select_conditional_projection(async);
+
+ AssertSql(
+ @"SELECT [o].[CustomerID], COALESCE(SUM(CASE
+ WHEN [o].[OrderID] < 11000 THEN [o].[OrderID]
+ ELSE 0
+END), 0) AS [TenK], COALESCE(SUM(CASE
+ WHEN [o].[OrderID] >= 11000 THEN [o].[OrderID]
+ ELSE 0
+END), 0) AS [EleventK]
+FROM [Orders] AS [o]
+GROUP BY [o].[CustomerID]");
}
public override async Task GroupBy_Key_as_part_of_element_selector(bool async)
@@ -1668,14 +1830,17 @@ GROUP BY [o].[CustomerID]
) AS [t]");
}
- public override async Task LongCount_after_client_GroupBy(bool async)
+ public override async Task LongCount_after_GroupBy_aggregate(bool async)
{
- await base.LongCount_after_client_GroupBy(async);
+ await base.LongCount_after_GroupBy_aggregate(async);
AssertSql(
- @"SELECT [o].[OrderID], [o].[CustomerID], [o].[EmployeeID], [o].[OrderDate]
-FROM [Orders] AS [o]
-ORDER BY [o].[CustomerID]");
+ @"SELECT COUNT_BIG(*)
+FROM (
+ SELECT [o].[CustomerID]
+ FROM [Orders] AS [o]
+ GROUP BY [o].[CustomerID]
+) AS [t]");
}
public override async Task MinMax_after_GroupBy_aggregate(bool async)
@@ -1947,16 +2112,28 @@ OFFSET @__p_0 ROWS
GROUP BY [t].[CustomerID]");
}
- public override Task GroupBy_Property_Select_Count_with_predicate(bool async)
+ public override async Task GroupBy_Property_Select_Count_with_predicate(bool async)
{
- return Assert.ThrowsAsync(
- () => base.GroupBy_Property_Select_Count_with_predicate(async));
+ await base.GroupBy_Property_Select_Count_with_predicate(async);
+
+ AssertSql(
+ @"SELECT COUNT(CASE
+ WHEN [o].[OrderID] < 10300 THEN 1
+END)
+FROM [Orders] AS [o]
+GROUP BY [o].[CustomerID]");
}
- public override Task GroupBy_Property_Select_LongCount_with_predicate(bool async)
+ public override async Task GroupBy_Property_Select_LongCount_with_predicate(bool async)
{
- return Assert.ThrowsAsync(
- () => base.GroupBy_Property_Select_LongCount_with_predicate(async));
+ await base.GroupBy_Property_Select_LongCount_with_predicate(async);
+
+ AssertSql(
+ @"SELECT COUNT_BIG(CASE
+ WHEN [o].[OrderID] < 10300 THEN 1
+END)
+FROM [Orders] AS [o]
+GROUP BY [o].[CustomerID]");
}
public override async Task GroupBy_orderby_projection_with_coalesce_operation(bool async)
diff --git a/test/EFCore.Sqlite.FunctionalTests/Query/NorthwindGroupByQuerySqliteTest.cs b/test/EFCore.Sqlite.FunctionalTests/Query/NorthwindGroupByQuerySqliteTest.cs
index fbe2b67c34d..cdbba5b44a8 100644
--- a/test/EFCore.Sqlite.FunctionalTests/Query/NorthwindGroupByQuerySqliteTest.cs
+++ b/test/EFCore.Sqlite.FunctionalTests/Query/NorthwindGroupByQuerySqliteTest.cs
@@ -18,17 +18,5 @@ public NorthwindGroupByQuerySqliteTest(NorthwindQuerySqliteFixture(
- () => base.GroupBy_Property_Select_Count_with_predicate(async));
- }
-
- public override Task GroupBy_Property_Select_LongCount_with_predicate(bool async)
- {
- return Assert.ThrowsAsync(
- () => base.GroupBy_Property_Select_LongCount_with_predicate(async));
- }
}
}