Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Query: Merge query optimizing expression visitors #20489

Merged
merged 1 commit into from
Apr 1, 2020
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view

This file was deleted.

31 changes: 12 additions & 19 deletions src/EFCore/Query/Internal/NavigationExpandingExpressionVisitor.cs
Original file line number Diff line number Diff line change
Expand Up @@ -178,9 +178,9 @@ protected override Expression VisitMember(MemberExpression memberExpression)
if (innerQueryable.Type.TryGetElementType(typeof(IQueryable<>)) != null)
{
return Visit(
Expression.Call(
QueryableMethods.CountWithoutPredicate.MakeGenericMethod(innerQueryable.Type.TryGetSequenceType()),
innerQueryable));
Expression.Call(
QueryableMethods.CountWithoutPredicate.MakeGenericMethod(innerQueryable.Type.TryGetSequenceType()),
innerQueryable));
}
}

Expand Down Expand Up @@ -528,13 +528,8 @@ when QueryableMethods.IsSumWithSelector(method):
&& (method.GetGenericMethodDefinition() == EnumerableMethods.ToList
|| method.GetGenericMethodDefinition() == EnumerableMethods.ToArray))
{
var argument = Visit(methodCallExpression.Arguments[0]);
if (argument is MaterializeCollectionNavigationExpression materializeCollectionNavigationExpression)
{
argument = materializeCollectionNavigationExpression.Subquery;
}

return methodCallExpression.Update(null, new[] { argument });
return methodCallExpression.Update(
null, new[] { UnwrapCollectionMaterialization(Visit(methodCallExpression.Arguments[0])) });
}

return ProcessUnknownMethod(methodCallExpression);
Expand Down Expand Up @@ -1584,16 +1579,14 @@ private LambdaExpression GenerateLambda(Expression body, ParameterExpression cur

private Expression UnwrapCollectionMaterialization(Expression expression)
{
if (expression is MethodCallExpression innerMethodCall
&& innerMethodCall.Method.IsGenericMethod)
while (expression is MethodCallExpression innerMethodCall
&& innerMethodCall.Method.IsGenericMethod
&& innerMethodCall.Method.GetGenericMethodDefinition() is MethodInfo innerMethod
&& (innerMethod == EnumerableMethods.AsEnumerable
|| innerMethod == EnumerableMethods.ToList
|| innerMethod == EnumerableMethods.ToArray))
{
var innerGenericMethod = innerMethodCall.Method.GetGenericMethodDefinition();
if (innerGenericMethod == EnumerableMethods.AsEnumerable
|| innerGenericMethod == EnumerableMethods.ToList
|| innerGenericMethod == EnumerableMethods.ToArray)
{
expression = innerMethodCall.Arguments[0];
}
expression = innerMethodCall.Arguments[0];
}

if (expression is MaterializeCollectionNavigationExpression materializeCollectionNavigationExpression)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -9,16 +9,51 @@

namespace Microsoft.EntityFrameworkCore.Query.Internal
{
public class AllAnyContainsRewritingExpressionVisitor : ExpressionVisitor
public class QueryOptimizingExpressionVisitor : ExpressionVisitor
{
private static bool IsExpressionOfFunc(Type type, int funcGenericArgs = 2)
=> type.IsGenericType
&& type.GetGenericArguments().Length == funcGenericArgs;
private static readonly MethodInfo _stringCompareWithComparisonMethod =
typeof(string).GetRuntimeMethod(nameof(string.Compare), new[] { typeof(string), typeof(string), typeof(StringComparison) });
private static readonly MethodInfo _stringCompareWithoutComparisonMethod =
typeof(string).GetRuntimeMethod(nameof(string.Compare), new[] { typeof(string), typeof(string) });
private static readonly MethodInfo _startsWithMethodInfo =
typeof(string).GetRuntimeMethod(nameof(string.StartsWith), new[] { typeof(string) });
private static readonly MethodInfo _endsWithMethodInfo =
typeof(string).GetRuntimeMethod(nameof(string.EndsWith), new[] { typeof(string) });

private static readonly Expression _constantNullString = Expression.Constant(null, typeof(string));

protected override Expression VisitMethodCall(MethodCallExpression methodCallExpression)
{
Check.NotNull(methodCallExpression, nameof(methodCallExpression));

if (_startsWithMethodInfo.Equals(methodCallExpression.Method)
|| _endsWithMethodInfo.Equals(methodCallExpression.Method))
{
if (methodCallExpression.Arguments[0] is ConstantExpression constantArgument
&& (string)constantArgument.Value == string.Empty)
{
// every string starts/ends with empty string.
return Expression.Constant(true);
}

var newObject = Visit(methodCallExpression.Object);
var newArgument = Visit(methodCallExpression.Arguments[0]);

var result = Expression.AndAlso(
Expression.NotEqual(newObject, _constantNullString),
Expression.AndAlso(
Expression.NotEqual(newArgument, _constantNullString),
methodCallExpression.Update(newObject, new[] { newArgument })));

return newArgument is ConstantExpression
? result
: Expression.OrElse(
Expression.Equal(
newArgument,
Expression.Constant(string.Empty)),
result);
}

if (methodCallExpression.Method.IsGenericMethod
&& methodCallExpression.Method.GetGenericMethodDefinition() is MethodInfo methodInfo
&& (methodInfo.Equals(EnumerableMethods.AnyWithPredicate) || methodInfo.Equals(EnumerableMethods.All))
Expand Down Expand Up @@ -46,9 +81,7 @@ protected override Expression VisitMethodCall(MethodCallExpression methodCallExp

if (methodCallExpression.Method.IsGenericMethod
&& methodCallExpression.Method.GetGenericMethodDefinition() is MethodInfo containsMethodInfo
&& containsMethodInfo.Equals(QueryableMethods.Contains)
// special case Queryable.Contains(byte_array, byte) - we don't want those to be rewritten
&& methodCallExpression.Arguments[1].Type != typeof(byte))
&& containsMethodInfo.Equals(QueryableMethods.Contains))
{
var typeArgument = methodCallExpression.Method.GetGenericArguments()[0];
var anyMethod = QueryableMethods.AnyWithPredicate.MakeGenericMethod(typeArgument);
Expand All @@ -63,7 +96,67 @@ protected override Expression VisitMethodCall(MethodCallExpression methodCallExp
return Expression.Call(null, anyMethod, new[] { methodCallExpression.Arguments[0], anyLambda });
}

return base.VisitMethodCall(methodCallExpression);
var visited = (MethodCallExpression)base.VisitMethodCall(methodCallExpression);

// In VB.NET, comparison operators between strings (equality, greater-than, less-than) yield
// calls to a VB-specific CompareString method. Normalize that to string.Compare.
if (visited.Method.Name == "CompareString"
&& visited.Method.DeclaringType?.Name == "Operators"
&& visited.Method.DeclaringType?.Namespace == "Microsoft.VisualBasic.CompilerServices"
&& visited.Object == null
&& visited.Arguments.Count == 3
&& visited.Arguments[2] is ConstantExpression textCompareConstantExpression)
{
return (bool)textCompareConstantExpression.Value
? Expression.Call(
_stringCompareWithComparisonMethod,
visited.Arguments[0],
visited.Arguments[1],
Expression.Constant(StringComparison.OrdinalIgnoreCase))
: Expression.Call(
_stringCompareWithoutComparisonMethod,
visited.Arguments[0],
visited.Arguments[1]);
}

return visited;
}

protected override Expression VisitUnary(UnaryExpression unaryExpression)
{
Check.NotNull(unaryExpression, nameof(unaryExpression));

if (unaryExpression.NodeType == ExpressionType.Not
&& unaryExpression.Operand is MethodCallExpression innerMethodCall
&& (_startsWithMethodInfo.Equals(innerMethodCall.Method)
|| _endsWithMethodInfo.Equals(innerMethodCall.Method)))
{
if (innerMethodCall.Arguments[0] is ConstantExpression constantArgument
&& (string)constantArgument.Value == string.Empty)
{
// every string starts/ends with empty string.
return Expression.Constant(false);
}

var newObject = Visit(innerMethodCall.Object);
var newArgument = Visit(innerMethodCall.Arguments[0]);

var result = Expression.AndAlso(
Expression.NotEqual(newObject, _constantNullString),
Expression.AndAlso(
Expression.NotEqual(newArgument, _constantNullString),
Expression.Not(innerMethodCall.Update(newObject, new[] { newArgument }))));

return newArgument is ConstantExpression
? result
: Expression.AndAlso(
Expression.NotEqual(
newArgument,
Expression.Constant(string.Empty)),
result);
}

return base.VisitUnary(unaryExpression);
}

private bool TryExtractEqualityOperands(Expression expression, out Expression left, out Expression right, out bool negated)
Expand Down
54 changes: 0 additions & 54 deletions src/EFCore/Query/Internal/VBToCSharpConvertingExpressionVisitor.cs

This file was deleted.

6 changes: 1 addition & 5 deletions src/EFCore/Query/QueryTranslationPreprocessor.cs
Original file line number Diff line number Diff line change
Expand Up @@ -30,16 +30,12 @@ public virtual Expression Process([NotNull] Expression query)
Check.NotNull(query, nameof(query));

query = new InvocationExpressionRemovingExpressionVisitor().Visit(query);

query = NormalizeQueryableMethodCall(query);

query = new VBToCSharpConvertingExpressionVisitor().Visit(query);
query = new AllAnyContainsRewritingExpressionVisitor().Visit(query);
query = new NullCheckRemovingExpressionVisitor().Visit(query);
query = new SubqueryMemberPushdownExpressionVisitor(QueryCompilationContext.Model).Visit(query);
query = new NavigationExpandingExpressionVisitor(this, QueryCompilationContext, Dependencies.EvaluatableExpressionFilter)
.Expand(query);
query = new FunctionPreprocessingExpressionVisitor().Visit(query);
query = new QueryOptimizingExpressionVisitor().Visit(query);

return query;
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2176,7 +2176,7 @@ public virtual Task Where_collection_navigation_ToArray_Count(bool async)
elementAsserter: (e, a) => AssertCollection(e, a));
}

[ConditionalTheory(Skip = "Issue#19433")]
[ConditionalTheory]
[MemberData(nameof(IsAsyncData))]
public virtual Task Where_collection_navigation_ToArray_Contains(bool async)
{
Expand All @@ -2185,7 +2185,7 @@ public virtual Task Where_collection_navigation_ToArray_Contains(bool async)
return AssertQuery(
async,
ss => ss.Set<Customer>()
.Select(c => c.Orders.ToArray())
.Select(c => c.Orders.AsEnumerable().ToArray())
.Where(e => e.Contains(order)),
entryCount: 5);
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2190,9 +2190,13 @@ public override async Task Contains_with_subquery_optional_navigation_and_consta
FROM [LevelOne] AS [l]
LEFT JOIN [LevelTwo] AS [l0] ON [l].[Id] = [l0].[Level1_Optional_Id]
WHERE EXISTS (
SELECT DISTINCT 1
FROM [LevelThree] AS [l1]
WHERE ([l0].[Id] IS NOT NULL AND ([l0].[Id] = [l1].[OneToMany_Optional_Inverse3Id])) AND ([l1].[Id] = 1))");
SELECT 1
FROM (
SELECT DISTINCT [l1].[Id], [l1].[Level2_Optional_Id], [l1].[Level2_Required_Id], [l1].[Name], [l1].[OneToMany_Optional_Inverse3Id], [l1].[OneToMany_Optional_Self_Inverse3Id], [l1].[OneToMany_Required_Inverse3Id], [l1].[OneToMany_Required_Self_Inverse3Id], [l1].[OneToOne_Optional_PK_Inverse3Id], [l1].[OneToOne_Optional_Self3Id]
FROM [LevelThree] AS [l1]
WHERE [l0].[Id] IS NOT NULL AND ([l0].[Id] = [l1].[OneToMany_Optional_Inverse3Id])
) AS [t]
WHERE [t].[Id] = 1)");
}

public override async Task Contains_with_subquery_optional_navigation_scalar_distinct_and_constant_item(bool async)
Expand Down
Loading