Skip to content

Commit

Permalink
Query: Fix Avg/Max/Min after DefaultIfEmpty
Browse files Browse the repository at this point in the history
Resolves #20589
  • Loading branch information
smitpatel committed Apr 14, 2020
1 parent 5f4fab4 commit 7129f40
Show file tree
Hide file tree
Showing 5 changed files with 109 additions and 40 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -194,7 +194,7 @@ protected override ShapedQueryExpression TranslateAverage(ShapedQueryExpression
projection = _sqlExpressionFactory.Function(
"AVG", new[] { projection }, projection.Type, projection.TypeMapping);

return AggregateResultShaper(source, projection, throwOnNullResult: true, resultType);
return AggregateResultShaper(source, projection, throwWhenEmpty: true, resultType);
}

/// <summary>
Expand Down Expand Up @@ -551,7 +551,7 @@ protected override ShapedQueryExpression TranslateMax(ShapedQueryExpression sour

projection = _sqlExpressionFactory.Function("MAX", new[] { projection }, resultType, projection.TypeMapping);

return AggregateResultShaper(source, projection, throwOnNullResult: true, resultType);
return AggregateResultShaper(source, projection, throwWhenEmpty: true, resultType);
}

/// <summary>
Expand Down Expand Up @@ -581,7 +581,7 @@ protected override ShapedQueryExpression TranslateMin(ShapedQueryExpression sour

projection = _sqlExpressionFactory.Function("MIN", new[] { projection }, resultType, projection.TypeMapping);

return AggregateResultShaper(source, projection, throwOnNullResult: true, resultType);
return AggregateResultShaper(source, projection, throwWhenEmpty: true, resultType);
}

/// <summary>
Expand Down Expand Up @@ -796,7 +796,7 @@ protected override ShapedQueryExpression TranslateSum(ShapedQueryExpression sour
projection = _sqlExpressionFactory.Function(
"SUM", new[] { projection }, serverOutputType, projection.TypeMapping);

return AggregateResultShaper(source, projection, throwOnNullResult: false, resultType);
return AggregateResultShaper(source, projection, throwWhenEmpty: false, resultType);
}

/// <summary>
Expand Down Expand Up @@ -912,29 +912,35 @@ private static Expression RemapLambdaBody(Expression shaperBody, LambdaExpressio
}

private ShapedQueryExpression AggregateResultShaper(
ShapedQueryExpression source, Expression projection, bool throwOnNullResult, Type resultType)
ShapedQueryExpression source, Expression projection, bool throwWhenEmpty, Type resultType)
{
var selectExpression = (SelectExpression)source.QueryExpression;
selectExpression.ReplaceProjectionMapping(
new Dictionary<ProjectionMember, Expression> { { new ProjectionMember(), projection } });

selectExpression.ClearOrdering();
Expression shaper;

var nullableResultType = resultType.MakeNullable();
Expression shaper = new ProjectionBindingExpression(
source.QueryExpression, new ProjectionMember(), throwOnNullResult ? nullableResultType : projection.Type);

if (throwOnNullResult)
if (throwWhenEmpty)
{
// Avg/Max/Min case.
// We always read nullable value
// If resultType is nullable then we always return null. Only non-null result shows throwing behavior.
// otherwise, if projection.Type is nullable then server result is passed through DefaultIfEmpty, hence we return default
// otherwise, server would return null only if it is empty, and we throw
var nullableResultType = resultType.MakeNullable();
shaper = new ProjectionBindingExpression(source.QueryExpression, new ProjectionMember(), nullableResultType);
var resultVariable = Expression.Variable(nullableResultType, "result");
var returnValueForNull = resultType.IsNullableType()
? (Expression)Expression.Constant(null, resultType)
: Expression.Throw(
Expression.New(
typeof(InvalidOperationException).GetConstructors()
.Single(ci => ci.GetParameters().Length == 1),
Expression.Constant(CoreStrings.NoElements)),
resultType);
? Expression.Constant(null, resultType)
: projection.Type.IsNullableType()
? (Expression)Expression.Default(resultType)
: Expression.Throw(
Expression.New(
typeof(InvalidOperationException).GetConstructors()
.Single(ci => ci.GetParameters().Length == 1),
Expression.Constant(CoreStrings.NoElements)),
resultType);

shaper = Expression.Block(
new[] { resultVariable },
Expand All @@ -946,9 +952,15 @@ private ShapedQueryExpression AggregateResultShaper(
? Expression.Convert(resultVariable, resultType)
: (Expression)resultVariable));
}
else if (resultType != shaper.Type)
else
{
shaper = Expression.Convert(shaper, resultType);
// Sum case. Projection is always non-null. We read non-nullable value (0 if empty)
shaper = new ProjectionBindingExpression(source.QueryExpression, new ProjectionMember(), projection.Type);
// Cast to nullable type if required
if (resultType != shaper.Type)
{
shaper = Expression.Convert(shaper, resultType);
}
}

return source.UpdateShaperExpression(shaper);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -204,7 +204,7 @@ protected override ShapedQueryExpression TranslateAverage(ShapedQueryExpression

var projection = _sqlTranslator.TranslateAverage(newSelector);
return projection != null
? AggregateResultShaper(source, projection, throwOnNullResult: true, resultType)
? AggregateResultShaper(source, projection, throwWhenEmpty: true, resultType)
: null;
}

Expand Down Expand Up @@ -648,7 +648,7 @@ protected override ShapedQueryExpression TranslateMax(ShapedQueryExpression sour

var projection = _sqlTranslator.TranslateMax(newSelector);

return AggregateResultShaper(source, projection, throwOnNullResult: true, resultType);
return AggregateResultShaper(source, projection, throwWhenEmpty: true, resultType);
}

protected override ShapedQueryExpression TranslateMin(ShapedQueryExpression source, LambdaExpression selector, Type resultType)
Expand All @@ -665,7 +665,7 @@ protected override ShapedQueryExpression TranslateMin(ShapedQueryExpression sour

var projection = _sqlTranslator.TranslateMin(newSelector);

return AggregateResultShaper(source, projection, throwOnNullResult: true, resultType);
return AggregateResultShaper(source, projection, throwWhenEmpty: true, resultType);
}

protected override ShapedQueryExpression TranslateOfType(ShapedQueryExpression source, Type resultType)
Expand Down Expand Up @@ -984,7 +984,7 @@ protected override ShapedQueryExpression TranslateSum(ShapedQueryExpression sour

var projection = _sqlTranslator.TranslateSum(newSelector);
return projection != null
? AggregateResultShaper(source, projection, throwOnNullResult: false, resultType)
? AggregateResultShaper(source, projection, throwWhenEmpty: false, resultType)
: null;
}

Expand Down Expand Up @@ -1315,7 +1315,7 @@ private static IDictionary<IProperty, ColumnExpression> GetPropertyExpressionsFr
}

private ShapedQueryExpression AggregateResultShaper(
ShapedQueryExpression source, Expression projection, bool throwOnNullResult, Type resultType)
ShapedQueryExpression source, Expression projection, bool throwWhenEmpty, Type resultType)
{
if (projection == null)
{
Expand All @@ -1327,22 +1327,28 @@ private ShapedQueryExpression AggregateResultShaper(
new Dictionary<ProjectionMember, Expression> { { new ProjectionMember(), projection } });

selectExpression.ClearOrdering();
Expression shaper;

var nullableResultType = resultType.MakeNullable();
Expression shaper = new ProjectionBindingExpression(
source.QueryExpression, new ProjectionMember(), throwOnNullResult ? nullableResultType : projection.Type);

if (throwOnNullResult)
if (throwWhenEmpty)
{
// Avg/Max/Min case.
// We always read nullable value
// If resultType is nullable then we always return null. Only non-null result shows throwing behavior.
// otherwise, if projection.Type is nullable then server result is passed through DefaultIfEmpty, hence we return default
// otherwise, server would return null only if it is empty, and we throw
var nullableResultType = resultType.MakeNullable();
shaper = new ProjectionBindingExpression(source.QueryExpression, new ProjectionMember(), nullableResultType);
var resultVariable = Expression.Variable(nullableResultType, "result");
var returnValueForNull = resultType.IsNullableType()
? (Expression)Expression.Constant(null, resultType)
: Expression.Throw(
Expression.New(
typeof(InvalidOperationException).GetConstructors()
.Single(ci => ci.GetParameters().Length == 1),
Expression.Constant(CoreStrings.NoElements)),
resultType);
? Expression.Constant(null, resultType)
: projection.Type.IsNullableType()
? (Expression)Expression.Default(resultType)
: Expression.Throw(
Expression.New(
typeof(InvalidOperationException).GetConstructors()
.Single(ci => ci.GetParameters().Length == 1),
Expression.Constant(CoreStrings.NoElements)),
resultType);

shaper = Expression.Block(
new[] { resultVariable },
Expand All @@ -1354,9 +1360,15 @@ private ShapedQueryExpression AggregateResultShaper(
? Expression.Convert(resultVariable, resultType)
: (Expression)resultVariable));
}
else if (resultType != shaper.Type)
else
{
shaper = Expression.Convert(shaper, resultType);
// Sum case. Projection is always non-null. We read non-nullable value (0 if empty)
shaper = new ProjectionBindingExpression(source.QueryExpression, new ProjectionMember(), projection.Type);
// Cast to nullable type if required
if (resultType != shaper.Type)
{
shaper = Expression.Convert(shaper, resultType);
}
}

return source.UpdateShaperExpression(shaper);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1582,6 +1582,24 @@ public override Task Contains_over_entityType_should_materialize_when_composite2
return base.Contains_over_entityType_should_materialize_when_composite2(async);
}

[ConditionalTheory(Skip = "Issue#17246 (DefaultIfEmpty is not translated)")]
public override Task Average_after_default_if_empty_does_not_throw(bool isAsync)
{
return base.Average_after_default_if_empty_does_not_throw(isAsync);
}

[ConditionalTheory(Skip = "Issue#17246 (DefaultIfEmpty is not translated)")]
public override Task Max_after_default_if_empty_does_not_throw(bool isAsync)
{
return base.Max_after_default_if_empty_does_not_throw(isAsync);
}

[ConditionalTheory(Skip = "Issue#17246 (DefaultIfEmpty is not translated)")]
public override Task Min_after_default_if_empty_does_not_throw(bool isAsync)
{
return base.Min_after_default_if_empty_does_not_throw(isAsync);
}

private void AssertSql(params string[] expected)
=> Fixture.TestSqlLoggerFactory.AssertBaseline(expected);

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1930,5 +1930,32 @@ public virtual Task Min_over_default_returns_default(bool isAsync)
ss => ss.Set<Order>().Where(o => o.OrderID == 10248),
o => o.OrderID - 10248);
}

[ConditionalTheory]
[MemberData(nameof(IsAsyncData))]
public virtual Task Average_after_default_if_empty_does_not_throw(bool isAsync)
{
return AssertAverage(
isAsync,
ss => ss.Set<Order>().Where(o => o.OrderID == 10243).Select(o => o.OrderID).DefaultIfEmpty());
}

[ConditionalTheory]
[MemberData(nameof(IsAsyncData))]
public virtual Task Max_after_default_if_empty_does_not_throw(bool isAsync)
{
return AssertMax(
isAsync,
ss => ss.Set<Order>().Where(o => o.OrderID == 10243).Select(o => o.OrderID).DefaultIfEmpty());
}

[ConditionalTheory]
[MemberData(nameof(IsAsyncData))]
public virtual Task Min_after_default_if_empty_does_not_throw(bool isAsync)
{
return AssertMin(
isAsync,
ss => ss.Set<Order>().Where(o => o.OrderID == 10243).Select(o => o.OrderID).DefaultIfEmpty());
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -16,11 +16,11 @@ public static TResult Maybe<TSource, TResult>(this TSource caller, Func<TSource,
public static TResult? MaybeScalar<TSource, TResult>(this TSource caller, Func<TSource, TResult> result)
where TResult : struct
=> caller != null ? (TResult?)result(caller) : null;

public static TResult? MaybeScalar<TSource, TResult>(this TSource caller, Func<TSource, TResult?> result)
where TResult : struct
=> caller != null ? result(caller) : null;

public static IEnumerable<TResult> MaybeDefaultIfEmpty<TResult>(this IEnumerable<TResult> caller)
where TResult : class
=> caller == null ? new List<TResult> { default } : caller.DefaultIfEmpty();
Expand Down

0 comments on commit 7129f40

Please sign in to comment.