From de7531276eafdb804b1f7295d96ce80f811ad2c3 Mon Sep 17 00:00:00 2001 From: Smit Patel Date: Tue, 14 Apr 2020 16:01:23 -0700 Subject: [PATCH] Query: Fix Max/Min after DefaultIfEmpty Resolves #20589 --- ...yableMethodTranslatingExpressionVisitor.cs | 50 ++++++++++++------- ...yableMethodTranslatingExpressionVisitor.cs | 50 ++++++++++++------- ...thwindAggregateOperatorsQueryCosmosTest.cs | 18 +++++++ ...orthwindAggregateOperatorsQueryTestBase.cs | 27 ++++++++++ .../TestUtilities/QueryTestExtensions.cs | 4 +- 5 files changed, 109 insertions(+), 40 deletions(-) diff --git a/src/EFCore.Cosmos/Query/Internal/CosmosQueryableMethodTranslatingExpressionVisitor.cs b/src/EFCore.Cosmos/Query/Internal/CosmosQueryableMethodTranslatingExpressionVisitor.cs index 8d5c3e31279..e9d59f3cbea 100644 --- a/src/EFCore.Cosmos/Query/Internal/CosmosQueryableMethodTranslatingExpressionVisitor.cs +++ b/src/EFCore.Cosmos/Query/Internal/CosmosQueryableMethodTranslatingExpressionVisitor.cs @@ -194,7 +194,7 @@ protected override ShapedQueryExpression TranslateAverage(ShapedQueryExpression projection = _sqlExpressionFactory.Function( "AVG", new[] { projection }, projection.Type, projection.TypeMapping); - return AggregateResultShaper(source, projection, throwOnNullResult: true, resultType); + return AggregateResultShaper(source, projection, throwWhenEmpty: true, resultType); } /// @@ -551,7 +551,7 @@ protected override ShapedQueryExpression TranslateMax(ShapedQueryExpression sour projection = _sqlExpressionFactory.Function("MAX", new[] { projection }, resultType, projection.TypeMapping); - return AggregateResultShaper(source, projection, throwOnNullResult: true, resultType); + return AggregateResultShaper(source, projection, throwWhenEmpty: true, resultType); } /// @@ -581,7 +581,7 @@ protected override ShapedQueryExpression TranslateMin(ShapedQueryExpression sour projection = _sqlExpressionFactory.Function("MIN", new[] { projection }, resultType, projection.TypeMapping); - return AggregateResultShaper(source, projection, throwOnNullResult: true, resultType); + return AggregateResultShaper(source, projection, throwWhenEmpty: true, resultType); } /// @@ -796,7 +796,7 @@ protected override ShapedQueryExpression TranslateSum(ShapedQueryExpression sour projection = _sqlExpressionFactory.Function( "SUM", new[] { projection }, serverOutputType, projection.TypeMapping); - return AggregateResultShaper(source, projection, throwOnNullResult: false, resultType); + return AggregateResultShaper(source, projection, throwWhenEmpty: false, resultType); } /// @@ -912,29 +912,35 @@ private static Expression RemapLambdaBody(Expression shaperBody, LambdaExpressio } private ShapedQueryExpression AggregateResultShaper( - ShapedQueryExpression source, Expression projection, bool throwOnNullResult, Type resultType) + ShapedQueryExpression source, Expression projection, bool throwWhenEmpty, Type resultType) { var selectExpression = (SelectExpression)source.QueryExpression; selectExpression.ReplaceProjectionMapping( new Dictionary { { new ProjectionMember(), projection } }); selectExpression.ClearOrdering(); + Expression shaper; - var nullableResultType = resultType.MakeNullable(); - Expression shaper = new ProjectionBindingExpression( - source.QueryExpression, new ProjectionMember(), throwOnNullResult ? nullableResultType : projection.Type); - - if (throwOnNullResult) + if (throwWhenEmpty) { + // Avg/Max/Min case. + // We always read nullable value + // If resultType is nullable then we always return null. Only non-null result shows throwing behavior. + // otherwise, if projection.Type is nullable then server result is passed through DefaultIfEmpty, hence we return default + // otherwise, server would return null only if it is empty, and we throw + var nullableResultType = resultType.MakeNullable(); + shaper = new ProjectionBindingExpression(source.QueryExpression, new ProjectionMember(), nullableResultType); var resultVariable = Expression.Variable(nullableResultType, "result"); var returnValueForNull = resultType.IsNullableType() - ? (Expression)Expression.Constant(null, resultType) - : Expression.Throw( - Expression.New( - typeof(InvalidOperationException).GetConstructors() - .Single(ci => ci.GetParameters().Length == 1), - Expression.Constant(CoreStrings.NoElements)), - resultType); + ? Expression.Constant(null, resultType) + : projection.Type.IsNullableType() + ? (Expression)Expression.Default(resultType) + : Expression.Throw( + Expression.New( + typeof(InvalidOperationException).GetConstructors() + .Single(ci => ci.GetParameters().Length == 1), + Expression.Constant(CoreStrings.NoElements)), + resultType); shaper = Expression.Block( new[] { resultVariable }, @@ -946,9 +952,15 @@ private ShapedQueryExpression AggregateResultShaper( ? Expression.Convert(resultVariable, resultType) : (Expression)resultVariable)); } - else if (resultType != shaper.Type) + else { - shaper = Expression.Convert(shaper, resultType); + // Sum case. Projection is always non-null. We read non-nullable value (0 if empty) + shaper = new ProjectionBindingExpression(source.QueryExpression, new ProjectionMember(), projection.Type); + // Cast to nullable type if required + if (resultType != shaper.Type) + { + shaper = Expression.Convert(shaper, resultType); + } } return source.UpdateShaperExpression(shaper); diff --git a/src/EFCore.Relational/Query/RelationalQueryableMethodTranslatingExpressionVisitor.cs b/src/EFCore.Relational/Query/RelationalQueryableMethodTranslatingExpressionVisitor.cs index 509d3291960..d6a8ee7bc44 100644 --- a/src/EFCore.Relational/Query/RelationalQueryableMethodTranslatingExpressionVisitor.cs +++ b/src/EFCore.Relational/Query/RelationalQueryableMethodTranslatingExpressionVisitor.cs @@ -204,7 +204,7 @@ protected override ShapedQueryExpression TranslateAverage(ShapedQueryExpression var projection = _sqlTranslator.TranslateAverage(newSelector); return projection != null - ? AggregateResultShaper(source, projection, throwOnNullResult: true, resultType) + ? AggregateResultShaper(source, projection, throwWhenEmpty: true, resultType) : null; } @@ -648,7 +648,7 @@ protected override ShapedQueryExpression TranslateMax(ShapedQueryExpression sour var projection = _sqlTranslator.TranslateMax(newSelector); - return AggregateResultShaper(source, projection, throwOnNullResult: true, resultType); + return AggregateResultShaper(source, projection, throwWhenEmpty: true, resultType); } protected override ShapedQueryExpression TranslateMin(ShapedQueryExpression source, LambdaExpression selector, Type resultType) @@ -665,7 +665,7 @@ protected override ShapedQueryExpression TranslateMin(ShapedQueryExpression sour var projection = _sqlTranslator.TranslateMin(newSelector); - return AggregateResultShaper(source, projection, throwOnNullResult: true, resultType); + return AggregateResultShaper(source, projection, throwWhenEmpty: true, resultType); } protected override ShapedQueryExpression TranslateOfType(ShapedQueryExpression source, Type resultType) @@ -984,7 +984,7 @@ protected override ShapedQueryExpression TranslateSum(ShapedQueryExpression sour var projection = _sqlTranslator.TranslateSum(newSelector); return projection != null - ? AggregateResultShaper(source, projection, throwOnNullResult: false, resultType) + ? AggregateResultShaper(source, projection, throwWhenEmpty: false, resultType) : null; } @@ -1315,7 +1315,7 @@ private static IDictionary GetPropertyExpressionsFr } private ShapedQueryExpression AggregateResultShaper( - ShapedQueryExpression source, Expression projection, bool throwOnNullResult, Type resultType) + ShapedQueryExpression source, Expression projection, bool throwWhenEmpty, Type resultType) { if (projection == null) { @@ -1327,22 +1327,28 @@ private ShapedQueryExpression AggregateResultShaper( new Dictionary { { new ProjectionMember(), projection } }); selectExpression.ClearOrdering(); + Expression shaper; - var nullableResultType = resultType.MakeNullable(); - Expression shaper = new ProjectionBindingExpression( - source.QueryExpression, new ProjectionMember(), throwOnNullResult ? nullableResultType : projection.Type); - - if (throwOnNullResult) + if (throwWhenEmpty) { + // Avg/Max/Min case. + // We always read nullable value + // If resultType is nullable then we always return null. Only non-null result shows throwing behavior. + // otherwise, if projection.Type is nullable then server result is passed through DefaultIfEmpty, hence we return default + // otherwise, server would return null only if it is empty, and we throw + var nullableResultType = resultType.MakeNullable(); + shaper = new ProjectionBindingExpression(source.QueryExpression, new ProjectionMember(), nullableResultType); var resultVariable = Expression.Variable(nullableResultType, "result"); var returnValueForNull = resultType.IsNullableType() - ? (Expression)Expression.Constant(null, resultType) - : Expression.Throw( - Expression.New( - typeof(InvalidOperationException).GetConstructors() - .Single(ci => ci.GetParameters().Length == 1), - Expression.Constant(CoreStrings.NoElements)), - resultType); + ? Expression.Constant(null, resultType) + : projection.Type.IsNullableType() + ? (Expression)Expression.Default(resultType) + : Expression.Throw( + Expression.New( + typeof(InvalidOperationException).GetConstructors() + .Single(ci => ci.GetParameters().Length == 1), + Expression.Constant(CoreStrings.NoElements)), + resultType); shaper = Expression.Block( new[] { resultVariable }, @@ -1354,9 +1360,15 @@ private ShapedQueryExpression AggregateResultShaper( ? Expression.Convert(resultVariable, resultType) : (Expression)resultVariable)); } - else if (resultType != shaper.Type) + else { - shaper = Expression.Convert(shaper, resultType); + // Sum case. Projection is always non-null. We read non-nullable value (0 if empty) + shaper = new ProjectionBindingExpression(source.QueryExpression, new ProjectionMember(), projection.Type); + // Cast to nullable type if required + if (resultType != shaper.Type) + { + shaper = Expression.Convert(shaper, resultType); + } } return source.UpdateShaperExpression(shaper); diff --git a/test/EFCore.Cosmos.FunctionalTests/Query/NorthwindAggregateOperatorsQueryCosmosTest.cs b/test/EFCore.Cosmos.FunctionalTests/Query/NorthwindAggregateOperatorsQueryCosmosTest.cs index 7af3a15cfc0..361517cf452 100644 --- a/test/EFCore.Cosmos.FunctionalTests/Query/NorthwindAggregateOperatorsQueryCosmosTest.cs +++ b/test/EFCore.Cosmos.FunctionalTests/Query/NorthwindAggregateOperatorsQueryCosmosTest.cs @@ -1582,6 +1582,24 @@ public override Task Contains_over_entityType_should_materialize_when_composite2 return base.Contains_over_entityType_should_materialize_when_composite2(async); } + [ConditionalTheory(Skip = "Issue#17246 (DefaultIfEmpty is not translated)")] + public override Task Average_after_default_if_empty_does_not_throw(bool isAsync) + { + return base.Average_after_default_if_empty_does_not_throw(isAsync); + } + + [ConditionalTheory(Skip = "Issue#17246 (DefaultIfEmpty is not translated)")] + public override Task Max_after_default_if_empty_does_not_throw(bool isAsync) + { + return base.Max_after_default_if_empty_does_not_throw(isAsync); + } + + [ConditionalTheory(Skip = "Issue#17246 (DefaultIfEmpty is not translated)")] + public override Task Min_after_default_if_empty_does_not_throw(bool isAsync) + { + return base.Min_after_default_if_empty_does_not_throw(isAsync); + } + private void AssertSql(params string[] expected) => Fixture.TestSqlLoggerFactory.AssertBaseline(expected); diff --git a/test/EFCore.Specification.Tests/Query/NorthwindAggregateOperatorsQueryTestBase.cs b/test/EFCore.Specification.Tests/Query/NorthwindAggregateOperatorsQueryTestBase.cs index ac6cd1d6a12..646edc41142 100644 --- a/test/EFCore.Specification.Tests/Query/NorthwindAggregateOperatorsQueryTestBase.cs +++ b/test/EFCore.Specification.Tests/Query/NorthwindAggregateOperatorsQueryTestBase.cs @@ -1930,5 +1930,32 @@ public virtual Task Min_over_default_returns_default(bool isAsync) ss => ss.Set().Where(o => o.OrderID == 10248), o => o.OrderID - 10248); } + + [ConditionalTheory(Skip = "Issue#20637")] + [MemberData(nameof(IsAsyncData))] + public virtual Task Average_after_default_if_empty_does_not_throw(bool isAsync) + { + return AssertAverage( + isAsync, + ss => ss.Set().Where(o => o.OrderID == 10243).Select(o => o.OrderID).DefaultIfEmpty()); + } + + [ConditionalTheory] + [MemberData(nameof(IsAsyncData))] + public virtual Task Max_after_default_if_empty_does_not_throw(bool isAsync) + { + return AssertMax( + isAsync, + ss => ss.Set().Where(o => o.OrderID == 10243).Select(o => o.OrderID).DefaultIfEmpty()); + } + + [ConditionalTheory] + [MemberData(nameof(IsAsyncData))] + public virtual Task Min_after_default_if_empty_does_not_throw(bool isAsync) + { + return AssertMin( + isAsync, + ss => ss.Set().Where(o => o.OrderID == 10243).Select(o => o.OrderID).DefaultIfEmpty()); + } } } diff --git a/test/EFCore.Specification.Tests/TestUtilities/QueryTestExtensions.cs b/test/EFCore.Specification.Tests/TestUtilities/QueryTestExtensions.cs index 5891267d2ef..51a36f62f36 100644 --- a/test/EFCore.Specification.Tests/TestUtilities/QueryTestExtensions.cs +++ b/test/EFCore.Specification.Tests/TestUtilities/QueryTestExtensions.cs @@ -16,11 +16,11 @@ public static TResult Maybe(this TSource caller, Func(this TSource caller, Func result) where TResult : struct => caller != null ? (TResult?)result(caller) : null; - + public static TResult? MaybeScalar(this TSource caller, Func result) where TResult : struct => caller != null ? result(caller) : null; - + public static IEnumerable MaybeDefaultIfEmpty(this IEnumerable caller) where TResult : class => caller == null ? new List { default } : caller.DefaultIfEmpty();