Skip to content

Commit

Permalink
Query: Rewrite Entity Equality during translation phase
Browse files Browse the repository at this point in the history
Resolves #15080
Implemented behavior:
- If any part of composite key is null then key is null.
- If comparing entity with null then check if "any" key value is null.
- If comparing entity with non-null then check if "all" key values are non null.

Resolves #20344
Resolves #19431
Resolves #13568
Resolves #13655
Since we already convert property access to nullable, if entity from client is null, make key value as null.

Resolves #19676
Clr type mismatch between proxy type and entity type is ignored.

Resolves #20164
Rewrites entity equality during translation

Part of #18923
  • Loading branch information
smitpatel committed Mar 31, 2020
1 parent 7b48952 commit bb88898
Show file tree
Hide file tree
Showing 35 changed files with 1,748 additions and 1,530 deletions.
8 changes: 2 additions & 6 deletions src/EFCore.Cosmos/Query/Internal/ContainsTranslator.cs
Original file line number Diff line number Diff line change
Expand Up @@ -45,12 +45,8 @@ public virtual SqlExpression Translate(SqlExpression instance, MethodInfo method
return _sqlExpressionFactory.In(arguments[1], arguments[0], false);
}

if (method.Name == nameof(IList.Contains)
&& arguments.Count == 1
&& method.DeclaringType.GetInterfaces().Append(method.DeclaringType).Any(
t => t == typeof(IList)
|| (t.IsGenericType
&& t.GetGenericTypeDefinition() == typeof(ICollection<>)))
if (arguments.Count == 1
&& method.IsContainsMethod()
&& ValidateValues(instance))
{
return _sqlExpressionFactory.In(arguments[0], instance, false);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,7 @@ public CosmosQueryableMethodTranslatingExpressionVisitor(
_model = queryCompilationContext.Model;
_sqlExpressionFactory = sqlExpressionFactory;
_sqlTranslator = new CosmosSqlTranslatingExpressionVisitor(
_model,
queryCompilationContext,
sqlExpressionFactory,
memberTranslatorProvider,
methodCallTranslatorProvider);
Expand Down

Large diffs are not rendered by default.

Large diffs are not rendered by default.

Original file line number Diff line number Diff line change
Expand Up @@ -7,21 +7,18 @@
using System.Linq.Expressions;
using System.Reflection;
using JetBrains.Annotations;
using Microsoft.EntityFrameworkCore.Diagnostics;
using Microsoft.EntityFrameworkCore.Infrastructure;
using Microsoft.EntityFrameworkCore.InMemory.Internal;
using Microsoft.EntityFrameworkCore.Internal;
using Microsoft.EntityFrameworkCore.Metadata;
using Microsoft.EntityFrameworkCore.Query;
using Microsoft.EntityFrameworkCore.Query.Internal;
using Microsoft.EntityFrameworkCore.Storage;
using Microsoft.EntityFrameworkCore.Utilities;

namespace Microsoft.EntityFrameworkCore.InMemory.Query.Internal
{
public class InMemoryQueryableMethodTranslatingExpressionVisitor : QueryableMethodTranslatingExpressionVisitor
{
private static readonly MethodInfo _efPropertyMethod = typeof(EF).GetTypeInfo().GetDeclaredMethod(nameof(EF.Property));

private readonly InMemoryExpressionTranslatingExpressionVisitor _expressionTranslator;
private readonly WeakEntityExpandingExpressionVisitor _weakEntityExpandingExpressionVisitor;
private readonly InMemoryProjectionBindingExpressionVisitor _projectionBindingExpressionVisitor;
Expand All @@ -32,7 +29,7 @@ public InMemoryQueryableMethodTranslatingExpressionVisitor(
[NotNull] QueryCompilationContext queryCompilationContext)
: base(dependencies, subquery: false)
{
_expressionTranslator = new InMemoryExpressionTranslatingExpressionVisitor(this, queryCompilationContext.Model);
_expressionTranslator = new InMemoryExpressionTranslatingExpressionVisitor(queryCompilationContext, this);
_weakEntityExpandingExpressionVisitor = new WeakEntityExpandingExpressionVisitor(_expressionTranslator);
_projectionBindingExpressionVisitor = new InMemoryProjectionBindingExpressionVisitor(this, _expressionTranslator);
_model = queryCompilationContext.Model;
Expand Down Expand Up @@ -402,16 +399,14 @@ protected override ShapedQueryExpression TranslateJoin(
Check.NotNull(inner, nameof(inner));
Check.NotNull(resultSelector, nameof(resultSelector));

outerKeySelector = TranslateLambdaExpression(outer, outerKeySelector);
innerKeySelector = TranslateLambdaExpression(inner, innerKeySelector);
(outerKeySelector, innerKeySelector) = ProcessJoinKeySelector(outer, inner, outerKeySelector, innerKeySelector);

if (outerKeySelector == null
|| innerKeySelector == null)
{
return null;
}

(outerKeySelector, innerKeySelector) = AlignKeySelectorTypes(outerKeySelector, innerKeySelector);

var transparentIdentifierType = TransparentIdentifierFactory.Create(
resultSelector.Parameters[0].Type,
resultSelector.Parameters[1].Type);
Expand All @@ -429,6 +424,71 @@ protected override ShapedQueryExpression TranslateJoin(
transparentIdentifierType);
}

private (LambdaExpression OuterKeySelector, LambdaExpression InnerKeySelector) ProcessJoinKeySelector(
ShapedQueryExpression outer, ShapedQueryExpression inner, LambdaExpression outerKeySelector, LambdaExpression innerKeySelector)
{
var left = RemapLambdaBody(outer, outerKeySelector);
var right = RemapLambdaBody(inner, innerKeySelector);

var joinCondition = TranslateExpression(Expression.Equal(left, right));

var (outerKeyBody, innerKeyBody) = DecomposeJoinCondition(joinCondition);

if (outerKeyBody == null
|| innerKeyBody == null)
{
return (null, null);
}

outerKeySelector = Expression.Lambda(outerKeyBody, ((InMemoryQueryExpression)outer.QueryExpression).CurrentParameter);
innerKeySelector = Expression.Lambda(innerKeyBody, ((InMemoryQueryExpression)inner.QueryExpression).CurrentParameter);

return AlignKeySelectorTypes(outerKeySelector, innerKeySelector);
}

private static (Expression, Expression) DecomposeJoinCondition(Expression joinCondition)
{
var leftExpressions = new List<Expression>();
var rightExpressions = new List<Expression>();

return ProcessJoinCondition(joinCondition, leftExpressions, rightExpressions)
? leftExpressions.Count == 1
? (leftExpressions[0], rightExpressions[0])
: (CreateAnonymousObject(leftExpressions), CreateAnonymousObject(rightExpressions))
: (null, null);

static Expression CreateAnonymousObject(List<Expression> expressions)
=> Expression.New(
AnonymousObject.AnonymousObjectCtor,
Expression.NewArrayInit(
typeof(object),
expressions.Select(e => Expression.Convert(e, typeof(object)))));
}


private static bool ProcessJoinCondition(
Expression joinCondition, List<Expression> leftExpressions, List<Expression> rightExpressions)
{
if (joinCondition is BinaryExpression binaryExpression)
{
if (binaryExpression.NodeType == ExpressionType.Equal)
{
leftExpressions.Add(binaryExpression.Left);
rightExpressions.Add(binaryExpression.Right);

return true;
}

if (binaryExpression.NodeType == ExpressionType.AndAlso)
{
return ProcessJoinCondition(binaryExpression.Left, leftExpressions, rightExpressions)
&& ProcessJoinCondition(binaryExpression.Right, leftExpressions, rightExpressions);
}
}

return false;
}

private static (LambdaExpression OuterKeySelector, LambdaExpression InnerKeySelector)
AlignKeySelectorTypes(LambdaExpression outerKeySelector, LambdaExpression innerKeySelector)
{
Expand Down Expand Up @@ -477,15 +537,14 @@ protected override ShapedQueryExpression TranslateLeftJoin(
Check.NotNull(inner, nameof(inner));
Check.NotNull(resultSelector, nameof(resultSelector));

outerKeySelector = TranslateLambdaExpression(outer, outerKeySelector);
innerKeySelector = TranslateLambdaExpression(inner, innerKeySelector);
(outerKeySelector, innerKeySelector) = ProcessJoinKeySelector(outer, inner, outerKeySelector, innerKeySelector);

if (outerKeySelector == null
|| innerKeySelector == null)
{
return null;
}

(outerKeySelector, innerKeySelector) = AlignKeySelectorTypes(outerKeySelector, innerKeySelector);

var transparentIdentifierType = TransparentIdentifierFactory.Create(
resultSelector.Parameters[0].Type,
Expand Down Expand Up @@ -579,22 +638,16 @@ protected override ShapedQueryExpression TranslateOfType(ShapedQueryExpression s
var discriminatorProperty = entityType.GetDiscriminatorProperty();
var parameter = Expression.Parameter(entityType.ClrType);

var callEFProperty = Expression.Call(
_efPropertyMethod.MakeGenericMethod(
discriminatorProperty.ClrType),
parameter,
Expression.Constant(discriminatorProperty.Name));

var equals = Expression.Equal(
callEFProperty,
parameter.CreateEFPropertyExpression(discriminatorProperty),
Expression.Constant(derivedType.GetDiscriminatorValue(), discriminatorProperty.ClrType));

foreach (var derivedDerivedType in derivedType.GetDerivedTypes())
{
equals = Expression.OrElse(
equals,
Expression.Equal(
callEFProperty,
parameter.CreateEFPropertyExpression(discriminatorProperty),
Expression.Constant(derivedDerivedType.GetDiscriminatorValue(), discriminatorProperty.ClrType)));
}

Expand Down
8 changes: 2 additions & 6 deletions src/EFCore.Relational/Query/Internal/ContainsTranslator.cs
Original file line number Diff line number Diff line change
Expand Up @@ -32,12 +32,8 @@ public virtual SqlExpression Translate(SqlExpression instance, MethodInfo method
return _sqlExpressionFactory.In(arguments[1], arguments[0], negated: false);
}

if (method.Name == nameof(IList.Contains)
&& arguments.Count == 1
&& method.DeclaringType.GetInterfaces().Append(method.DeclaringType).Any(
t => t == typeof(IList)
|| (t.IsGenericType
&& t.GetGenericTypeDefinition() == typeof(ICollection<>)))
if (arguments.Count == 1
&& method.IsContainsMethod()
&& ValidateValues(instance))
{
return _sqlExpressionFactory.In(arguments[0], instance, negated: false);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@ public class RelationalQueryableMethodTranslatingExpressionVisitor : QueryableMe
private readonly RelationalSqlTranslatingExpressionVisitor _sqlTranslator;
private readonly WeakEntityExpandingExpressionVisitor _weakEntityExpandingExpressionVisitor;
private readonly RelationalProjectionBindingExpressionVisitor _projectionBindingExpressionVisitor;
private readonly QueryCompilationContext _queryCompilationContext;
private readonly IModel _model;
private readonly ISqlExpressionFactory _sqlExpressionFactory;
private readonly bool _subquery;
Expand Down Expand Up @@ -54,7 +55,7 @@ protected RelationalQueryableMethodTranslatingExpressionVisitor(
: base(parentVisitor.Dependencies, subquery: true)
{
RelationalDependencies = parentVisitor.RelationalDependencies;
_model = parentVisitor._model;
_queryCompilationContext = parentVisitor._queryCompilationContext;
_sqlTranslator = parentVisitor._sqlTranslator;
_weakEntityExpandingExpressionVisitor = parentVisitor._weakEntityExpandingExpressionVisitor;
_projectionBindingExpressionVisitor = new RelationalProjectionBindingExpressionVisitor(this, _sqlTranslator);
Expand Down Expand Up @@ -116,7 +117,7 @@ protected override ShapedQueryExpression CreateShapedQueryExpression(Type elemen
{
Check.NotNull(elementType, nameof(elementType));

var entityType = _model.FindEntityType(elementType);
var entityType = _queryCompilationContext.Model.FindEntityType(elementType);
var queryExpression = _sqlExpressionFactory.Select(entityType);

return CreateShapedQueryExpression(entityType, queryExpression);
Expand Down
Loading

0 comments on commit bb88898

Please sign in to comment.