From 555e4ebb681a5e118e7e515128b0a8df2c9fda4e Mon Sep 17 00:00:00 2001 From: Smit Patel Date: Thu, 26 Mar 2020 20:28:25 -0700 Subject: [PATCH] Query: Add ability for translation pipeline to translate property access over a subquery This removes need to run subquery member pushdown after entity equality Part of #20164 Part of #18923 This does not change anything on Cosmos as Cosmos does not have subquery scalar selection --- ...yExpressionTranslatingExpressionVisitor.cs | 749 +++++++++--------- ...lationalSqlTranslatingExpressionVisitor.cs | 631 ++++++++------- src/EFCore/Properties/CoreStrings.Designer.cs | 10 +- src/EFCore/Properties/CoreStrings.resx | 4 +- 4 files changed, 718 insertions(+), 676 deletions(-) diff --git a/src/EFCore.InMemory/Query/Internal/InMemoryExpressionTranslatingExpressionVisitor.cs b/src/EFCore.InMemory/Query/Internal/InMemoryExpressionTranslatingExpressionVisitor.cs index b1a31ce5550..49fbb379e1e 100644 --- a/src/EFCore.InMemory/Query/Internal/InMemoryExpressionTranslatingExpressionVisitor.cs +++ b/src/EFCore.InMemory/Query/Internal/InMemoryExpressionTranslatingExpressionVisitor.cs @@ -25,6 +25,12 @@ public class InMemoryExpressionTranslatingExpressionVisitor : ExpressionVisitor { private const string _compiledQueryParameterPrefix = "__"; + private static readonly MemberInfo _valueBufferIsEmpty = typeof(ValueBuffer).GetMember(nameof(ValueBuffer.IsEmpty))[0]; + + private static readonly MethodInfo _getParameterValueMethodInfo + = typeof(InMemoryExpressionTranslatingExpressionVisitor) + .GetTypeInfo().GetDeclaredMethod(nameof(GetParameterValue)); + private static readonly MethodInfo _likeMethodInfo = typeof(DbFunctionsExtensions).GetRuntimeMethod( nameof(DbFunctionsExtensions.Like), @@ -52,7 +58,7 @@ private static string BuildEscapeRegexCharsPattern(IEnumerable regexSpecia => string.Join("|", regexSpecialChars.Select(c => @"\" + c)); private readonly QueryableMethodTranslatingExpressionVisitor _queryableMethodTranslatingExpressionVisitor; - private readonly EntityProjectionFindingExpressionVisitor _entityProjectionFindingExpressionVisitor; + private readonly EntityReferenceFindingExpressionVisitor _entityReferenceFindingExpressionVisitor; private readonly IModel _model; public InMemoryExpressionTranslatingExpressionVisitor( @@ -60,113 +66,15 @@ public InMemoryExpressionTranslatingExpressionVisitor( [NotNull] IModel model) { _queryableMethodTranslatingExpressionVisitor = queryableMethodTranslatingExpressionVisitor; - _entityProjectionFindingExpressionVisitor = new EntityProjectionFindingExpressionVisitor(); + _entityReferenceFindingExpressionVisitor = new EntityReferenceFindingExpressionVisitor(); _model = model; } - private sealed class EntityProjectionFindingExpressionVisitor : ExpressionVisitor - { - private bool _found; - - public bool Find(Expression expression) - { - _found = false; - - Visit(expression); - - return _found; - } - - public override Expression Visit(Expression expression) - { - if (_found) - { - return expression; - } - - if (expression is EntityProjectionExpression) - { - _found = true; - return expression; - } - - return base.Visit(expression); - } - } - - private sealed class PropertyFindingExpressionVisitor : ExpressionVisitor - { - private readonly IModel _model; - private IProperty _property; - - public PropertyFindingExpressionVisitor(IModel model) - { - _model = model; - } - - public IProperty Find(Expression expression) - { - Visit(expression); - - return _property; - } - - protected override Expression VisitMember(MemberExpression memberExpression) - { - var entityType = FindEntityType(memberExpression.Expression); - if (entityType != null) - { - _property = GetProperty(entityType, MemberIdentity.Create(memberExpression.Member)); - } - - return memberExpression; - } - - protected override Expression VisitMethodCall(MethodCallExpression methodCallExpression) - { - if (methodCallExpression.TryGetEFPropertyArguments(out var source, out var propertyName) - || methodCallExpression.TryGetIndexerArguments(_model, out source, out propertyName)) - { - var entityType = FindEntityType(source); - if (entityType != null) - { - _property = GetProperty(entityType, MemberIdentity.Create(propertyName)); - } - } - - return methodCallExpression; - } - - private static IProperty GetProperty(IEntityType entityType, MemberIdentity memberIdentity) - => memberIdentity.MemberInfo != null - ? entityType.FindProperty(memberIdentity.MemberInfo) - : entityType.FindProperty(memberIdentity.Name); - - private static IEntityType FindEntityType(Expression source) - { - source = source.UnwrapTypeConversion(out var convertedType); - - if (source is EntityShaperExpression entityShaperExpression) - { - var entityType = entityShaperExpression.EntityType; - if (convertedType != null) - { - entityType = entityType.GetRootType().GetDerivedTypesInclusive() - .FirstOrDefault(et => et.ClrType == convertedType); - } - - return entityType; - } - - return null; - } - } - public virtual Expression Translate([NotNull] Expression expression) { var result = Visit(expression); - return _entityProjectionFindingExpressionVisitor.Find(result) + return _entityReferenceFindingExpressionVisitor.Find(result) ? null : result; } @@ -254,18 +162,36 @@ protected override Expression VisitConditional(ConditionalExpression conditional return Expression.Condition(test, ifTrue, ifFalse); } - protected override Expression VisitMember(MemberExpression memberExpression) + protected override Expression VisitExtension(Expression extensionExpression) { - Check.NotNull(memberExpression, nameof(memberExpression)); + Check.NotNull(extensionExpression, nameof(extensionExpression)); - if (TryBindMember( - memberExpression.Expression, - MemberIdentity.Create(memberExpression.Member), - memberExpression.Type, - out var result)) + switch (extensionExpression) { - return result; + case EntityProjectionExpression _: + return extensionExpression; + + case EntityShaperExpression entityShaperExpression: + return new EntityReferenceExpression(entityShaperExpression); + + case ProjectionBindingExpression projectionBindingExpression: + return projectionBindingExpression.ProjectionMember != null + ? ((InMemoryQueryExpression)projectionBindingExpression.QueryExpression) + .GetMappedProjection(projectionBindingExpression.ProjectionMember) + : null; + + default: + return null; } + } + + protected override Expression VisitInvocation(InvocationExpression node) => null; + protected override Expression VisitLambda(Expression node) => null; + protected override Expression VisitListInit(ListInitExpression node) => null; + + protected override Expression VisitMember(MemberExpression memberExpression) + { + Check.NotNull(memberExpression, nameof(memberExpression)); var innerExpression = Visit(memberExpression.Expression); if (memberExpression.Expression != null @@ -274,6 +200,11 @@ protected override Expression VisitMember(MemberExpression memberExpression) return null; } + if (TryBindMember(innerExpression, MemberIdentity.Create(memberExpression.Member), memberExpression.Type) is Expression result) + { + return result; + } + var updatedMemberExpression = (Expression)memberExpression.Update(innerExpression); if (innerExpression != null && innerExpression.Type.IsNullableType() @@ -295,104 +226,20 @@ static bool ShouldApplyNullProtectionForMemberAccess(Type callerType, string mem && (memberName == nameof(Nullable.Value) || memberName == nameof(Nullable.HasValue))); } - private bool TryBindMember(Expression source, MemberIdentity memberIdentity, Type type, out Expression result) - { - source = source.UnwrapTypeConversion(out var convertedType); - result = null; - if (source is EntityShaperExpression entityShaperExpression) - { - var entityType = entityShaperExpression.EntityType; - if (convertedType != null) - { - entityType = entityType.GetRootType().GetDerivedTypesInclusive() - .FirstOrDefault(et => et.ClrType == convertedType); - if (entityType == null) - { - return false; - } - } - - var property = memberIdentity.MemberInfo != null - ? entityType.FindProperty(memberIdentity.MemberInfo) - : entityType.FindProperty(memberIdentity.Name); - if (property != null - && Visit(entityShaperExpression.ValueBufferExpression) is EntityProjectionExpression entityProjectionExpression - && (entityProjectionExpression.EntityType.IsAssignableFrom(property.DeclaringEntityType) - || property.DeclaringEntityType.IsAssignableFrom(entityProjectionExpression.EntityType))) - { - result = BindProperty(entityProjectionExpression, property); - - // if the result type change was just nullability change e.g from int to int? - // we want to preserve the new type for null propagation - if (result.Type != type - && !(result.Type.IsNullableType() - && !type.IsNullableType() - && result.Type.UnwrapNullableType() == type)) - { - result = Expression.Convert(result, type); - } - - return true; - } - } - - return false; - } - - private static bool IsConvertedToNullable(Expression result, Expression original) - => result.Type.IsNullableType() - && !original.Type.IsNullableType() - && result.Type.UnwrapNullableType() == original.Type; - - private static Expression ConvertToNullable(Expression expression) - => !expression.Type.IsNullableType() - ? Expression.Convert(expression, expression.Type.MakeNullable()) - : expression; - - private static Expression ConvertToNonNullable(Expression expression) - => expression.Type.IsNullableType() - ? Expression.Convert(expression, expression.Type.UnwrapNullableType()) - : expression; - - private static Expression BindProperty(EntityProjectionExpression entityProjectionExpression, IProperty property) - => entityProjectionExpression.BindProperty(property); - - private static Expression GetSelector(MethodCallExpression methodCallExpression, GroupByShaperExpression groupByShaperExpression) - { - if (methodCallExpression.Arguments.Count == 1) - { - return groupByShaperExpression.ElementSelector; - } - - if (methodCallExpression.Arguments.Count == 2) - { - var selectorLambda = methodCallExpression.Arguments[1].UnwrapLambdaFromQuote(); - return ReplacingExpressionVisitor.Replace( - selectorLambda.Parameters[0], - groupByShaperExpression.ElementSelector, - selectorLambda.Body); - } - - throw new InvalidOperationException(CoreStrings.TranslationFailed(methodCallExpression.Print())); - } - - private Expression GetPredicate(MethodCallExpression methodCallExpression, GroupByShaperExpression groupByShaperExpression) + protected override MemberAssignment VisitMemberAssignment(MemberAssignment memberAssignment) { - if (methodCallExpression.Arguments.Count == 1) + var expression = Visit(memberAssignment.Expression); + if (expression == null) { return null; } - if (methodCallExpression.Arguments.Count == 2) + if (IsConvertedToNullable(expression, memberAssignment.Expression)) { - var selectorLambda = methodCallExpression.Arguments[1].UnwrapLambdaFromQuote(); - return ReplacingExpressionVisitor.Replace( - selectorLambda.Parameters[0], - groupByShaperExpression.ElementSelector, - selectorLambda.Body); + expression = ConvertToNonNullable(expression); } - throw new InvalidOperationException(CoreStrings.TranslationFailed(methodCallExpression.Print())); + return memberAssignment.Update(expression); } protected override Expression VisitMethodCall(MethodCallExpression methodCallExpression) @@ -408,18 +255,14 @@ protected override Expression VisitMethodCall(MethodCallExpression methodCallExp // EF.Property case if (methodCallExpression.TryGetEFPropertyArguments(out var source, out var propertyName)) { - if (TryBindMember(source, MemberIdentity.Create(propertyName), methodCallExpression.Type, out var result)) - { - return result; - } - - throw new InvalidOperationException(CoreStrings.EFPropertyCalledWithWrongPropertyName); + return TryBindMember(Visit(source), MemberIdentity.Create(propertyName), methodCallExpression.Type) + ?? throw new InvalidOperationException(CoreStrings.UnableToTranslateEFPropertyToServer(methodCallExpression.Print())); } // EF Indexer property if (methodCallExpression.TryGetIndexerArguments(_model, out source, out propertyName)) { - return TryBindMember(source, MemberIdentity.Create(propertyName), methodCallExpression.Type, out var result) ? result : null; + return TryBindMember(Visit(source), MemberIdentity.Create(propertyName), methodCallExpression.Type); } // GroupBy Aggregate case @@ -436,7 +279,7 @@ protected override Expression VisitMethodCall(MethodCallExpression methodCallExp case nameof(Enumerable.Min): case nameof(Enumerable.Sum): { - var translation = Translate(GetSelector(methodCallExpression, groupByShaperExpression)); + var translation = Translate(GetSelectorOnGrouping(methodCallExpression, groupByShaperExpression)); if (translation == null) { return null; @@ -468,7 +311,7 @@ MethodInfo GetMethod() case nameof(Enumerable.LongCount): { var countMethod = string.Equals(methodName, nameof(Enumerable.Count)); - var predicate = GetPredicate(methodCallExpression, groupByShaperExpression); + var predicate = GetPredicateOnGrouping(methodCallExpression, groupByShaperExpression); if (predicate == null) { return Expression.Call( @@ -511,58 +354,22 @@ MethodInfo GetMethod() return null; } - subquery.ApplyProjection(); - if (subquery.Projection.Count != 1) - { - return null; - } - - // Unwrap ResultEnumerable - var selectMethod = (MethodCallExpression)subquery.ServerQueryExpression; - var resultEnumerable = (NewExpression)selectMethod.Arguments[0]; - var resultFunc = ((LambdaExpression)resultEnumerable.Arguments[0]).Body; - // New ValueBuffer construct - if (resultFunc is NewExpression newValueBufferExpression) + if (subqueryTranslation.ShaperExpression is EntityShaperExpression entityShaperExpression) { - Expression result; - var innerExpression = ((NewArrayExpression)newValueBufferExpression.Arguments[0]).Expressions[0]; - result = innerExpression is UnaryExpression unaryExpression - && innerExpression.NodeType == ExpressionType.Convert - && innerExpression.Type == typeof(object) - ? unaryExpression.Operand - : innerExpression; - - return result.Type == methodCallExpression.Type - ? result - : Expression.Convert(result, methodCallExpression.Type); + return new EntityReferenceExpression(subqueryTranslation); } - var selector = (LambdaExpression)selectMethod.Arguments[1]; - var readValueExpression = ((NewArrayExpression)((NewExpression)selector.Body).Arguments[0]).Expressions[0]; - if (readValueExpression is UnaryExpression unaryExpression2 - && unaryExpression2.NodeType == ExpressionType.Convert - && unaryExpression2.Type == typeof(object)) +#pragma warning disable IDE0046 // Convert to conditional expression + if (!(subqueryTranslation.ShaperExpression is ProjectionBindingExpression projectionBindingExpression)) +#pragma warning restore IDE0046 // Convert to conditional expression { - readValueExpression = unaryExpression2.Operand; + return null; } - var valueBufferVariable = Expression.Variable(typeof(ValueBuffer)); - var replacedReadExpression = ReplacingExpressionVisitor.Replace( - selector.Parameters[0], - valueBufferVariable, - readValueExpression); - - replacedReadExpression = replacedReadExpression.Type == methodCallExpression.Type - ? replacedReadExpression - : Expression.Convert(replacedReadExpression, methodCallExpression.Type); - - return Expression.Block( - variables: new[] { valueBufferVariable }, - Expression.Assign(valueBufferVariable, resultFunc), - Expression.Condition( - Expression.MakeMemberAccess(valueBufferVariable, _valueBufferIsEmpty), - Expression.Default(methodCallExpression.Type), - replacedReadExpression)); + return ProcessSingleResultScalar(subquery.ServerQueryExpression, + subquery.GetMappedProjection(projectionBindingExpression.ProjectionMember), + subquery.CurrentParameter, + methodCallExpression.Type); } if (methodCallExpression.Method == _likeMethodInfo @@ -636,48 +443,6 @@ MethodInfo GetMethod() return methodCallExpression.Update(@object, arguments); } - private static readonly MemberInfo _valueBufferIsEmpty = typeof(ValueBuffer).GetMember(nameof(ValueBuffer.IsEmpty))[0]; - - protected override Expression VisitTypeBinary(TypeBinaryExpression typeBinaryExpression) - { - Check.NotNull(typeBinaryExpression, nameof(typeBinaryExpression)); - - if (typeBinaryExpression.NodeType == ExpressionType.TypeIs - && Visit(typeBinaryExpression.Expression) is EntityProjectionExpression entityProjectionExpression) - { - var entityType = entityProjectionExpression.EntityType; - - if (entityType.GetAllBaseTypesInclusive().Any(et => et.ClrType == typeBinaryExpression.TypeOperand)) - { - return Expression.Constant(true); - } - - var derivedType = entityType.GetDerivedTypes().SingleOrDefault(et => et.ClrType == typeBinaryExpression.TypeOperand); - if (derivedType != null) - { - var discriminatorProperty = entityType.GetDiscriminatorProperty(); - var boundProperty = BindProperty(entityProjectionExpression, discriminatorProperty); - - var equals = Expression.Equal( - boundProperty, - Expression.Constant(derivedType.GetDiscriminatorValue(), discriminatorProperty.ClrType)); - - foreach (var derivedDerivedType in derivedType.GetDerivedTypes()) - { - equals = Expression.OrElse( - equals, - Expression.Equal( - boundProperty, - Expression.Constant(derivedDerivedType.GetDiscriminatorValue(), discriminatorProperty.ClrType))); - } - - return equals; - } - } - - return Expression.Constant(false); - } - protected override Expression VisitNew(NewExpression newExpression) { Check.NotNull(newExpression, nameof(newExpression)); @@ -726,89 +491,61 @@ protected override Expression VisitNewArray(NewArrayExpression newArrayExpressio return newArrayExpression.Update(newExpressions); } - protected override MemberAssignment VisitMemberAssignment(MemberAssignment memberAssignment) + protected override Expression VisitParameter(ParameterExpression parameterExpression) { - var expression = Visit(memberAssignment.Expression); - if (expression == null) - { - return null; - } + Check.NotNull(parameterExpression, nameof(parameterExpression)); - if (IsConvertedToNullable(expression, memberAssignment.Expression)) + if (parameterExpression.Name.StartsWith(_compiledQueryParameterPrefix, StringComparison.Ordinal)) { - expression = ConvertToNonNullable(expression); + return Expression.Call( + _getParameterValueMethodInfo.MakeGenericMethod(parameterExpression.Type), + QueryCompilationContext.QueryContextParameter, + Expression.Constant(parameterExpression.Name)); } - return memberAssignment.Update(expression); + throw new InvalidOperationException(CoreStrings.TranslationFailed(parameterExpression.Print())); } - protected override Expression VisitExtension(Expression extensionExpression) + protected override Expression VisitTypeBinary(TypeBinaryExpression typeBinaryExpression) { - Check.NotNull(extensionExpression, nameof(extensionExpression)); + Check.NotNull(typeBinaryExpression, nameof(typeBinaryExpression)); - switch (extensionExpression) + if (typeBinaryExpression.NodeType == ExpressionType.TypeIs + && Visit(typeBinaryExpression.Expression) is EntityReferenceExpression entityReferenceExpression) { - case EntityProjectionExpression _: - return extensionExpression; - - case EntityShaperExpression entityShaperExpression: - return Visit(entityShaperExpression.ValueBufferExpression); - - case ProjectionBindingExpression projectionBindingExpression: - return projectionBindingExpression.ProjectionMember != null - ? ((InMemoryQueryExpression)projectionBindingExpression.QueryExpression) - .GetMappedProjection(projectionBindingExpression.ProjectionMember) - : null; - - default: - return null; - } - } - - protected override Expression VisitListInit(ListInitExpression node) - { - Check.NotNull(node, nameof(node)); + var entityType = entityReferenceExpression.EntityType; - return null; - } - - protected override Expression VisitInvocation(InvocationExpression node) - { - Check.NotNull(node, nameof(node)); - - return null; - } + if (entityType.GetAllBaseTypesInclusive().Any(et => et.ClrType == typeBinaryExpression.TypeOperand)) + { + return Expression.Constant(true); + } - protected override Expression VisitLambda(Expression node) - { - Check.NotNull(node, nameof(node)); + var derivedType = entityType.GetDerivedTypes().SingleOrDefault(et => et.ClrType == typeBinaryExpression.TypeOperand); + if (derivedType != null) + { + var discriminatorProperty = entityType.GetDiscriminatorProperty(); + var boundProperty = BindProperty(entityReferenceExpression, discriminatorProperty, discriminatorProperty.ClrType); - return null; - } + var equals = Expression.Equal( + boundProperty, + Expression.Constant(derivedType.GetDiscriminatorValue(), discriminatorProperty.ClrType)); - protected override Expression VisitParameter(ParameterExpression parameterExpression) - { - Check.NotNull(parameterExpression, nameof(parameterExpression)); + foreach (var derivedDerivedType in derivedType.GetDerivedTypes()) + { + equals = Expression.OrElse( + equals, + Expression.Equal( + boundProperty, + Expression.Constant(derivedDerivedType.GetDiscriminatorValue(), discriminatorProperty.ClrType))); + } - if (parameterExpression.Name.StartsWith(_compiledQueryParameterPrefix, StringComparison.Ordinal)) - { - return Expression.Call( - _getParameterValueMethodInfo.MakeGenericMethod(parameterExpression.Type), - QueryCompilationContext.QueryContextParameter, - Expression.Constant(parameterExpression.Name)); + return equals; + } } - throw new InvalidOperationException(CoreStrings.TranslationFailed(parameterExpression.Print())); + return Expression.Constant(false); } - private static readonly MethodInfo _getParameterValueMethodInfo - = typeof(InMemoryExpressionTranslatingExpressionVisitor) - .GetTypeInfo().GetDeclaredMethod(nameof(GetParameterValue)); - - [UsedImplicitly] - private static T GetParameterValue(QueryContext queryContext, string parameterName) - => (T)queryContext.ParameterValues[parameterName]; - protected override Expression VisitUnary(UnaryExpression unaryExpression) { Check.NotNull(unaryExpression, nameof(unaryExpression)); @@ -819,6 +556,14 @@ protected override Expression VisitUnary(UnaryExpression unaryExpression) return null; } + if (newOperand is EntityReferenceExpression entityReferenceExpression + && (unaryExpression.NodeType == ExpressionType.Convert + || unaryExpression.NodeType == ExpressionType.ConvertChecked + || unaryExpression.NodeType == ExpressionType.TypeAs)) + { + return entityReferenceExpression.Convert(unaryExpression.Type); + } + if (unaryExpression.NodeType == ExpressionType.Convert && newOperand.Type == unaryExpression.Type) { @@ -856,9 +601,148 @@ protected override Expression VisitUnary(UnaryExpression unaryExpression) return result; } + private Expression TryBindMember(Expression source, MemberIdentity member, Type type) + { + if (!(source is EntityReferenceExpression entityReferenceExpression)) + { + return null; + } + + var entityType = entityReferenceExpression.EntityType; + + var property = member.MemberInfo != null + ? entityType.FindProperty(member.MemberInfo) + : entityType.FindProperty(member.Name); + + return property != null ? BindProperty(entityReferenceExpression, property, type) : null; + } + + private Expression BindProperty(EntityReferenceExpression entityReferenceExpression, IProperty property, Type type) + { + if (entityReferenceExpression.ParameterEntity != null) + { + var result = ((EntityProjectionExpression)Visit(entityReferenceExpression.ParameterEntity.ValueBufferExpression)) + .BindProperty(property); + + // if the result type change was just nullability change e.g from int to int? + // we want to preserve the new type for null propagation + if (result.Type != type + && !(result.Type.IsNullableType() + && !type.IsNullableType() + && result.Type.UnwrapNullableType() == type)) + { + result = Expression.Convert(result, type); + } + + return result; + } + + if (entityReferenceExpression.SubqueryEntity != null) + { + var entityShaper = (EntityShaperExpression)entityReferenceExpression.SubqueryEntity.ShaperExpression; + var readValueExpression = ((EntityProjectionExpression)Visit(entityShaper.ValueBufferExpression)).BindProperty(property); + var inMemoryQueryExpression = (InMemoryQueryExpression)entityReferenceExpression.SubqueryEntity.QueryExpression; + + return ProcessSingleResultScalar( + inMemoryQueryExpression.ServerQueryExpression, + readValueExpression, + inMemoryQueryExpression.CurrentParameter, + type); + } + + return null; + } + + private static Expression ProcessSingleResultScalar( + Expression serverQuery, Expression readValueExpression, Expression valueBufferParameter, Type type) + { + var singleResult = ((LambdaExpression)((NewExpression)serverQuery).Arguments[0]).Body; + if (readValueExpression is UnaryExpression unaryExpression + && unaryExpression.NodeType == ExpressionType.Convert + && unaryExpression.Type == typeof(object)) + { + readValueExpression = unaryExpression.Operand; + } + + var valueBufferVariable = Expression.Variable(typeof(ValueBuffer)); + var replacedReadExpression = ReplacingExpressionVisitor.Replace( + valueBufferParameter, + valueBufferVariable, + readValueExpression); + + replacedReadExpression = replacedReadExpression.Type == type + ? replacedReadExpression + : Expression.Convert(replacedReadExpression, type); + + return Expression.Block( + variables: new[] { valueBufferVariable }, + Expression.Assign(valueBufferVariable, singleResult), + Expression.Condition( + Expression.MakeMemberAccess(valueBufferVariable, _valueBufferIsEmpty), + Expression.Default(type), + replacedReadExpression)); + } + + [UsedImplicitly] + private static T GetParameterValue(QueryContext queryContext, string parameterName) + => (T)queryContext.ParameterValues[parameterName]; + + private static bool IsConvertedToNullable(Expression result, Expression original) + => result.Type.IsNullableType() + && !original.Type.IsNullableType() + && result.Type.UnwrapNullableType() == original.Type; + + private static Expression ConvertToNullable(Expression expression) + => !expression.Type.IsNullableType() + ? Expression.Convert(expression, expression.Type.MakeNullable()) + : expression; + + private static Expression ConvertToNonNullable(Expression expression) + => expression.Type.IsNullableType() + ? Expression.Convert(expression, expression.Type.UnwrapNullableType()) + : expression; + + private static Expression GetSelectorOnGrouping(MethodCallExpression methodCallExpression, GroupByShaperExpression groupByShaperExpression) + { + if (methodCallExpression.Arguments.Count == 1) + { + return groupByShaperExpression.ElementSelector; + } + + if (methodCallExpression.Arguments.Count == 2) + { + var selectorLambda = methodCallExpression.Arguments[1].UnwrapLambdaFromQuote(); + return ReplacingExpressionVisitor.Replace( + selectorLambda.Parameters[0], + groupByShaperExpression.ElementSelector, + selectorLambda.Body); + } + + throw new InvalidOperationException(CoreStrings.TranslationFailed(methodCallExpression.Print())); + } + + private Expression GetPredicateOnGrouping(MethodCallExpression methodCallExpression, GroupByShaperExpression groupByShaperExpression) + { + if (methodCallExpression.Arguments.Count == 1) + { + return null; + } + + if (methodCallExpression.Arguments.Count == 2) + { + var selectorLambda = methodCallExpression.Arguments[1].UnwrapLambdaFromQuote(); + return ReplacingExpressionVisitor.Replace( + selectorLambda.Parameters[0], + groupByShaperExpression.ElementSelector, + selectorLambda.Body); + } + + throw new InvalidOperationException(CoreStrings.TranslationFailed(methodCallExpression.Print())); + } + [DebuggerStepThrough] - private bool TranslationFailed(Expression original, Expression translation) - => original != null && (translation == null || translation is EntityProjectionExpression); + private static bool TranslationFailed(Expression original, Expression translation) + => original != null && (translation == null || translation is EntityReferenceExpression); private static bool InMemoryLike(string matchExpression, string pattern, string escapeCharacter) { @@ -940,5 +824,146 @@ var regexPattern RegexOptions.IgnoreCase | RegexOptions.Singleline, _regexTimeout); } + + private sealed class EntityReferenceFindingExpressionVisitor : ExpressionVisitor + { + private bool _found; + + public bool Find(Expression expression) + { + _found = false; + + Visit(expression); + + return _found; + } + + public override Expression Visit(Expression expression) + { + if (_found) + { + return expression; + } + + if (expression is EntityReferenceExpression) + { + _found = true; + return expression; + } + + return base.Visit(expression); + } + } + + private sealed class PropertyFindingExpressionVisitor : ExpressionVisitor + { + private readonly IModel _model; + private IProperty _property; + + public PropertyFindingExpressionVisitor(IModel model) + { + _model = model; + } + + public IProperty Find(Expression expression) + { + Visit(expression); + + return _property; + } + + protected override Expression VisitMember(MemberExpression memberExpression) + { + var entityType = FindEntityType(memberExpression.Expression); + if (entityType != null) + { + _property = GetProperty(entityType, MemberIdentity.Create(memberExpression.Member)); + } + + return memberExpression; + } + + protected override Expression VisitMethodCall(MethodCallExpression methodCallExpression) + { + if (methodCallExpression.TryGetEFPropertyArguments(out var source, out var propertyName) + || methodCallExpression.TryGetIndexerArguments(_model, out source, out propertyName)) + { + var entityType = FindEntityType(source); + if (entityType != null) + { + _property = GetProperty(entityType, MemberIdentity.Create(propertyName)); + } + } + + return methodCallExpression; + } + + private static IProperty GetProperty(IEntityType entityType, MemberIdentity memberIdentity) + => memberIdentity.MemberInfo != null + ? entityType.FindProperty(memberIdentity.MemberInfo) + : entityType.FindProperty(memberIdentity.Name); + + private static IEntityType FindEntityType(Expression source) + { + source = source.UnwrapTypeConversion(out var convertedType); + + if (source is EntityShaperExpression entityShaperExpression) + { + var entityType = entityShaperExpression.EntityType; + if (convertedType != null) + { + entityType = entityType.GetRootType().GetDerivedTypesInclusive() + .FirstOrDefault(et => et.ClrType == convertedType); + } + + return entityType; + } + + return null; + } + } + + private sealed class EntityReferenceExpression : Expression + { + + public EntityReferenceExpression(EntityShaperExpression parameter) + { + ParameterEntity = parameter; + EntityType = parameter.EntityType; + } + + public EntityReferenceExpression(ShapedQueryExpression subquery) + { + SubqueryEntity = subquery; + EntityType = ((EntityShaperExpression)subquery.ShaperExpression).EntityType; + } + + private EntityReferenceExpression(EntityReferenceExpression entityReferenceExpression, IEntityType entityType) + { + ParameterEntity = entityReferenceExpression.ParameterEntity; + SubqueryEntity = entityReferenceExpression.SubqueryEntity; + EntityType = entityType; + } + + public EntityShaperExpression ParameterEntity { get; } + public ShapedQueryExpression SubqueryEntity { get; } + public IEntityType EntityType { get; } + + public override Type Type => EntityType.ClrType; + public override ExpressionType NodeType => ExpressionType.Extension; + + public Expression Convert(Type type) + { + if (type == typeof(object) // Ignore object conversion + || type.IsAssignableFrom(Type)) // Ignore casting to base type/interface + { + return this; + } + + var derivedEntityType = EntityType.GetDerivedTypes().FirstOrDefault(et => et.ClrType == type); + + return derivedEntityType == null ? null : new EntityReferenceExpression(this, derivedEntityType); + } + } } } diff --git a/src/EFCore.Relational/Query/RelationalSqlTranslatingExpressionVisitor.cs b/src/EFCore.Relational/Query/RelationalSqlTranslatingExpressionVisitor.cs index 4aa6ef53eb3..064fa5fb758 100644 --- a/src/EFCore.Relational/Query/RelationalSqlTranslatingExpressionVisitor.cs +++ b/src/EFCore.Relational/Query/RelationalSqlTranslatingExpressionVisitor.cs @@ -19,11 +19,10 @@ namespace Microsoft.EntityFrameworkCore.Query public class RelationalSqlTranslatingExpressionVisitor : ExpressionVisitor { private readonly IModel _model; + private readonly RelationalSqlTranslatingExpressionVisitorDependencies _dependencies; private readonly QueryableMethodTranslatingExpressionVisitor _queryableMethodTranslatingExpressionVisitor; private readonly SqlTypeMappingVerifyingExpressionVisitor _sqlTypeMappingVerifyingExpressionVisitor; - protected virtual ISqlExpressionFactory SqlExpressionFactory { get; } - public RelationalSqlTranslatingExpressionVisitor( [NotNull] RelationalSqlTranslatingExpressionVisitorDependencies dependencies, [NotNull] QueryCompilationContext queryCompilationContext, @@ -33,7 +32,7 @@ public RelationalSqlTranslatingExpressionVisitor( Check.NotNull(queryCompilationContext, nameof(queryCompilationContext)); Check.NotNull(queryableMethodTranslatingExpressionVisitor, nameof(queryableMethodTranslatingExpressionVisitor)); - Dependencies = dependencies; + _dependencies = dependencies; SqlExpressionFactory = dependencies.SqlExpressionFactory; _model = queryCompilationContext.Model; @@ -41,7 +40,7 @@ public RelationalSqlTranslatingExpressionVisitor( _sqlTypeMappingVerifyingExpressionVisitor = new SqlTypeMappingVerifyingExpressionVisitor(); } - protected virtual RelationalSqlTranslatingExpressionVisitorDependencies Dependencies { get; } + protected virtual ISqlExpressionFactory SqlExpressionFactory { get; } public virtual SqlExpression Translate([NotNull] Expression expression) { @@ -224,144 +223,99 @@ public virtual SqlExpression TranslateSum([NotNull] Expression expression) sqlExpression.TypeMapping); } - private sealed class SqlTypeMappingVerifyingExpressionVisitor : ExpressionVisitor + protected override Expression VisitBinary(BinaryExpression binaryExpression) { - protected override Expression VisitExtension(Expression node) - { - Check.NotNull(node, nameof(node)); - - if (node is SqlExpression sqlExpression - && !(node is SqlFragmentExpression) - && !(node is SqlFunctionExpression sqlFunctionExpression - && sqlFunctionExpression.Type.IsQueryableType())) - { - if (sqlExpression.TypeMapping == null) - { - throw new InvalidOperationException(CoreStrings.NullTypeMappingInSqlTree); - } - } + Check.NotNull(binaryExpression, nameof(binaryExpression)); - return base.VisitExtension(node); + if (binaryExpression.Left.Type == typeof(AnonymousObject) + && binaryExpression.NodeType == ExpressionType.Equal) + { + return Visit(ConvertAnonymousObjectEqualityComparison(binaryExpression)); } - } - protected override Expression VisitMember(MemberExpression memberExpression) - { - Check.NotNull(memberExpression, nameof(memberExpression)); + var uncheckedNodeTypeVariant = binaryExpression.NodeType switch + { + ExpressionType.AddChecked => ExpressionType.Add, + ExpressionType.SubtractChecked => ExpressionType.Subtract, + ExpressionType.MultiplyChecked => ExpressionType.Multiply, + _ => binaryExpression.NodeType + }; - return TryBindMember(memberExpression.Expression, MemberIdentity.Create(memberExpression.Member), out var result) - ? result - : TranslationFailed(memberExpression.Expression, base.Visit(memberExpression.Expression), out var sqlInnerExpression) - ? null - : Dependencies.MemberTranslatorProvider.Translate(sqlInnerExpression, memberExpression.Member, memberExpression.Type); + var left = TryRemoveImplicitConvert(binaryExpression.Left); + var right = TryRemoveImplicitConvert(binaryExpression.Right); + + return TranslationFailed(binaryExpression.Left, Visit(left), out var sqlLeft) + || TranslationFailed(binaryExpression.Right, Visit(right), out var sqlRight) + ? null + : uncheckedNodeTypeVariant == ExpressionType.Coalesce + ? SqlExpressionFactory.Coalesce(sqlLeft, sqlRight) + : (Expression)SqlExpressionFactory.MakeBinary( + uncheckedNodeTypeVariant, + sqlLeft, + sqlRight, + null); } - private bool TryBindMember(Expression source, MemberIdentity member, out Expression expression) + protected override Expression VisitConditional(ConditionalExpression conditionalExpression) { - source = source.UnwrapTypeConversion(out var convertedType); - expression = null; - if (source is EntityShaperExpression entityShaperExpression) - { - var entityType = entityShaperExpression.EntityType; - if (convertedType != null) - { - entityType = entityType.GetRootType().GetDerivedTypesInclusive() - .FirstOrDefault(et => et.ClrType == convertedType); - if (entityType == null) - { - return false; - } - } + Check.NotNull(conditionalExpression, nameof(conditionalExpression)); - var property = member.MemberInfo != null - ? entityType.FindProperty(member.MemberInfo) - : entityType.FindProperty(member.Name); - if (property != null - && Visit(entityShaperExpression.ValueBufferExpression) is EntityProjectionExpression entityProjectionExpression - && (entityProjectionExpression.EntityType.IsAssignableFrom(property.DeclaringEntityType) - || property.DeclaringEntityType.IsAssignableFrom(entityProjectionExpression.EntityType))) - { - expression = entityProjectionExpression.BindProperty(property); - return true; - } - } + var test = Visit(conditionalExpression.Test); + var ifTrue = Visit(conditionalExpression.IfTrue); + var ifFalse = Visit(conditionalExpression.IfFalse); - return false; + return TranslationFailed(conditionalExpression.Test, test, out var sqlTest) + || TranslationFailed(conditionalExpression.IfTrue, ifTrue, out var sqlIfTrue) + || TranslationFailed(conditionalExpression.IfFalse, ifFalse, out var sqlIfFalse) + ? null + : SqlExpressionFactory.Case(new[] { new CaseWhenClause(sqlTest, sqlIfTrue) }, sqlIfFalse); } - protected override Expression VisitTypeBinary(TypeBinaryExpression typeBinaryExpression) + protected override Expression VisitConstant(ConstantExpression constantExpression) + => new SqlConstantExpression(Check.NotNull(constantExpression, nameof(constantExpression)), null); + + protected override Expression VisitExtension(Expression extensionExpression) { - Check.NotNull(typeBinaryExpression, nameof(typeBinaryExpression)); + Check.NotNull(extensionExpression, nameof(extensionExpression)); - if (typeBinaryExpression.NodeType == ExpressionType.TypeIs - && Visit(typeBinaryExpression.Expression) is EntityProjectionExpression entityProjectionExpression) + switch (extensionExpression) { - var entityType = entityProjectionExpression.EntityType; - if (entityType.GetAllBaseTypesInclusive().Any(et => et.ClrType == typeBinaryExpression.TypeOperand)) - { - return SqlExpressionFactory.Constant(true); - } + case EntityProjectionExpression _: + case SqlExpression _: + return extensionExpression; - var derivedType = entityType.GetDerivedTypes().SingleOrDefault(et => et.ClrType == typeBinaryExpression.TypeOperand); - if (derivedType != null) - { - var concreteEntityTypes = derivedType.GetConcreteDerivedTypesInclusive().ToList(); - var discriminatorColumn = entityProjectionExpression.BindProperty(entityType.GetDiscriminatorProperty()); + case EntityShaperExpression entityShaperExpression: + return new EntityReferenceExpression(entityShaperExpression); - return concreteEntityTypes.Count == 1 - ? SqlExpressionFactory.Equal( - discriminatorColumn, - SqlExpressionFactory.Constant(concreteEntityTypes[0].GetDiscriminatorValue())) - : (Expression)SqlExpressionFactory.In( - discriminatorColumn, - SqlExpressionFactory.Constant(concreteEntityTypes.Select(et => et.GetDiscriminatorValue()).ToList()), - negated: false); - } + case ProjectionBindingExpression projectionBindingExpression: + return projectionBindingExpression.ProjectionMember != null + ? ((SelectExpression)projectionBindingExpression.QueryExpression) + .GetMappedProjection(projectionBindingExpression.ProjectionMember) + : null; - return SqlExpressionFactory.Constant(false); + default: + return null; } - - return null; } - private Expression GetSelector(MethodCallExpression methodCallExpression, GroupByShaperExpression groupByShaperExpression) - { - if (methodCallExpression.Arguments.Count == 1) - { - return groupByShaperExpression.ElementSelector; - } - - if (methodCallExpression.Arguments.Count == 2) - { - var selectorLambda = methodCallExpression.Arguments[1].UnwrapLambdaFromQuote(); - return ReplacingExpressionVisitor.Replace( - selectorLambda.Parameters[0], - groupByShaperExpression.ElementSelector, - selectorLambda.Body); - } - - throw new InvalidOperationException(CoreStrings.TranslationFailed(methodCallExpression.Print())); - } + protected override Expression VisitInvocation(InvocationExpression node) => null; + protected override Expression VisitLambda(Expression node) => null; + protected override Expression VisitListInit(ListInitExpression node) => null; - private Expression GetPredicate(MethodCallExpression methodCallExpression, GroupByShaperExpression groupByShaperExpression) + protected override Expression VisitMember(MemberExpression memberExpression) { - if (methodCallExpression.Arguments.Count == 1) - { - return null; - } + Check.NotNull(memberExpression, nameof(memberExpression)); - if (methodCallExpression.Arguments.Count == 2) - { - var selectorLambda = methodCallExpression.Arguments[1].UnwrapLambdaFromQuote(); - return ReplacingExpressionVisitor.Replace( - selectorLambda.Parameters[0], - groupByShaperExpression.ElementSelector, - selectorLambda.Body); - } + var innerExpression = Visit(memberExpression.Expression); - throw new InvalidOperationException(CoreStrings.TranslationFailed(methodCallExpression.Print())); + return TryBindMember(innerExpression, MemberIdentity.Create(memberExpression.Member)) + ?? (TranslationFailed(memberExpression.Expression, base.Visit(memberExpression.Expression), out var sqlInnerExpression) + ? null + : _dependencies.MemberTranslatorProvider.Translate(sqlInnerExpression, memberExpression.Member, memberExpression.Type)); } + protected override Expression VisitMemberInit(MemberInitExpression node) => GetConstantOrNull(Check.NotNull(node, nameof(node))); + protected override Expression VisitMethodCall(MethodCallExpression methodCallExpression) { Check.NotNull(methodCallExpression, nameof(methodCallExpression)); @@ -369,18 +323,14 @@ protected override Expression VisitMethodCall(MethodCallExpression methodCallExp // EF.Property case if (methodCallExpression.TryGetEFPropertyArguments(out var source, out var propertyName)) { - if (TryBindMember(source, MemberIdentity.Create(propertyName), out var result)) - { - return result; - } - - throw new InvalidOperationException(CoreStrings.EFPropertyCalledWithWrongPropertyName); + return TryBindMember(Visit(source), MemberIdentity.Create(propertyName)) + ?? throw new InvalidOperationException(CoreStrings.UnableToTranslateEFPropertyToServer(methodCallExpression.Print())); } // EF Indexer property if (methodCallExpression.TryGetIndexerArguments(_model, out source, out propertyName)) { - return TryBindMember(source, MemberIdentity.Create(propertyName), out var result) ? result : null; + return TryBindMember(Visit(source), MemberIdentity.Create(propertyName)); } // GroupBy Aggregate case @@ -391,12 +341,12 @@ protected override Expression VisitMethodCall(MethodCallExpression methodCallExp { var translatedAggregate = methodCallExpression.Method.Name switch { - nameof(Enumerable.Average) => TranslateAverage(GetSelector(methodCallExpression, groupByShaperExpression)), - nameof(Enumerable.Count) => TranslateCount(GetPredicate(methodCallExpression, groupByShaperExpression)), - nameof(Enumerable.LongCount) => TranslateLongCount(GetPredicate(methodCallExpression, groupByShaperExpression)), - nameof(Enumerable.Max) => TranslateMax(GetSelector(methodCallExpression, groupByShaperExpression)), - nameof(Enumerable.Min) => TranslateMin(GetSelector(methodCallExpression, groupByShaperExpression)), - nameof(Enumerable.Sum) => TranslateSum(GetSelector(methodCallExpression, groupByShaperExpression)), + nameof(Enumerable.Average) => TranslateAverage(GetSelectorOnGrouping(methodCallExpression, groupByShaperExpression)), + nameof(Enumerable.Count) => TranslateCount(GetPredicateOnGrouping(methodCallExpression, groupByShaperExpression)), + nameof(Enumerable.LongCount) => TranslateLongCount(GetPredicateOnGrouping(methodCallExpression, groupByShaperExpression)), + nameof(Enumerable.Max) => TranslateMax(GetSelectorOnGrouping(methodCallExpression, groupByShaperExpression)), + nameof(Enumerable.Min) => TranslateMin(GetSelectorOnGrouping(methodCallExpression, groupByShaperExpression)), + nameof(Enumerable.Sum) => TranslateSum(GetSelectorOnGrouping(methodCallExpression, groupByShaperExpression)), _ => null }; @@ -434,8 +384,10 @@ static bool IsAggregateResultWithCustomShaper(MethodInfo method) return null; } - var subquery = (SelectExpression)subqueryTranslation.QueryExpression; - subquery.ApplyProjection(); + if (subqueryTranslation.ShaperExpression is EntityShaperExpression entityShaperExpression) + { + return new EntityReferenceExpression(subqueryTranslation); + } if (!(subqueryTranslation.ShaperExpression is ProjectionBindingExpression || IsAggregateResultWithCustomShaper(methodCallExpression.Method))) @@ -443,6 +395,9 @@ static bool IsAggregateResultWithCustomShaper(MethodInfo method) return null; } + var subquery = (SelectExpression)subqueryTranslation.QueryExpression; + subquery.ApplyProjection(); + #pragma warning disable IDE0046 // Convert to conditional expression if (subquery.Tables.Count == 0 #pragma warning restore IDE0046 // Convert to conditional expression @@ -476,7 +431,183 @@ static bool IsAggregateResultWithCustomShaper(MethodInfo method) arguments[i] = sqlArgument; } - return Dependencies.MethodCallTranslatorProvider.Translate(_model, sqlObject, methodCallExpression.Method, arguments); + return _dependencies.MethodCallTranslatorProvider.Translate(_model, sqlObject, methodCallExpression.Method, arguments); + } + + protected override Expression VisitNew(NewExpression node) => GetConstantOrNull(Check.NotNull(node, nameof(node))); + + protected override Expression VisitNewArray(NewArrayExpression node) => null; + + protected override Expression VisitParameter(ParameterExpression parameterExpression) + => new SqlParameterExpression(Check.NotNull(parameterExpression, nameof(parameterExpression)), null); + + protected override Expression VisitTypeBinary(TypeBinaryExpression typeBinaryExpression) + { + Check.NotNull(typeBinaryExpression, nameof(typeBinaryExpression)); + + var innerExpression = Visit(typeBinaryExpression.Expression); + + if (typeBinaryExpression.NodeType == ExpressionType.TypeIs + && innerExpression is EntityReferenceExpression entityReferenceExpression) + { + var entityType = entityReferenceExpression.EntityType; + if (entityType.GetAllBaseTypesInclusive().Any(et => et.ClrType == typeBinaryExpression.TypeOperand)) + { + return SqlExpressionFactory.Constant(true); + } + + var derivedType = entityType.GetDerivedTypes().SingleOrDefault(et => et.ClrType == typeBinaryExpression.TypeOperand); + if (derivedType != null) + { + var concreteEntityTypes = derivedType.GetConcreteDerivedTypesInclusive().ToList(); + var discriminatorColumn = BindProperty(entityReferenceExpression, entityType.GetDiscriminatorProperty()); + + return concreteEntityTypes.Count == 1 + ? SqlExpressionFactory.Equal( + discriminatorColumn, + SqlExpressionFactory.Constant(concreteEntityTypes[0].GetDiscriminatorValue())) + : (Expression)SqlExpressionFactory.In( + discriminatorColumn, + SqlExpressionFactory.Constant(concreteEntityTypes.Select(et => et.GetDiscriminatorValue()).ToList()), + negated: false); + } + + return SqlExpressionFactory.Constant(false); + } + + return null; + } + + protected override Expression VisitUnary(UnaryExpression unaryExpression) + { + Check.NotNull(unaryExpression, nameof(unaryExpression)); + + var operand = Visit(unaryExpression.Operand); + + if (operand is EntityReferenceExpression entityReferenceExpression + && (unaryExpression.NodeType == ExpressionType.Convert + || unaryExpression.NodeType == ExpressionType.ConvertChecked + || unaryExpression.NodeType == ExpressionType.TypeAs)) + { + return entityReferenceExpression.Convert(unaryExpression.Type); + } + + if (TranslationFailed(unaryExpression.Operand, operand, out var sqlOperand)) + { + return null; + } + + switch (unaryExpression.NodeType) + { + case ExpressionType.Not: + return SqlExpressionFactory.Not(sqlOperand); + + case ExpressionType.Negate: + return SqlExpressionFactory.Negate(sqlOperand); + + case ExpressionType.Convert: + case ExpressionType.ConvertChecked: + case ExpressionType.TypeAs: + // Object convert needs to be converted to explicit cast when mismatching types + if (operand.Type.IsInterface + && unaryExpression.Type.GetInterfaces().Any(e => e == operand.Type) + || unaryExpression.Type.UnwrapNullableType() == operand.Type.UnwrapNullableType() + || unaryExpression.Type.UnwrapNullableType() == typeof(Enum)) + { + return sqlOperand; + } + + // Introduce explicit cast only if the target type is mapped else we need to client eval + if (unaryExpression.Type == typeof(object) + || SqlExpressionFactory.FindMapping(unaryExpression.Type) != null) + { + sqlOperand = SqlExpressionFactory.ApplyDefaultTypeMapping(sqlOperand); + + return SqlExpressionFactory.Convert(sqlOperand, unaryExpression.Type); + } + + break; + + case ExpressionType.Quote: + return operand; + } + + return null; + } + + private Expression TryBindMember(Expression source, MemberIdentity member) + { + if (!(source is EntityReferenceExpression entityReferenceExpression)) + { + return null; + } + + var entityType = entityReferenceExpression.EntityType; + var property = member.MemberInfo != null + ? entityType.FindProperty(member.MemberInfo) + : entityType.FindProperty(member.Name); + + return property != null ? BindProperty(entityReferenceExpression, property) : null; + } + + private SqlExpression BindProperty(EntityReferenceExpression entityReferenceExpression, IProperty property) + { + if (entityReferenceExpression.ParameterEntity != null) + { + return ((EntityProjectionExpression)Visit(entityReferenceExpression.ParameterEntity.ValueBufferExpression)).BindProperty(property); + } + + if (entityReferenceExpression.SubqueryEntity != null) + { + var entityShaper = (EntityShaperExpression)entityReferenceExpression.SubqueryEntity.ShaperExpression; + var innerProjection = ((EntityProjectionExpression)Visit(entityShaper.ValueBufferExpression)).BindProperty(property); + var subSelectExpression = (SelectExpression)entityReferenceExpression.SubqueryEntity.QueryExpression; + subSelectExpression.AddToProjection(innerProjection); + + return new ScalarSubqueryExpression(subSelectExpression); + } + + return null; + } + + private static Expression GetSelectorOnGrouping( + MethodCallExpression methodCallExpression, GroupByShaperExpression groupByShaperExpression) + { + if (methodCallExpression.Arguments.Count == 1) + { + return groupByShaperExpression.ElementSelector; + } + + if (methodCallExpression.Arguments.Count == 2) + { + var selectorLambda = methodCallExpression.Arguments[1].UnwrapLambdaFromQuote(); + return ReplacingExpressionVisitor.Replace( + selectorLambda.Parameters[0], + groupByShaperExpression.ElementSelector, + selectorLambda.Body); + } + + throw new InvalidOperationException(CoreStrings.TranslationFailed(methodCallExpression.Print())); + } + + private static Expression GetPredicateOnGrouping( + MethodCallExpression methodCallExpression, GroupByShaperExpression groupByShaperExpression) + { + if (methodCallExpression.Arguments.Count == 1) + { + return null; + } + + if (methodCallExpression.Arguments.Count == 2) + { + var selectorLambda = methodCallExpression.Arguments[1].UnwrapLambdaFromQuote(); + return ReplacingExpressionVisitor.Replace( + selectorLambda.Parameters[0], + groupByShaperExpression.ElementSelector, + selectorLambda.Body); + } + + throw new InvalidOperationException(CoreStrings.TranslationFailed(methodCallExpression.Print())); } private static Expression TryRemoveImplicitConvert(Expression expression) @@ -510,7 +641,7 @@ private static Expression TryRemoveImplicitConvert(Expression expression) return expression; } - private Expression ConvertAnonymousObjectEqualityComparison(BinaryExpression binaryExpression) + private static Expression ConvertAnonymousObjectEqualityComparison(BinaryExpression binaryExpression) { var leftExpressions = ((NewArrayExpression)((NewExpression)binaryExpression.Left).Arguments[0]).Expressions; var rightExpressions = ((NewArrayExpression)((NewExpression)binaryExpression.Right).Arguments[0]).Expressions; @@ -542,49 +673,14 @@ static Expression RemoveObjectConvert(Expression expression) : expression; } - protected override Expression VisitBinary(BinaryExpression binaryExpression) - { - Check.NotNull(binaryExpression, nameof(binaryExpression)); - - if (binaryExpression.Left.Type == typeof(AnonymousObject) - && binaryExpression.NodeType == ExpressionType.Equal) - { - return Visit(ConvertAnonymousObjectEqualityComparison(binaryExpression)); - } - - var uncheckedNodeTypeVariant = binaryExpression.NodeType switch - { - ExpressionType.AddChecked => ExpressionType.Add, - ExpressionType.SubtractChecked => ExpressionType.Subtract, - ExpressionType.MultiplyChecked => ExpressionType.Multiply, - _ => binaryExpression.NodeType - }; - - var left = TryRemoveImplicitConvert(binaryExpression.Left); - var right = TryRemoveImplicitConvert(binaryExpression.Right); - - return TranslationFailed(binaryExpression.Left, Visit(left), out var sqlLeft) - || TranslationFailed(binaryExpression.Right, Visit(right), out var sqlRight) - ? null - : uncheckedNodeTypeVariant == ExpressionType.Coalesce - ? SqlExpressionFactory.Coalesce(sqlLeft, sqlRight) - : (Expression)SqlExpressionFactory.MakeBinary( - uncheckedNodeTypeVariant, - sqlLeft, - sqlRight, - null); - } - - private SqlConstantExpression GetConstantOrNull(Expression expression) - { - if (CanEvaluate(expression)) - { - var value = Expression.Lambda>(Expression.Convert(expression, typeof(object))).Compile().Invoke(); - return new SqlConstantExpression(Expression.Constant(value, expression.Type), null); - } - - return null; - } + private static SqlConstantExpression GetConstantOrNull(Expression expression) + => CanEvaluate(expression) + ? new SqlConstantExpression( + Expression.Constant( + Expression.Lambda>(Expression.Convert(expression, typeof(object))).Compile().Invoke(), + expression.Type), + null) + : null; private static bool CanEvaluate(Expression expression) { @@ -608,162 +704,81 @@ private static bool CanEvaluate(Expression expression) } } - protected override Expression VisitNew(NewExpression node) - { - Check.NotNull(node, nameof(node)); - - return GetConstantOrNull(node); - } - - protected override Expression VisitMemberInit(MemberInitExpression node) - { - Check.NotNull(node, nameof(node)); - - return GetConstantOrNull(node); - } - - protected override Expression VisitNewArray(NewArrayExpression node) - { - Check.NotNull(node, nameof(node)); - - return null; - } - - protected override Expression VisitListInit(ListInitExpression node) - { - Check.NotNull(node, nameof(node)); - - return null; - } - - protected override Expression VisitInvocation(InvocationExpression node) - { - Check.NotNull(node, nameof(node)); - - return null; - } - - protected override Expression VisitLambda(Expression node) - { - Check.NotNull(node, nameof(node)); - - return node.Body != null ? Visit(node.Body) : null; - } - - protected override Expression VisitConstant(ConstantExpression constantExpression) - { - Check.NotNull(constantExpression, nameof(constantExpression)); - - return new SqlConstantExpression(constantExpression, null); - } - - protected override Expression VisitParameter(ParameterExpression parameterExpression) - { - Check.NotNull(parameterExpression, nameof(parameterExpression)); - - return new SqlParameterExpression(parameterExpression, null); - } - - protected override Expression VisitExtension(Expression extensionExpression) + [DebuggerStepThrough] + private static bool TranslationFailed(Expression original, Expression translation, out SqlExpression castTranslation) { - Check.NotNull(extensionExpression, nameof(extensionExpression)); - - switch (extensionExpression) + if (original != null + && !(translation is SqlExpression)) { - case EntityProjectionExpression _: - case SqlExpression _: - return extensionExpression; - - case EntityShaperExpression entityShaperExpression: - return Visit(entityShaperExpression.ValueBufferExpression); - - case ProjectionBindingExpression projectionBindingExpression: - return projectionBindingExpression.ProjectionMember != null - ? ((SelectExpression)projectionBindingExpression.QueryExpression) - .GetMappedProjection(projectionBindingExpression.ProjectionMember) - : null; - - default: - return null; + castTranslation = null; + return true; } - } - - protected override Expression VisitConditional(ConditionalExpression conditionalExpression) - { - Check.NotNull(conditionalExpression, nameof(conditionalExpression)); - - var test = Visit(conditionalExpression.Test); - var ifTrue = Visit(conditionalExpression.IfTrue); - var ifFalse = Visit(conditionalExpression.IfFalse); - return TranslationFailed(conditionalExpression.Test, test, out var sqlTest) - || TranslationFailed(conditionalExpression.IfTrue, ifTrue, out var sqlIfTrue) - || TranslationFailed(conditionalExpression.IfFalse, ifFalse, out var sqlIfFalse) - ? null - : SqlExpressionFactory.Case(new[] { new CaseWhenClause(sqlTest, sqlIfTrue) }, sqlIfFalse); + castTranslation = translation as SqlExpression; + return false; } - protected override Expression VisitUnary(UnaryExpression unaryExpression) + private sealed class EntityReferenceExpression : Expression { - Check.NotNull(unaryExpression, nameof(unaryExpression)); - - var operand = Visit(unaryExpression.Operand); - - if (TranslationFailed(unaryExpression.Operand, operand, out var sqlOperand)) + public EntityReferenceExpression(EntityShaperExpression parameter) { - return null; + ParameterEntity = parameter; + EntityType = parameter.EntityType; } - switch (unaryExpression.NodeType) + public EntityReferenceExpression(ShapedQueryExpression subquery) { - case ExpressionType.Not: - return SqlExpressionFactory.Not(sqlOperand); + SubqueryEntity = subquery; + EntityType = ((EntityShaperExpression)subquery.ShaperExpression).EntityType; + } - case ExpressionType.Negate: - return SqlExpressionFactory.Negate(sqlOperand); + private EntityReferenceExpression(EntityReferenceExpression entityReferenceExpression, IEntityType entityType) + { + ParameterEntity = entityReferenceExpression.ParameterEntity; + SubqueryEntity = entityReferenceExpression.SubqueryEntity; + EntityType = entityType; + } - case ExpressionType.Convert: - case ExpressionType.ConvertChecked: - case ExpressionType.TypeAs: - // Object convert needs to be converted to explicit cast when mismatching types - if (operand.Type.IsInterface - && unaryExpression.Type.GetInterfaces().Any(e => e == operand.Type) - || unaryExpression.Type.UnwrapNullableType() == operand.Type.UnwrapNullableType() - || unaryExpression.Type.UnwrapNullableType() == typeof(Enum)) - { - return sqlOperand; - } + public EntityShaperExpression ParameterEntity { get; } + public ShapedQueryExpression SubqueryEntity { get; } + public IEntityType EntityType { get; } - // Introduce explicit cast only if the target type is mapped else we need to client eval - if (unaryExpression.Type == typeof(object) - || SqlExpressionFactory.FindMapping(unaryExpression.Type) != null) - { - sqlOperand = SqlExpressionFactory.ApplyDefaultTypeMapping(sqlOperand); + public override Type Type => EntityType.ClrType; + public override ExpressionType NodeType => ExpressionType.Extension; - return SqlExpressionFactory.Convert(sqlOperand, unaryExpression.Type); - } + public Expression Convert(Type type) + { + if (type == typeof(object) // Ignore object conversion + || type.IsAssignableFrom(Type)) // Ignore casting to base type/interface + { + return this; + } - break; + var derivedEntityType = EntityType.GetDerivedTypes().FirstOrDefault(et => et.ClrType == type); - case ExpressionType.Quote: - return operand; + return derivedEntityType == null ? null : new EntityReferenceExpression(this, derivedEntityType); } - - return null; } - [DebuggerStepThrough] - private bool TranslationFailed(Expression original, Expression translation, out SqlExpression castTranslation) + private sealed class SqlTypeMappingVerifyingExpressionVisitor : ExpressionVisitor { - if (original != null - && !(translation is SqlExpression)) + protected override Expression VisitExtension(Expression node) { - castTranslation = null; - return true; - } + Check.NotNull(node, nameof(node)); - castTranslation = translation as SqlExpression; - return false; + if (node is SqlExpression sqlExpression + && !(node is SqlFragmentExpression) + && !(node is SqlFunctionExpression sqlFunctionExpression + && sqlFunctionExpression.Type.IsQueryableType())) + { + if (sqlExpression.TypeMapping == null) + { + throw new InvalidOperationException(CoreStrings.NullTypeMappingInSqlTree); + } + } + + return base.VisitExtension(node); + } } } } diff --git a/src/EFCore/Properties/CoreStrings.Designer.cs b/src/EFCore/Properties/CoreStrings.Designer.cs index 465ea5761da..7b01f06270d 100644 --- a/src/EFCore/Properties/CoreStrings.Designer.cs +++ b/src/EFCore/Properties/CoreStrings.Designer.cs @@ -2153,7 +2153,7 @@ public static string PropertyWrongName([CanBeNull] object property, [CanBeNull] property, entityType, clrName); /// - /// The indexed property '{property}' cannot be added to type '{entityType}' because the CLR class contains a member with the same name. + /// The indexer property '{property}' cannot be added to type '{entityType}' because the CLR class contains a member with the same name. /// public static string PropertyClashingNonIndexer([CanBeNull] object property, [CanBeNull] object entityType) => string.Format( @@ -2535,10 +2535,12 @@ public static string UnsupportedBinaryOperator => GetString("UnsupportedBinaryOperator"); /// - /// EF.Property called with wrong property name. + /// Translation of '{expression}' to server failed. Either source is not an entity type or the specified property does not exist on the entity type. /// - public static string EFPropertyCalledWithWrongPropertyName - => GetString("EFPropertyCalledWithWrongPropertyName"); + public static string UnableToTranslateEFPropertyToServer([CanBeNull] object expression) + => string.Format( + GetString("UnableToTranslateEFPropertyToServer", nameof(expression)), + expression); /// /// Invalid {state} encountered. diff --git a/src/EFCore/Properties/CoreStrings.resx b/src/EFCore/Properties/CoreStrings.resx index 6d70da84fd6..2380ddfb41b 100644 --- a/src/EFCore/Properties/CoreStrings.resx +++ b/src/EFCore/Properties/CoreStrings.resx @@ -1341,8 +1341,8 @@ Unsupported Binary operator type specified. - - EF.Property called with wrong property name. + + Translation of '{expression}' to server failed. Either source is not an entity type or the specified property does not exist on the entity type. Invalid {state} encountered.