Skip to content

Commit

Permalink
Reduce allocations in ExpressionEqualityComparer
Browse files Browse the repository at this point in the history
Closes #17757
  • Loading branch information
roji committed Jun 2, 2020
1 parent f727d7a commit c5331d3
Showing 1 changed file with 43 additions and 121 deletions.
164 changes: 43 additions & 121 deletions src/EFCore/Query/ExpressionEqualityComparer.cs
Original file line number Diff line number Diff line change
Expand Up @@ -286,7 +286,7 @@ void AddMemberBindingsToHash(IReadOnlyList<MemberBinding> memberBindings)

private struct ExpressionComparer
{
private ScopedDictionary<ParameterExpression, ParameterExpression> _parameterScope;
private Dictionary<ParameterExpression, ParameterExpression> _parameterScope;

public bool Compare(Expression left, Expression right)
{
Expand All @@ -311,86 +311,36 @@ public bool Compare(Expression left, Expression right)
return false;
}

switch (left)
{
case BinaryExpression leftBinary:
return CompareBinary(leftBinary, (BinaryExpression)right);

case BlockExpression leftBlock:
return CompareBlock(leftBlock, (BlockExpression)right);

case ConditionalExpression leftConditional:
return CompareConditional(leftConditional, (ConditionalExpression)right);

case ConstantExpression leftConstant:
return CompareConstant(leftConstant, (ConstantExpression)right);

case DefaultExpression _:
// Intentionally empty. No additional members
return true;

case GotoExpression leftGoto:
return CompareGoto(leftGoto, (GotoExpression)right);

case IndexExpression leftIndex:
return CompareIndex(leftIndex, (IndexExpression)right);

case InvocationExpression leftInvocation:
return CompareInvocation(leftInvocation, (InvocationExpression)right);

case LabelExpression leftLabel:
return CompareLabel(leftLabel, (LabelExpression)right);

case LambdaExpression leftLambda:
return CompareLambda(leftLambda, (LambdaExpression)right);

case ListInitExpression leftListInit:
return CompareListInit(leftListInit, (ListInitExpression)right);

case LoopExpression leftLoop:
return CompareLoop(leftLoop, (LoopExpression)right);

case MemberExpression leftMember:
return CompareMember(leftMember, (MemberExpression)right);

case MemberInitExpression leftMemberInit:
return CompareMemberInit(leftMemberInit, (MemberInitExpression)right);

case MethodCallExpression leftMethodCall:
return CompareMethodCall(leftMethodCall, (MethodCallExpression)right);

case NewArrayExpression leftNewArray:
return CompareNewArray(leftNewArray, (NewArrayExpression)right);

case NewExpression leftNew:
return CompareNew(leftNew, (NewExpression)right);

case ParameterExpression leftParameter:
return CompareParameter(leftParameter, (ParameterExpression)right);

case RuntimeVariablesExpression leftRuntimeVariables:
return CompareRuntimeVariables(leftRuntimeVariables, (RuntimeVariablesExpression)right);

case SwitchExpression leftSwitch:
return CompareSwitch(leftSwitch, (SwitchExpression)right);

case TryExpression leftTry:
return CompareTry(leftTry, (TryExpression)right);

case TypeBinaryExpression leftTypeBinary:
return CompareTypeBinary(leftTypeBinary, (TypeBinaryExpression)right);

case UnaryExpression leftUnary:
return CompareUnary(leftUnary, (UnaryExpression)right);

default:
if (left.NodeType == ExpressionType.Extension)
{
return left.Equals(right);
}

throw new NotImplementedException(CoreStrings.UnhandledExpressionNode(left.NodeType));
}
return left switch
{
BinaryExpression leftBinary => CompareBinary(leftBinary, (BinaryExpression)right),
BlockExpression leftBlock => CompareBlock(leftBlock, (BlockExpression)right),
ConditionalExpression leftConditional => CompareConditional(leftConditional, (ConditionalExpression)right),
ConstantExpression leftConstant => CompareConstant(leftConstant, (ConstantExpression)right),
DefaultExpression _ => true, // Intentionally empty. No additional members
GotoExpression leftGoto => CompareGoto(leftGoto, (GotoExpression)right),
IndexExpression leftIndex => CompareIndex(leftIndex, (IndexExpression)right),
InvocationExpression leftInvocation => CompareInvocation(leftInvocation, (InvocationExpression)right),
LabelExpression leftLabel => CompareLabel(leftLabel, (LabelExpression)right),
LambdaExpression leftLambda => CompareLambda(leftLambda, (LambdaExpression)right),
ListInitExpression leftListInit => CompareListInit(leftListInit, (ListInitExpression)right),
LoopExpression leftLoop => CompareLoop(leftLoop, (LoopExpression)right),
MemberExpression leftMember => CompareMember(leftMember, (MemberExpression)right),
MemberInitExpression leftMemberInit => CompareMemberInit(leftMemberInit, (MemberInitExpression)right),
MethodCallExpression leftMethodCall => CompareMethodCall(leftMethodCall, (MethodCallExpression)right),
NewArrayExpression leftNewArray => CompareNewArray(leftNewArray, (NewArrayExpression)right),
NewExpression leftNew => CompareNew(leftNew, (NewExpression)right),
ParameterExpression leftParameter => CompareParameter(leftParameter, (ParameterExpression)right),
RuntimeVariablesExpression leftRuntimeVariables => CompareRuntimeVariables(leftRuntimeVariables, (RuntimeVariablesExpression)right),
SwitchExpression leftSwitch => CompareSwitch(leftSwitch, (SwitchExpression)right),
TryExpression leftTry => CompareTry(leftTry, (TryExpression)right),
TypeBinaryExpression leftTypeBinary => CompareTypeBinary(leftTypeBinary, (TypeBinaryExpression)right),
UnaryExpression leftUnary => CompareUnary(leftUnary, (UnaryExpression)right),

_ => left.NodeType == ExpressionType.Extension
? left.Equals(right)
: throw new InvalidOperationException(CoreStrings.UnhandledExpressionNode(left.NodeType))
};
}

private bool CompareBinary(BinaryExpression a, BinaryExpression b)
Expand Down Expand Up @@ -440,31 +390,32 @@ private bool CompareLambda(LambdaExpression a, LambdaExpression b)
return false;
}

// all must have same type
_parameterScope ??= new Dictionary<ParameterExpression, ParameterExpression>();

for (var i = 0; i < n; i++)
{
if (a.Parameters[i].Type != b.Parameters[i].Type)
{
for (var j = 0; j < i; j++)
{
_parameterScope.Remove(a.Parameters[j]);
}
return false;
}
}

var save = _parameterScope;

_parameterScope = new ScopedDictionary<ParameterExpression, ParameterExpression>(_parameterScope);
_parameterScope.Add(a.Parameters[i], b.Parameters[i]);
}

try
{
for (var i = 0; i < n; i++)
{
_parameterScope.Add(a.Parameters[i], b.Parameters[i]);
}

return Compare(a.Body, b.Body);
}
finally
{
_parameterScope = save;
for (var i = 0; i < n; i++)
{
_parameterScope.Remove(a.Parameters[i]);
}
}
}

Expand Down Expand Up @@ -733,35 +684,6 @@ private bool CompareCatchBlock(CatchBlock a, CatchBlock b)
&& Compare(a.Body, b.Body)
&& Compare(a.Filter, b.Filter)
&& Compare(a.Variable, b.Variable);

private sealed class ScopedDictionary<TKey, TValue>
{
private readonly ScopedDictionary<TKey, TValue> _previous;
private readonly Dictionary<TKey, TValue> _map;

public ScopedDictionary(ScopedDictionary<TKey, TValue> previous)
{
_previous = previous;
_map = new Dictionary<TKey, TValue>();
}

public void Add(TKey key, TValue value) => _map.Add(key, value);

public bool TryGetValue(TKey key, out TValue value)
{
for (var scope = this; scope != null; scope = scope._previous)
{
if (scope._map.TryGetValue(key, out value))
{
return true;
}
}

value = default;

return false;
}
}
}
}
}

0 comments on commit c5331d3

Please sign in to comment.