Skip to content

Commit

Permalink
Query: Convert Enumerable methods to queryable specified after ToList…
Browse files Browse the repository at this point in the history
…/ToArray/AsEnumerable

- Convert all enumerable methods to queryable unless specified on query parameter (coming from closure)
- Convert ICollection<T>.Count to Queryable.Count<T>()
- Simplify (Queryable/CollectionNav/OwnedCollectionNav).(ToList/ToArray/AsEnumerable).AsQueryable to underlying queryable
- Remove MaterializeCollectionNavigation when ToList/ToArray is called on collection navigation.
- Convert Array.Length to Queryable.Count<T>()

Resolves #19059
Resolves #19060
  • Loading branch information
smitpatel committed Jan 1, 2020
1 parent 447f48e commit 961fbc6
Show file tree
Hide file tree
Showing 15 changed files with 852 additions and 158 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -178,7 +178,8 @@ private static bool ClientSource(Expression expression)
=> expression is ConstantExpression
|| expression is MemberInitExpression
|| expression is NewExpression
|| expression is ParameterExpression;
|| (expression is ParameterExpression parameter
&& parameter.Name.StartsWith(CompiledQueryCache.CompiledQueryParameterPrefix, StringComparison.Ordinal));

private static bool CanConvertEnumerableToQueryable(Type enumerableType, Type queryableType)
{
Expand All @@ -198,19 +199,8 @@ private static bool CanConvertEnumerableToQueryable(Type enumerableType, Type qu
enumerableType = enumerableType.GetGenericTypeDefinition();
queryableType = queryableType.GetGenericTypeDefinition();

if (enumerableType == typeof(IEnumerable<>)
&& queryableType == typeof(IQueryable<>))
{
return true;
}

if (enumerableType == typeof(IOrderedEnumerable<>)
&& queryableType == typeof(IOrderedQueryable<>))
{
return true;
}

return false;
return enumerableType == typeof(IEnumerable<>) && queryableType == typeof(IQueryable<>)
|| enumerableType == typeof(IOrderedEnumerable<>) && queryableType == typeof(IOrderedQueryable<>);
}
}
}
130 changes: 88 additions & 42 deletions src/EFCore/Query/Internal/NavigationExpandingExpressionVisitor.cs
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
// Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information.

using System;
using System.Collections;
using System.Collections.Generic;
using System.Linq;
using System.Linq.Expressions;
Expand Down Expand Up @@ -150,24 +151,18 @@ protected override Expression VisitMember(MemberExpression memberExpression)

var innerExpression = Visit(memberExpression.Expression);

// Convert CollectionNavigation.Count to subquery.Count()
if (innerExpression is MaterializeCollectionNavigationExpression materializeCollectionNavigation
&& memberExpression.Member.Name == nameof(List<int>.Count))
// Convert ICollection<T>.Count to Count<T>()
if (memberExpression.Expression != null
&& memberExpression.Member.Name == nameof(ICollection<int>.Count)
&& memberExpression.Expression.Type.GetInterfaces().Append(memberExpression.Expression.Type)
.Any(e => e.IsGenericType && e.GetGenericTypeDefinition() == typeof(ICollection<>)))
{
var subquery = materializeCollectionNavigation.Subquery;
var elementType = subquery.Type.TryGetSequenceType();
if (subquery is OwnedNavigationReference ownedNavigationReference
&& ownedNavigationReference.Navigation.IsCollection())
{
subquery = Expression.Call(
QueryableMethods.AsQueryable.MakeGenericMethod(elementType),
subquery);
}
var innerQueryable = UnwrapCollectionMaterialization(innerExpression);

return Visit(
Expression.Call(
QueryableMethods.CountWithoutPredicate.MakeGenericMethod(elementType),
subquery));
QueryableMethods.CountWithoutPredicate.MakeGenericMethod(innerQueryable.Type.TryGetSequenceType()),
innerQueryable));
}

var updatedExpression = (Expression)memberExpression.Update(innerExpression);
Expand Down Expand Up @@ -493,30 +488,7 @@ when QueryableMethods.IsSumWithSelector(method):

if (genericMethod == QueryableMethods.AsQueryable)
{
if (firstArgument is MaterializeCollectionNavigationExpression materializeCollectionNavigationExpression)
{
var subquery = materializeCollectionNavigationExpression.Subquery;

return subquery is OwnedNavigationReference innerOwnedNavigationReference
&& innerOwnedNavigationReference.Navigation.IsCollection()
? Visit(
Expression.Call(
QueryableMethods.AsQueryable.MakeGenericMethod(subquery.Type.TryGetSequenceType()),
subquery))
: subquery;
}

if (firstArgument is OwnedNavigationReference ownedNavigationReference
&& ownedNavigationReference.Navigation.IsCollection())
{
var parameterName = GetParameterName("o");
var entityReference = ownedNavigationReference.EntityReference;
var currentTree = new NavigationTreeExpression(entityReference);

return new NavigationExpansionExpression(methodCallExpression, currentTree, currentTree, parameterName);
}

return firstArgument;
return UnwrapCollectionMaterialization(firstArgument);
}

if (firstArgument.Type.TryGetElementType(typeof(IQueryable<>)) == null)
Expand All @@ -531,8 +503,10 @@ when QueryableMethods.IsSumWithSelector(method):
throw new InvalidOperationException(CoreStrings.TranslationFailed(methodCallExpression.Print()));
}

// Remove MaterializeCollectionNavigationExpression when applying ToList/ToArray
if (method.IsGenericMethod
&& method.GetGenericMethodDefinition() == EnumerableMethods.ToList)
&& (method.GetGenericMethodDefinition() == EnumerableMethods.ToList
|| method.GetGenericMethodDefinition() == EnumerableMethods.ToArray))
{
var argument = Visit(methodCallExpression.Arguments[0]);
if (argument is MaterializeCollectionNavigationExpression materializeCollectionNavigationExpression)
Expand All @@ -543,6 +517,21 @@ when QueryableMethods.IsSumWithSelector(method):
return methodCallExpression.Update(null, new[] { argument });
}

// Remove MaterializeCollectionNavigationExpression when applying instance.ToArray for specific collection type
if (typeof(IEnumerable).IsAssignableFrom(method.DeclaringType)
&& string.Equals("ToArray", method.Name, StringComparison.Ordinal))
{
var argument = Visit(methodCallExpression.Object);
if (argument is MaterializeCollectionNavigationExpression materializeCollectionNavigationExpression)
{
argument = materializeCollectionNavigationExpression.Subquery;
}

return Expression.Call(
EnumerableMethods.ToArray.MakeGenericMethod(argument.Type.TryGetSequenceType()),
argument);
}

if (method.IsGenericMethod
&& method.Name == "FromSqlOnQueryable"
&& methodCallExpression.Arguments.Count == 3
Expand All @@ -564,6 +553,24 @@ when QueryableMethods.IsSumWithSelector(method):
return ProcessUnknownMethod(methodCallExpression);
}

protected override Expression VisitUnary(UnaryExpression unaryExpression)
{
var operand = Visit(unaryExpression.Operand);
// Convert Array.Length to Count()
if (unaryExpression.Operand.Type.IsArray
&& unaryExpression.NodeType == ExpressionType.ArrayLength)
{
var innerQueryable = UnwrapCollectionMaterialization(operand);
// Only if inner is queryable as array properties could also have Length access
if (innerQueryable.Type.TryGetElementType(typeof(IQueryable<>)) is Type elementType)
{
return Visit(Expression.Call(QueryableMethods.CountWithoutPredicate.MakeGenericMethod(elementType), innerQueryable));
}
}

return unaryExpression.Update(operand);
}

private Expression ProcessAllAnyCountLongCount(
NavigationExpansionExpression source, MethodInfo genericMethod, LambdaExpression predicate)
{
Expand Down Expand Up @@ -757,7 +764,7 @@ private NavigationExpansionExpression ProcessGroupBy(
source.Source,
Expression.Quote(keySelector),
Expression.Quote(elementSelector),
Expression.Quote(resultSelector));
Expression.Quote(Visit(resultSelector)));

var navigationTree = new NavigationTreeExpression(Expression.Default(result.Type.TryGetSequenceType()));
var parameterName = GetParameterName("e");
Expand Down Expand Up @@ -1278,7 +1285,7 @@ private MethodCallExpression ConvertToEnumerable(MethodInfo queryableMethod, IEn
var enumerableMethod = EnumerableMethods.GetMinWithSelector(resultType);

enumerableMethod = IsNumericType(resultType)
? enumerableMethod.MakeGenericMethod(resultType)
? enumerableMethod.MakeGenericMethod(genericTypeArguments[0])
: enumerableMethod.MakeGenericMethod(genericTypeArguments);

return Expression.Call(enumerableMethod, enumerableArguments);
Expand Down Expand Up @@ -1306,7 +1313,7 @@ private MethodCallExpression ConvertToEnumerable(MethodInfo queryableMethod, IEn
var enumerableMethod = EnumerableMethods.GetMaxWithSelector(resultType);

enumerableMethod = IsNumericType(resultType)
? enumerableMethod.MakeGenericMethod(resultType)
? enumerableMethod.MakeGenericMethod(genericTypeArguments[0])
: enumerableMethod.MakeGenericMethod(genericTypeArguments);

return Expression.Call(enumerableMethod, enumerableArguments);
Expand Down Expand Up @@ -1376,6 +1383,16 @@ private NavigationExpansionExpression CreateNavigationExpansionExpression(Expres
return new NavigationExpansionExpression(sourceExpression, currentTree, currentTree, parameterName);
}

private NavigationExpansionExpression CreateNavigationExpansionExpression(
Expression sourceExpression, OwnedNavigationReference ownedNavigationReference)
{
var parameterName = GetParameterName("o");
var entityReference = ownedNavigationReference.EntityReference;
var currentTree = new NavigationTreeExpression(entityReference);

return new NavigationExpansionExpression(sourceExpression, currentTree, currentTree, parameterName);
}

private Expression ExpandNavigationsForSource(NavigationExpansionExpression source, Expression expression)
{
expression = new ExpandingExpressionVisitor(this, source).Visit(expression);
Expand Down Expand Up @@ -1412,6 +1429,35 @@ private static IEnumerable<INavigation> FindNavigations(IEntityType entityType,
private LambdaExpression GenerateLambda(Expression body, ParameterExpression currentParameter)
=> Expression.Lambda(Reduce(body), currentParameter);

private Expression UnwrapCollectionMaterialization(Expression expression)
{
if (expression is MethodCallExpression innerMethodCall
&& innerMethodCall.Method.IsGenericMethod)
{
var innerGenericMethod = innerMethodCall.Method.GetGenericMethodDefinition();
if (innerGenericMethod == EnumerableMethods.AsEnumerable
|| innerGenericMethod == EnumerableMethods.ToList
|| innerGenericMethod == EnumerableMethods.ToArray)
{
expression = innerMethodCall.Arguments[0];
}
}

if (expression is MaterializeCollectionNavigationExpression materializeCollectionNavigationExpression)
{
expression = materializeCollectionNavigationExpression.Subquery;
}

return expression is OwnedNavigationReference ownedNavigationReference
&& ownedNavigationReference.Navigation.IsCollection()
? CreateNavigationExpansionExpression(
Expression.Call(
QueryableMethods.AsQueryable.MakeGenericMethod(ownedNavigationReference.Type.TryGetSequenceType()),
ownedNavigationReference),
ownedNavigationReference)
: expression;
}

private string GetParameterName(string prefix)
{
var uniqueName = prefix;
Expand Down
11 changes: 4 additions & 7 deletions src/EFCore/Query/QueryableMethodTranslatingExpressionVisitor.cs
Original file line number Diff line number Diff line change
Expand Up @@ -526,15 +526,12 @@ protected override Expression VisitExtension(Expression node)
{
Check.NotNull(node, nameof(node));

if (node is ProjectionBindingExpression projectionBindingExpression)
{
return new ProjectionBindingExpression(
return node is ProjectionBindingExpression projectionBindingExpression
? new ProjectionBindingExpression(
_queryExpression,
projectionBindingExpression.ProjectionMember.Prepend(_memberShift),
projectionBindingExpression.Type);
}

return base.VisitExtension(node);
projectionBindingExpression.Type)
: base.VisitExtension(node);
}
}

Expand Down
3 changes: 3 additions & 0 deletions src/Shared/EnumerableMethods.cs
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@ internal static class EnumerableMethods
public static MethodInfo Contains { get; }

public static MethodInfo ToList { get; }
public static MethodInfo ToArray { get; }

public static MethodInfo Concat { get; }
public static MethodInfo Except { get; }
Expand Down Expand Up @@ -153,6 +154,8 @@ static EnumerableMethods()

ToList = enumerableMethods.Single(
mi => mi.Name == nameof(Enumerable.ToList) && mi.GetParameters().Length == 1);
ToArray = enumerableMethods.Single(
mi => mi.Name == nameof(Enumerable.ToArray) && mi.GetParameters().Length == 1);

Concat = enumerableMethods.Single(
mi => mi.Name == nameof(Enumerable.Concat) && mi.GetParameters().Length == 2);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1082,6 +1082,12 @@ FROM root c
ORDER BY c[""EmployeeID""]");
}

[ConditionalTheory(Skip = "Issue#17246")]
public override Task Projection_AsEnumerable_projection(bool async)
{
return base.Projection_AsEnumerable_projection(async);
}

private void AssertSql(params string[] expected)
=> Fixture.TestSqlLoggerFactory.AssertBaseline(expected);

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1947,6 +1947,102 @@ public override async Task Using_same_parameter_twice_in_query_generates_one_sql
AssertSql(" ");
}

[ConditionalTheory(Skip = "Issue #17246")]
public override Task Where_Queryable_ToList_Count(bool async)
{
return base.Where_Queryable_ToList_Count(async);
}

[ConditionalTheory(Skip = "Issue #17246")]
public override Task Where_Queryable_ToList_Contains(bool async)
{
return base.Where_Queryable_ToList_Contains(async);
}

[ConditionalTheory(Skip = "Issue #17246")]
public override Task Where_Queryable_ToArray_Count(bool async)
{
return base.Where_Queryable_ToArray_Count(async);
}

[ConditionalTheory(Skip = "Issue #17246")]
public override Task Where_Queryable_ToArray_Contains(bool async)
{
return base.Where_Queryable_ToArray_Contains(async);
}

[ConditionalTheory(Skip = "Issue #17246")]
public override Task Where_Queryable_AsEnumerable_Count(bool async)
{
return base.Where_Queryable_AsEnumerable_Count(async);
}

[ConditionalTheory(Skip = "Issue #17246")]
public override Task Where_Queryable_AsEnumerable_Contains(bool async)
{
return base.Where_Queryable_AsEnumerable_Contains(async);
}

[ConditionalTheory(Skip = "Issue #17246")]
public override Task Where_Queryable_ToList_Count_member(bool async)
{
return base.Where_Queryable_ToList_Count_member(async);
}

[ConditionalTheory(Skip = "Issue #17246")]
public override Task Where_Queryable_ToArray_Length_member(bool async)
{
return base.Where_Queryable_ToArray_Length_member(async);
}

[ConditionalTheory(Skip = "Issue #17246")]
public override Task Where_collection_navigation_ToList_Count(bool async)
{
return base.Where_collection_navigation_ToList_Count(async);
}

[ConditionalTheory(Skip = "Issue #17246")]
public override Task Where_collection_navigation_ToList_Contains(bool async)
{
return base.Where_collection_navigation_ToList_Contains(async);
}

[ConditionalTheory(Skip = "Issue #17246")]
public override Task Where_collection_navigation_ToArray_Count(bool async)
{
return base.Where_collection_navigation_ToArray_Count(async);
}

[ConditionalTheory(Skip = "Issue #17246")]
public override Task Where_collection_navigation_ToArray_Contains(bool async)
{
return base.Where_collection_navigation_ToArray_Contains(async);
}

[ConditionalTheory(Skip = "Issue #17246")]
public override Task Where_collection_navigation_AsEnumerable_Count(bool async)
{
return base.Where_collection_navigation_AsEnumerable_Count(async);
}

[ConditionalTheory(Skip = "Issue #17246")]
public override Task Where_collection_navigation_AsEnumerable_Contains(bool async)
{
return base.Where_collection_navigation_AsEnumerable_Contains(async);
}

[ConditionalTheory(Skip = "Issue #17246")]
public override Task Where_collection_navigation_ToList_Count_member(bool async)
{
return base.Where_collection_navigation_ToList_Count_member(async);
}

[ConditionalTheory(Skip = "Issue #17246")]
public override Task Where_collection_navigation_ToArray_Length_member(bool async)
{
return base.Where_collection_navigation_ToArray_Length_member(async);
}

private void AssertSql(params string[] expected)
=> Fixture.TestSqlLoggerFactory.AssertBaseline(expected);

Expand Down
Loading

0 comments on commit 961fbc6

Please sign in to comment.