Skip to content

Commit

Permalink
Query: Simplify member access on conditional null check constructing …
Browse files Browse the repository at this point in the history
…anonymous type

Resolves #20711
  • Loading branch information
smitpatel committed Aug 12, 2020
1 parent b08b7bb commit ddb1974
Show file tree
Hide file tree
Showing 3 changed files with 213 additions and 0 deletions.
59 changes: 59 additions & 0 deletions src/EFCore/Query/Internal/QueryOptimizingExpressionVisitor.cs
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,61 @@ public class QueryOptimizingExpressionVisitor : ExpressionVisitor

private static readonly Expression _constantNullString = Expression.Constant(null, typeof(string));

/// <summary>
/// This is an internal API that supports the Entity Framework Core infrastructure and not subject to
/// the same compatibility standards as public APIs. It may be changed or removed without notice in
/// any release. You should only use it directly in your code with extreme caution and knowing that
/// doing so can result in application failures when updating to a new Entity Framework Core release.
/// </summary>
protected override Expression VisitMember(MemberExpression memberExpression)
{
var visitedExpression = base.VisitMember(memberExpression);
// Simplify (a != null ? new { Member = b, ... } : null).Member
// to a != null ? b : null
// Later null check removal will simplify it further
if (visitedExpression is MemberExpression visitedMemberExpression
&& visitedMemberExpression.Expression is ConditionalExpression conditionalExpression
&& conditionalExpression.Test is BinaryExpression binaryTest
&& (binaryTest.NodeType == ExpressionType.Equal
|| binaryTest.NodeType == ExpressionType.NotEqual)
// Exclude HasValue/Value over Nullable<> as they return non-null type and we don't have equivalent for it for null part
&& !(conditionalExpression.Type.IsNullableValueType()
&& (visitedMemberExpression.Member.Name == nameof(Nullable<int>.HasValue)
|| visitedMemberExpression.Member.Name == nameof(Nullable<int>.Value))))
{
var isLeftNullConstant = IsNullConstant(binaryTest.Left);
var isRightNullConstant = IsNullConstant(binaryTest.Right);

if (isLeftNullConstant != isRightNullConstant
&& ((binaryTest.NodeType == ExpressionType.Equal
&& IsNullConstant(conditionalExpression.IfTrue))
|| (binaryTest.NodeType == ExpressionType.NotEqual
&& IsNullConstant(conditionalExpression.IfFalse))))
{
var nonNullExpression = binaryTest.NodeType == ExpressionType.Equal
? conditionalExpression.IfFalse
: conditionalExpression.IfTrue;

// Use ReplacingExpressionVisitor rather than creating MemberExpression
// So that member access chain on NewExpression/MemberInitExpression condenses
nonNullExpression = ReplacingExpressionVisitor.Replace(
visitedMemberExpression.Expression, nonNullExpression, visitedMemberExpression);
if (!nonNullExpression.Type.IsNullableType())
{
nonNullExpression = Expression.Convert(nonNullExpression, nonNullExpression.Type.MakeNullable());
}
var nullExpression = Expression.Constant(null, nonNullExpression.Type);

return Expression.Condition(
conditionalExpression.Test,
binaryTest.NodeType == ExpressionType.Equal ? nullExpression : nonNullExpression,
binaryTest.NodeType == ExpressionType.Equal ? nonNullExpression : nullExpression);
}
}

return visitedExpression;
}

/// <summary>
/// This is an internal API that supports the Entity Framework Core infrastructure and not subject to
/// the same compatibility standards as public APIs. It may be changed or removed without notice in
Expand Down Expand Up @@ -228,5 +283,9 @@ when unaryExpression.IsLogicalNot():

return false;
}

private bool IsNullConstant(Expression expression)
=> expression is ConstantExpression constantExpression
&& constantExpression.Value == null;
}
}
1 change: 1 addition & 0 deletions src/EFCore/Query/QueryTranslationPreprocessor.cs
Original file line number Diff line number Diff line change
Expand Up @@ -61,6 +61,7 @@ public virtual Expression Process([NotNull] Expression query)
query = new NavigationExpandingExpressionVisitor(this, QueryCompilationContext, Dependencies.EvaluatableExpressionFilter)
.Expand(query);
query = new QueryOptimizingExpressionVisitor().Visit(query);
query = new NullCheckRemovingExpressionVisitor().Visit(query);

return query;
}
Expand Down
153 changes: 153 additions & 0 deletions test/EFCore.SqlServer.FunctionalTests/Query/QueryBugsTest.cs
Original file line number Diff line number Diff line change
Expand Up @@ -8375,6 +8375,159 @@ private SqlServerTestStore CreateDatabase21807()

#endregion

#region Issue20711

[ConditionalFact]
public virtual void Simplify_member_access_on_null_conditional_check()
{
using (CreateDatabase20711())
{
using var context = new MyContext20711(_options);

var query = context.Set<SubRegion20711>()
.Select(
s => new SubRegionDto20711
{
Id = s.Id,
Name = s.Name,
Region = (s.Region == null)
? null
: new OptionDto20711
{
Id = s.Region.Id,
Name = s.Region.Name
}
})
.OrderBy(s => s.Region.Name)
.ToList();

AssertSql(
@"SELECT [s].[Id], [s].[Name], CAST(0 AS bit), [r].[Id], [r].[Name]
FROM [SubRegion20711] AS [s]
INNER JOIN [Region20711] AS [r] ON [s].[RegionId] = [r].[Id]
ORDER BY [r].[Name]");
}
}

[ConditionalFact]
public virtual void Simplify_member_access_on_null_conditional_check_nested()
{
using (CreateDatabase20711())
{
using var context = new MyContext20711(_options);

var query = context.Set<SubRegion20711>()
.Select(
s => new
{
Id = s.Id,
Name = s.Name,
Region = (s.Region == null)
? null
: new
{
Nested = new
{
Name = s.Region.Name
}
}
})
.OrderBy(s => s.Region.Nested.Name)
.ToList();

AssertSql(
@"SELECT [s].[Id], [s].[Name], CAST(0 AS bit), [r].[Name]
FROM [SubRegion20711] AS [s]
INNER JOIN [Region20711] AS [r] ON [s].[RegionId] = [r].[Id]
ORDER BY [r].[Name]");
}
}

[ConditionalFact]
public virtual void Simplify_member_access_on_null_conditional_check_nested_in_where()
{
using (CreateDatabase20711())
{
using var context = new MyContext20711(_options);

var query = context.Set<SubRegion20711>()
.Select(
s => new
{
Id = s.Id,
Name = s.Name,
Region = (s.Region == null)
? null
: new
{
Nested = new
{
Name = s.Region.Name
}
}
})
.Where(s => s.Region.Nested.Name == "A")
.ToList();

AssertSql(
@"SELECT [s].[Id], [s].[Name], CAST(0 AS bit), [r].[Name]
FROM [SubRegion20711] AS [s]
INNER JOIN [Region20711] AS [r] ON [s].[RegionId] = [r].[Id]
WHERE [r].[Name] = N'A'");
}
}

private class Region20711
{
public int Id { get; set; }
public string Name { get; set; }
public ICollection<SubRegion20711> SubRegion { get; set; }
}

private class SubRegion20711
{
public int Id { get; set; }
public int RegionId { get; set; }
public string Name { get; set; }
public Region20711 Region { get; set; }
}

private class SubRegionDto20711
{
public int Id { get; set; }
public string Name { get; set; }
public OptionDto20711 Region { get; set; }
}

private class OptionDto20711
{
public int Id { get; set; }
public string Name { get; set; }
}

private class MyContext20711 : DbContext
{
public MyContext20711(DbContextOptions options)
: base(options)
{
}

protected override void OnModelCreating(ModelBuilder modelBuilder)
{
modelBuilder.Entity<Region20711>();
}
}

private SqlServerTestStore CreateDatabase20711()
=> CreateTestStore(
() => new MyContext20711(_options),
context =>
{
ClearLog();
});

#endregion

private DbContextOptions _options;

private SqlServerTestStore CreateTestStore<TContext>(
Expand Down

0 comments on commit ddb1974

Please sign in to comment.