diff --git a/src/EFCore/Query/Internal/QueryOptimizingExpressionVisitor.cs b/src/EFCore/Query/Internal/QueryOptimizingExpressionVisitor.cs index d59111e8a1d..1cba6595b3d 100644 --- a/src/EFCore/Query/Internal/QueryOptimizingExpressionVisitor.cs +++ b/src/EFCore/Query/Internal/QueryOptimizingExpressionVisitor.cs @@ -28,6 +28,61 @@ public class QueryOptimizingExpressionVisitor : ExpressionVisitor private static readonly Expression _constantNullString = Expression.Constant(null, typeof(string)); + /// + /// This is an internal API that supports the Entity Framework Core infrastructure and not subject to + /// the same compatibility standards as public APIs. It may be changed or removed without notice in + /// any release. You should only use it directly in your code with extreme caution and knowing that + /// doing so can result in application failures when updating to a new Entity Framework Core release. + /// + protected override Expression VisitMember(MemberExpression memberExpression) + { + var visitedExpression = base.VisitMember(memberExpression); + // Simplify (a != null ? new { Member = b, ... } : null).Member + // to a != null ? b : null + // Later null check removal will simplify it further + if (visitedExpression is MemberExpression visitedMemberExpression + && visitedMemberExpression.Expression is ConditionalExpression conditionalExpression + && conditionalExpression.Test is BinaryExpression binaryTest + && (binaryTest.NodeType == ExpressionType.Equal + || binaryTest.NodeType == ExpressionType.NotEqual) + // Exclude HasValue/Value over Nullable<> as they return non-null type and we don't have equivalent for it for null part + && !(conditionalExpression.Type.IsNullableValueType() + && (visitedMemberExpression.Member.Name == nameof(Nullable.HasValue) + || visitedMemberExpression.Member.Name == nameof(Nullable.Value)))) + { + var isLeftNullConstant = IsNullConstant(binaryTest.Left); + var isRightNullConstant = IsNullConstant(binaryTest.Right); + + if (isLeftNullConstant != isRightNullConstant + && ((binaryTest.NodeType == ExpressionType.Equal + && IsNullConstant(conditionalExpression.IfTrue)) + || (binaryTest.NodeType == ExpressionType.NotEqual + && IsNullConstant(conditionalExpression.IfFalse)))) + { + var nonNullExpression = binaryTest.NodeType == ExpressionType.Equal + ? conditionalExpression.IfFalse + : conditionalExpression.IfTrue; + + // Use ReplacingExpressionVisitor rather than creating MemberExpression + // So that member access chain on NewExpression/MemberInitExpression condenses + nonNullExpression = ReplacingExpressionVisitor.Replace( + visitedMemberExpression.Expression, nonNullExpression, visitedMemberExpression); + if (!nonNullExpression.Type.IsNullableType()) + { + nonNullExpression = Expression.Convert(nonNullExpression, nonNullExpression.Type.MakeNullable()); + } + var nullExpression = Expression.Constant(null, nonNullExpression.Type); + + return Expression.Condition( + conditionalExpression.Test, + binaryTest.NodeType == ExpressionType.Equal ? nullExpression : nonNullExpression, + binaryTest.NodeType == ExpressionType.Equal ? nonNullExpression : nullExpression); + } + } + + return visitedExpression; + } + /// /// This is an internal API that supports the Entity Framework Core infrastructure and not subject to /// the same compatibility standards as public APIs. It may be changed or removed without notice in @@ -228,5 +283,9 @@ when unaryExpression.IsLogicalNot(): return false; } + + private bool IsNullConstant(Expression expression) + => expression is ConstantExpression constantExpression + && constantExpression.Value == null; } } diff --git a/src/EFCore/Query/QueryTranslationPreprocessor.cs b/src/EFCore/Query/QueryTranslationPreprocessor.cs index 015a842b214..5b27453c70a 100644 --- a/src/EFCore/Query/QueryTranslationPreprocessor.cs +++ b/src/EFCore/Query/QueryTranslationPreprocessor.cs @@ -61,6 +61,7 @@ public virtual Expression Process([NotNull] Expression query) query = new NavigationExpandingExpressionVisitor(this, QueryCompilationContext, Dependencies.EvaluatableExpressionFilter) .Expand(query); query = new QueryOptimizingExpressionVisitor().Visit(query); + query = new NullCheckRemovingExpressionVisitor().Visit(query); return query; } diff --git a/test/EFCore.SqlServer.FunctionalTests/Query/QueryBugsTest.cs b/test/EFCore.SqlServer.FunctionalTests/Query/QueryBugsTest.cs index 9e6707da8e1..7e96843bdee 100644 --- a/test/EFCore.SqlServer.FunctionalTests/Query/QueryBugsTest.cs +++ b/test/EFCore.SqlServer.FunctionalTests/Query/QueryBugsTest.cs @@ -8375,6 +8375,159 @@ private SqlServerTestStore CreateDatabase21807() #endregion + #region Issue20711 + + [ConditionalFact] + public virtual void Simplify_member_access_on_null_conditional_check() + { + using (CreateDatabase20711()) + { + using var context = new MyContext20711(_options); + + var query = context.Set() + .Select( + s => new SubRegionDto20711 + { + Id = s.Id, + Name = s.Name, + Region = (s.Region == null) + ? null + : new OptionDto20711 + { + Id = s.Region.Id, + Name = s.Region.Name + } + }) + .OrderBy(s => s.Region.Name) + .ToList(); + + AssertSql( + @"SELECT [s].[Id], [s].[Name], CAST(0 AS bit), [r].[Id], [r].[Name] +FROM [SubRegion20711] AS [s] +INNER JOIN [Region20711] AS [r] ON [s].[RegionId] = [r].[Id] +ORDER BY [r].[Name]"); + } + } + + [ConditionalFact] + public virtual void Simplify_member_access_on_null_conditional_check_nested() + { + using (CreateDatabase20711()) + { + using var context = new MyContext20711(_options); + + var query = context.Set() + .Select( + s => new + { + Id = s.Id, + Name = s.Name, + Region = (s.Region == null) + ? null + : new + { + Nested = new + { + Name = s.Region.Name + } + } + }) + .OrderBy(s => s.Region.Nested.Name) + .ToList(); + + AssertSql( + @"SELECT [s].[Id], [s].[Name], CAST(0 AS bit), [r].[Name] +FROM [SubRegion20711] AS [s] +INNER JOIN [Region20711] AS [r] ON [s].[RegionId] = [r].[Id] +ORDER BY [r].[Name]"); + } + } + + [ConditionalFact] + public virtual void Simplify_member_access_on_null_conditional_check_nested_in_where() + { + using (CreateDatabase20711()) + { + using var context = new MyContext20711(_options); + + var query = context.Set() + .Select( + s => new + { + Id = s.Id, + Name = s.Name, + Region = (s.Region == null) + ? null + : new + { + Nested = new + { + Name = s.Region.Name + } + } + }) + .Where(s => s.Region.Nested.Name == "A") + .ToList(); + + AssertSql( + @"SELECT [s].[Id], [s].[Name], CAST(0 AS bit), [r].[Name] +FROM [SubRegion20711] AS [s] +INNER JOIN [Region20711] AS [r] ON [s].[RegionId] = [r].[Id] +WHERE [r].[Name] = N'A'"); + } + } + + private class Region20711 + { + public int Id { get; set; } + public string Name { get; set; } + public ICollection SubRegion { get; set; } + } + + private class SubRegion20711 + { + public int Id { get; set; } + public int RegionId { get; set; } + public string Name { get; set; } + public Region20711 Region { get; set; } + } + + private class SubRegionDto20711 + { + public int Id { get; set; } + public string Name { get; set; } + public OptionDto20711 Region { get; set; } + } + + private class OptionDto20711 + { + public int Id { get; set; } + public string Name { get; set; } + } + + private class MyContext20711 : DbContext + { + public MyContext20711(DbContextOptions options) + : base(options) + { + } + + protected override void OnModelCreating(ModelBuilder modelBuilder) + { + modelBuilder.Entity(); + } + } + + private SqlServerTestStore CreateDatabase20711() + => CreateTestStore( + () => new MyContext20711(_options), + context => + { + ClearLog(); + }); + + #endregion + private DbContextOptions _options; private SqlServerTestStore CreateTestStore(