Skip to content

Commit

Permalink
InMemory: Find property to get value comparer performantly
Browse files Browse the repository at this point in the history
  • Loading branch information
smitpatel committed Mar 27, 2020
1 parent d5522fc commit 8a30073
Showing 1 changed file with 37 additions and 84 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -99,26 +99,27 @@ protected override Expression VisitBinary(BinaryExpression binaryExpression)
newRight = ConvertToNullable(newRight);
}

var propertyFindingExpressionVisitor = new PropertyFindingExpressionVisitor(_model);
var property = propertyFindingExpressionVisitor.Find(binaryExpression.Left)
?? propertyFindingExpressionVisitor.Find(binaryExpression.Right);

if (property != null)
if (binaryExpression.NodeType == ExpressionType.Equal
|| binaryExpression.NodeType == ExpressionType.NotEqual)
{
var comparer = property.GetValueComparer();

if (comparer != null
&& comparer.Type.IsAssignableFrom(newLeft.Type)
&& comparer.Type.IsAssignableFrom(newRight.Type))
var property = FindProperty(newLeft) ?? FindProperty(newRight);
if (property != null)
{
if (binaryExpression.NodeType == ExpressionType.Equal)
{
return comparer.ExtractEqualsBody(newLeft, newRight);
}
var comparer = property.GetValueComparer();

if (binaryExpression.NodeType == ExpressionType.NotEqual)
if (comparer != null
&& comparer.Type.IsAssignableFrom(newLeft.Type)
&& comparer.Type.IsAssignableFrom(newRight.Type))
{
return Expression.IsFalse(comparer.ExtractEqualsBody(newLeft, newRight));
if (binaryExpression.NodeType == ExpressionType.Equal)
{
return comparer.ExtractEqualsBody(newLeft, newRight);
}

if (binaryExpression.NodeType == ExpressionType.NotEqual)
{
return Expression.IsFalse(comparer.ExtractEqualsBody(newLeft, newRight));
}
}
}
}
Expand Down Expand Up @@ -742,6 +743,26 @@ private Expression GetPredicateOnGrouping(
throw new InvalidOperationException(CoreStrings.TranslationFailed(methodCallExpression.Print()));
}

private IProperty FindProperty(Expression expression)
{
if (expression.NodeType == ExpressionType.Convert
&& expression.Type.IsNullableType()
&& expression is UnaryExpression unaryExpression
&& expression.Type.UnwrapNullableType() == unaryExpression.Type)
{
expression = unaryExpression.Operand;
}

if (expression is MethodCallExpression readValueMethodCall
&& readValueMethodCall.Method.IsGenericMethod
&& readValueMethodCall.Method.GetGenericMethodDefinition() == ExpressionExtensions.ValueBufferTryReadValueMethod)
{
return (IProperty)((ConstantExpression)readValueMethodCall.Arguments[2]).Value;
}

return null;
}

[DebuggerStepThrough]
private static bool TranslationFailed(Expression original, Expression translation)
=> original != null && (translation == null || translation is EntityReferenceExpression);
Expand Down Expand Up @@ -857,74 +878,6 @@ public override Expression Visit(Expression expression)
}
}

private sealed class PropertyFindingExpressionVisitor : ExpressionVisitor
{
private readonly IModel _model;
private IProperty _property;

public PropertyFindingExpressionVisitor(IModel model)
{
_model = model;
}

public IProperty Find(Expression expression)
{
Visit(expression);

return _property;
}

protected override Expression VisitMember(MemberExpression memberExpression)
{
var entityType = FindEntityType(memberExpression.Expression);
if (entityType != null)
{
_property = GetProperty(entityType, MemberIdentity.Create(memberExpression.Member));
}

return memberExpression;
}

protected override Expression VisitMethodCall(MethodCallExpression methodCallExpression)
{
if (methodCallExpression.TryGetEFPropertyArguments(out var source, out var propertyName)
|| methodCallExpression.TryGetIndexerArguments(_model, out source, out propertyName))
{
var entityType = FindEntityType(source);
if (entityType != null)
{
_property = GetProperty(entityType, MemberIdentity.Create(propertyName));
}
}

return methodCallExpression;
}

private static IProperty GetProperty(IEntityType entityType, MemberIdentity memberIdentity)
=> memberIdentity.MemberInfo != null
? entityType.FindProperty(memberIdentity.MemberInfo)
: entityType.FindProperty(memberIdentity.Name);

private static IEntityType FindEntityType(Expression source)
{
source = source.UnwrapTypeConversion(out var convertedType);

if (source is EntityShaperExpression entityShaperExpression)
{
var entityType = entityShaperExpression.EntityType;
if (convertedType != null)
{
entityType = entityType.GetRootType().GetDerivedTypesInclusive()
.FirstOrDefault(et => et.ClrType == convertedType);
}

return entityType;
}

return null;
}
}

private sealed class EntityReferenceExpression : Expression
{

Expand Down

0 comments on commit 8a30073

Please sign in to comment.