Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Query: Remove object convert for Contains #20628

Merged
merged 1 commit into from
Apr 14, 2020
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,9 @@ public class CosmosSqlTranslatingExpressionVisitor : ExpressionVisitor
private static readonly MethodInfo _parameterListValueExtractor =
typeof(CosmosSqlTranslatingExpressionVisitor).GetTypeInfo().GetDeclaredMethod(nameof(ParameterListValueExtractor));

private static readonly MethodInfo _concatMethodInfo
= typeof(string).GetRuntimeMethod(nameof(string.Concat), new[] { typeof(object), typeof(object) });
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should move to a shared internal class to avoid ad-hoc proliferation

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Filed #20631


private readonly QueryCompilationContext _queryCompilationContext;
private readonly IModel _model;
private readonly ISqlExpressionFactory _sqlExpressionFactory;
Expand Down Expand Up @@ -128,6 +131,11 @@ protected override Expression VisitBinary(BinaryExpression binaryExpression)
return result;
}

if (binaryExpression.Method == _concatMethodInfo)
{
return null;
}

var uncheckedNodeTypeVariant = binaryExpression.NodeType switch
{
ExpressionType.AddChecked => ExpressionType.Add,
Expand Down Expand Up @@ -474,11 +482,13 @@ protected override Expression VisitUnary(UnaryExpression unaryExpression)

case ExpressionType.Convert:
case ExpressionType.ConvertChecked:
// Object convert needs to be converted to explicit cast when mismatching types
if (operand.Type.IsInterface
&& unaryExpression.Type.GetInterfaces().Any(e => e == operand.Type)
|| unaryExpression.Type.UnwrapNullableType() == operand.Type
|| unaryExpression.Type.UnwrapNullableType() == typeof(Enum))
|| unaryExpression.Type.UnwrapNullableType() == typeof(Enum)
// Object convert needs to be converted to explicit cast when mismatching types
// But we let is pass here since we don't have explicit cast mechanism here and in some cases object convert is due to value types
|| unaryExpression.Type == typeof(object))
{
return sqlOperand;
}
Expand Down
12 changes: 10 additions & 2 deletions src/EFCore.Relational/Query/Internal/ContainsTranslator.cs
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
// Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information.

using System.Collections.Generic;
using System.Linq.Expressions;
using System.Reflection;
using JetBrains.Annotations;
using Microsoft.EntityFrameworkCore.Query.SqlExpressions;
Expand All @@ -27,20 +28,27 @@ public virtual SqlExpression Translate(SqlExpression instance, MethodInfo method
&& method.GetGenericMethodDefinition().Equals(EnumerableMethods.Contains)
&& ValidateValues(arguments[0]))
{
return _sqlExpressionFactory.In(arguments[1], arguments[0], negated: false);
return _sqlExpressionFactory.In(RemoveObjectConvert(arguments[1]), arguments[0], negated: false);
}

if (arguments.Count == 1
&& method.IsContainsMethod()
&& ValidateValues(instance))
{
return _sqlExpressionFactory.In(arguments[0], instance, negated: false);
return _sqlExpressionFactory.In(RemoveObjectConvert(arguments[0]), instance, negated: false);
}

return null;
}

private bool ValidateValues(SqlExpression values)
=> values is SqlConstantExpression || values is SqlParameterExpression;

private SqlExpression RemoveObjectConvert(SqlExpression expression)
=> expression is SqlUnaryExpression sqlUnaryExpression
&& sqlUnaryExpression.OperatorType == ExpressionType.Convert
&& sqlUnaryExpression.Type == typeof(object)
? sqlUnaryExpression.Operand
: expression;
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -2059,6 +2059,26 @@ public override Task Where_Queryable_AsEnumerable_Contains_negated(bool async)
return base.Where_Queryable_AsEnumerable_Contains_negated(async);
}

public override async Task Where_list_object_contains_over_value_type(bool async)
{
await base.Where_list_object_contains_over_value_type(async);

AssertSql(
@"SELECT c
FROM root c
WHERE ((c[""Discriminator""] = ""Order"") AND c[""OrderID""] IN (10248, 10249))");
}

public override async Task Where_array_of_object_contains_over_value_type(bool async)
{
await base.Where_array_of_object_contains_over_value_type(async);

AssertSql(
@"SELECT c
FROM root c
WHERE ((c[""Discriminator""] = ""Order"") AND c[""OrderID""] IN (10248, 10249))");
}

private void AssertSql(params string[] expected)
=> Fixture.TestSqlLoggerFactory.AssertBaseline(expected);

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2248,5 +2248,29 @@ public virtual Task Where_collection_navigation_ToArray_Length_member(bool async
assertOrder: true,
elementAsserter: (e, a) => AssertCollection(e, a));
}

[ConditionalTheory]
[MemberData(nameof(IsAsyncData))]
public virtual Task Where_list_object_contains_over_value_type(bool async)
{
var orderIds = new List<object> { 10248, 10249 };
return AssertQuery(
async,
ss => ss.Set<Order>()
.Where(o => orderIds.Contains(o.OrderID)),
entryCount: 2);
}

[ConditionalTheory]
[MemberData(nameof(IsAsyncData))]
public virtual Task Where_array_of_object_contains_over_value_type(bool async)
{
var orderIds = new object[] { 10248, 10249 };
return AssertQuery(
async,
ss => ss.Set<Order>()
.Where(o => orderIds.Contains(o.OrderID)),
entryCount: 2);
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -2041,6 +2041,26 @@ FROM [Order Details] AS [o1]
ORDER BY [o].[OrderID], [o0].[OrderID], [o0].[ProductID]");
}

public override async Task Where_list_object_contains_over_value_type(bool async)
{
await base.Where_list_object_contains_over_value_type(async);

AssertSql(
@"SELECT [o].[OrderID], [o].[CustomerID], [o].[EmployeeID], [o].[OrderDate]
FROM [Orders] AS [o]
WHERE [o].[OrderID] IN (10248, 10249)");
}

public override async Task Where_array_of_object_contains_over_value_type(bool async)
{
await base.Where_array_of_object_contains_over_value_type(async);

AssertSql(
@"SELECT [o].[OrderID], [o].[CustomerID], [o].[EmployeeID], [o].[OrderDate]
FROM [Orders] AS [o]
WHERE [o].[OrderID] IN (10248, 10249)");
}

private void AssertSql(params string[] expected)
=> Fixture.TestSqlLoggerFactory.AssertBaseline(expected);

Expand Down