Skip to content

Commit

Permalink
Additional refactoring of Null Semantics:
Browse files Browse the repository at this point in the history
- moving NullSemantics visitor after 2nd level cache - we need to know the parameter values to properly handle IN expressions wrt null semantics,
- NullSemantics visitor needs to go before SqlExpressionOptimizer and SearchCondition, so those two are also moved after 2nd level cache,
- moving optimizations that depend on knowing the nullability to NullSemantics visitor - optimizer now only contains optimizations that also work in 3-value logic, or when we know nulls can't happen,
- merging InExpressionValuesExpandingExpressionVisitor int NullSemantics visitor, so that we don't apply the rewrite for UseRelationalNulls.

Resolves #11464
Resolves #15722
Resolved #18338
Resolves #18597
Resolves #18689
  • Loading branch information
maumar committed Dec 7, 2019
1 parent f765b9d commit 40e4479
Show file tree
Hide file tree
Showing 20 changed files with 945 additions and 688 deletions.

Large diffs are not rendered by default.

Large diffs are not rendered by default.

Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,6 @@
// Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information.

using System;
using System.Collections;
using System.Collections.Generic;
using System.Data.Common;
using System.Linq.Expressions;
Expand Down Expand Up @@ -38,17 +37,19 @@ public virtual (SelectExpression selectExpression, bool canCache) Optimize(
SelectExpression selectExpression, IReadOnlyDictionary<string, object> parametersValues)
{
var canCache = true;
var nullSemanticsRewritingExpressionVisitor = new NullSemanticsRewritingExpressionVisitor(
UseRelationalNulls,
Dependencies.SqlExpressionFactory,
parametersValues);

var inExpressionOptimized = new InExpressionValuesExpandingExpressionVisitor(
Dependencies.SqlExpressionFactory, parametersValues).Visit(selectExpression);

if (!ReferenceEquals(selectExpression, inExpressionOptimized))
var nullSemanticsOptimized = nullSemanticsRewritingExpressionVisitor.Visit(selectExpression);
if (!nullSemanticsRewritingExpressionVisitor.CanCache)
{
canCache = false;
}

var nullParametersOptimized = new ParameterNullabilityBasedSqlExpressionOptimizingExpressionVisitor(
Dependencies.SqlExpressionFactory, UseRelationalNulls, parametersValues).Visit(inExpressionOptimized);
var nullParametersOptimized = new SqlExpressionOptimizingExpressionVisitor(
Dependencies.SqlExpressionFactory, UseRelationalNulls, parametersValues).Visit(nullSemanticsOptimized);

var fromSqlParameterOptimized = new FromSqlParameterApplyingExpressionVisitor(
Dependencies.SqlExpressionFactory,
Expand All @@ -63,163 +64,6 @@ public virtual (SelectExpression selectExpression, bool canCache) Optimize(
return (selectExpression: (SelectExpression)fromSqlParameterOptimized, canCache);
}

private sealed class ParameterNullabilityBasedSqlExpressionOptimizingExpressionVisitor : SqlExpressionOptimizingExpressionVisitor
{
private readonly IReadOnlyDictionary<string, object> _parametersValues;

public ParameterNullabilityBasedSqlExpressionOptimizingExpressionVisitor(
ISqlExpressionFactory sqlExpressionFactory,
bool useRelationalNulls,
IReadOnlyDictionary<string, object> parametersValues)
: base(sqlExpressionFactory, useRelationalNulls)
{
_parametersValues = parametersValues;
}

protected override Expression VisitSqlUnaryExpression(SqlUnaryExpression sqlUnaryExpression)
{
var result = base.VisitSqlUnaryExpression(sqlUnaryExpression);
if (result is SqlUnaryExpression newUnaryExpression
&& newUnaryExpression.Operand is SqlParameterExpression parameterOperand)
{
var parameterValue = _parametersValues[parameterOperand.Name];
if (sqlUnaryExpression.OperatorType == ExpressionType.Equal)
{
return SqlExpressionFactory.Constant(parameterValue == null, sqlUnaryExpression.TypeMapping);
}

if (sqlUnaryExpression.OperatorType == ExpressionType.NotEqual)
{
return SqlExpressionFactory.Constant(parameterValue != null, sqlUnaryExpression.TypeMapping);
}
}

return result;
}

protected override Expression VisitSqlBinaryExpression(SqlBinaryExpression sqlBinaryExpression)
{
var result = base.VisitSqlBinaryExpression(sqlBinaryExpression);
if (result is SqlBinaryExpression sqlBinaryResult)
{
var leftNullParameter = sqlBinaryResult.Left is SqlParameterExpression leftParameter
&& _parametersValues[leftParameter.Name] == null;

var rightNullParameter = sqlBinaryResult.Right is SqlParameterExpression rightParameter
&& _parametersValues[rightParameter.Name] == null;

if ((sqlBinaryResult.OperatorType == ExpressionType.Equal || sqlBinaryResult.OperatorType == ExpressionType.NotEqual)
&& (leftNullParameter || rightNullParameter))
{
return SimplifyNullComparisonExpression(
sqlBinaryResult.OperatorType,
sqlBinaryResult.Left,
sqlBinaryResult.Right,
leftNullParameter,
rightNullParameter,
sqlBinaryResult.TypeMapping);
}
}

return result;
}
}

private sealed class InExpressionValuesExpandingExpressionVisitor : ExpressionVisitor
{
private readonly ISqlExpressionFactory _sqlExpressionFactory;
private readonly IReadOnlyDictionary<string, object> _parametersValues;

public InExpressionValuesExpandingExpressionVisitor(
ISqlExpressionFactory sqlExpressionFactory, IReadOnlyDictionary<string, object> parametersValues)
{
_sqlExpressionFactory = sqlExpressionFactory;
_parametersValues = parametersValues;
}

public override Expression Visit(Expression expression)
{
if (expression is InExpression inExpression
&& inExpression.Values != null)
{
var inValues = new List<object>();
var hasNullValue = false;
RelationalTypeMapping typeMapping = null;

switch (inExpression.Values)
{
case SqlConstantExpression sqlConstant:
{
typeMapping = sqlConstant.TypeMapping;
var values = (IEnumerable)sqlConstant.Value;
foreach (var value in values)
{
if (value == null)
{
hasNullValue = true;
continue;
}

inValues.Add(value);
}

break;
}

case SqlParameterExpression sqlParameter:
{
typeMapping = sqlParameter.TypeMapping;
var values = (IEnumerable)_parametersValues[sqlParameter.Name];
foreach (var value in values)
{
if (value == null)
{
hasNullValue = true;
continue;
}

inValues.Add(value);
}

break;
}
}

var updatedInExpression = inValues.Count > 0
? _sqlExpressionFactory.In(
(SqlExpression)Visit(inExpression.Item),
_sqlExpressionFactory.Constant(inValues, typeMapping),
inExpression.IsNegated)
: null;

var nullCheckExpression = hasNullValue
? inExpression.IsNegated
? _sqlExpressionFactory.IsNotNull(inExpression.Item)
: _sqlExpressionFactory.IsNull(inExpression.Item)
: null;

if (updatedInExpression != null
&& nullCheckExpression != null)
{
return inExpression.IsNegated
? _sqlExpressionFactory.AndAlso(updatedInExpression, nullCheckExpression)
: _sqlExpressionFactory.OrElse(updatedInExpression, nullCheckExpression);
}

if (updatedInExpression == null
&& nullCheckExpression == null)
{
return _sqlExpressionFactory.Equal(
_sqlExpressionFactory.Constant(true), _sqlExpressionFactory.Constant(inExpression.IsNegated));
}

return (SqlExpression)updatedInExpression ?? nullCheckExpression;
}

return base.Visit(expression);
}
}

private sealed class FromSqlParameterApplyingExpressionVisitor : ExpressionVisitor
{
private readonly IDictionary<FromSqlExpression, Expression> _visitedFromSqlExpressions
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -9,8 +9,6 @@ namespace Microsoft.EntityFrameworkCore.Query
{
public class RelationalQueryTranslationPostprocessor : QueryTranslationPostprocessor
{
private readonly SqlExpressionOptimizingExpressionVisitor _sqlExpressionOptimizingExpressionVisitor;

public RelationalQueryTranslationPostprocessor(
QueryTranslationPostprocessorDependencies dependencies,
RelationalQueryTranslationPostprocessorDependencies relationalDependencies,
Expand All @@ -20,8 +18,6 @@ public RelationalQueryTranslationPostprocessor(
RelationalDependencies = relationalDependencies;
UseRelationalNulls = RelationalOptionsExtension.Extract(queryCompilationContext.ContextOptions).UseRelationalNulls;
SqlExpressionFactory = relationalDependencies.SqlExpressionFactory;
_sqlExpressionOptimizingExpressionVisitor
= new SqlExpressionOptimizingExpressionVisitor(SqlExpressionFactory, UseRelationalNulls);
}

protected virtual RelationalQueryTranslationPostprocessorDependencies RelationalDependencies { get; }
Expand All @@ -37,17 +33,12 @@ public override Expression Process(Expression query)
query = new CollectionJoinApplyingExpressionVisitor().Visit(query);
query = new TableAliasUniquifyingExpressionVisitor().Visit(query);
query = new CaseWhenFlatteningExpressionVisitor(SqlExpressionFactory).Visit(query);

if (!UseRelationalNulls)
{
query = new NullSemanticsRewritingExpressionVisitor(SqlExpressionFactory).Visit(query);
}

query = OptimizeSqlExpression(query);

return query;
}

protected virtual Expression OptimizeSqlExpression(Expression query) => _sqlExpressionOptimizingExpressionVisitor.Visit(query);
protected virtual Expression OptimizeSqlExpression(Expression query)
=> query;
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@

using System.Collections.Generic;
using Microsoft.EntityFrameworkCore.Query;
using Microsoft.EntityFrameworkCore.Query.Internal;
using Microsoft.EntityFrameworkCore.Query.SqlExpressions;

namespace Microsoft.EntityFrameworkCore.SqlServer.Query.Internal
Expand All @@ -25,7 +26,10 @@ public override (SelectExpression selectExpression, bool canCache) Optimize(
var searchConditionOptimized = (SelectExpression)new SearchConditionConvertingExpressionVisitor(Dependencies.SqlExpressionFactory)
.Visit(optimizedSelectExpression);

return (searchConditionOptimized, canCache);
var optimized = (SelectExpression)new SqlExpressionOptimizingExpressionVisitor(
Dependencies.SqlExpressionFactory, UseRelationalNulls, parametersValues).Visit(searchConditionOptimized);

return (optimized, canCache);
}
}
}
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
// Copyright (c) .NET Foundation. All rights reserved.
// Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information.

using System.Linq.Expressions;
using Microsoft.EntityFrameworkCore.Query;

namespace Microsoft.EntityFrameworkCore.SqlServer.Query.Internal
Expand All @@ -15,13 +14,5 @@ public SqlServerQueryTranslationPostprocessor(
: base(dependencies, relationalDependencies, queryCompilationContext)
{
}

public override Expression Process(Expression query)
{
query = base.Process(query);
query = new SearchConditionConvertingExpressionVisitor(SqlExpressionFactory).Visit(query);

return query;
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -316,7 +316,7 @@ public virtual void Contains_with_local_array_closure_false_with_null()
AssertQuery<NullSemanticsEntity1>(es => es.Where(e => !ids.Contains(e.NullableStringA)));
}

[ConditionalFact(Skip = "issue #14171")]
[ConditionalFact]
public virtual void Contains_with_local_nullable_array_closure_negated()
{
string[] ids = { "Foo" };
Expand Down Expand Up @@ -946,40 +946,58 @@ join e2 in _clientData._entities2
}
}

[ConditionalFact(Skip = "issue #14171")]
[ConditionalFact]
public virtual void Null_semantics_contains()
{
using var ctx = CreateContext();
var ids = new List<int?> { 1, 2 };
var query1 = ctx.Entities1.Where(e => ids.Contains(e.NullableIntA));
var result1 = query1.ToList();
AssertQuery<NullSemanticsEntity1>(es => es.Where(e => ids.Contains(e.NullableIntA)));
AssertQuery<NullSemanticsEntity1>(es => es.Where(e => !ids.Contains(e.NullableIntA)));

var query2 = ctx.Entities1.Where(e => !ids.Contains(e.NullableIntA));
var result2 = query2.ToList();
var ids2 = new List<int?> { 1, 2, null };
AssertQuery<NullSemanticsEntity1>(es => es.Where(e => ids2.Contains(e.NullableIntA)));
AssertQuery<NullSemanticsEntity1>(es => es.Where(e => !ids2.Contains(e.NullableIntA)));

var ids2 = new List<int?>
{
1,
2,
null
};
var query3 = ctx.Entities1.Where(e => ids.Contains(e.NullableIntA));
var result3 = query3.ToList();
AssertQuery<NullSemanticsEntity1>(es => es.Where(e => new List<int?> { 1, 2 }.Contains(e.NullableIntA)));
AssertQuery<NullSemanticsEntity1>(es => es.Where(e => !new List<int?> { 1, 2 }.Contains(e.NullableIntA)));
AssertQuery<NullSemanticsEntity1>(es => es.Where(e => new List<int?> { 1, 2, null }.Contains(e.NullableIntA)));
AssertQuery<NullSemanticsEntity1>(es => es.Where(e => !new List<int?> { 1, 2, null }.Contains(e.NullableIntA)));
}

var query4 = ctx.Entities1.Where(e => !ids.Contains(e.NullableIntA));
var result4 = query4.ToList();
[ConditionalFact]
public virtual void Null_semantics_contains_array_with_no_values()
{
var ids = new List<int?>();
AssertQuery<NullSemanticsEntity1>(es => es.Where(e => ids.Contains(e.NullableIntA)));
AssertQuery<NullSemanticsEntity1>(es => es.Where(e => !ids.Contains(e.NullableIntA)));

var query5 = ctx.Entities1.Where(e => !new List<int?> { 1, 2 }.Contains(e.NullableIntA));
var result5 = query5.ToList();
var ids2 = new List<int?> { null };
AssertQuery<NullSemanticsEntity1>(es => es.Where(e => ids2.Contains(e.NullableIntA)));
AssertQuery<NullSemanticsEntity1>(es => es.Where(e => !ids2.Contains(e.NullableIntA)));

var query6 = ctx.Entities1.Where(
e => !new List<int?>
{
1,
2,
null
}.Contains(e.NullableIntA));
var result6 = query6.ToList();
AssertQuery<NullSemanticsEntity1>(es => es.Where(e => new List<int?>().Contains(e.NullableIntA)));
AssertQuery<NullSemanticsEntity1>(es => es.Where(e => !new List<int?>().Contains(e.NullableIntA)));
AssertQuery<NullSemanticsEntity1>(es => es.Where(e => new List<int?> { null }.Contains(e.NullableIntA)));
AssertQuery<NullSemanticsEntity1>(es => es.Where(e => !new List<int?> { null }.Contains(e.NullableIntA)));
}

[ConditionalFact]
public virtual void Null_semantics_contains_non_nullable_argument()
{
var ids = new List<int?> { 1, 2, null };
AssertQuery<NullSemanticsEntity1>(es => es.Where(e => ids.Contains(e.IntA)));
AssertQuery<NullSemanticsEntity1>(es => es.Where(e => !ids.Contains(e.IntA)));

var ids2 = new List<int?> { 1, 2, };
AssertQuery<NullSemanticsEntity1>(es => es.Where(e => ids2.Contains(e.IntA)));
AssertQuery<NullSemanticsEntity1>(es => es.Where(e => !ids2.Contains(e.IntA)));

var ids3 = new List<int?>();
AssertQuery<NullSemanticsEntity1>(es => es.Where(e => ids3.Contains(e.IntA)));
AssertQuery<NullSemanticsEntity1>(es => es.Where(e => !ids3.Contains(e.IntA)));

var ids4 = new List<int?> { null };
AssertQuery<NullSemanticsEntity1>(es => es.Where(e => ids4.Contains(e.IntA)));
AssertQuery<NullSemanticsEntity1>(es => es.Where(e => !ids4.Contains(e.IntA)));
}

[ConditionalFact]
Expand Down Expand Up @@ -1044,6 +1062,26 @@ public virtual void Coalesce_not_equal()
AssertQuery<NullSemanticsEntity1>(es => es.Where(e => (e.NullableIntA ?? 0) != 0));
}

[ConditionalFact]
public virtual void Negated_order_comparison_on_non_nullable_arguments_gets_optimized()
{
var i = 1;
AssertQuery<NullSemanticsEntity1>(es => es.Where(e => !(e.IntA > i)));
AssertQuery<NullSemanticsEntity1>(es => es.Where(e => !(e.IntA >= i)));
AssertQuery<NullSemanticsEntity1>(es => es.Where(e => !(e.IntA < i)));
AssertQuery<NullSemanticsEntity1>(es => es.Where(e => !(e.IntA <= i)));
}

[ConditionalFact(Skip = "issue #9544")]
public virtual void Negated_order_comparison_on_nullable_arguments_doesnt_get_optimized()
{
var i = 1;
AssertQuery<NullSemanticsEntity1>(es => es.Where(e => !(e.NullableIntA > i)));
AssertQuery<NullSemanticsEntity1>(es => es.Where(e => !(e.NullableIntA >= i)));
AssertQuery<NullSemanticsEntity1>(es => es.Where(e => !(e.NullableIntA < i)));
AssertQuery<NullSemanticsEntity1>(es => es.Where(e => !(e.NullableIntA <= i)));
}

protected static TResult Maybe<TResult>(object caller, Func<TResult> expression)
where TResult : class
{
Expand Down
Loading

0 comments on commit 40e4479

Please sign in to comment.