diff --git a/src/EFCore/Query/ReplacingExpressionVisitor.cs b/src/EFCore/Query/ReplacingExpressionVisitor.cs index ccbb066bdc8..b7dfa94e1ca 100644 --- a/src/EFCore/Query/ReplacingExpressionVisitor.cs +++ b/src/EFCore/Query/ReplacingExpressionVisitor.cs @@ -5,6 +5,7 @@ using System.Collections.Generic; using System.Linq; using System.Linq.Expressions; +using System.Reflection; using Microsoft.EntityFrameworkCore.Infrastructure; using Microsoft.EntityFrameworkCore.Internal; @@ -12,7 +13,8 @@ namespace Microsoft.EntityFrameworkCore.Query { public class ReplacingExpressionVisitor : ExpressionVisitor { - private readonly bool _quirkMode; + private readonly bool _quirkMode19737; + private readonly bool _quirkMode19087; private readonly Expression[] _originals; private readonly Expression[] _replacements; @@ -26,9 +28,10 @@ public static Expression Replace(Expression original, Expression replacement, Ex public ReplacingExpressionVisitor(Expression[] originals, Expression[] replacements) { - _quirkMode = AppContext.TryGetSwitch("Microsoft.EntityFrameworkCore.Issue19737", out var enabled) && enabled; + _quirkMode19737 = AppContext.TryGetSwitch("Microsoft.EntityFrameworkCore.Issue19737", out var enabled) && enabled; + _quirkMode19087 = AppContext.TryGetSwitch("Microsoft.EntityFrameworkCore.Issue19087", out var enabled2) && enabled2; - if (_quirkMode) + if (_quirkMode19737) { _quirkReplacements = new Dictionary(); for (var i = 0; i < originals.Length; i++) @@ -45,9 +48,9 @@ public ReplacingExpressionVisitor(Expression[] originals, Expression[] replaceme public ReplacingExpressionVisitor(IDictionary replacements) { - _quirkMode = AppContext.TryGetSwitch("Microsoft.EntityFrameworkCore.Issue19737", out var enabled) && enabled; + _quirkMode19737 = AppContext.TryGetSwitch("Microsoft.EntityFrameworkCore.Issue19737", out var enabled) && enabled; - if (_quirkMode) + if (_quirkMode19737) { _quirkReplacements = replacements; } @@ -65,7 +68,7 @@ public override Expression Visit(Expression expression) return expression; } - if (_quirkMode) + if (_quirkMode19737) { if (_quirkReplacements.TryGetValue(expression, out var replacement)) { @@ -107,11 +110,23 @@ protected override Expression VisitMember(MemberExpression memberExpression) } } - if (innerExpression is MemberInitExpression memberInitExpression - && memberInitExpression.Bindings.SingleOrDefault( - mb => mb.Member == memberExpression.Member) is MemberAssignment memberAssignment) + if (_quirkMode19087) { - return memberAssignment.Expression; + if (innerExpression is MemberInitExpression memberInitExpression + && memberInitExpression.Bindings.SingleOrDefault( + mb => mb.Member == memberExpression.Member) is MemberAssignment memberAssignment) + { + return memberAssignment.Expression; + } + } + else + { + if (innerExpression is MemberInitExpression memberInitExpression + && memberInitExpression.Bindings.SingleOrDefault( + mb => mb.Member.IsSameAs(memberExpression.Member)) is MemberAssignment memberAssignment) + { + return memberAssignment.Expression; + } } return memberExpression.Update(innerExpression); diff --git a/src/Shared/MemberInfoExtensions.cs b/src/Shared/MemberInfoExtensions.cs index 13cd017d2e1..380871815dd 100644 --- a/src/Shared/MemberInfoExtensions.cs +++ b/src/Shared/MemberInfoExtensions.cs @@ -31,16 +31,5 @@ public static string GetSimpleMemberName(this MemberInfo member) var index = name.LastIndexOf('.'); return index >= 0 ? name.Substring(index + 1) : name; } - - private class MemberInfoComparer : IEqualityComparer - { - public static readonly MemberInfoComparer Instance = new MemberInfoComparer(); - - public bool Equals(MemberInfo x, MemberInfo y) - => x.IsSameAs(y); - - public int GetHashCode(MemberInfo obj) - => obj.GetHashCode(); - } } } diff --git a/test/EFCore.Specification.Tests/Query/InheritanceTestBase.cs b/test/EFCore.Specification.Tests/Query/InheritanceTestBase.cs index f4aa852af0c..4f427b42a8d 100644 --- a/test/EFCore.Specification.Tests/Query/InheritanceTestBase.cs +++ b/test/EFCore.Specification.Tests/Query/InheritanceTestBase.cs @@ -2,6 +2,7 @@ // Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information. using System.Linq; +using System.Linq.Expressions; using Microsoft.EntityFrameworkCore.Infrastructure; using Microsoft.EntityFrameworkCore.Storage; using Microsoft.EntityFrameworkCore.TestModels.Inheritance; @@ -593,8 +594,6 @@ public virtual void Setting_foreign_key_to_a_different_type_throws() } } - protected virtual bool EnforcesFkConstraints => true; - [ConditionalFact] public virtual void Byte_enum_value_constant_used_in_projection() { @@ -608,8 +607,32 @@ public virtual void Byte_enum_value_constant_used_in_projection() } } + [ConditionalFact] + public virtual void Member_access_on_intermediate_type_works() + { + using var context = CreateContext(); + var query = context.Set().Select(k => new Kiwi { Name = k.Name }); + + var parameter = Expression.Parameter(query.ElementType, "p"); + var property = Expression.Property(parameter, "Name"); + var getProperty = Expression.Lambda(property, new[] { parameter }); + + var expression = Expression.Call(typeof(Queryable), nameof(Queryable.OrderBy), + new[] { query.ElementType, typeof(string) }, + new[] { query.Expression, Expression.Quote(getProperty) }); + + query = query.Provider.CreateQuery(expression); + + var result = query.ToList(); + + var kiwi = Assert.Single(result); + Assert.Equal("Great spotted kiwi", kiwi.Name); + } + protected InheritanceContext CreateContext() => Fixture.CreateContext(); + protected virtual bool EnforcesFkConstraints => true; + protected virtual void ClearLog() { } diff --git a/test/EFCore.SqlServer.FunctionalTests/Query/InheritanceSqlServerTest.cs b/test/EFCore.SqlServer.FunctionalTests/Query/InheritanceSqlServerTest.cs index 75683b321a1..02496474422 100644 --- a/test/EFCore.SqlServer.FunctionalTests/Query/InheritanceSqlServerTest.cs +++ b/test/EFCore.SqlServer.FunctionalTests/Query/InheritanceSqlServerTest.cs @@ -502,6 +502,17 @@ FROM [Animal] AS [a0] WHERE CAST(0 AS bit) = CAST(1 AS bit)"); } + public override void Member_access_on_intermediate_type_works() + { + base.Member_access_on_intermediate_type_works(); + + AssertSql( + @"SELECT [a].[Name] +FROM [Animal] AS [a] +WHERE [a].[Discriminator] = N'Kiwi' +ORDER BY [a].[Name]"); + } + protected override void UseTransaction(DatabaseFacade facade, IDbContextTransaction transaction) => facade.UseTransaction(transaction.GetDbTransaction());