From c5331d3b509a4893bcb208279f9935d5f2cdbf1b Mon Sep 17 00:00:00 2001 From: Shay Rojansky Date: Tue, 2 Jun 2020 17:15:06 +0200 Subject: [PATCH] Reduce allocations in ExpressionEqualityComparer Closes #17757 --- .../Query/ExpressionEqualityComparer.cs | 164 +++++------------- 1 file changed, 43 insertions(+), 121 deletions(-) diff --git a/src/EFCore/Query/ExpressionEqualityComparer.cs b/src/EFCore/Query/ExpressionEqualityComparer.cs index 2712a9bb8a0..c99cce31fa7 100644 --- a/src/EFCore/Query/ExpressionEqualityComparer.cs +++ b/src/EFCore/Query/ExpressionEqualityComparer.cs @@ -286,7 +286,7 @@ void AddMemberBindingsToHash(IReadOnlyList memberBindings) private struct ExpressionComparer { - private ScopedDictionary _parameterScope; + private Dictionary _parameterScope; public bool Compare(Expression left, Expression right) { @@ -311,86 +311,36 @@ public bool Compare(Expression left, Expression right) return false; } - switch (left) - { - case BinaryExpression leftBinary: - return CompareBinary(leftBinary, (BinaryExpression)right); - - case BlockExpression leftBlock: - return CompareBlock(leftBlock, (BlockExpression)right); - - case ConditionalExpression leftConditional: - return CompareConditional(leftConditional, (ConditionalExpression)right); - - case ConstantExpression leftConstant: - return CompareConstant(leftConstant, (ConstantExpression)right); - - case DefaultExpression _: - // Intentionally empty. No additional members - return true; - - case GotoExpression leftGoto: - return CompareGoto(leftGoto, (GotoExpression)right); - - case IndexExpression leftIndex: - return CompareIndex(leftIndex, (IndexExpression)right); - - case InvocationExpression leftInvocation: - return CompareInvocation(leftInvocation, (InvocationExpression)right); - - case LabelExpression leftLabel: - return CompareLabel(leftLabel, (LabelExpression)right); - - case LambdaExpression leftLambda: - return CompareLambda(leftLambda, (LambdaExpression)right); - - case ListInitExpression leftListInit: - return CompareListInit(leftListInit, (ListInitExpression)right); - - case LoopExpression leftLoop: - return CompareLoop(leftLoop, (LoopExpression)right); - - case MemberExpression leftMember: - return CompareMember(leftMember, (MemberExpression)right); - - case MemberInitExpression leftMemberInit: - return CompareMemberInit(leftMemberInit, (MemberInitExpression)right); - - case MethodCallExpression leftMethodCall: - return CompareMethodCall(leftMethodCall, (MethodCallExpression)right); - - case NewArrayExpression leftNewArray: - return CompareNewArray(leftNewArray, (NewArrayExpression)right); - - case NewExpression leftNew: - return CompareNew(leftNew, (NewExpression)right); - - case ParameterExpression leftParameter: - return CompareParameter(leftParameter, (ParameterExpression)right); - - case RuntimeVariablesExpression leftRuntimeVariables: - return CompareRuntimeVariables(leftRuntimeVariables, (RuntimeVariablesExpression)right); - - case SwitchExpression leftSwitch: - return CompareSwitch(leftSwitch, (SwitchExpression)right); - - case TryExpression leftTry: - return CompareTry(leftTry, (TryExpression)right); - - case TypeBinaryExpression leftTypeBinary: - return CompareTypeBinary(leftTypeBinary, (TypeBinaryExpression)right); - - case UnaryExpression leftUnary: - return CompareUnary(leftUnary, (UnaryExpression)right); - - default: - if (left.NodeType == ExpressionType.Extension) - { - return left.Equals(right); - } - - throw new NotImplementedException(CoreStrings.UnhandledExpressionNode(left.NodeType)); - } + return left switch + { + BinaryExpression leftBinary => CompareBinary(leftBinary, (BinaryExpression)right), + BlockExpression leftBlock => CompareBlock(leftBlock, (BlockExpression)right), + ConditionalExpression leftConditional => CompareConditional(leftConditional, (ConditionalExpression)right), + ConstantExpression leftConstant => CompareConstant(leftConstant, (ConstantExpression)right), + DefaultExpression _ => true, // Intentionally empty. No additional members + GotoExpression leftGoto => CompareGoto(leftGoto, (GotoExpression)right), + IndexExpression leftIndex => CompareIndex(leftIndex, (IndexExpression)right), + InvocationExpression leftInvocation => CompareInvocation(leftInvocation, (InvocationExpression)right), + LabelExpression leftLabel => CompareLabel(leftLabel, (LabelExpression)right), + LambdaExpression leftLambda => CompareLambda(leftLambda, (LambdaExpression)right), + ListInitExpression leftListInit => CompareListInit(leftListInit, (ListInitExpression)right), + LoopExpression leftLoop => CompareLoop(leftLoop, (LoopExpression)right), + MemberExpression leftMember => CompareMember(leftMember, (MemberExpression)right), + MemberInitExpression leftMemberInit => CompareMemberInit(leftMemberInit, (MemberInitExpression)right), + MethodCallExpression leftMethodCall => CompareMethodCall(leftMethodCall, (MethodCallExpression)right), + NewArrayExpression leftNewArray => CompareNewArray(leftNewArray, (NewArrayExpression)right), + NewExpression leftNew => CompareNew(leftNew, (NewExpression)right), + ParameterExpression leftParameter => CompareParameter(leftParameter, (ParameterExpression)right), + RuntimeVariablesExpression leftRuntimeVariables => CompareRuntimeVariables(leftRuntimeVariables, (RuntimeVariablesExpression)right), + SwitchExpression leftSwitch => CompareSwitch(leftSwitch, (SwitchExpression)right), + TryExpression leftTry => CompareTry(leftTry, (TryExpression)right), + TypeBinaryExpression leftTypeBinary => CompareTypeBinary(leftTypeBinary, (TypeBinaryExpression)right), + UnaryExpression leftUnary => CompareUnary(leftUnary, (UnaryExpression)right), + + _ => left.NodeType == ExpressionType.Extension + ? left.Equals(right) + : throw new InvalidOperationException(CoreStrings.UnhandledExpressionNode(left.NodeType)) + }; } private bool CompareBinary(BinaryExpression a, BinaryExpression b) @@ -440,31 +390,32 @@ private bool CompareLambda(LambdaExpression a, LambdaExpression b) return false; } - // all must have same type + _parameterScope ??= new Dictionary(); + for (var i = 0; i < n; i++) { if (a.Parameters[i].Type != b.Parameters[i].Type) { + for (var j = 0; j < i; j++) + { + _parameterScope.Remove(a.Parameters[j]); + } return false; } - } - var save = _parameterScope; - - _parameterScope = new ScopedDictionary(_parameterScope); + _parameterScope.Add(a.Parameters[i], b.Parameters[i]); + } try { - for (var i = 0; i < n; i++) - { - _parameterScope.Add(a.Parameters[i], b.Parameters[i]); - } - return Compare(a.Body, b.Body); } finally { - _parameterScope = save; + for (var i = 0; i < n; i++) + { + _parameterScope.Remove(a.Parameters[i]); + } } } @@ -733,35 +684,6 @@ private bool CompareCatchBlock(CatchBlock a, CatchBlock b) && Compare(a.Body, b.Body) && Compare(a.Filter, b.Filter) && Compare(a.Variable, b.Variable); - - private sealed class ScopedDictionary - { - private readonly ScopedDictionary _previous; - private readonly Dictionary _map; - - public ScopedDictionary(ScopedDictionary previous) - { - _previous = previous; - _map = new Dictionary(); - } - - public void Add(TKey key, TValue value) => _map.Add(key, value); - - public bool TryGetValue(TKey key, out TValue value) - { - for (var scope = this; scope != null; scope = scope._previous) - { - if (scope._map.TryGetValue(key, out value)) - { - return true; - } - } - - value = default; - - return false; - } - } } } }