diff --git a/src/EFCore.InMemory/Query/Internal/InMemoryExpressionTranslatingExpressionVisitor.cs b/src/EFCore.InMemory/Query/Internal/InMemoryExpressionTranslatingExpressionVisitor.cs index 7df03ed425a..41d096e459e 100644 --- a/src/EFCore.InMemory/Query/Internal/InMemoryExpressionTranslatingExpressionVisitor.cs +++ b/src/EFCore.InMemory/Query/Internal/InMemoryExpressionTranslatingExpressionVisitor.cs @@ -268,6 +268,12 @@ protected override Expression VisitExtension(Expression extensionExpression) .GetMappedProjection(projectionBindingExpression.ProjectionMember) : null; + case InMemoryGroupByShaperExpression inMemoryGroupByShaperExpression: + return new GroupingElementExpression( + inMemoryGroupByShaperExpression.GroupingParameter, + inMemoryGroupByShaperExpression.ElementSelector, + inMemoryGroupByShaperExpression.ValueBufferParameter); + default: return null; } @@ -392,79 +398,193 @@ protected override Expression VisitMethodCall(MethodCallExpression methodCallExp // GroupBy Aggregate case if (methodCallExpression.Object == null && methodCallExpression.Method.DeclaringType == typeof(Enumerable) - && methodCallExpression.Arguments.Count > 0 - && methodCallExpression.Arguments[0] is InMemoryGroupByShaperExpression groupByShaperExpression) + && methodCallExpression.Arguments.Count > 0) { - var methodName = methodCallExpression.Method.Name; - switch (methodName) + if (methodCallExpression.Arguments[0].Type.TryGetElementType(typeof(IQueryable<>)) == null + && Visit(methodCallExpression.Arguments[0]) is GroupingElementExpression groupingElementExpression) { - case nameof(Enumerable.Average): - case nameof(Enumerable.Max): - case nameof(Enumerable.Min): - case nameof(Enumerable.Sum): + Expression result = null; + switch (methodCallExpression.Method.Name) { - var translation = TranslateInternal(GetSelectorOnGrouping(methodCallExpression, groupByShaperExpression)); - if (translation == null) + case nameof(Enumerable.Average): { - return null; + if (methodCallExpression.Arguments.Count == 2) + { + groupingElementExpression = ApplySelector( + groupingElementExpression, methodCallExpression.Arguments[1].UnwrapLambdaFromQuote()); + } + + var selector = GetSelector(groupingElementExpression); + + result = selector == null + ? null + : Expression.Call( + EnumerableMethods.GetAverageWithSelector(selector.ReturnType).MakeGenericMethod(typeof(ValueBuffer)), + groupingElementExpression.Source, + selector); + break; } - var selector = Expression.Lambda(translation, groupByShaperExpression.ValueBufferParameter); - var method2 = GetMethod(); - method2 = method2.GetGenericArguments().Length == 2 - ? method2.MakeGenericMethod(typeof(ValueBuffer), selector.ReturnType) - : method2.MakeGenericMethod(typeof(ValueBuffer)); + case nameof(Enumerable.Count): + { + if (methodCallExpression.Arguments.Count == 2) + { + groupingElementExpression = ApplyPredicate( + groupingElementExpression, methodCallExpression.Arguments[1].UnwrapLambdaFromQuote()); + + if (groupingElementExpression == null) + { + result = null; + break; + } + } + + result = Expression.Call( + EnumerableMethods.CountWithoutPredicate.MakeGenericMethod(typeof(ValueBuffer)), + groupingElementExpression.Source); + break; + } - return Expression.Call( - method2, - groupByShaperExpression.GroupingParameter, - selector); + case nameof(Enumerable.LongCount): + { + if (methodCallExpression.Arguments.Count == 2) + { + groupingElementExpression = ApplyPredicate( + groupingElementExpression, methodCallExpression.Arguments[1].UnwrapLambdaFromQuote()); + + if (groupingElementExpression == null) + { + result = null; + break; + } + } + + result = Expression.Call( + EnumerableMethods.LongCountWithoutPredicate.MakeGenericMethod(typeof(ValueBuffer)), + groupingElementExpression.Source); + break; + } - MethodInfo GetMethod() - => methodName switch + case nameof(Enumerable.Max): + { + if (methodCallExpression.Arguments.Count == 2) { - nameof(Enumerable.Average) => EnumerableMethods.GetAverageWithSelector(selector.ReturnType), - nameof(Enumerable.Max) => EnumerableMethods.GetMaxWithSelector(selector.ReturnType), - nameof(Enumerable.Min) => EnumerableMethods.GetMinWithSelector(selector.ReturnType), - nameof(Enumerable.Sum) => EnumerableMethods.GetSumWithSelector(selector.ReturnType), - _ => throw new InvalidOperationException(InMemoryStrings.InvalidStateEncountered("Aggregate Operator")), - }; - } + groupingElementExpression = ApplySelector( + groupingElementExpression, methodCallExpression.Arguments[1].UnwrapLambdaFromQuote()); + } - case nameof(Enumerable.Count): - case nameof(Enumerable.LongCount): - { - var countMethod = string.Equals(methodName, nameof(Enumerable.Count)); - var predicate = GetPredicateOnGrouping(methodCallExpression, groupByShaperExpression); - if (predicate == null) + var selector = GetSelector(groupingElementExpression); + if (selector != null) + { + var aggregateMethod = EnumerableMethods.GetMaxWithSelector(selector.ReturnType); + aggregateMethod = aggregateMethod.GetGenericArguments().Length == 2 + ? aggregateMethod.MakeGenericMethod(typeof(ValueBuffer), selector.ReturnType) + : aggregateMethod.MakeGenericMethod(typeof(ValueBuffer)); + + + result = Expression.Call(aggregateMethod, groupingElementExpression.Source, selector); + } + else + { + result = null; + } + + break; + } + + case nameof(Enumerable.Min): { - return Expression.Call( - (countMethod - ? EnumerableMethods.CountWithoutPredicate - : EnumerableMethods.LongCountWithoutPredicate) - .MakeGenericMethod(typeof(ValueBuffer)), - groupByShaperExpression.GroupingParameter); + if (methodCallExpression.Arguments.Count == 2) + { + groupingElementExpression = ApplySelector( + groupingElementExpression, methodCallExpression.Arguments[1].UnwrapLambdaFromQuote()); + } + + var selector = GetSelector(groupingElementExpression); + + if (selector != null) + { + var aggregateMethod = EnumerableMethods.GetMinWithSelector(selector.ReturnType); + aggregateMethod = aggregateMethod.GetGenericArguments().Length == 2 + ? aggregateMethod.MakeGenericMethod(typeof(ValueBuffer), selector.ReturnType) + : aggregateMethod.MakeGenericMethod(typeof(ValueBuffer)); + + + result = Expression.Call(aggregateMethod, groupingElementExpression.Source, selector); + } + else + { + result = null; + } + + break; } - var translation = TranslateInternal(predicate); - if (translation == null) + case nameof(Enumerable.Select): + result = ApplySelector(groupingElementExpression, methodCallExpression.Arguments[1].UnwrapLambdaFromQuote()); + break; + + case nameof(Enumerable.Sum): { - return null; + if (methodCallExpression.Arguments.Count == 2) + { + groupingElementExpression = ApplySelector( + groupingElementExpression, methodCallExpression.Arguments[1].UnwrapLambdaFromQuote()); + } + + var selector = GetSelector(groupingElementExpression); + + result = selector == null + ? null + : Expression.Call( + EnumerableMethods.GetSumWithSelector(selector.ReturnType).MakeGenericMethod(typeof(ValueBuffer)), + groupingElementExpression.Source, + selector); + break; } - predicate = Expression.Lambda(translation, groupByShaperExpression.ValueBufferParameter); + case nameof(Enumerable.Where): + result = ApplyPredicate(groupingElementExpression, methodCallExpression.Arguments[1].UnwrapLambdaFromQuote()); + break; - return Expression.Call( - (countMethod - ? EnumerableMethods.CountWithPredicate - : EnumerableMethods.LongCountWithPredicate) - .MakeGenericMethod(typeof(ValueBuffer)), - groupByShaperExpression.GroupingParameter, - predicate); + default: + result = null; + break; } - default: - throw new InvalidOperationException(CoreStrings.TranslationFailed(methodCallExpression.Print())); + return result ?? throw new InvalidOperationException(CoreStrings.TranslationFailed(methodCallExpression.Print())); + + GroupingElementExpression ApplyPredicate(GroupingElementExpression groupingElement, LambdaExpression lambdaExpression) + { + var predicate = TranslateInternal(RemapLambda(groupingElement, lambdaExpression)); + + return predicate == null + ? null + : groupingElement.UpdateSource( + Expression.Call( + EnumerableMethods.Where.MakeGenericMethod(typeof(ValueBuffer)), + groupingElement.Source, + Expression.Lambda(predicate, groupingElement.ValueBufferParameter))); + } + + LambdaExpression GetSelector(GroupingElementExpression groupingElement) + { + var selector = TranslateInternal(groupingElement.Selector); + return selector == null + ? null + : Expression.Lambda(selector, groupingElement.ValueBufferParameter); + } + + static GroupingElementExpression ApplySelector(GroupingElementExpression groupingElement, LambdaExpression lambdaExpression) + { + var selector = RemapLambda(groupingElement, lambdaExpression); + + return groupingElement.ApplySelector(selector); + } + + static Expression RemapLambda(GroupingElementExpression groupingElement, LambdaExpression lambdaExpression) + => ReplacingExpressionVisitor.Replace( + lambdaExpression.Parameters[0], groupingElement.Selector, lambdaExpression.Body); } } @@ -997,46 +1117,6 @@ private static Expression ConvertToNonNullable(Expression expression) ? Expression.Convert(expression, expression.Type.UnwrapNullableType()) : expression; - private static Expression GetSelectorOnGrouping( - MethodCallExpression methodCallExpression, GroupByShaperExpression groupByShaperExpression) - { - if (methodCallExpression.Arguments.Count == 1) - { - return groupByShaperExpression.ElementSelector; - } - - if (methodCallExpression.Arguments.Count == 2) - { - var selectorLambda = methodCallExpression.Arguments[1].UnwrapLambdaFromQuote(); - return ReplacingExpressionVisitor.Replace( - selectorLambda.Parameters[0], - groupByShaperExpression.ElementSelector, - selectorLambda.Body); - } - - throw new InvalidOperationException(CoreStrings.TranslationFailed(methodCallExpression.Print())); - } - - private Expression GetPredicateOnGrouping( - MethodCallExpression methodCallExpression, GroupByShaperExpression groupByShaperExpression) - { - if (methodCallExpression.Arguments.Count == 1) - { - return null; - } - - if (methodCallExpression.Arguments.Count == 2) - { - var selectorLambda = methodCallExpression.Arguments[1].UnwrapLambdaFromQuote(); - return ReplacingExpressionVisitor.Replace( - selectorLambda.Parameters[0], - groupByShaperExpression.ElementSelector, - selectorLambda.Body); - } - - throw new InvalidOperationException(CoreStrings.TranslationFailed(methodCallExpression.Print())); - } - private IProperty FindProperty(Expression expression) { if (expression.NodeType == ExpressionType.Convert @@ -1481,5 +1561,35 @@ public Expression Convert(Type type) return derivedEntityType == null ? null : new EntityReferenceExpression(this, derivedEntityType); } } + + private sealed class GroupingElementExpression : Expression + { + public GroupingElementExpression(Expression source, Expression selector, ParameterExpression valueBufferParameter) + { + Source = source; + ValueBufferParameter = valueBufferParameter; + Selector = selector; + } + public Expression Source { get; private set; } + public Expression Selector { get; private set; } + public ParameterExpression ValueBufferParameter { get; } + + public GroupingElementExpression ApplySelector(Expression expression) + { + Selector = expression; + + return this; + } + + public GroupingElementExpression UpdateSource(Expression source) + { + Source = source; + + return this; + } + + public override Type Type => typeof(IEnumerable<>).MakeGenericType(Selector.Type); + public override ExpressionType NodeType => ExpressionType.Extension; + } } } diff --git a/src/EFCore.Relational/Query/RelationalQueryableMethodTranslatingExpressionVisitor.cs b/src/EFCore.Relational/Query/RelationalQueryableMethodTranslatingExpressionVisitor.cs index 05330794b8e..8f60facc6cd 100644 --- a/src/EFCore.Relational/Query/RelationalQueryableMethodTranslatingExpressionVisitor.cs +++ b/src/EFCore.Relational/Query/RelationalQueryableMethodTranslatingExpressionVisitor.cs @@ -260,7 +260,13 @@ protected override ShapedQueryExpression TranslateAverage(ShapedQueryExpression return null; } - var projection = _sqlTranslator.TranslateAverage(newSelector); + var translatedSelector = TranslateExpression(newSelector); + if (translatedSelector == null) + { + return null; + } + + var projection = _sqlTranslator.TranslateAverage(translatedSelector); return projection != null ? AggregateResultShaper(source, projection, throwWhenEmpty: true, resultType) : null; @@ -336,7 +342,7 @@ protected override ShapedQueryExpression TranslateCount(ShapedQueryExpression so HandleGroupByForAggregate(selectExpression, eraseProjection: true); - var translation = _sqlTranslator.TranslateCount(); + var translation = _sqlTranslator.TranslateCount(_sqlExpressionFactory.Fragment("*")); if (translation == null) { return null; @@ -692,7 +698,7 @@ protected override ShapedQueryExpression TranslateLongCount(ShapedQueryExpressio HandleGroupByForAggregate(selectExpression, eraseProjection: true); - var translation = _sqlTranslator.TranslateLongCount(); + var translation = _sqlTranslator.TranslateLongCount(_sqlExpressionFactory.Fragment("*")); if (translation == null) { return null; @@ -730,7 +736,13 @@ protected override ShapedQueryExpression TranslateMax(ShapedQueryExpression sour return null; } - var projection = _sqlTranslator.TranslateMax(newSelector); + var translatedSelector = TranslateExpression(newSelector); + if (translatedSelector == null) + { + return null; + } + + var projection = _sqlTranslator.TranslateMax(translatedSelector); return AggregateResultShaper(source, projection, throwWhenEmpty: true, resultType); } @@ -756,7 +768,13 @@ protected override ShapedQueryExpression TranslateMin(ShapedQueryExpression sour return null; } - var projection = _sqlTranslator.TranslateMin(newSelector); + var translatedSelector = TranslateExpression(newSelector); + if (translatedSelector == null) + { + return null; + } + + var projection = _sqlTranslator.TranslateMin(translatedSelector); return AggregateResultShaper(source, projection, throwWhenEmpty: true, resultType); } @@ -1053,7 +1071,13 @@ protected override ShapedQueryExpression TranslateSum(ShapedQueryExpression sour return null; } - var projection = _sqlTranslator.TranslateSum(newSelector); + var translatedSelector = TranslateExpression(newSelector); + if (translatedSelector == null) + { + return null; + } + + var projection = _sqlTranslator.TranslateSum(translatedSelector); return projection != null ? AggregateResultShaper(source, projection, throwWhenEmpty: false, resultType) : null; diff --git a/src/EFCore.Relational/Query/RelationalSqlTranslatingExpressionVisitor.cs b/src/EFCore.Relational/Query/RelationalSqlTranslatingExpressionVisitor.cs index d9c8a5daa9f..81a47af521f 100644 --- a/src/EFCore.Relational/Query/RelationalSqlTranslatingExpressionVisitor.cs +++ b/src/EFCore.Relational/Query/RelationalSqlTranslatingExpressionVisitor.cs @@ -146,24 +146,11 @@ private SqlExpression TranslateInternal(Expression expression) /// /// Translates Average over an expression to an equivalent SQL representation. /// - /// An expression to translate Average over. + /// An expression to translate Average over. /// A SQL translation of Average over the given expression. - public virtual SqlExpression TranslateAverage([NotNull] Expression expression) + public virtual SqlExpression TranslateAverage([NotNull] SqlExpression sqlExpression) { - Check.NotNull(expression, nameof(expression)); - - if (!(expression is SqlExpression sqlExpression)) - { - sqlExpression = TranslateInternal(expression); - } - - if (sqlExpression == null) - { - throw new InvalidOperationException( - TranslationErrorDetails == null - ? CoreStrings.TranslationFailed(expression.Print()) - : CoreStrings.TranslationFailedWithDetails(expression.Print(), TranslationErrorDetails)); - } + Check.NotNull(sqlExpression, nameof(sqlExpression)); var inputType = sqlExpression.Type; if (inputType == typeof(int) @@ -195,20 +182,16 @@ public virtual SqlExpression TranslateAverage([NotNull] Expression expression) /// /// Translates Count over an expression to an equivalent SQL representation. /// - /// An expression to translate Count over. + /// An expression to translate Count over. /// A SQL translation of Count over the given expression. - public virtual SqlExpression TranslateCount([CanBeNull] Expression expression = null) + public virtual SqlExpression TranslateCount([NotNull] SqlExpression sqlExpression) { - if (expression != null) - { - // TODO: Translate Count with predicate for GroupBy - return null; - } + Check.NotNull(sqlExpression, nameof(sqlExpression)); return _sqlExpressionFactory.ApplyDefaultTypeMapping( _sqlExpressionFactory.Function( "COUNT", - new[] { _sqlExpressionFactory.Fragment("*") }, + new[] { sqlExpression }, nullable: false, argumentsPropagateNullability: new[] { false }, typeof(int))); @@ -217,20 +200,16 @@ public virtual SqlExpression TranslateCount([CanBeNull] Expression expression = /// /// Translates LongCount over an expression to an equivalent SQL representation. /// - /// An expression to translate LongCount over. + /// An expression to translate LongCount over. /// A SQL translation of LongCount over the given expression. - public virtual SqlExpression TranslateLongCount([CanBeNull] Expression expression = null) + public virtual SqlExpression TranslateLongCount([NotNull] SqlExpression sqlExpression) { - if (expression != null) - { - // TODO: Translate Count with predicate for GroupBy - return null; - } + Check.NotNull(sqlExpression, nameof(sqlExpression)); return _sqlExpressionFactory.ApplyDefaultTypeMapping( _sqlExpressionFactory.Function( "COUNT", - new[] { _sqlExpressionFactory.Fragment("*") }, + new[] { sqlExpression }, nullable: false, argumentsPropagateNullability: new[] { false }, typeof(long))); @@ -239,16 +218,11 @@ public virtual SqlExpression TranslateLongCount([CanBeNull] Expression expressio /// /// Translates Max over an expression to an equivalent SQL representation. /// - /// An expression to translate Max over. + /// An expression to translate Max over. /// A SQL translation of Max over the given expression. - public virtual SqlExpression TranslateMax([NotNull] Expression expression) + public virtual SqlExpression TranslateMax([NotNull] SqlExpression sqlExpression) { - Check.NotNull(expression, nameof(expression)); - - if (!(expression is SqlExpression sqlExpression)) - { - sqlExpression = TranslateInternal(expression); - } + Check.NotNull(sqlExpression, nameof(sqlExpression)); return sqlExpression != null ? _sqlExpressionFactory.Function( @@ -264,16 +238,11 @@ public virtual SqlExpression TranslateMax([NotNull] Expression expression) /// /// Translates Min over an expression to an equivalent SQL representation. /// - /// An expression to translate Min over. + /// An expression to translate Min over. /// A SQL translation of Min over the given expression. - public virtual SqlExpression TranslateMin([NotNull] Expression expression) + public virtual SqlExpression TranslateMin([NotNull] SqlExpression sqlExpression) { - Check.NotNull(expression, nameof(expression)); - - if (!(expression is SqlExpression sqlExpression)) - { - sqlExpression = TranslateInternal(expression); - } + Check.NotNull(sqlExpression, nameof(sqlExpression)); return sqlExpression != null ? _sqlExpressionFactory.Function( @@ -289,24 +258,11 @@ public virtual SqlExpression TranslateMin([NotNull] Expression expression) /// /// Translates Sum over an expression to an equivalent SQL representation. /// - /// An expression to translate Sum over. + /// An expression to translate Sum over. /// A SQL translation of Sum over the given expression. - public virtual SqlExpression TranslateSum([NotNull] Expression expression) + public virtual SqlExpression TranslateSum([NotNull] SqlExpression sqlExpression) { - Check.NotNull(expression, nameof(expression)); - - if (!(expression is SqlExpression sqlExpression)) - { - sqlExpression = TranslateInternal(expression); - } - - if (sqlExpression == null) - { - throw new InvalidOperationException( - TranslationErrorDetails == null - ? CoreStrings.TranslationFailed(expression.Print()) - : CoreStrings.TranslationFailedWithDetails(expression.Print(), TranslationErrorDetails)); - } + Check.NotNull(sqlExpression, nameof(sqlExpression)); var inputType = sqlExpression.Type; @@ -448,6 +404,9 @@ protected override Expression VisitExtension(Expression extensionExpression) .GetMappedProjection(projectionBindingExpression.ProjectionMember) : null; + case GroupByShaperExpression groupByShaperExpression: + return new GroupingElementExpression(groupByShaperExpression.ElementSelector); + default: return null; } @@ -498,29 +457,160 @@ protected override Expression VisitMethodCall(MethodCallExpression methodCallExp // GroupBy Aggregate case if (methodCallExpression.Object == null && methodCallExpression.Method.DeclaringType == typeof(Enumerable) - && methodCallExpression.Arguments.Count > 0 - && methodCallExpression.Arguments[0] is GroupByShaperExpression groupByShaperExpression) + && methodCallExpression.Arguments.Count > 0) { - var translatedAggregate = methodCallExpression.Method.Name switch - { - nameof(Enumerable.Average) => TranslateAverage(GetSelectorOnGrouping(methodCallExpression, groupByShaperExpression)), - nameof(Enumerable.Count) => TranslateCount(GetPredicateOnGrouping(methodCallExpression, groupByShaperExpression)), - nameof(Enumerable.LongCount) => TranslateLongCount(GetPredicateOnGrouping(methodCallExpression, groupByShaperExpression)), - nameof(Enumerable.Max) => TranslateMax(GetSelectorOnGrouping(methodCallExpression, groupByShaperExpression)), - nameof(Enumerable.Min) => TranslateMin(GetSelectorOnGrouping(methodCallExpression, groupByShaperExpression)), - nameof(Enumerable.Sum) => TranslateSum(GetSelectorOnGrouping(methodCallExpression, groupByShaperExpression)), - _ => null - }; - - if (translatedAggregate == null) + if (methodCallExpression.Arguments[0].Type.TryGetElementType(typeof(IQueryable<>)) == null + && Visit(methodCallExpression.Arguments[0]) is GroupingElementExpression groupingElementExpression) { - throw new InvalidOperationException( - TranslationErrorDetails == null - ? CoreStrings.TranslationFailed(methodCallExpression.Print()) - : CoreStrings.TranslationFailedWithDetails(methodCallExpression.Print(), TranslationErrorDetails)); - } + Expression result; + switch (methodCallExpression.Method.Name) + { + case nameof(Enumerable.Average): + if (methodCallExpression.Arguments.Count == 2) + { + groupingElementExpression = ApplySelector( + groupingElementExpression, methodCallExpression.Arguments[1].UnwrapLambdaFromQuote()); + } + + result = GetExpressionForAggregation(groupingElementExpression) is SqlExpression averageExpression + ? TranslateAverage(averageExpression) + : null; + break; + + case nameof(Enumerable.Count): + if (methodCallExpression.Arguments.Count == 2) + { + groupingElementExpression = ApplyPredicate( + groupingElementExpression, methodCallExpression.Arguments[1].UnwrapLambdaFromQuote()); + if (groupingElementExpression == null) + { + result = null; + break; + } + } + + result = TranslateCount(GetExpressionForAggregation(groupingElementExpression, starProjection: true)); + break; + + case nameof(Enumerable.LongCount): + if (methodCallExpression.Arguments.Count == 2) + { + groupingElementExpression = ApplyPredicate( + groupingElementExpression, methodCallExpression.Arguments[1].UnwrapLambdaFromQuote()); + if (groupingElementExpression == null) + { + result = null; + break; + } + } + + result = TranslateLongCount(GetExpressionForAggregation(groupingElementExpression, starProjection: true)); + break; + + case nameof(Enumerable.Max): + if (methodCallExpression.Arguments.Count == 2) + { + groupingElementExpression = ApplySelector( + groupingElementExpression, methodCallExpression.Arguments[1].UnwrapLambdaFromQuote()); + } + + result = GetExpressionForAggregation(groupingElementExpression) is SqlExpression maxExpression + ? TranslateMax(maxExpression) + : null; + break; + + case nameof(Enumerable.Min): + if (methodCallExpression.Arguments.Count == 2) + { + groupingElementExpression = ApplySelector( + groupingElementExpression, methodCallExpression.Arguments[1].UnwrapLambdaFromQuote()); + } + + result = GetExpressionForAggregation(groupingElementExpression) is SqlExpression minExpression + ? TranslateMin(minExpression) + : null; + break; + + case nameof(Enumerable.Select): + result = ApplySelector(groupingElementExpression, methodCallExpression.Arguments[1].UnwrapLambdaFromQuote()); + break; + + case nameof(Enumerable.Sum): + if (methodCallExpression.Arguments.Count == 2) + { + groupingElementExpression = ApplySelector( + groupingElementExpression, methodCallExpression.Arguments[1].UnwrapLambdaFromQuote()); + } + + result = GetExpressionForAggregation(groupingElementExpression) is SqlExpression sumExpression + ? TranslateSum(sumExpression) + : null; + break; + + case nameof(Enumerable.Where): + result = ApplyPredicate(groupingElementExpression, methodCallExpression.Arguments[1].UnwrapLambdaFromQuote()); + break; + + default: + result = null; + break; + } + + return result ?? throw new InvalidOperationException(CoreStrings.TranslationFailed(methodCallExpression.Print())); + + GroupingElementExpression ApplyPredicate(GroupingElementExpression groupingElement, LambdaExpression lambdaExpression) + { + var predicate = TranslateInternal(RemapLambda(groupingElement, lambdaExpression)); + + return predicate == null + ? null + : groupingElement.ApplyPredicate(predicate); + } + + static GroupingElementExpression ApplySelector(GroupingElementExpression groupingElement, LambdaExpression lambdaExpression) + { + var selector = RemapLambda(groupingElement, lambdaExpression); + + return groupingElement.ApplySelector(selector); + } + + static Expression RemapLambda(GroupingElementExpression groupingElement, LambdaExpression lambdaExpression) + => ReplacingExpressionVisitor.Replace( + lambdaExpression.Parameters[0], groupingElement.Element, lambdaExpression.Body); + + SqlExpression GetExpressionForAggregation(GroupingElementExpression groupingElement, bool starProjection = false) + { + var selector = TranslateInternal(groupingElement.Element); + if (selector == null) + { + if (starProjection) + { + selector = _sqlExpressionFactory.Fragment("*"); + } + else + { + return null; + } + } + + if (groupingElement.Predicate != null) + { + if (selector is SqlFragmentExpression) + { + selector = _sqlExpressionFactory.Constant(1); + } - return translatedAggregate; + return _sqlExpressionFactory.Case( + new List + { + new CaseWhenClause(groupingElement.Predicate, selector) + }, + elseResult: null); + } + + return selector; + } + } } // Subquery case @@ -990,46 +1080,6 @@ private SqlExpression BindProperty(EntityReferenceExpression entityReferenceExpr return null; } - private static Expression GetSelectorOnGrouping( - MethodCallExpression methodCallExpression, GroupByShaperExpression groupByShaperExpression) - { - if (methodCallExpression.Arguments.Count == 1) - { - return groupByShaperExpression.ElementSelector; - } - - if (methodCallExpression.Arguments.Count == 2) - { - var selectorLambda = methodCallExpression.Arguments[1].UnwrapLambdaFromQuote(); - return ReplacingExpressionVisitor.Replace( - selectorLambda.Parameters[0], - groupByShaperExpression.ElementSelector, - selectorLambda.Body); - } - - throw new InvalidOperationException(CoreStrings.TranslationFailed(methodCallExpression.Print())); - } - - private static Expression GetPredicateOnGrouping( - MethodCallExpression methodCallExpression, GroupByShaperExpression groupByShaperExpression) - { - if (methodCallExpression.Arguments.Count == 1) - { - return null; - } - - if (methodCallExpression.Arguments.Count == 2) - { - var selectorLambda = methodCallExpression.Arguments[1].UnwrapLambdaFromQuote(); - return ReplacingExpressionVisitor.Replace( - selectorLambda.Parameters[0], - groupByShaperExpression.ElementSelector, - selectorLambda.Body); - } - - throw new InvalidOperationException(CoreStrings.TranslationFailed(methodCallExpression.Print())); - } - private static Expression TryRemoveImplicitConvert(Expression expression) { if (expression is UnaryExpression unaryExpression @@ -1412,6 +1462,56 @@ public Expression Convert(Type type) } } + private sealed class GroupingElementExpression : Expression + { + public GroupingElementExpression(Expression element) + { + Element = element; + } + public Expression Element { get; private set; } + public bool IsDistinct { get; private set; } + public SqlExpression Predicate { get; private set; } + + public GroupingElementExpression ApplyDistinct() + { + IsDistinct = true; + + return this; + } + + public GroupingElementExpression ApplySelector(Expression expression) + { + Element = expression; + + return this; + } + + public GroupingElementExpression ApplyPredicate(SqlExpression expression) + { + Check.NotNull(expression, nameof(expression)); + + if (expression is SqlConstantExpression sqlConstant + && sqlConstant.Value is bool boolValue + && boolValue) + { + return this; + } + + Predicate = Predicate == null + ? expression + : new SqlBinaryExpression( + ExpressionType.AndAlso, + Predicate, + expression, + typeof(bool), + expression.TypeMapping); + + return this; + } + public override Type Type => typeof(IEnumerable<>).MakeGenericType(Element.Type); + public override ExpressionType NodeType => ExpressionType.Extension; + } + private sealed class SqlTypeMappingVerifyingExpressionVisitor : ExpressionVisitor { protected override Expression VisitExtension(Expression extensionExpression) diff --git a/src/EFCore.SqlServer/Query/Internal/SqlServerSqlTranslatingExpressionVisitor.cs b/src/EFCore.SqlServer/Query/Internal/SqlServerSqlTranslatingExpressionVisitor.cs index 643eadc10ac..fdd5f952b00 100644 --- a/src/EFCore.SqlServer/Query/Internal/SqlServerSqlTranslatingExpressionVisitor.cs +++ b/src/EFCore.SqlServer/Query/Internal/SqlServerSqlTranslatingExpressionVisitor.cs @@ -110,18 +110,14 @@ protected override Expression VisitUnary(UnaryExpression unaryExpression) /// any release. You should only use it directly in your code with extreme caution and knowing that /// doing so can result in application failures when updating to a new Entity Framework Core release. /// - public override SqlExpression TranslateLongCount(Expression expression = null) + public override SqlExpression TranslateLongCount(SqlExpression sqlExpression) { - if (expression != null) - { - // TODO: Translate Count with predicate for GroupBy - return null; - } + Check.NotNull(sqlExpression, nameof(sqlExpression)); return Dependencies.SqlExpressionFactory.ApplyDefaultTypeMapping( Dependencies.SqlExpressionFactory.Function( "COUNT_BIG", - new[] { Dependencies.SqlExpressionFactory.Fragment("*") }, + new[] { sqlExpression }, nullable: false, argumentsPropagateNullability: new[] { false }, typeof(long))); diff --git a/src/EFCore.Sqlite.Core/Query/Internal/SqliteSqlTranslatingExpressionVisitor.cs b/src/EFCore.Sqlite.Core/Query/Internal/SqliteSqlTranslatingExpressionVisitor.cs index 662b3d158c8..4b72a4f0f74 100644 --- a/src/EFCore.Sqlite.Core/Query/Internal/SqliteSqlTranslatingExpressionVisitor.cs +++ b/src/EFCore.Sqlite.Core/Query/Internal/SqliteSqlTranslatingExpressionVisitor.cs @@ -196,11 +196,11 @@ protected override Expression VisitBinary(BinaryExpression binaryExpression) /// any release. You should only use it directly in your code with extreme caution and knowing that /// doing so can result in application failures when updating to a new Entity Framework Core release. /// - public override SqlExpression TranslateAverage(Expression expression) + public override SqlExpression TranslateAverage(SqlExpression sqlExpression) { - Check.NotNull(expression, nameof(expression)); + Check.NotNull(sqlExpression, nameof(sqlExpression)); - var visitedExpression = base.TranslateAverage(expression); + var visitedExpression = base.TranslateAverage(sqlExpression); var argumentType = GetProviderType(visitedExpression); if (argumentType == typeof(decimal)) { @@ -217,11 +217,11 @@ public override SqlExpression TranslateAverage(Expression expression) /// any release. You should only use it directly in your code with extreme caution and knowing that /// doing so can result in application failures when updating to a new Entity Framework Core release. /// - public override SqlExpression TranslateMax(Expression expression) + public override SqlExpression TranslateMax(SqlExpression sqlExpression) { - Check.NotNull(expression, nameof(expression)); + Check.NotNull(sqlExpression, nameof(sqlExpression)); - var visitedExpression = base.TranslateMax(expression); + var visitedExpression = base.TranslateMax(sqlExpression); var argumentType = GetProviderType(visitedExpression); if (argumentType == typeof(DateTimeOffset) || argumentType == typeof(decimal) @@ -247,11 +247,11 @@ public override SqlExpression TranslateMax(Expression expression) /// any release. You should only use it directly in your code with extreme caution and knowing that /// doing so can result in application failures when updating to a new Entity Framework Core release. /// - public override SqlExpression TranslateMin(Expression expression) + public override SqlExpression TranslateMin(SqlExpression sqlExpression) { - Check.NotNull(expression, nameof(expression)); + Check.NotNull(sqlExpression, nameof(sqlExpression)); - var visitedExpression = base.TranslateMin(expression); + var visitedExpression = base.TranslateMin(sqlExpression); var argumentType = GetProviderType(visitedExpression); if (argumentType == typeof(DateTimeOffset) || argumentType == typeof(decimal) @@ -271,11 +271,11 @@ public override SqlExpression TranslateMin(Expression expression) /// any release. You should only use it directly in your code with extreme caution and knowing that /// doing so can result in application failures when updating to a new Entity Framework Core release. /// - public override SqlExpression TranslateSum(Expression expression) + public override SqlExpression TranslateSum(SqlExpression sqlExpression) { - Check.NotNull(expression, nameof(expression)); + Check.NotNull(sqlExpression, nameof(sqlExpression)); - var visitedExpression = base.TranslateSum(expression); + var visitedExpression = base.TranslateSum(sqlExpression); var argumentType = GetProviderType(visitedExpression); if (argumentType == typeof(decimal)) { diff --git a/src/EFCore/Query/ShapedQueryCompilingExpressionVisitor.cs b/src/EFCore/Query/ShapedQueryCompilingExpressionVisitor.cs index 386b256e7bc..4501fcfa7ca 100644 --- a/src/EFCore/Query/ShapedQueryCompilingExpressionVisitor.cs +++ b/src/EFCore/Query/ShapedQueryCompilingExpressionVisitor.cs @@ -221,7 +221,9 @@ public ConstantVerifyingExpressionVisitor(ITypeMappingSource typeMappingSource) private bool ValidConstant(ConstantExpression constantExpression) { return constantExpression.Value == null - || _typeMappingSource.FindMapping(constantExpression.Type) != null; + || _typeMappingSource.FindMapping(constantExpression.Type) != null + || constantExpression.Value is Array array + && array.Length == 0; } protected override Expression VisitConstant(ConstantExpression constantExpression) diff --git a/test/EFCore.Specification.Tests/Query/NorthwindGroupByQueryTestBase.cs b/test/EFCore.Specification.Tests/Query/NorthwindGroupByQueryTestBase.cs index 10ede25a272..04b7085109e 100644 --- a/test/EFCore.Specification.Tests/Query/NorthwindGroupByQueryTestBase.cs +++ b/test/EFCore.Specification.Tests/Query/NorthwindGroupByQueryTestBase.cs @@ -4,6 +4,7 @@ using System; using System.Linq; using System.Threading.Tasks; +using Castle.Components.DictionaryAdapter; using Microsoft.EntityFrameworkCore.TestModels.Northwind; using Microsoft.EntityFrameworkCore.TestUtilities; using Xunit; @@ -37,7 +38,7 @@ public virtual Task GroupBy_Property_Select_Average(bool async) [ConditionalTheory(Skip = "issue #18923")] [MemberData(nameof(IsAsyncData))] - public virtual Task GroupBy_Property_Select_Average_with_navigation_expansion(bool async) + public virtual Task GroupBy_Property_Select_Average_with_group_enumerable_projected(bool async) { return AssertQueryScalar( async, @@ -2007,7 +2008,7 @@ public virtual Task Distinct_GroupBy_OrderBy_key(bool async) assertOrder: true); } - [ConditionalTheory(Skip = "Issue #18923")] + [ConditionalTheory(Skip = "Issue #15873")] [MemberData(nameof(IsAsyncData))] public virtual Task Select_nested_collection_with_groupby(bool async) { @@ -2020,7 +2021,7 @@ public virtual Task Select_nested_collection_with_groupby(bool async) : Array.Empty())); } - [ConditionalTheory(Skip = "Issue #18923")] + [ConditionalTheory] [MemberData(nameof(IsAsyncData))] public virtual Task Select_GroupBy_All(bool async) { @@ -2050,9 +2051,21 @@ public override bool Equals(object obj) public override int GetHashCode() => Order.GetHashCode(); } - [ConditionalTheory(Skip = "Issue #18836")] + [ConditionalTheory] [MemberData(nameof(IsAsyncData))] - public virtual Task GroupBy_Where_in_aggregate(bool async) + public virtual Task GroupBy_Where_Average(bool async) + { + return AssertQueryScalar( + async, + ss => from o in ss.Set() + group o by new { o.CustomerID } + into g + select g.Where(e => e.OrderID < 10300).Select(e => (int?)e.OrderID).Average()); + } + + [ConditionalTheory] + [MemberData(nameof(IsAsyncData))] + public virtual Task GroupBy_Where_Count(bool async) { return AssertQueryScalar( async, @@ -2062,6 +2075,160 @@ into g select g.Where(e => e.OrderID < 10300).Count()); } + [ConditionalTheory] + [MemberData(nameof(IsAsyncData))] + public virtual Task GroupBy_Where_LongCount(bool async) + { + return AssertQueryScalar( + async, + ss => from o in ss.Set() + group o by new { o.CustomerID } + into g + select g.Where(e => e.OrderID < 10300).LongCount()); + } + + [ConditionalTheory] + [MemberData(nameof(IsAsyncData))] + public virtual Task GroupBy_Where_Max(bool async) + { + return AssertQueryScalar( + async, + ss => from o in ss.Set() + group o by new { o.CustomerID } + into g + select g.Where(e => e.OrderID < 10300).Select(e => (int?)e.OrderID).Max()); + } + + [ConditionalTheory] + [MemberData(nameof(IsAsyncData))] + public virtual Task GroupBy_Where_Min(bool async) + { + return AssertQueryScalar( + async, + ss => from o in ss.Set() + group o by new { o.CustomerID } + into g + select g.Where(e => e.OrderID < 10300).Select(e => (int?)e.OrderID).Min()); + } + + [ConditionalTheory] + [MemberData(nameof(IsAsyncData))] + public virtual Task GroupBy_Where_Sum(bool async) + { + return AssertQueryScalar( + async, + ss => from o in ss.Set() + group o by new { o.CustomerID } + into g + select g.Where(e => e.OrderID < 10300).Select(e => e.OrderID).Sum()); + } + + [ConditionalTheory] + [MemberData(nameof(IsAsyncData))] + public virtual Task GroupBy_Where_Count_with_predicate(bool async) + { + return AssertQueryScalar( + async, + ss => from o in ss.Set() + group o by new { o.CustomerID } + into g + select g.Where(e => e.OrderID < 10300).Count(e => e.OrderDate.HasValue && e.OrderDate.Value.Year == 1997)); + } + + [ConditionalTheory] + [MemberData(nameof(IsAsyncData))] + public virtual Task GroupBy_Where_Where_Count(bool async) + { + return AssertQueryScalar( + async, + ss => from o in ss.Set() + group o by new { o.CustomerID } + into g + select g.Where(e => e.OrderID < 10300).Where(e => e.OrderDate.HasValue && e.OrderDate.Value.Year == 1997).Count()); + } + + [ConditionalTheory] + [MemberData(nameof(IsAsyncData))] + public virtual Task GroupBy_Where_Select_Where_Count(bool async) + { + return AssertQueryScalar( + async, + ss => from o in ss.Set() + group o by new { o.CustomerID } + into g + select g.Where(e => e.OrderID < 10300).Select(e => e.OrderDate).Where(e => e.HasValue && e.Value.Year == 1997).Count()); + } + + [ConditionalTheory] + [MemberData(nameof(IsAsyncData))] + public virtual Task GroupBy_Where_Select_Where_Select_Min(bool async) + { + return AssertQueryScalar( + async, + ss => from o in ss.Set() + group o by new { o.CustomerID } + into g + select g.Where(e => e.OrderID < 10300) + .Select(e => new { e.OrderID, e.OrderDate }) + .Where(e => e.OrderDate.HasValue && e.OrderDate.Value.Year == 1997) + .Select(e => (int?)e.OrderID).Min()); + } + + [ConditionalTheory] + [MemberData(nameof(IsAsyncData))] + public virtual Task GroupBy_multiple_Count_with_predicate(bool async) + { + return AssertQuery( + async, + ss => from o in ss.Set() + group o by new { o.CustomerID } + into g + select new + { + g.Key, + All = g.Count(), + TenK = g.Count(e => e.OrderID < 11000), + EleventK = g.Count(e => e.OrderID < 12000) + }, + elementSorter: e => e.Key.CustomerID); + } + + [ConditionalTheory] + [MemberData(nameof(IsAsyncData))] + public virtual Task GroupBy_multiple_Sum_with_conditional_projection(bool async) + { + return AssertQuery( + async, + ss => from o in ss.Set() + group o by new { o.CustomerID } + into g + select new + { + g.Key, + TenK = g.Sum(e => e.OrderID < 11000 ? e.OrderID : 0), + EleventK = g.Sum(e => e.OrderID >= 11000 ? e.OrderID : 0) + }, + elementSorter: e => e.Key.CustomerID); + } + + [ConditionalTheory] + [MemberData(nameof(IsAsyncData))] + public virtual Task GroupBy_multiple_Sum_with_Select_conditional_projection(bool async) + { + return AssertQuery( + async, + ss => from o in ss.Set() + group o by new { o.CustomerID } + into g + select new + { + g.Key, + TenK = g.Select(e => e.OrderID < 11000 ? e.OrderID : 0).Sum(), + EleventK = g.Select(e => e.OrderID >= 11000 ? e.OrderID : 0).Sum() + }, + elementSorter: e => e.Key.CustomerID); + } + [ConditionalTheory] [MemberData(nameof(IsAsyncData))] public virtual Task GroupBy_Key_as_part_of_element_selector(bool async) @@ -2460,9 +2627,9 @@ public virtual Task Count_after_GroupBy_aggregate(bool async) ss => ss.Set().GroupBy(o => o.CustomerID).Select(g => g.Sum(gg => gg.OrderID)).CountAsync(default)); } - [ConditionalTheory(Skip = "Issue #18836")] + [ConditionalTheory] [MemberData(nameof(IsAsyncData))] - public virtual Task LongCount_after_client_GroupBy(bool async) + public virtual Task LongCount_after_GroupBy_aggregate(bool async) { return AssertSingleResult( async, @@ -2662,7 +2829,7 @@ public virtual Task Complex_query_with_groupBy_in_subquery3(bool async) } // also 15279 - [ConditionalTheory(Skip = "issue #11711")] + [ConditionalTheory(Skip = "issue #15873")] [MemberData(nameof(IsAsyncData))] public virtual Task Complex_query_with_groupBy_in_subquery4(bool async) { diff --git a/test/EFCore.SqlServer.FunctionalTests/Query/NorthwindGroupByQuerySqlServerTest.cs b/test/EFCore.SqlServer.FunctionalTests/Query/NorthwindGroupByQuerySqlServerTest.cs index 5fb1ad6505b..cfabc6d61a2 100644 --- a/test/EFCore.SqlServer.FunctionalTests/Query/NorthwindGroupByQuerySqlServerTest.cs +++ b/test/EFCore.SqlServer.FunctionalTests/Query/NorthwindGroupByQuerySqlServerTest.cs @@ -37,9 +37,9 @@ FROM [Orders] AS [o] Fixture.TestSqlLoggerFactory.Log.Select(l => l.Message)); } - public override async Task GroupBy_Property_Select_Average_with_navigation_expansion(bool async) + public override async Task GroupBy_Property_Select_Average_with_group_enumerable_projected(bool async) { - await base.GroupBy_Property_Select_Average_with_navigation_expansion(async); + await base.GroupBy_Property_Select_Average_with_group_enumerable_projected(async); AssertSql( @""); @@ -644,7 +644,7 @@ public override async Task GroupBy_Property_scalar_element_selector_Count(bool a await base.GroupBy_Property_scalar_element_selector_Count(async); AssertSql( - @"SELECT COUNT(*) + @"SELECT COUNT([o].[OrderID]) FROM [Orders] AS [o] GROUP BY [o].[CustomerID]"); } @@ -654,7 +654,7 @@ public override async Task GroupBy_Property_scalar_element_selector_LongCount(bo await base.GroupBy_Property_scalar_element_selector_LongCount(async); AssertSql( - @"SELECT COUNT_BIG(*) + @"SELECT COUNT_BIG([o].[OrderID]) FROM [Orders] AS [o] GROUP BY [o].[CustomerID]"); } @@ -1513,19 +1513,181 @@ public override async Task Select_GroupBy_All(bool async) await base.Select_GroupBy_All(async); AssertSql( - @"SELECT [o].[OrderID] AS [Order], [o].[CustomerID] AS [Customer] + @"SELECT CASE + WHEN NOT EXISTS ( + SELECT 1 + FROM [Orders] AS [o] + GROUP BY [o].[CustomerID] + HAVING ([o].[CustomerID] <> N'ALFKI') OR [o].[CustomerID] IS NULL) THEN CAST(1 AS bit) + ELSE CAST(0 AS bit) +END"); + } + + public override async Task GroupBy_Where_Average(bool async) + { + await base.GroupBy_Where_Average(async); + + AssertSql( + @"SELECT AVG(CAST(CASE + WHEN [o].[OrderID] < 10300 THEN [o].[OrderID] +END AS float)) FROM [Orders] AS [o] -ORDER BY [o].[CustomerID]"); +GROUP BY [o].[CustomerID]"); } - public override async Task GroupBy_Where_in_aggregate(bool async) + + public override async Task GroupBy_Where_Count(bool async) { - await base.GroupBy_Where_in_aggregate(async); + await base.GroupBy_Where_Count(async); AssertSql( - @"SELECT [o].[OrderID], [o].[CustomerID], [o].[EmployeeID], [o].[OrderDate] + @"SELECT COUNT(CASE + WHEN [o].[OrderID] < 10300 THEN 1 +END) FROM [Orders] AS [o] -ORDER BY [o].[CustomerID]"); +GROUP BY [o].[CustomerID]"); + } + + public override async Task GroupBy_Where_LongCount(bool async) + { + await base.GroupBy_Where_LongCount(async); + + AssertSql( + @"SELECT COUNT_BIG(CASE + WHEN [o].[OrderID] < 10300 THEN 1 +END) +FROM [Orders] AS [o] +GROUP BY [o].[CustomerID]"); + } + + public override async Task GroupBy_Where_Max(bool async) + { + await base.GroupBy_Where_Max(async); + + AssertSql( + @"SELECT MAX(CASE + WHEN [o].[OrderID] < 10300 THEN [o].[OrderID] +END) +FROM [Orders] AS [o] +GROUP BY [o].[CustomerID]"); + } + + public override async Task GroupBy_Where_Min(bool async) + { + await base.GroupBy_Where_Min(async); + + AssertSql( + @"SELECT MIN(CASE + WHEN [o].[OrderID] < 10300 THEN [o].[OrderID] +END) +FROM [Orders] AS [o] +GROUP BY [o].[CustomerID]"); + } + + public override async Task GroupBy_Where_Sum(bool async) + { + await base.GroupBy_Where_Sum(async); + + AssertSql( + @"SELECT COALESCE(SUM(CASE + WHEN [o].[OrderID] < 10300 THEN [o].[OrderID] +END), 0) +FROM [Orders] AS [o] +GROUP BY [o].[CustomerID]"); + } + + public override async Task GroupBy_Where_Count_with_predicate(bool async) + { + await base.GroupBy_Where_Count_with_predicate(async); + + AssertSql( + @"SELECT COUNT(CASE + WHEN ([o].[OrderID] < 10300) AND ([o].[OrderDate] IS NOT NULL AND (DATEPART(year, [o].[OrderDate]) = 1997)) THEN 1 +END) +FROM [Orders] AS [o] +GROUP BY [o].[CustomerID]"); + } + + public override async Task GroupBy_Where_Where_Count(bool async) + { + await base.GroupBy_Where_Where_Count(async); + + AssertSql( + @"SELECT COUNT(CASE + WHEN ([o].[OrderID] < 10300) AND ([o].[OrderDate] IS NOT NULL AND (DATEPART(year, [o].[OrderDate]) = 1997)) THEN 1 +END) +FROM [Orders] AS [o] +GROUP BY [o].[CustomerID]"); + } + + public override async Task GroupBy_Where_Select_Where_Count(bool async) + { + await base.GroupBy_Where_Select_Where_Count(async); + + AssertSql( + @"SELECT COUNT(CASE + WHEN ([o].[OrderID] < 10300) AND ([o].[OrderDate] IS NOT NULL AND (DATEPART(year, [o].[OrderDate]) = 1997)) THEN [o].[OrderDate] +END) +FROM [Orders] AS [o] +GROUP BY [o].[CustomerID]"); + } + + public override async Task GroupBy_Where_Select_Where_Select_Min(bool async) + { + await base.GroupBy_Where_Select_Where_Select_Min(async); + + AssertSql( + @"SELECT MIN(CASE + WHEN ([o].[OrderID] < 10300) AND ([o].[OrderDate] IS NOT NULL AND (DATEPART(year, [o].[OrderDate]) = 1997)) THEN [o].[OrderID] +END) +FROM [Orders] AS [o] +GROUP BY [o].[CustomerID]"); + } + + public override async Task GroupBy_multiple_Count_with_predicate(bool async) + { + await base.GroupBy_multiple_Count_with_predicate(async); + + AssertSql( + @"SELECT [o].[CustomerID], COUNT(*) AS [All], COUNT(CASE + WHEN [o].[OrderID] < 11000 THEN 1 +END) AS [TenK], COUNT(CASE + WHEN [o].[OrderID] < 12000 THEN 1 +END) AS [EleventK] +FROM [Orders] AS [o] +GROUP BY [o].[CustomerID]"); + } + + public override async Task GroupBy_multiple_Sum_with_conditional_projection(bool async) + { + await base.GroupBy_multiple_Sum_with_conditional_projection(async); + + AssertSql( + @"SELECT [o].[CustomerID], COALESCE(SUM(CASE + WHEN [o].[OrderID] < 11000 THEN [o].[OrderID] + ELSE 0 +END), 0) AS [TenK], COALESCE(SUM(CASE + WHEN [o].[OrderID] >= 11000 THEN [o].[OrderID] + ELSE 0 +END), 0) AS [EleventK] +FROM [Orders] AS [o] +GROUP BY [o].[CustomerID]"); + } + + public override async Task GroupBy_multiple_Sum_with_Select_conditional_projection(bool async) + { + await base.GroupBy_multiple_Sum_with_Select_conditional_projection(async); + + AssertSql( + @"SELECT [o].[CustomerID], COALESCE(SUM(CASE + WHEN [o].[OrderID] < 11000 THEN [o].[OrderID] + ELSE 0 +END), 0) AS [TenK], COALESCE(SUM(CASE + WHEN [o].[OrderID] >= 11000 THEN [o].[OrderID] + ELSE 0 +END), 0) AS [EleventK] +FROM [Orders] AS [o] +GROUP BY [o].[CustomerID]"); } public override async Task GroupBy_Key_as_part_of_element_selector(bool async) @@ -1668,14 +1830,17 @@ GROUP BY [o].[CustomerID] ) AS [t]"); } - public override async Task LongCount_after_client_GroupBy(bool async) + public override async Task LongCount_after_GroupBy_aggregate(bool async) { - await base.LongCount_after_client_GroupBy(async); + await base.LongCount_after_GroupBy_aggregate(async); AssertSql( - @"SELECT [o].[OrderID], [o].[CustomerID], [o].[EmployeeID], [o].[OrderDate] -FROM [Orders] AS [o] -ORDER BY [o].[CustomerID]"); + @"SELECT COUNT_BIG(*) +FROM ( + SELECT [o].[CustomerID] + FROM [Orders] AS [o] + GROUP BY [o].[CustomerID] +) AS [t]"); } public override async Task MinMax_after_GroupBy_aggregate(bool async) @@ -1947,16 +2112,28 @@ OFFSET @__p_0 ROWS GROUP BY [t].[CustomerID]"); } - public override Task GroupBy_Property_Select_Count_with_predicate(bool async) + public override async Task GroupBy_Property_Select_Count_with_predicate(bool async) { - return Assert.ThrowsAsync( - () => base.GroupBy_Property_Select_Count_with_predicate(async)); + await base.GroupBy_Property_Select_Count_with_predicate(async); + + AssertSql( + @"SELECT COUNT(CASE + WHEN [o].[OrderID] < 10300 THEN 1 +END) +FROM [Orders] AS [o] +GROUP BY [o].[CustomerID]"); } - public override Task GroupBy_Property_Select_LongCount_with_predicate(bool async) + public override async Task GroupBy_Property_Select_LongCount_with_predicate(bool async) { - return Assert.ThrowsAsync( - () => base.GroupBy_Property_Select_LongCount_with_predicate(async)); + await base.GroupBy_Property_Select_LongCount_with_predicate(async); + + AssertSql( + @"SELECT COUNT_BIG(CASE + WHEN [o].[OrderID] < 10300 THEN 1 +END) +FROM [Orders] AS [o] +GROUP BY [o].[CustomerID]"); } public override async Task GroupBy_orderby_projection_with_coalesce_operation(bool async) diff --git a/test/EFCore.Sqlite.FunctionalTests/Query/NorthwindGroupByQuerySqliteTest.cs b/test/EFCore.Sqlite.FunctionalTests/Query/NorthwindGroupByQuerySqliteTest.cs index fbe2b67c34d..cdbba5b44a8 100644 --- a/test/EFCore.Sqlite.FunctionalTests/Query/NorthwindGroupByQuerySqliteTest.cs +++ b/test/EFCore.Sqlite.FunctionalTests/Query/NorthwindGroupByQuerySqliteTest.cs @@ -18,17 +18,5 @@ public NorthwindGroupByQuerySqliteTest(NorthwindQuerySqliteFixture( - () => base.GroupBy_Property_Select_Count_with_predicate(async)); - } - - public override Task GroupBy_Property_Select_LongCount_with_predicate(bool async) - { - return Assert.ThrowsAsync( - () => base.GroupBy_Property_Select_LongCount_with_predicate(async)); - } } }