diff --git a/src/EFCore.InMemory/Query/Internal/InMemoryExpressionTranslatingExpressionVisitor.cs b/src/EFCore.InMemory/Query/Internal/InMemoryExpressionTranslatingExpressionVisitor.cs index 6bf26ae4458..f461d69e13d 100644 --- a/src/EFCore.InMemory/Query/Internal/InMemoryExpressionTranslatingExpressionVisitor.cs +++ b/src/EFCore.InMemory/Query/Internal/InMemoryExpressionTranslatingExpressionVisitor.cs @@ -99,26 +99,27 @@ protected override Expression VisitBinary(BinaryExpression binaryExpression) newRight = ConvertToNullable(newRight); } - var propertyFindingExpressionVisitor = new PropertyFindingExpressionVisitor(_model); - var property = propertyFindingExpressionVisitor.Find(binaryExpression.Left) - ?? propertyFindingExpressionVisitor.Find(binaryExpression.Right); - - if (property != null) + if (binaryExpression.NodeType == ExpressionType.Equal + || binaryExpression.NodeType == ExpressionType.NotEqual) { - var comparer = property.GetValueComparer(); - - if (comparer != null - && comparer.Type.IsAssignableFrom(newLeft.Type) - && comparer.Type.IsAssignableFrom(newRight.Type)) + var property = FindProperty(newLeft) ?? FindProperty(newRight); + if (property != null) { - if (binaryExpression.NodeType == ExpressionType.Equal) - { - return comparer.ExtractEqualsBody(newLeft, newRight); - } + var comparer = property.GetValueComparer(); - if (binaryExpression.NodeType == ExpressionType.NotEqual) + if (comparer != null + && comparer.Type.IsAssignableFrom(newLeft.Type) + && comparer.Type.IsAssignableFrom(newRight.Type)) { - return Expression.IsFalse(comparer.ExtractEqualsBody(newLeft, newRight)); + if (binaryExpression.NodeType == ExpressionType.Equal) + { + return comparer.ExtractEqualsBody(newLeft, newRight); + } + + if (binaryExpression.NodeType == ExpressionType.NotEqual) + { + return Expression.IsFalse(comparer.ExtractEqualsBody(newLeft, newRight)); + } } } } @@ -742,6 +743,26 @@ private Expression GetPredicateOnGrouping( throw new InvalidOperationException(CoreStrings.TranslationFailed(methodCallExpression.Print())); } + private IProperty FindProperty(Expression expression) + { + if (expression.NodeType == ExpressionType.Convert + && expression.Type.IsNullableType() + && expression is UnaryExpression unaryExpression + && expression.Type.UnwrapNullableType() == unaryExpression.Type) + { + expression = unaryExpression.Operand; + } + + if (expression is MethodCallExpression readValueMethodCall + && readValueMethodCall.Method.IsGenericMethod + && readValueMethodCall.Method.GetGenericMethodDefinition() == ExpressionExtensions.ValueBufferTryReadValueMethod) + { + return (IProperty)((ConstantExpression)readValueMethodCall.Arguments[2]).Value; + } + + return null; + } + [DebuggerStepThrough] private static bool TranslationFailed(Expression original, Expression translation) => original != null && (translation == null || translation is EntityReferenceExpression); @@ -857,74 +878,6 @@ public override Expression Visit(Expression expression) } } - private sealed class PropertyFindingExpressionVisitor : ExpressionVisitor - { - private readonly IModel _model; - private IProperty _property; - - public PropertyFindingExpressionVisitor(IModel model) - { - _model = model; - } - - public IProperty Find(Expression expression) - { - Visit(expression); - - return _property; - } - - protected override Expression VisitMember(MemberExpression memberExpression) - { - var entityType = FindEntityType(memberExpression.Expression); - if (entityType != null) - { - _property = GetProperty(entityType, MemberIdentity.Create(memberExpression.Member)); - } - - return memberExpression; - } - - protected override Expression VisitMethodCall(MethodCallExpression methodCallExpression) - { - if (methodCallExpression.TryGetEFPropertyArguments(out var source, out var propertyName) - || methodCallExpression.TryGetIndexerArguments(_model, out source, out propertyName)) - { - var entityType = FindEntityType(source); - if (entityType != null) - { - _property = GetProperty(entityType, MemberIdentity.Create(propertyName)); - } - } - - return methodCallExpression; - } - - private static IProperty GetProperty(IEntityType entityType, MemberIdentity memberIdentity) - => memberIdentity.MemberInfo != null - ? entityType.FindProperty(memberIdentity.MemberInfo) - : entityType.FindProperty(memberIdentity.Name); - - private static IEntityType FindEntityType(Expression source) - { - source = source.UnwrapTypeConversion(out var convertedType); - - if (source is EntityShaperExpression entityShaperExpression) - { - var entityType = entityShaperExpression.EntityType; - if (convertedType != null) - { - entityType = entityType.GetRootType().GetDerivedTypesInclusive() - .FirstOrDefault(et => et.ClrType == convertedType); - } - - return entityType; - } - - return null; - } - } - private sealed class EntityReferenceExpression : Expression { public EntityReferenceExpression(EntityShaperExpression parameter)