Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Query: Fix Max/Min after DefaultIfEmpty #20635

Merged
merged 1 commit into from
Apr 15, 2020
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -194,7 +194,7 @@ protected override ShapedQueryExpression TranslateAverage(ShapedQueryExpression
projection = _sqlExpressionFactory.Function(
"AVG", new[] { projection }, projection.Type, projection.TypeMapping);

return AggregateResultShaper(source, projection, throwOnNullResult: true, resultType);
return AggregateResultShaper(source, projection, throwWhenEmpty: true, resultType);
}

/// <summary>
Expand Down Expand Up @@ -551,7 +551,7 @@ protected override ShapedQueryExpression TranslateMax(ShapedQueryExpression sour

projection = _sqlExpressionFactory.Function("MAX", new[] { projection }, resultType, projection.TypeMapping);

return AggregateResultShaper(source, projection, throwOnNullResult: true, resultType);
return AggregateResultShaper(source, projection, throwWhenEmpty: true, resultType);
}

/// <summary>
Expand Down Expand Up @@ -581,7 +581,7 @@ protected override ShapedQueryExpression TranslateMin(ShapedQueryExpression sour

projection = _sqlExpressionFactory.Function("MIN", new[] { projection }, resultType, projection.TypeMapping);

return AggregateResultShaper(source, projection, throwOnNullResult: true, resultType);
return AggregateResultShaper(source, projection, throwWhenEmpty: true, resultType);
}

/// <summary>
Expand Down Expand Up @@ -796,7 +796,7 @@ protected override ShapedQueryExpression TranslateSum(ShapedQueryExpression sour
projection = _sqlExpressionFactory.Function(
"SUM", new[] { projection }, serverOutputType, projection.TypeMapping);

return AggregateResultShaper(source, projection, throwOnNullResult: false, resultType);
return AggregateResultShaper(source, projection, throwWhenEmpty: false, resultType);
}

/// <summary>
Expand Down Expand Up @@ -912,29 +912,35 @@ private static Expression RemapLambdaBody(Expression shaperBody, LambdaExpressio
}

private ShapedQueryExpression AggregateResultShaper(
ShapedQueryExpression source, Expression projection, bool throwOnNullResult, Type resultType)
ShapedQueryExpression source, Expression projection, bool throwWhenEmpty, Type resultType)
{
var selectExpression = (SelectExpression)source.QueryExpression;
selectExpression.ReplaceProjectionMapping(
new Dictionary<ProjectionMember, Expression> { { new ProjectionMember(), projection } });

selectExpression.ClearOrdering();
Expression shaper;

var nullableResultType = resultType.MakeNullable();
Expression shaper = new ProjectionBindingExpression(
source.QueryExpression, new ProjectionMember(), throwOnNullResult ? nullableResultType : projection.Type);

if (throwOnNullResult)
if (throwWhenEmpty)
{
// Avg/Max/Min case.
// We always read nullable value
// If resultType is nullable then we always return null. Only non-null result shows throwing behavior.
// otherwise, if projection.Type is nullable then server result is passed through DefaultIfEmpty, hence we return default
// otherwise, server would return null only if it is empty, and we throw
var nullableResultType = resultType.MakeNullable();
shaper = new ProjectionBindingExpression(source.QueryExpression, new ProjectionMember(), nullableResultType);
var resultVariable = Expression.Variable(nullableResultType, "result");
var returnValueForNull = resultType.IsNullableType()
? (Expression)Expression.Constant(null, resultType)
: Expression.Throw(
Expression.New(
typeof(InvalidOperationException).GetConstructors()
.Single(ci => ci.GetParameters().Length == 1),
Expression.Constant(CoreStrings.NoElements)),
resultType);
? Expression.Constant(null, resultType)
: projection.Type.IsNullableType()
? (Expression)Expression.Default(resultType)
: Expression.Throw(
Expression.New(
typeof(InvalidOperationException).GetConstructors()
.Single(ci => ci.GetParameters().Length == 1),
Expression.Constant(CoreStrings.NoElements)),
resultType);

shaper = Expression.Block(
new[] { resultVariable },
Expand All @@ -946,9 +952,15 @@ private ShapedQueryExpression AggregateResultShaper(
? Expression.Convert(resultVariable, resultType)
: (Expression)resultVariable));
}
else if (resultType != shaper.Type)
else
{
shaper = Expression.Convert(shaper, resultType);
// Sum case. Projection is always non-null. We read non-nullable value (0 if empty)
shaper = new ProjectionBindingExpression(source.QueryExpression, new ProjectionMember(), projection.Type);
// Cast to nullable type if required
if (resultType != shaper.Type)
{
shaper = Expression.Convert(shaper, resultType);
}
}

return source.UpdateShaperExpression(shaper);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -204,7 +204,7 @@ protected override ShapedQueryExpression TranslateAverage(ShapedQueryExpression

var projection = _sqlTranslator.TranslateAverage(newSelector);
return projection != null
? AggregateResultShaper(source, projection, throwOnNullResult: true, resultType)
? AggregateResultShaper(source, projection, throwWhenEmpty: true, resultType)
: null;
}

Expand Down Expand Up @@ -648,7 +648,7 @@ protected override ShapedQueryExpression TranslateMax(ShapedQueryExpression sour

var projection = _sqlTranslator.TranslateMax(newSelector);

return AggregateResultShaper(source, projection, throwOnNullResult: true, resultType);
return AggregateResultShaper(source, projection, throwWhenEmpty: true, resultType);
}

protected override ShapedQueryExpression TranslateMin(ShapedQueryExpression source, LambdaExpression selector, Type resultType)
Expand All @@ -665,7 +665,7 @@ protected override ShapedQueryExpression TranslateMin(ShapedQueryExpression sour

var projection = _sqlTranslator.TranslateMin(newSelector);

return AggregateResultShaper(source, projection, throwOnNullResult: true, resultType);
return AggregateResultShaper(source, projection, throwWhenEmpty: true, resultType);
}

protected override ShapedQueryExpression TranslateOfType(ShapedQueryExpression source, Type resultType)
Expand Down Expand Up @@ -984,7 +984,7 @@ protected override ShapedQueryExpression TranslateSum(ShapedQueryExpression sour

var projection = _sqlTranslator.TranslateSum(newSelector);
return projection != null
? AggregateResultShaper(source, projection, throwOnNullResult: false, resultType)
? AggregateResultShaper(source, projection, throwWhenEmpty: false, resultType)
: null;
}

Expand Down Expand Up @@ -1315,7 +1315,7 @@ private static IDictionary<IProperty, ColumnExpression> GetPropertyExpressionsFr
}

private ShapedQueryExpression AggregateResultShaper(
ShapedQueryExpression source, Expression projection, bool throwOnNullResult, Type resultType)
ShapedQueryExpression source, Expression projection, bool throwWhenEmpty, Type resultType)
{
if (projection == null)
{
Expand All @@ -1327,22 +1327,28 @@ private ShapedQueryExpression AggregateResultShaper(
new Dictionary<ProjectionMember, Expression> { { new ProjectionMember(), projection } });

selectExpression.ClearOrdering();
Expression shaper;

var nullableResultType = resultType.MakeNullable();
Expression shaper = new ProjectionBindingExpression(
source.QueryExpression, new ProjectionMember(), throwOnNullResult ? nullableResultType : projection.Type);

if (throwOnNullResult)
if (throwWhenEmpty)
{
// Avg/Max/Min case.
// We always read nullable value
// If resultType is nullable then we always return null. Only non-null result shows throwing behavior.
// otherwise, if projection.Type is nullable then server result is passed through DefaultIfEmpty, hence we return default
// otherwise, server would return null only if it is empty, and we throw
var nullableResultType = resultType.MakeNullable();
shaper = new ProjectionBindingExpression(source.QueryExpression, new ProjectionMember(), nullableResultType);
var resultVariable = Expression.Variable(nullableResultType, "result");
var returnValueForNull = resultType.IsNullableType()
? (Expression)Expression.Constant(null, resultType)
: Expression.Throw(
Expression.New(
typeof(InvalidOperationException).GetConstructors()
.Single(ci => ci.GetParameters().Length == 1),
Expression.Constant(CoreStrings.NoElements)),
resultType);
? Expression.Constant(null, resultType)
: projection.Type.IsNullableType()
? (Expression)Expression.Default(resultType)
: Expression.Throw(
Expression.New(
typeof(InvalidOperationException).GetConstructors()
.Single(ci => ci.GetParameters().Length == 1),
Expression.Constant(CoreStrings.NoElements)),
resultType);

shaper = Expression.Block(
new[] { resultVariable },
Expand All @@ -1354,9 +1360,15 @@ private ShapedQueryExpression AggregateResultShaper(
? Expression.Convert(resultVariable, resultType)
: (Expression)resultVariable));
}
else if (resultType != shaper.Type)
else
{
shaper = Expression.Convert(shaper, resultType);
// Sum case. Projection is always non-null. We read non-nullable value (0 if empty)
shaper = new ProjectionBindingExpression(source.QueryExpression, new ProjectionMember(), projection.Type);
// Cast to nullable type if required
if (resultType != shaper.Type)
{
shaper = Expression.Convert(shaper, resultType);
}
}

return source.UpdateShaperExpression(shaper);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1582,6 +1582,24 @@ public override Task Contains_over_entityType_should_materialize_when_composite2
return base.Contains_over_entityType_should_materialize_when_composite2(async);
}

[ConditionalTheory(Skip = "Issue#17246 (DefaultIfEmpty is not translated)")]
public override Task Average_after_default_if_empty_does_not_throw(bool isAsync)
{
return base.Average_after_default_if_empty_does_not_throw(isAsync);
}

[ConditionalTheory(Skip = "Issue#17246 (DefaultIfEmpty is not translated)")]
public override Task Max_after_default_if_empty_does_not_throw(bool isAsync)
{
return base.Max_after_default_if_empty_does_not_throw(isAsync);
}

[ConditionalTheory(Skip = "Issue#17246 (DefaultIfEmpty is not translated)")]
public override Task Min_after_default_if_empty_does_not_throw(bool isAsync)
{
return base.Min_after_default_if_empty_does_not_throw(isAsync);
}

private void AssertSql(params string[] expected)
=> Fixture.TestSqlLoggerFactory.AssertBaseline(expected);

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1930,5 +1930,32 @@ public virtual Task Min_over_default_returns_default(bool isAsync)
ss => ss.Set<Order>().Where(o => o.OrderID == 10248),
o => o.OrderID - 10248);
}

[ConditionalTheory(Skip = "Issue#20637")]
[MemberData(nameof(IsAsyncData))]
public virtual Task Average_after_default_if_empty_does_not_throw(bool isAsync)
{
return AssertAverage(
isAsync,
ss => ss.Set<Order>().Where(o => o.OrderID == 10243).Select(o => o.OrderID).DefaultIfEmpty());
}

[ConditionalTheory]
[MemberData(nameof(IsAsyncData))]
public virtual Task Max_after_default_if_empty_does_not_throw(bool isAsync)
{
return AssertMax(
isAsync,
ss => ss.Set<Order>().Where(o => o.OrderID == 10243).Select(o => o.OrderID).DefaultIfEmpty());
}

[ConditionalTheory]
[MemberData(nameof(IsAsyncData))]
public virtual Task Min_after_default_if_empty_does_not_throw(bool isAsync)
{
return AssertMin(
isAsync,
ss => ss.Set<Order>().Where(o => o.OrderID == 10243).Select(o => o.OrderID).DefaultIfEmpty());
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -16,11 +16,11 @@ public static TResult Maybe<TSource, TResult>(this TSource caller, Func<TSource,
public static TResult? MaybeScalar<TSource, TResult>(this TSource caller, Func<TSource, TResult> result)
where TResult : struct
=> caller != null ? (TResult?)result(caller) : null;

public static TResult? MaybeScalar<TSource, TResult>(this TSource caller, Func<TSource, TResult?> result)
where TResult : struct
=> caller != null ? result(caller) : null;

public static IEnumerable<TResult> MaybeDefaultIfEmpty<TResult>(this IEnumerable<TResult> caller)
where TResult : class
=> caller == null ? new List<TResult> { default } : caller.DefaultIfEmpty();
Expand Down