Skip to content

Commit

Permalink
Query: Key comparison should use object.Equals internally in query (#…
Browse files Browse the repository at this point in the history
…21742)

And associated changes to support translation.

Resolves #19407
  • Loading branch information
smitpatel committed Jul 22, 2020
1 parent 2cf65e9 commit b083ed3
Show file tree
Hide file tree
Showing 8 changed files with 203 additions and 211 deletions.
6 changes: 4 additions & 2 deletions src/EFCore.Cosmos/Query/Internal/EqualsTranslator.cs
Original file line number Diff line number Diff line change
Expand Up @@ -64,8 +64,10 @@ public virtual SqlExpression Translate(SqlExpression instance, MethodInfo method
&& right != null)
{
return left.Type.UnwrapNullableType() == right.Type.UnwrapNullableType()
? (SqlExpression)_sqlExpressionFactory.Equal(left, right)
: _sqlExpressionFactory.Constant(false);
|| (right.Type == typeof(object) && right is SqlParameterExpression)
|| (left.Type == typeof(object) && left is SqlParameterExpression)
? _sqlExpressionFactory.Equal(left, right)
: (SqlExpression)_sqlExpressionFactory.Constant(false);
}

return null;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -145,7 +145,7 @@ protected override Expression VisitBinary(BinaryExpression binaryExpression)
&& binaryExpression.Left is NewArrayExpression
&& binaryExpression.NodeType == ExpressionType.Equal)
{
return Visit(ConvertObjectArrayEqualityComparison(binaryExpression));
return Visit(ConvertObjectArrayEqualityComparison(binaryExpression.Left, binaryExpression.Right));
}

var newLeft = Visit(binaryExpression.Left);
Expand Down Expand Up @@ -557,6 +557,13 @@ MethodInfo GetMethod()
&& methodCallExpression.Object == null
&& methodCallExpression.Arguments.Count == 2)
{
if (methodCallExpression.Arguments[0].Type == typeof(object[])
&& methodCallExpression.Arguments[0] is NewArrayExpression)
{
return Visit(ConvertObjectArrayEqualityComparison(
methodCallExpression.Arguments[0], methodCallExpression.Arguments[1]));
}

var left = Visit(methodCallExpression.Arguments[0]);
var right = Visit(methodCallExpression.Arguments[1]);

Expand Down Expand Up @@ -1262,10 +1269,10 @@ private static bool CanEvaluate(Expression expression)
}
}

private static Expression ConvertObjectArrayEqualityComparison(BinaryExpression binaryExpression)
private static Expression ConvertObjectArrayEqualityComparison(Expression left, Expression right)
{
var leftExpressions = ((NewArrayExpression)binaryExpression.Left).Expressions;
var rightExpressions = ((NewArrayExpression)binaryExpression.Right).Expressions;
var leftExpressions = ((NewArrayExpression)left).Expressions;
var rightExpressions = ((NewArrayExpression)right).Expressions;

return leftExpressions.Zip(
rightExpressions,
Expand Down
11 changes: 5 additions & 6 deletions src/EFCore.Relational/Query/Internal/EqualsTranslator.cs
Original file line number Diff line number Diff line change
Expand Up @@ -64,12 +64,11 @@ public virtual SqlExpression Translate(SqlExpression instance, MethodInfo method
if (left != null
&& right != null)
{
if (left.Type == right.Type)
{
return _sqlExpressionFactory.Equal(left, right);
}

return _sqlExpressionFactory.Constant(false);
return left.Type == right.Type
|| (right.Type == typeof(object) && right is SqlParameterExpression)
|| (left.Type == typeof(object) && left is SqlParameterExpression)
? _sqlExpressionFactory.Equal(left, right)
: (SqlExpression)_sqlExpressionFactory.Constant(false);
}

return null;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -335,7 +335,7 @@ protected override Expression VisitBinary(BinaryExpression binaryExpression)
&& binaryExpression.Left is NewArrayExpression
&& binaryExpression.NodeType == ExpressionType.Equal)
{
return Visit(ConvertObjectArrayEqualityComparison(binaryExpression));
return Visit(ConvertObjectArrayEqualityComparison(binaryExpression.Left, binaryExpression.Right));
}

var left = TryRemoveImplicitConvert(binaryExpression.Left);
Expand Down Expand Up @@ -624,6 +624,13 @@ static bool IsAggregateResultWithCustomShaper(MethodInfo method)
&& methodCallExpression.Object == null
&& methodCallExpression.Arguments.Count == 2)
{
if (methodCallExpression.Arguments[0].Type == typeof(object[])
&& methodCallExpression.Arguments[0] is NewArrayExpression)
{
return Visit(ConvertObjectArrayEqualityComparison(
methodCallExpression.Arguments[0], methodCallExpression.Arguments[1]));
}

var left = Visit(RemoveObjectConvert(methodCallExpression.Arguments[0]));
var right = Visit(RemoveObjectConvert(methodCallExpression.Arguments[1]));

Expand Down Expand Up @@ -1000,10 +1007,10 @@ private static Expression RemoveObjectConvert(Expression expression)
? unaryExpression.Operand
: expression;

private static Expression ConvertObjectArrayEqualityComparison(BinaryExpression binaryExpression)
private static Expression ConvertObjectArrayEqualityComparison(Expression left, Expression right)
{
var leftExpressions = ((NewArrayExpression)binaryExpression.Left).Expressions;
var rightExpressions = ((NewArrayExpression)binaryExpression.Right).Expressions;
var leftExpressions = ((NewArrayExpression)left).Expressions;
var rightExpressions = ((NewArrayExpression)right).Expressions;

return leftExpressions.Zip(
rightExpressions,
Expand Down
47 changes: 33 additions & 14 deletions src/EFCore/Internal/EntityFinder.cs
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,9 @@ namespace Microsoft.EntityFrameworkCore.Internal
public class EntityFinder<TEntity> : IEntityFinder<TEntity>
where TEntity : class
{
private static readonly MethodInfo _objectEqualsMethodInfo
= typeof(object).GetRuntimeMethod(nameof(object.Equals), new[] { typeof(object), typeof(object) });

private readonly IStateManager _stateManager;
private readonly IDbSetSource _setSource;
private readonly IDbSetCache _setCache;
Expand Down Expand Up @@ -354,34 +357,50 @@ private static IQueryable<TResult> Select<TSource, TResult>(
parameter));
}

private static BinaryExpression BuildPredicate(
private static Expression BuildPredicate(
IReadOnlyList<IProperty> keyProperties,
ValueBuffer keyValues,
ParameterExpression entityParameter)
{
var keyValuesConstant = Expression.Constant(keyValues);

var predicate = GenerateEqualExpression(keyProperties[0], 0);
var predicate = GenerateEqualExpression(entityParameter, keyValuesConstant, keyProperties[0], 0);

for (var i = 1; i < keyProperties.Count; i++)
{
predicate = Expression.AndAlso(predicate, GenerateEqualExpression(keyProperties[i], i));
predicate = Expression.AndAlso(predicate, GenerateEqualExpression(entityParameter, keyValuesConstant, keyProperties[i], i));
}

return predicate;

BinaryExpression GenerateEqualExpression(IProperty property, int i) =>
Expression.Equal(
Expression.Call(
EF.PropertyMethod.MakeGenericMethod(property.ClrType),
entityParameter,
Expression.Constant(property.Name, typeof(string))),
Expression.Convert(
static Expression GenerateEqualExpression(
Expression entityParameterExpression, Expression keyValuesConstantExpression, IProperty property, int i)
=> property.ClrType.IsValueType
&& property.ClrType.UnwrapNullableType() is Type nonNullableType
&& !(nonNullableType == typeof(bool) || nonNullableType.IsNumeric() || nonNullableType.IsEnum)
? Expression.Call(
_objectEqualsMethodInfo,
Expression.Call(
keyValuesConstant,
ValueBuffer.GetValueMethod,
Expression.Constant(i)),
property.ClrType));
EF.PropertyMethod.MakeGenericMethod(typeof(object)),
entityParameterExpression,
Expression.Constant(property.Name, typeof(string))),
Expression.Convert(
Expression.Call(
keyValuesConstantExpression,
ValueBuffer.GetValueMethod,
Expression.Constant(i)),
typeof(object)))
: (Expression)Expression.Equal(
Expression.Call(
EF.PropertyMethod.MakeGenericMethod(property.ClrType),
entityParameterExpression,
Expression.Constant(property.Name, typeof(string))),
Expression.Convert(
Expression.Call(
keyValuesConstantExpression,
ValueBuffer.GetValueMethod,
Expression.Constant(i)),
property.ClrType));
}

private static Expression<Func<object, object[]>> BuildProjection(IEntityType entityType)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,9 @@ public partial class NavigationExpandingExpressionVisitor
/// </summary>
private class ExpandingExpressionVisitor : ExpressionVisitor
{
private static readonly MethodInfo _objectEqualsMethodInfo
= typeof(object).GetRuntimeMethod(nameof(object.Equals), new[] { typeof(object), typeof(object) });

private readonly NavigationExpandingExpressionVisitor _navigationExpandingExpressionVisitor;
private readonly NavigationExpansionExpression _source;

Expand Down Expand Up @@ -393,7 +396,7 @@ outerKey is NewArrayExpression newArrayExpression
})
.Aggregate((l, r) => Expression.AndAlso(l, r))
: Expression.NotEqual(outerKey, Expression.Constant(null, outerKey.Type)),
Expression.Equal(outerKey, innerKey));
Expression.Call(_objectEqualsMethodInfo, AddConvertToObject(outerKey), AddConvertToObject(innerKey)));

// Caller should take care of wrapping MaterializeCollectionNavigation
return Expression.Call(
Expand Down Expand Up @@ -455,6 +458,11 @@ outerKey is NewArrayExpression newArrayExpression

return innerSource.PendingSelector;
}

static Expression AddConvertToObject(Expression expression)
=> expression.Type.IsValueType
? Expression.Convert(expression, typeof(object))
: expression;
}

/// <summary>
Expand Down
Loading

0 comments on commit b083ed3

Please sign in to comment.