From 94eb47659dac30691534a1b6aa5cea10e19384e2 Mon Sep 17 00:00:00 2001 From: Smit Patel Date: Tue, 14 Apr 2020 12:29:09 -0700 Subject: [PATCH] Query: Remove object convert for Contains Resolves #20624 --- .../CosmosSqlTranslatingExpressionVisitor.cs | 14 +++++++++-- .../Query/Internal/ContainsTranslator.cs | 12 ++++++++-- .../Query/NorthwindWhereQueryCosmosTest.cs | 20 ++++++++++++++++ .../Query/NorthwindWhereQueryTestBase.cs | 24 +++++++++++++++++++ .../Query/NorthwindWhereQuerySqlServerTest.cs | 20 ++++++++++++++++ 5 files changed, 86 insertions(+), 4 deletions(-) diff --git a/src/EFCore.Cosmos/Query/Internal/CosmosSqlTranslatingExpressionVisitor.cs b/src/EFCore.Cosmos/Query/Internal/CosmosSqlTranslatingExpressionVisitor.cs index 837d28d4460..34bb6b21c1c 100644 --- a/src/EFCore.Cosmos/Query/Internal/CosmosSqlTranslatingExpressionVisitor.cs +++ b/src/EFCore.Cosmos/Query/Internal/CosmosSqlTranslatingExpressionVisitor.cs @@ -33,6 +33,9 @@ public class CosmosSqlTranslatingExpressionVisitor : ExpressionVisitor private static readonly MethodInfo _parameterListValueExtractor = typeof(CosmosSqlTranslatingExpressionVisitor).GetTypeInfo().GetDeclaredMethod(nameof(ParameterListValueExtractor)); + private static readonly MethodInfo _concatMethodInfo + = typeof(string).GetRuntimeMethod(nameof(string.Concat), new[] { typeof(object), typeof(object) }); + private readonly QueryCompilationContext _queryCompilationContext; private readonly IModel _model; private readonly ISqlExpressionFactory _sqlExpressionFactory; @@ -128,6 +131,11 @@ protected override Expression VisitBinary(BinaryExpression binaryExpression) return result; } + if (binaryExpression.Method == _concatMethodInfo) + { + return null; + } + var uncheckedNodeTypeVariant = binaryExpression.NodeType switch { ExpressionType.AddChecked => ExpressionType.Add, @@ -474,11 +482,13 @@ protected override Expression VisitUnary(UnaryExpression unaryExpression) case ExpressionType.Convert: case ExpressionType.ConvertChecked: - // Object convert needs to be converted to explicit cast when mismatching types if (operand.Type.IsInterface && unaryExpression.Type.GetInterfaces().Any(e => e == operand.Type) || unaryExpression.Type.UnwrapNullableType() == operand.Type - || unaryExpression.Type.UnwrapNullableType() == typeof(Enum)) + || unaryExpression.Type.UnwrapNullableType() == typeof(Enum) + // Object convert needs to be converted to explicit cast when mismatching types + // But we let is pass here since we don't have explicit cast mechanism here and in some cases object convert is due to value types + || unaryExpression.Type == typeof(object)) { return sqlOperand; } diff --git a/src/EFCore.Relational/Query/Internal/ContainsTranslator.cs b/src/EFCore.Relational/Query/Internal/ContainsTranslator.cs index 68c8b9a8c28..dddbb8b28f9 100644 --- a/src/EFCore.Relational/Query/Internal/ContainsTranslator.cs +++ b/src/EFCore.Relational/Query/Internal/ContainsTranslator.cs @@ -2,6 +2,7 @@ // Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information. using System.Collections.Generic; +using System.Linq.Expressions; using System.Reflection; using JetBrains.Annotations; using Microsoft.EntityFrameworkCore.Query.SqlExpressions; @@ -27,14 +28,14 @@ public virtual SqlExpression Translate(SqlExpression instance, MethodInfo method && method.GetGenericMethodDefinition().Equals(EnumerableMethods.Contains) && ValidateValues(arguments[0])) { - return _sqlExpressionFactory.In(arguments[1], arguments[0], negated: false); + return _sqlExpressionFactory.In(RemoveObjectConvert(arguments[1]), arguments[0], negated: false); } if (arguments.Count == 1 && method.IsContainsMethod() && ValidateValues(instance)) { - return _sqlExpressionFactory.In(arguments[0], instance, negated: false); + return _sqlExpressionFactory.In(RemoveObjectConvert(arguments[0]), instance, negated: false); } return null; @@ -42,5 +43,12 @@ public virtual SqlExpression Translate(SqlExpression instance, MethodInfo method private bool ValidateValues(SqlExpression values) => values is SqlConstantExpression || values is SqlParameterExpression; + + private SqlExpression RemoveObjectConvert(SqlExpression expression) + => expression is SqlUnaryExpression sqlUnaryExpression + && sqlUnaryExpression.OperatorType == ExpressionType.Convert + && sqlUnaryExpression.Type == typeof(object) + ? sqlUnaryExpression.Operand + : expression; } } diff --git a/test/EFCore.Cosmos.FunctionalTests/Query/NorthwindWhereQueryCosmosTest.cs b/test/EFCore.Cosmos.FunctionalTests/Query/NorthwindWhereQueryCosmosTest.cs index da476fbc348..5e90bca8708 100644 --- a/test/EFCore.Cosmos.FunctionalTests/Query/NorthwindWhereQueryCosmosTest.cs +++ b/test/EFCore.Cosmos.FunctionalTests/Query/NorthwindWhereQueryCosmosTest.cs @@ -2059,6 +2059,26 @@ public override Task Where_Queryable_AsEnumerable_Contains_negated(bool async) return base.Where_Queryable_AsEnumerable_Contains_negated(async); } + public override async Task Where_list_object_contains_over_value_type(bool async) + { + await base.Where_list_object_contains_over_value_type(async); + + AssertSql( + @"SELECT c +FROM root c +WHERE ((c[""Discriminator""] = ""Order"") AND c[""OrderID""] IN (10248, 10249))"); + } + + public override async Task Where_array_of_object_contains_over_value_type(bool async) + { + await base.Where_array_of_object_contains_over_value_type(async); + + AssertSql( + @"SELECT c +FROM root c +WHERE ((c[""Discriminator""] = ""Order"") AND c[""OrderID""] IN (10248, 10249))"); + } + private void AssertSql(params string[] expected) => Fixture.TestSqlLoggerFactory.AssertBaseline(expected); diff --git a/test/EFCore.Specification.Tests/Query/NorthwindWhereQueryTestBase.cs b/test/EFCore.Specification.Tests/Query/NorthwindWhereQueryTestBase.cs index bb8e5fde3e4..7029721de67 100644 --- a/test/EFCore.Specification.Tests/Query/NorthwindWhereQueryTestBase.cs +++ b/test/EFCore.Specification.Tests/Query/NorthwindWhereQueryTestBase.cs @@ -2248,5 +2248,29 @@ public virtual Task Where_collection_navigation_ToArray_Length_member(bool async assertOrder: true, elementAsserter: (e, a) => AssertCollection(e, a)); } + + [ConditionalTheory] + [MemberData(nameof(IsAsyncData))] + public virtual Task Where_list_object_contains_over_value_type(bool async) + { + var orderIds = new List { 10248, 10249 }; + return AssertQuery( + async, + ss => ss.Set() + .Where(o => orderIds.Contains(o.OrderID)), + entryCount: 2); + } + + [ConditionalTheory] + [MemberData(nameof(IsAsyncData))] + public virtual Task Where_array_of_object_contains_over_value_type(bool async) + { + var orderIds = new object[] { 10248, 10249 }; + return AssertQuery( + async, + ss => ss.Set() + .Where(o => orderIds.Contains(o.OrderID)), + entryCount: 2); + } } } diff --git a/test/EFCore.SqlServer.FunctionalTests/Query/NorthwindWhereQuerySqlServerTest.cs b/test/EFCore.SqlServer.FunctionalTests/Query/NorthwindWhereQuerySqlServerTest.cs index 10f220cc74d..760912f2986 100644 --- a/test/EFCore.SqlServer.FunctionalTests/Query/NorthwindWhereQuerySqlServerTest.cs +++ b/test/EFCore.SqlServer.FunctionalTests/Query/NorthwindWhereQuerySqlServerTest.cs @@ -2041,6 +2041,26 @@ FROM [Order Details] AS [o1] ORDER BY [o].[OrderID], [o0].[OrderID], [o0].[ProductID]"); } + public override async Task Where_list_object_contains_over_value_type(bool async) + { + await base.Where_list_object_contains_over_value_type(async); + + AssertSql( + @"SELECT [o].[OrderID], [o].[CustomerID], [o].[EmployeeID], [o].[OrderDate] +FROM [Orders] AS [o] +WHERE [o].[OrderID] IN (10248, 10249)"); + } + + public override async Task Where_array_of_object_contains_over_value_type(bool async) + { + await base.Where_array_of_object_contains_over_value_type(async); + + AssertSql( + @"SELECT [o].[OrderID], [o].[CustomerID], [o].[EmployeeID], [o].[OrderDate] +FROM [Orders] AS [o] +WHERE [o].[OrderID] IN (10248, 10249)"); + } + private void AssertSql(params string[] expected) => Fixture.TestSqlLoggerFactory.AssertBaseline(expected);