From bb7edb6e005fe497c004177f76bba7e03397b4d1 Mon Sep 17 00:00:00 2001 From: Paul Middleton Date: Fri, 20 Sep 2019 15:14:39 -0500 Subject: [PATCH] Add support for querable functions --- .../RelationalEntityTypeBuilderExtensions.cs | 31 + .../RelationalEntityTypeExtensions.cs | 3 + ...ntityFrameworkRelationalServicesBuilder.cs | 1 + .../RelationalModelValidator.cs | 12 +- src/EFCore.Relational/Metadata/IDbFunction.cs | 5 + .../Metadata/Internal/DbFunction.cs | 44 +- .../Metadata/RelationalAnnotationNames.cs | 5 + .../Properties/RelationalStrings.Designer.cs | 14 + .../Properties/RelationalStrings.resx | 6 + .../Query/ISqlExpressionFactory.cs | 1 + ...NullSemanticsRewritingExpressionVisitor.cs | 7 + ...nalNavigationExpandingExpressionVisitor.cs | 27 + ...gationExpandingExpressionVisitorFactory.cs | 14 + ...ionalProjectionBindingExpressionVisitor.cs | 11 +- ...onalQueryTranslationPreprocessorFactory.cs | 7 +- ...lityBasedSqlProcessingExpressionVisitor.cs | 3 + .../Query/QuerySqlGenerator.cs | 11 + .../RelationalEvaluatableExpressionFilter.cs | 18 +- .../RelationalQueryTranslationPreprocessor.cs | 5 +- ...yableMethodTranslatingExpressionVisitor.cs | 21 +- ...lationalSqlTranslatingExpressionVisitor.cs | 53 +- .../Query/SqlExpressionFactory.cs | 8 + .../Query/SqlExpressionVisitor.cs | 4 + .../QuerableSqlFunctionExpression.cs | 45 + .../Query/SqlExpressions/SelectExpression.cs | 15 + ...rchConditionConvertingExpressionVisitor.cs | 7 + src/EFCore/DbContext.cs | 17 + .../EntityFrameworkServicesBuilder.cs | 5 +- .../Query/EvaluatableExpressionFilter.cs | 2 + .../Query/IEvaluatableExpressionFilter.cs | 8 + ...gationExpandingExpressionVisitorFactory.cs | 14 + .../NavigationExpandingExpressionVisitor.cs | 4 +- ...gationExpandingExpressionVisitorFactory.cs | 18 + .../ParameterExtractingExpressionVisitor.cs | 30 +- .../QueryTranslationPreprocessorFactory.cs | 7 +- .../Query/QueryTranslationPreprocessor.cs | 9 +- src/EFCore/Query/ShapedQueryExpression.cs | 2 +- .../Query/UdfDbFunctionTestBase.cs | 960 +++++++++++++++++- .../Metadata/DbFunctionMetadataTests.cs | 60 ++ .../Internal/MigrationsModelDifferTest.cs | 9 + .../NorthwindMiscellaneousQueryTestBase.cs | 26 +- .../Query/UdfDbFunctionSqlServerTests.cs | 410 +++++++- 42 files changed, 1827 insertions(+), 132 deletions(-) create mode 100644 src/EFCore.Relational/Query/Internal/RelationalNavigationExpandingExpressionVisitor.cs create mode 100644 src/EFCore.Relational/Query/Internal/RelationalNavigationExpandingExpressionVisitorFactory.cs create mode 100644 src/EFCore.Relational/Query/SqlExpressions/QuerableSqlFunctionExpression.cs create mode 100644 src/EFCore/Query/INavigationExpandingExpressionVisitorFactory.cs create mode 100644 src/EFCore/Query/Internal/NavigationExpandingExpressionVisitorFactory.cs diff --git a/src/EFCore.Relational/Extensions/RelationalEntityTypeBuilderExtensions.cs b/src/EFCore.Relational/Extensions/RelationalEntityTypeBuilderExtensions.cs index c7524b9a479..d6223ac8c20 100644 --- a/src/EFCore.Relational/Extensions/RelationalEntityTypeBuilderExtensions.cs +++ b/src/EFCore.Relational/Extensions/RelationalEntityTypeBuilderExtensions.cs @@ -264,6 +264,37 @@ public static bool CanSetSchema( return entityTypeBuilder.CanSetAnnotation(RelationalAnnotationNames.Schema, schema, fromDataAnnotation); } + /// + /// Configures the entity as a result of a queryable function. Prevents a table from being created for this entity. + /// + /// The builder for the entity type being configured. + /// The same builder instance so that multiple calls can be chained. + public static EntityTypeBuilder ToQueryableFunctionResultType( + [NotNull] this EntityTypeBuilder entityTypeBuilder) + { + Check.NotNull(entityTypeBuilder, nameof(entityTypeBuilder)); + + entityTypeBuilder.Metadata.SetAnnotation(RelationalAnnotationNames.QueryableFunctionResultType, null); + + return entityTypeBuilder; + } + + /// + /// Configures the entity as a result of a queryable function. Prevents a table from being created for this entity. + /// + /// The builder for the entity type being configured. + /// The same builder instance so that multiple calls can be chained. + public static EntityTypeBuilder ToQueryableFunctionResultType( + [NotNull] this EntityTypeBuilder entityTypeBuilder) + where TEntity : class + { + Check.NotNull(entityTypeBuilder, nameof(entityTypeBuilder)); + + entityTypeBuilder.Metadata.SetAnnotation(RelationalAnnotationNames.QueryableFunctionResultType, null); + + return entityTypeBuilder; + } + /// /// Configures the view that the entity type maps to when targeting a relational database. /// diff --git a/src/EFCore.Relational/Extensions/RelationalEntityTypeExtensions.cs b/src/EFCore.Relational/Extensions/RelationalEntityTypeExtensions.cs index 45c32cf6d4c..a9300130df3 100644 --- a/src/EFCore.Relational/Extensions/RelationalEntityTypeExtensions.cs +++ b/src/EFCore.Relational/Extensions/RelationalEntityTypeExtensions.cs @@ -292,6 +292,9 @@ public static bool IsIgnoredByMigrations([NotNull] this IEntityType entityType) return true; } + if (entityType.FindAnnotation(RelationalAnnotationNames.QueryableFunctionResultType) != null) + return true; + var viewDefinition = entityType.FindAnnotation(RelationalAnnotationNames.ViewDefinition); if (viewDefinition == null) { diff --git a/src/EFCore.Relational/Infrastructure/EntityFrameworkRelationalServicesBuilder.cs b/src/EFCore.Relational/Infrastructure/EntityFrameworkRelationalServicesBuilder.cs index e70312a710f..132edc53ae8 100644 --- a/src/EFCore.Relational/Infrastructure/EntityFrameworkRelationalServicesBuilder.cs +++ b/src/EFCore.Relational/Infrastructure/EntityFrameworkRelationalServicesBuilder.cs @@ -169,6 +169,7 @@ public override EntityFrameworkServicesBuilder TryAddCoreServices() TryAdd(); TryAdd(); TryAdd(); + TryAdd(); ServiceCollectionMap.GetInfrastructure() .AddDependencySingleton() diff --git a/src/EFCore.Relational/Infrastructure/RelationalModelValidator.cs b/src/EFCore.Relational/Infrastructure/RelationalModelValidator.cs index d00c17a598b..964e33baffe 100644 --- a/src/EFCore.Relational/Infrastructure/RelationalModelValidator.cs +++ b/src/EFCore.Relational/Infrastructure/RelationalModelValidator.cs @@ -80,7 +80,17 @@ protected virtual void ValidateDbFunctions( RelationalStrings.DbFunctionNameEmpty(methodInfo.DisplayName())); } - if (dbFunction.TypeMapping == null) + if (dbFunction.IsIQueryable) + { + if(model.FindEntityType(dbFunction.MethodInfo.ReturnType.GetGenericArguments()[0]) == null) + { + throw new InvalidOperationException( + RelationalStrings.DbFunctionInvalidReturnType( + methodInfo.DisplayName(), + methodInfo.ReturnType.ShortDisplayName())); + } + } + else if (dbFunction.TypeMapping == null) { throw new InvalidOperationException( RelationalStrings.DbFunctionInvalidReturnType( diff --git a/src/EFCore.Relational/Metadata/IDbFunction.cs b/src/EFCore.Relational/Metadata/IDbFunction.cs index af3d76d90d2..8b129911289 100644 --- a/src/EFCore.Relational/Metadata/IDbFunction.cs +++ b/src/EFCore.Relational/Metadata/IDbFunction.cs @@ -34,6 +34,11 @@ public interface IDbFunction /// MethodInfo MethodInfo { get; } + /// + /// Whether this method returns IQueryable + /// + bool IsIQueryable { get; } + /// /// The configured store type string /// diff --git a/src/EFCore.Relational/Metadata/Internal/DbFunction.cs b/src/EFCore.Relational/Metadata/Internal/DbFunction.cs index 6818016b6df..cd0d3d6267d 100644 --- a/src/EFCore.Relational/Metadata/Internal/DbFunction.cs +++ b/src/EFCore.Relational/Metadata/Internal/DbFunction.cs @@ -4,6 +4,7 @@ using System; using System.Collections.Generic; using System.Linq; +using System.Linq.Expressions; using System.Reflection; using JetBrains.Annotations; using Microsoft.EntityFrameworkCore.Diagnostics; @@ -75,6 +76,19 @@ public DbFunction( methodInfo.DisplayName(), methodInfo.ReturnType.ShortDisplayName())); } + if (methodInfo.ReturnType.IsGenericType + && methodInfo.ReturnType.GetGenericTypeDefinition() == typeof(IQueryable<>)) + { + IsIQueryable = true; + + //todo - if the generic argument is not usuable as an entitytype should we throw here? IE IQueryable + //the built in entitytype will throw is the type is not a class + if (model.FindEntityType(methodInfo.ReturnType.GetGenericArguments()[0]) == null) + { + model.AddEntityType(methodInfo.ReturnType.GetGenericArguments()[0]).SetAnnotation(RelationalAnnotationNames.QueryableFunctionResultType, null); + } + } + MethodInfo = methodInfo; _model = model; @@ -310,6 +324,14 @@ public virtual Func, SqlExpression> Translati set => SetTranslation(value, ConfigurationSource.Explicit); } + /// + /// This is an internal API that supports the Entity Framework Core infrastructure and not subject to + /// the same compatibility standards as public APIs. It may be changed or removed without notice in + /// any release. You should only use it directly in your code with extreme caution and knowing that + /// doing so can result in application failures when updating to a new Entity Framework Core release. + /// + public virtual bool IsIQueryable { get; } + /// /// This is an internal API that supports the Entity Framework Core infrastructure and not subject to /// the same compatibility standards as public APIs. It may be changed or removed without notice in @@ -345,7 +367,27 @@ private void UpdateTranslationConfigurationSource(ConfigurationSource configurat public static DbFunction FindDbFunction( [NotNull] IModel model, [NotNull] MethodInfo methodInfo) - => model[BuildAnnotationName(methodInfo)] as DbFunction; + { + var dbFunction = model[BuildAnnotationName(methodInfo)] as DbFunction; + + if(dbFunction == null + && methodInfo.GetParameters().Any(p => p.ParameterType.IsGenericType && p.ParameterType.GetGenericTypeDefinition() == typeof(Expression<>))) + { + var parameters = methodInfo.GetParameters().Select(p => p.ParameterType.IsGenericType + && p.ParameterType.GetGenericTypeDefinition() == typeof(Expression<>) + && p.ParameterType.GetGenericArguments()[0].GetGenericTypeDefinition() == typeof(Func<>) + ? p.ParameterType.GetGenericArguments()[0].GetGenericArguments()[0] + : p.ParameterType).ToArray(); + + var nonExpressionMethod = methodInfo.DeclaringType.GetMethod(methodInfo.Name, parameters); + + dbFunction = nonExpressionMethod != null + ? model[BuildAnnotationName(nonExpressionMethod)] as DbFunction + : null; + } + + return dbFunction; + } /// /// This is an internal API that supports the Entity Framework Core infrastructure and not subject to diff --git a/src/EFCore.Relational/Metadata/RelationalAnnotationNames.cs b/src/EFCore.Relational/Metadata/RelationalAnnotationNames.cs index fafdb4f054c..4d712d842a4 100644 --- a/src/EFCore.Relational/Metadata/RelationalAnnotationNames.cs +++ b/src/EFCore.Relational/Metadata/RelationalAnnotationNames.cs @@ -98,5 +98,10 @@ public static class RelationalAnnotationNames /// The definition of a database view. /// public const string ViewDefinition = Prefix + "ViewDefinition"; + + /// + /// The definition of a Queryable Function Result Type. + /// + public const string QueryableFunctionResultType = Prefix + "QueryableFunctionResultType"; } } diff --git a/src/EFCore.Relational/Properties/RelationalStrings.Designer.cs b/src/EFCore.Relational/Properties/RelationalStrings.Designer.cs index 316be0dac26..18e29b83574 100644 --- a/src/EFCore.Relational/Properties/RelationalStrings.Designer.cs +++ b/src/EFCore.Relational/Properties/RelationalStrings.Designer.cs @@ -488,6 +488,20 @@ public static string DbFunctionInvalidInstanceType([CanBeNull] object function, GetString("DbFunctionInvalidInstanceType", nameof(function), nameof(type)), function, type); + /// + /// Queryable Db Functions used in projections cannot return IQueryable. IQueryable must be converted to a collection type such as List or Array. + /// + public static string DbFunctionCantProjectIQueryable() + => GetString("DbFunctionCantProjectIQueryable"); + + /// + /// Return type of a queryable function '{functionName}' which is used in a projected collection must define a primary key. + /// + public static string DbFunctionProjectedCollectionMustHavePK([CanBeNull] string functionName) + => string.Format( + GetString("DbFunctionProjectedCollectionMustHavePK", nameof(functionName)), + functionName); + /// /// An ambient transaction has been detected. The ambient transaction needs to be completed before beginning a transaction on this connection. /// diff --git a/src/EFCore.Relational/Properties/RelationalStrings.resx b/src/EFCore.Relational/Properties/RelationalStrings.resx index df2095d5dc5..dfe238ab682 100644 --- a/src/EFCore.Relational/Properties/RelationalStrings.resx +++ b/src/EFCore.Relational/Properties/RelationalStrings.resx @@ -435,6 +435,12 @@ The DbFunction '{function}' defined on type '{type}' must be either a static method or an instance method defined on a DbContext subclass. Instance methods on other types are not supported. + + Queryable Db Functions used in projections cannot return IQueryable. IQueryable must be converted to a collection type such as List or Array. + + + Return type of a queryable function '{functionName}' which is used in a projected collection must define a primary key. + An ambient transaction has been detected. The ambient transaction needs to be completed before beginning a transaction on this connection. diff --git a/src/EFCore.Relational/Query/ISqlExpressionFactory.cs b/src/EFCore.Relational/Query/ISqlExpressionFactory.cs index 955e4857bbe..2b3e19bbea1 100644 --- a/src/EFCore.Relational/Query/ISqlExpressionFactory.cs +++ b/src/EFCore.Relational/Query/ISqlExpressionFactory.cs @@ -141,5 +141,6 @@ SqlFunctionExpression Function( SelectExpression Select([CanBeNull] SqlExpression projection); SelectExpression Select([NotNull] IEntityType entityType); SelectExpression Select([NotNull] IEntityType entityType, [NotNull] string sql, [NotNull] Expression sqlArguments); + SelectExpression Select([NotNull] IEntityType entityType, [NotNull] SqlFunctionExpression expression); } } diff --git a/src/EFCore.Relational/Query/Internal/NullSemanticsRewritingExpressionVisitor.cs b/src/EFCore.Relational/Query/Internal/NullSemanticsRewritingExpressionVisitor.cs index 69bf5410321..30947b4e134 100644 --- a/src/EFCore.Relational/Query/Internal/NullSemanticsRewritingExpressionVisitor.cs +++ b/src/EFCore.Relational/Query/Internal/NullSemanticsRewritingExpressionVisitor.cs @@ -656,6 +656,13 @@ protected override Expression VisitSqlFunction(SqlFunctionExpression sqlFunction return sqlFunctionExpression.Update(newInstance, newArguments); } + protected override Expression VisitQueryableSqlFunctionExpression(QuerableSqlFunctionExpression queryableFunctionExpression) + { + Check.NotNull(queryableFunctionExpression, nameof(queryableFunctionExpression)); + + return queryableFunctionExpression; + } + protected override Expression VisitSqlParameter(SqlParameterExpression sqlParameterExpression) { Check.NotNull(sqlParameterExpression, nameof(sqlParameterExpression)); diff --git a/src/EFCore.Relational/Query/Internal/RelationalNavigationExpandingExpressionVisitor.cs b/src/EFCore.Relational/Query/Internal/RelationalNavigationExpandingExpressionVisitor.cs new file mode 100644 index 00000000000..25cd5833908 --- /dev/null +++ b/src/EFCore.Relational/Query/Internal/RelationalNavigationExpandingExpressionVisitor.cs @@ -0,0 +1,27 @@ +// Copyright (c) .NET Foundation. All rights reserved. +// Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information. + +using System.Diagnostics.CodeAnalysis; +using System.Linq.Expressions; + +namespace Microsoft.EntityFrameworkCore.Query.Internal +{ + public class RelationalNavigationExpandingExpressionVisitor : NavigationExpandingExpressionVisitor + { + public RelationalNavigationExpandingExpressionVisitor( + [NotNull] QueryCompilationContext queryCompilationContext, + [NotNull] IEvaluatableExpressionFilter evaluatableExpressionFilter) + : base(queryCompilationContext, evaluatableExpressionFilter) + { + } + + protected override Expression VisitMethodCall(MethodCallExpression methodCallExpression) + { + var dbFunction = QueryCompilationContext.Model.FindDbFunction(methodCallExpression.Method); + + return dbFunction?.IsIQueryable == true + ? CreateNavigationExpansionExpression(methodCallExpression, QueryCompilationContext.Model.FindEntityType(dbFunction.MethodInfo.ReturnType.GetGenericArguments()[0])) + : base.VisitMethodCall(methodCallExpression); + } + } +} diff --git a/src/EFCore.Relational/Query/Internal/RelationalNavigationExpandingExpressionVisitorFactory.cs b/src/EFCore.Relational/Query/Internal/RelationalNavigationExpandingExpressionVisitorFactory.cs new file mode 100644 index 00000000000..6f2ffd343e9 --- /dev/null +++ b/src/EFCore.Relational/Query/Internal/RelationalNavigationExpandingExpressionVisitorFactory.cs @@ -0,0 +1,14 @@ +// Copyright (c) .NET Foundation. All rights reserved. +// Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information. + +namespace Microsoft.EntityFrameworkCore.Query.Internal +{ + public class RelationalNavigationExpandingExpressionVisitorFactory : INavigationExpandingExpressionVisitorFactory + { + public virtual NavigationExpandingExpressionVisitor Create( + QueryCompilationContext queryCompilationContext, IEvaluatableExpressionFilter evaluatableExpressionFilter) + { + return new RelationalNavigationExpandingExpressionVisitor(queryCompilationContext, evaluatableExpressionFilter); + } + } +} diff --git a/src/EFCore.Relational/Query/Internal/RelationalProjectionBindingExpressionVisitor.cs b/src/EFCore.Relational/Query/Internal/RelationalProjectionBindingExpressionVisitor.cs index 694ea4513c1..560ff880095 100644 --- a/src/EFCore.Relational/Query/Internal/RelationalProjectionBindingExpressionVisitor.cs +++ b/src/EFCore.Relational/Query/Internal/RelationalProjectionBindingExpressionVisitor.cs @@ -9,6 +9,7 @@ using JetBrains.Annotations; using Microsoft.EntityFrameworkCore.Diagnostics; using Microsoft.EntityFrameworkCore.Infrastructure; +using Microsoft.EntityFrameworkCore.Metadata; using Microsoft.EntityFrameworkCore.Query.SqlExpressions; using Microsoft.EntityFrameworkCore.Storage; using Microsoft.EntityFrameworkCore.Utilities; @@ -22,6 +23,7 @@ public class RelationalProjectionBindingExpressionVisitor : ExpressionVisitor private SelectExpression _selectExpression; private bool _clientEval; + private readonly IModel _model; private readonly IDictionary _projectionMapping = new Dictionary(); @@ -30,10 +32,12 @@ private readonly IDictionary _projectionMapping public RelationalProjectionBindingExpressionVisitor( [NotNull] RelationalQueryableMethodTranslatingExpressionVisitor queryableMethodTranslatingExpressionVisitor, - [NotNull] RelationalSqlTranslatingExpressionVisitor sqlTranslatingExpressionVisitor) + [NotNull] RelationalSqlTranslatingExpressionVisitor sqlTranslatingExpressionVisitor, + [NotNull] IModel model) { _queryableMethodTranslatingExpressionVisitor = queryableMethodTranslatingExpressionVisitor; _sqlTranslator = sqlTranslatingExpressionVisitor; + _model = model; } public virtual Expression Translate([NotNull] SelectExpression selectExpression, [NotNull] Expression expression) @@ -242,6 +246,11 @@ protected override Expression VisitNew(NewExpression newExpression) return null; } + if (newExpression.Arguments.Any(arg => arg is MethodCallExpression methodCallExp && _model.FindDbFunction(methodCallExp.Method)?.IsIQueryable == true)) + { + throw new InvalidOperationException(RelationalStrings.DbFunctionCantProjectIQueryable()); + } + var newArguments = new Expression[newExpression.Arguments.Count]; for (var i = 0; i < newArguments.Length; i++) { diff --git a/src/EFCore.Relational/Query/Internal/RelationalQueryTranslationPreprocessorFactory.cs b/src/EFCore.Relational/Query/Internal/RelationalQueryTranslationPreprocessorFactory.cs index 8ba74997331..a6e026532ab 100644 --- a/src/EFCore.Relational/Query/Internal/RelationalQueryTranslationPreprocessorFactory.cs +++ b/src/EFCore.Relational/Query/Internal/RelationalQueryTranslationPreprocessorFactory.cs @@ -21,20 +21,23 @@ public class RelationalQueryTranslationPreprocessorFactory : IQueryTranslationPr { private readonly QueryTranslationPreprocessorDependencies _dependencies; private readonly RelationalQueryTranslationPreprocessorDependencies _relationalDependencies; + private readonly INavigationExpandingExpressionVisitorFactory _navigationExpandingExpressionVisitorFactory; public RelationalQueryTranslationPreprocessorFactory( [NotNull] QueryTranslationPreprocessorDependencies dependencies, - [NotNull] RelationalQueryTranslationPreprocessorDependencies relationalDependencies) + [NotNull] RelationalQueryTranslationPreprocessorDependencies relationalDependencies, + [NotNull] INavigationExpandingExpressionVisitorFactory navigationExpandingExpressionVisitorFactory) { _dependencies = dependencies; _relationalDependencies = relationalDependencies; + _navigationExpandingExpressionVisitorFactory = navigationExpandingExpressionVisitorFactory; } public virtual QueryTranslationPreprocessor Create(QueryCompilationContext queryCompilationContext) { Check.NotNull(queryCompilationContext, nameof(queryCompilationContext)); - return new RelationalQueryTranslationPreprocessor(_dependencies, _relationalDependencies, queryCompilationContext); + return new RelationalQueryTranslationPreprocessor(_dependencies, _relationalDependencies, queryCompilationContext, _navigationExpandingExpressionVisitorFactory); } } } diff --git a/src/EFCore.Relational/Query/NullabilityBasedSqlProcessingExpressionVisitor.cs b/src/EFCore.Relational/Query/NullabilityBasedSqlProcessingExpressionVisitor.cs index 0c5e3a5aa42..6d5bec985d1 100644 --- a/src/EFCore.Relational/Query/NullabilityBasedSqlProcessingExpressionVisitor.cs +++ b/src/EFCore.Relational/Query/NullabilityBasedSqlProcessingExpressionVisitor.cs @@ -371,6 +371,9 @@ protected override Expression VisitProjection(ProjectionExpression projectionExp VisitInternal(projectionExpression.Expression).ResultExpression); } + protected override Expression VisitQueryableSqlFunctionExpression(QuerableSqlFunctionExpression queryableFunctionExpression) + => Check.NotNull(queryableFunctionExpression, nameof(queryableFunctionExpression)); + protected override Expression VisitRowNumber(RowNumberExpression rowNumberExpression) { Check.NotNull(rowNumberExpression, nameof(rowNumberExpression)); diff --git a/src/EFCore.Relational/Query/QuerySqlGenerator.cs b/src/EFCore.Relational/Query/QuerySqlGenerator.cs index 5e9f7314ffe..04e9013e0f3 100644 --- a/src/EFCore.Relational/Query/QuerySqlGenerator.cs +++ b/src/EFCore.Relational/Query/QuerySqlGenerator.cs @@ -270,6 +270,17 @@ protected override Expression VisitSqlFunction(SqlFunctionExpression sqlFunction return sqlFunctionExpression; } + protected override Expression VisitQueryableSqlFunctionExpression(QuerableSqlFunctionExpression queryableFunctionExpression) + { + Visit(queryableFunctionExpression.SqlFunctionExpression); + + _relationalCommandBuilder + .Append(AliasSeparator) + .Append(_sqlGenerationHelper.DelimitIdentifier(queryableFunctionExpression.Alias)); + + return queryableFunctionExpression; + } + protected override Expression VisitColumn(ColumnExpression columnExpression) { Check.NotNull(columnExpression, nameof(columnExpression)); diff --git a/src/EFCore.Relational/Query/RelationalEvaluatableExpressionFilter.cs b/src/EFCore.Relational/Query/RelationalEvaluatableExpressionFilter.cs index 786419d6ab5..2aa9b4f7b42 100644 --- a/src/EFCore.Relational/Query/RelationalEvaluatableExpressionFilter.cs +++ b/src/EFCore.Relational/Query/RelationalEvaluatableExpressionFilter.cs @@ -47,11 +47,11 @@ public RelationalEvaluatableExpressionFilter( /// protected virtual RelationalEvaluatableExpressionFilterDependencies RelationalDependencies { get; } - /// - /// Checks whether the given expression can be evaluated. - /// - /// The expression. - /// The model. + /// + /// Checks whether the given expression can be evaluated. + /// + /// The expression. + /// The model. /// True if the expression can be evaluated; false otherwise. public override bool IsEvaluatableExpression(Expression expression, IModel model) { @@ -59,12 +59,16 @@ public override bool IsEvaluatableExpression(Expression expression, IModel model Check.NotNull(model, nameof(model)); if (expression is MethodCallExpression methodCallExpression - && model.FindDbFunction(methodCallExpression.Method) != null) + && model.FindDbFunction(methodCallExpression.Method) is IDbFunction func) { - return false; + return func?.IsIQueryable ?? true; } return base.IsEvaluatableExpression(expression, model); } + + public override bool IsQueryableFunction(Expression expression, IModel model) => + expression is MethodCallExpression methodCallExpression + && model.FindDbFunction(methodCallExpression.Method)?.IsIQueryable == true; } } diff --git a/src/EFCore.Relational/Query/RelationalQueryTranslationPreprocessor.cs b/src/EFCore.Relational/Query/RelationalQueryTranslationPreprocessor.cs index 721660089da..f0f441f5fcf 100644 --- a/src/EFCore.Relational/Query/RelationalQueryTranslationPreprocessor.cs +++ b/src/EFCore.Relational/Query/RelationalQueryTranslationPreprocessor.cs @@ -11,8 +11,9 @@ public class RelationalQueryTranslationPreprocessor : QueryTranslationPreprocess public RelationalQueryTranslationPreprocessor( [NotNull] QueryTranslationPreprocessorDependencies dependencies, [NotNull] RelationalQueryTranslationPreprocessorDependencies relationalDependencies, - [NotNull] QueryCompilationContext queryCompilationContext) - : base(dependencies, queryCompilationContext) + [NotNull] QueryCompilationContext queryCompilationContext, + [NotNull] INavigationExpandingExpressionVisitorFactory navigationExpandingExpressionVisitorFactory) + : base(dependencies, queryCompilationContext, navigationExpandingExpressionVisitorFactory) { Check.NotNull(relationalDependencies, nameof(relationalDependencies)); diff --git a/src/EFCore.Relational/Query/RelationalQueryableMethodTranslatingExpressionVisitor.cs b/src/EFCore.Relational/Query/RelationalQueryableMethodTranslatingExpressionVisitor.cs index 04e2820db79..f6435f98779 100644 --- a/src/EFCore.Relational/Query/RelationalQueryableMethodTranslatingExpressionVisitor.cs +++ b/src/EFCore.Relational/Query/RelationalQueryableMethodTranslatingExpressionVisitor.cs @@ -42,7 +42,7 @@ public RelationalQueryableMethodTranslatingExpressionVisitor( var sqlExpressionFactory = relationalDependencies.SqlExpressionFactory; _sqlTranslator = relationalDependencies.RelationalSqlTranslatingExpressionVisitorFactory.Create(model, this); _weakEntityExpandingExpressionVisitor = new WeakEntityExpandingExpressionVisitor(_sqlTranslator, sqlExpressionFactory); - _projectionBindingExpressionVisitor = new RelationalProjectionBindingExpressionVisitor(this, _sqlTranslator); + _projectionBindingExpressionVisitor = new RelationalProjectionBindingExpressionVisitor(this, _sqlTranslator, model); _model = model; _sqlExpressionFactory = sqlExpressionFactory; _subquery = false; @@ -58,7 +58,7 @@ protected RelationalQueryableMethodTranslatingExpressionVisitor( _model = parentVisitor._model; _sqlTranslator = parentVisitor._sqlTranslator; _weakEntityExpandingExpressionVisitor = parentVisitor._weakEntityExpandingExpressionVisitor; - _projectionBindingExpressionVisitor = new RelationalProjectionBindingExpressionVisitor(this, _sqlTranslator); + _projectionBindingExpressionVisitor = new RelationalProjectionBindingExpressionVisitor(this, _sqlTranslator, _model); _sqlExpressionFactory = parentVisitor._sqlExpressionFactory; _subquery = true; } @@ -75,12 +75,29 @@ protected override Expression VisitMethodCall(MethodCallExpression methodCallExp return CreateShapedQueryExpression(queryable.ElementType, sql, methodCallExpression.Arguments[2]); } + var dbFunction = this._model.FindDbFunction(methodCallExpression.Method); + if (dbFunction != null && dbFunction.IsIQueryable) + { + return CreateShapedQueryExpression(methodCallExpression); + } + return base.VisitMethodCall(methodCallExpression); } protected override QueryableMethodTranslatingExpressionVisitor CreateSubqueryVisitor() => new RelationalQueryableMethodTranslatingExpressionVisitor(this); + protected virtual ShapedQueryExpression CreateShapedQueryExpression([NotNull] MethodCallExpression methodCallExpression) + { + var sqlFuncExpression = _sqlTranslator.TranslateMethodCall(methodCallExpression) as SqlFunctionExpression; + + var elementType = methodCallExpression.Method.ReturnType.GetGenericArguments()[0]; + var entityType =_model.FindEntityType(elementType); + var queryExpression = _sqlExpressionFactory.Select(entityType, sqlFuncExpression); + + return CreateShapedQueryExpression(entityType, queryExpression); + } + protected override ShapedQueryExpression CreateShapedQueryExpression(Type elementType) { Check.NotNull(elementType, nameof(elementType)); diff --git a/src/EFCore.Relational/Query/RelationalSqlTranslatingExpressionVisitor.cs b/src/EFCore.Relational/Query/RelationalSqlTranslatingExpressionVisitor.cs index 1bc21038d94..50a0d0391bf 100644 --- a/src/EFCore.Relational/Query/RelationalSqlTranslatingExpressionVisitor.cs +++ b/src/EFCore.Relational/Query/RelationalSqlTranslatingExpressionVisitor.cs @@ -9,6 +9,7 @@ using JetBrains.Annotations; using Microsoft.EntityFrameworkCore.Diagnostics; using Microsoft.EntityFrameworkCore.Infrastructure; +using Microsoft.EntityFrameworkCore.Internal; using Microsoft.EntityFrameworkCore.Metadata; using Microsoft.EntityFrameworkCore.Query.Internal; using Microsoft.EntityFrameworkCore.Query.SqlExpressions; @@ -183,6 +184,30 @@ public virtual SqlExpression TranslateSum([NotNull] Expression expression) "SUM", new[] { sqlExpression }, inputType, sqlExpression.TypeMapping); } + public virtual Expression TranslateMethodCall([NotNull] MethodCallExpression methodCallExpression) + { + Check.NotNull(methodCallExpression, nameof(methodCallExpression)); + + if (TranslationFailed(methodCallExpression.Object, Visit(methodCallExpression.Object), out var sqlObject)) + { + return null; + } + + var arguments = new SqlExpression[methodCallExpression.Arguments.Count]; + for (var i = 0; i < arguments.Length; i++) + { + var argument = methodCallExpression.Arguments[i]; + if (TranslationFailed(argument, Visit(argument), out var sqlArgument)) + { + return null; + } + + arguments[i] = sqlArgument; + } + + return Dependencies.MethodCallTranslatorProvider.Translate(_model, sqlObject, methodCallExpression.Method, arguments); + } + private sealed class SqlTypeMappingVerifyingExpressionVisitor : ExpressionVisitor { protected override Expression VisitExtension(Expression node) @@ -190,7 +215,9 @@ protected override Expression VisitExtension(Expression node) Check.NotNull(node, nameof(node)); if (node is SqlExpression sqlExpression - && !(node is SqlFragmentExpression)) + && !(node is SqlFragmentExpression) + && !(node is SqlFunctionExpression sqlFunctionExpression + && sqlFunctionExpression.Type.IsQueryableType())) { if (sqlExpression.TypeMapping == null) { @@ -417,24 +444,7 @@ static bool IsAggregateResultWithCustomShaper(MethodInfo method) } // MethodCall translators - if (TranslationFailed(methodCallExpression.Object, Visit(methodCallExpression.Object), out var sqlObject)) - { - return null; - } - - var arguments = new SqlExpression[methodCallExpression.Arguments.Count]; - for (var i = 0; i < arguments.Length; i++) - { - var argument = methodCallExpression.Arguments[i]; - if (TranslationFailed(argument, Visit(argument), out var sqlArgument)) - { - return null; - } - - arguments[i] = sqlArgument; - } - - return Dependencies.MethodCallTranslatorProvider.Translate(_model, sqlObject, methodCallExpression.Method, arguments); + return TranslateMethodCall(methodCallExpression); } private static Expression TryRemoveImplicitConvert(Expression expression) @@ -597,7 +607,7 @@ protected override Expression VisitLambda(Expression node) { Check.NotNull(node, nameof(node)); - return null; + return node.Body != null ? Visit(node.Body) : null; } protected override Expression VisitConstant(ConstantExpression constantExpression) @@ -693,6 +703,9 @@ protected override Expression VisitUnary(UnaryExpression unaryExpression) } break; + + case ExpressionType.Quote: + return operand; } return null; diff --git a/src/EFCore.Relational/Query/SqlExpressionFactory.cs b/src/EFCore.Relational/Query/SqlExpressionFactory.cs index 9703701c969..bc1cf9d9c5a 100644 --- a/src/EFCore.Relational/Query/SqlExpressionFactory.cs +++ b/src/EFCore.Relational/Query/SqlExpressionFactory.cs @@ -670,6 +670,14 @@ public virtual SelectExpression Select(IEntityType entityType, string sql, Expre return selectExpression; } + public virtual SelectExpression Select(IEntityType entityType, SqlFunctionExpression expression) + { + var selectExpression = new SelectExpression(entityType, expression); + AddConditions(selectExpression, entityType); + + return selectExpression; + } + private void AddConditions( SelectExpression selectExpression, IEntityType entityType, diff --git a/src/EFCore.Relational/Query/SqlExpressionVisitor.cs b/src/EFCore.Relational/Query/SqlExpressionVisitor.cs index 36b26110980..20226c66f15 100644 --- a/src/EFCore.Relational/Query/SqlExpressionVisitor.cs +++ b/src/EFCore.Relational/Query/SqlExpressionVisitor.cs @@ -65,6 +65,9 @@ protected override Expression VisitExtension(Expression extensionExpression) case ProjectionExpression projectionExpression: return VisitProjection(projectionExpression); + case QuerableSqlFunctionExpression queryableFunctionExpression: + return VisitQueryableSqlFunctionExpression(queryableFunctionExpression); + case RowNumberExpression rowNumberExpression: return VisitRowNumber(rowNumberExpression); @@ -117,6 +120,7 @@ protected override Expression VisitExtension(Expression extensionExpression) protected abstract Expression VisitOrdering([NotNull] OrderingExpression orderingExpression); protected abstract Expression VisitOuterApply([NotNull] OuterApplyExpression outerApplyExpression); protected abstract Expression VisitProjection([NotNull] ProjectionExpression projectionExpression); + protected abstract Expression VisitQueryableSqlFunctionExpression([NotNull] QuerableSqlFunctionExpression queryableFunctionExpression); protected abstract Expression VisitRowNumber([NotNull] RowNumberExpression rowNumberExpression); protected abstract Expression VisitScalarSubquery([NotNull] ScalarSubqueryExpression scalarSubqueryExpression); protected abstract Expression VisitSelect([NotNull] SelectExpression selectExpression); diff --git a/src/EFCore.Relational/Query/SqlExpressions/QuerableSqlFunctionExpression.cs b/src/EFCore.Relational/Query/SqlExpressions/QuerableSqlFunctionExpression.cs new file mode 100644 index 00000000000..5ae37ce65fe --- /dev/null +++ b/src/EFCore.Relational/Query/SqlExpressions/QuerableSqlFunctionExpression.cs @@ -0,0 +1,45 @@ +// Copyright (c) .NET Foundation. All rights reserved. +// Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information. + +using System; +using JetBrains.Annotations; +using Microsoft.EntityFrameworkCore.Utilities; + +namespace Microsoft.EntityFrameworkCore.Query.SqlExpressions +{ + /// + /// Represents a SQL Table Valued Fuction in the sql generation tree. + /// + public class QuerableSqlFunctionExpression : TableExpressionBase + { + public QuerableSqlFunctionExpression([NotNull] SqlFunctionExpression expression, [CanBeNull] string alias) + : base(alias) + { + Check.NotNull(expression, nameof(expression)); + + SqlFunctionExpression = expression; + } + + public virtual SqlFunctionExpression SqlFunctionExpression { get; } + + public override void Print(ExpressionPrinter expressionPrinter) + { + expressionPrinter.Append("("); + expressionPrinter.Visit(SqlFunctionExpression); + expressionPrinter.AppendLine() + .AppendLine($") AS {Alias}"); + } + + public override bool Equals(object obj) + => obj != null + && (ReferenceEquals(this, obj) + || obj is QuerableSqlFunctionExpression queryableExpression + && Equals(queryableExpression)); + + private bool Equals(QuerableSqlFunctionExpression queryableExpression) + => base.Equals(queryableExpression) + && SqlFunctionExpression.Equals(queryableExpression.SqlFunctionExpression); + + public override int GetHashCode() => HashCode.Combine(base.GetHashCode(), SqlFunctionExpression); + } +} diff --git a/src/EFCore.Relational/Query/SqlExpressions/SelectExpression.cs b/src/EFCore.Relational/Query/SqlExpressions/SelectExpression.cs index 70497e09b13..ff5f6addc1a 100644 --- a/src/EFCore.Relational/Query/SqlExpressions/SelectExpression.cs +++ b/src/EFCore.Relational/Query/SqlExpressions/SelectExpression.cs @@ -7,6 +7,7 @@ using System.Linq.Expressions; using System.Reflection; using JetBrains.Annotations; +using Microsoft.EntityFrameworkCore.Diagnostics; using Microsoft.EntityFrameworkCore.Metadata; using Microsoft.EntityFrameworkCore.Metadata.Internal; using Microsoft.EntityFrameworkCore.Utilities; @@ -79,6 +80,13 @@ internal SelectExpression(IEntityType entityType, string sql, Expression argumen { } + internal SelectExpression(IEntityType entityType, SqlFunctionExpression expression) + : this( + entityType, new QuerableSqlFunctionExpression(expression, + entityType.GetTableName().ToLower().Substring(0, 1))) + { + } + private SelectExpression(IEntityType entityType, TableExpressionBase tableExpression) : base(null) { @@ -927,6 +935,13 @@ public Expression ApplyCollectionJoin( var parentIdentifier = GetIdentifierAccessor(_identifier); var outerIdentifier = GetIdentifierAccessor(_identifier.Concat(_childIdentifiers)); innerSelectExpression.ApplyProjection(); + + if (innerSelectExpression._identifier.Count == 0 && innerSelectExpression.Tables.FirstOrDefault( + t => t is QuerableSqlFunctionExpression expression && expression.SqlFunctionExpression.Arguments.Count != 0) is QuerableSqlFunctionExpression queryableFunctionExpression) + { + throw new InvalidOperationException(RelationalStrings.DbFunctionProjectedCollectionMustHavePK(queryableFunctionExpression.SqlFunctionExpression.Name)); + } + var selfIdentifier = innerSelectExpression.GetIdentifierAccessor(innerSelectExpression._identifier); if (collectionIndex == 0) diff --git a/src/EFCore.SqlServer/Query/Internal/SearchConditionConvertingExpressionVisitor.cs b/src/EFCore.SqlServer/Query/Internal/SearchConditionConvertingExpressionVisitor.cs index a90cf2b6d58..8feaf62b939 100644 --- a/src/EFCore.SqlServer/Query/Internal/SearchConditionConvertingExpressionVisitor.cs +++ b/src/EFCore.SqlServer/Query/Internal/SearchConditionConvertingExpressionVisitor.cs @@ -341,6 +341,13 @@ protected override Expression VisitSqlFunction(SqlFunctionExpression sqlFunction return ApplyConversion(newFunction, condition); } + protected override Expression VisitQueryableSqlFunctionExpression(QuerableSqlFunctionExpression queryableFunctionExpression) + { + Check.NotNull(queryableFunctionExpression, nameof(queryableFunctionExpression)); + + return queryableFunctionExpression; + } + protected override Expression VisitSqlParameter(SqlParameterExpression sqlParameterExpression) { Check.NotNull(sqlParameterExpression, nameof(sqlParameterExpression)); diff --git a/src/EFCore/DbContext.cs b/src/EFCore/DbContext.cs index 722ec554f45..5ec49ed8aed 100644 --- a/src/EFCore/DbContext.cs +++ b/src/EFCore/DbContext.cs @@ -6,6 +6,7 @@ using System.ComponentModel; using System.Diagnostics; using System.Linq; +using System.Linq.Expressions; using System.Threading; using System.Threading.Tasks; using JetBrains.Annotations; @@ -1656,6 +1657,22 @@ public virtual ValueTask FindAsync([CanBeNull] object[] keyVal /// IServiceProvider IInfrastructure.Instance => InternalServiceProvider; + /// + /// Creates a query expression, which represents a function call, against the query store. + /// + /// The result type of the query expression + /// The query expression to create. + /// An IQueryable representing the query. + protected virtual IQueryable CreateQuery([NotNull] Expression>> expression) + { + //should we add this method as an extension in relational? That would require making DbContextDependencies public. + //Is there a 3rd way? + + Check.NotNull(expression, nameof(expression)); + + return DbContextDependencies.QueryProvider.CreateQuery(expression.Body); + } + #region Hidden System.Object members /// diff --git a/src/EFCore/Infrastructure/EntityFrameworkServicesBuilder.cs b/src/EFCore/Infrastructure/EntityFrameworkServicesBuilder.cs index 7f4db6e8fb9..107ad57a62c 100644 --- a/src/EFCore/Infrastructure/EntityFrameworkServicesBuilder.cs +++ b/src/EFCore/Infrastructure/EntityFrameworkServicesBuilder.cs @@ -140,7 +140,9 @@ public static readonly IDictionary CoreServices }, { typeof(ISingletonOptions), new ServiceCharacteristics(ServiceLifetime.Singleton, multipleRegistrations: true) }, { typeof(IConventionSetPlugin), new ServiceCharacteristics(ServiceLifetime.Scoped, multipleRegistrations: true) }, - { typeof(IResettableService), new ServiceCharacteristics(ServiceLifetime.Scoped, multipleRegistrations: true) } + { typeof(IResettableService), new ServiceCharacteristics(ServiceLifetime.Scoped, multipleRegistrations: true) }, + { typeof(INavigationExpandingExpressionVisitorFactory), new ServiceCharacteristics(ServiceLifetime.Singleton) }, + }; /// @@ -265,6 +267,7 @@ public virtual EntityFrameworkServicesBuilder TryAddCoreServices() TryAdd(); TryAdd(); TryAdd(); + TryAdd(); TryAdd(p => p.GetService()?.FindExtension()?.DbContextLogger ?? new NullDbContextLogger()); diff --git a/src/EFCore/Query/EvaluatableExpressionFilter.cs b/src/EFCore/Query/EvaluatableExpressionFilter.cs index 395610dbdcb..51a2d66f860 100644 --- a/src/EFCore/Query/EvaluatableExpressionFilter.cs +++ b/src/EFCore/Query/EvaluatableExpressionFilter.cs @@ -118,5 +118,7 @@ public virtual bool IsEvaluatableExpression(Expression expression, IModel model) return true; } + + public virtual bool IsQueryableFunction(Expression expression, IModel model) => false; } } diff --git a/src/EFCore/Query/IEvaluatableExpressionFilter.cs b/src/EFCore/Query/IEvaluatableExpressionFilter.cs index d461a93e17e..e39a4146e05 100644 --- a/src/EFCore/Query/IEvaluatableExpressionFilter.cs +++ b/src/EFCore/Query/IEvaluatableExpressionFilter.cs @@ -27,5 +27,13 @@ public interface IEvaluatableExpressionFilter /// The model. /// True if the expression can be evaluated; false otherwise. bool IsEvaluatableExpression([NotNull] Expression expression, [NotNull] IModel model); + + /// + /// Checks whether the given expression is a queryable function + /// + /// The expression. + /// The model. + /// True if the expression is a queryable function + bool IsQueryableFunction([NotNull] Expression expression, [NotNull] IModel model); } } diff --git a/src/EFCore/Query/INavigationExpandingExpressionVisitorFactory.cs b/src/EFCore/Query/INavigationExpandingExpressionVisitorFactory.cs new file mode 100644 index 00000000000..ce22ad612bd --- /dev/null +++ b/src/EFCore/Query/INavigationExpandingExpressionVisitorFactory.cs @@ -0,0 +1,14 @@ +// Copyright (c) .NET Foundation. All rights reserved. +// Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information. + +using System.Diagnostics.CodeAnalysis; +using Microsoft.EntityFrameworkCore.Query.Internal; + +namespace Microsoft.EntityFrameworkCore.Query +{ + public interface INavigationExpandingExpressionVisitorFactory + { + NavigationExpandingExpressionVisitor Create([NotNull] QueryCompilationContext queryCompilationContext, + [NotNull] IEvaluatableExpressionFilter evaluatableExpressionFilter); + } +} diff --git a/src/EFCore/Query/Internal/NavigationExpandingExpressionVisitor.cs b/src/EFCore/Query/Internal/NavigationExpandingExpressionVisitor.cs index f7a34f72d00..4fd77759ee5 100644 --- a/src/EFCore/Query/Internal/NavigationExpandingExpressionVisitor.cs +++ b/src/EFCore/Query/Internal/NavigationExpandingExpressionVisitor.cs @@ -48,6 +48,8 @@ private readonly Dictionary _parameterizedQueryFi private readonly Parameters _parameters = new Parameters(); + protected virtual QueryCompilationContext QueryCompilationContext => _queryCompilationContext; + public NavigationExpandingExpressionVisitor( [NotNull] QueryCompilationContext queryCompilationContext, [NotNull] IEvaluatableExpressionFilter evaluatableExpressionFilter) @@ -1356,7 +1358,7 @@ static bool IsNumericType(Type type) } } - private NavigationExpansionExpression CreateNavigationExpansionExpression(Expression sourceExpression, IEntityType entityType) + protected virtual NavigationExpansionExpression CreateNavigationExpansionExpression([NotNull] Expression sourceExpression, [NotNull] IEntityType entityType) { var entityReference = new EntityReference(entityType); PopulateEagerLoadedNavigations(entityReference.IncludePaths); diff --git a/src/EFCore/Query/Internal/NavigationExpandingExpressionVisitorFactory.cs b/src/EFCore/Query/Internal/NavigationExpandingExpressionVisitorFactory.cs new file mode 100644 index 00000000000..39774cf1d1b --- /dev/null +++ b/src/EFCore/Query/Internal/NavigationExpandingExpressionVisitorFactory.cs @@ -0,0 +1,18 @@ +// Copyright (c) .NET Foundation. All rights reserved. +// Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information. + +using System; +using System.Collections.Generic; +using System.Text; + +namespace Microsoft.EntityFrameworkCore.Query.Internal +{ + public class NavigationExpandingExpressionVisitorFactory : INavigationExpandingExpressionVisitorFactory + { + public virtual NavigationExpandingExpressionVisitor Create( + QueryCompilationContext queryCompilationContext, IEvaluatableExpressionFilter evaluatableExpressionFilter) + { + return new NavigationExpandingExpressionVisitor(queryCompilationContext, evaluatableExpressionFilter); + } + } +} diff --git a/src/EFCore/Query/Internal/ParameterExtractingExpressionVisitor.cs b/src/EFCore/Query/Internal/ParameterExtractingExpressionVisitor.cs index 2b2ecf1c48c..582489c1a2c 100644 --- a/src/EFCore/Query/Internal/ParameterExtractingExpressionVisitor.cs +++ b/src/EFCore/Query/Internal/ParameterExtractingExpressionVisitor.cs @@ -3,15 +3,14 @@ using System; using System.Collections.Generic; +using System.Diagnostics.CodeAnalysis; using System.Linq; using System.Linq.Expressions; using System.Reflection; using System.Runtime.CompilerServices; -using JetBrains.Annotations; using Microsoft.EntityFrameworkCore.Diagnostics; using Microsoft.EntityFrameworkCore.Metadata; using Microsoft.EntityFrameworkCore.Storage; -using Microsoft.EntityFrameworkCore.Utilities; namespace Microsoft.EntityFrameworkCore.Query.Internal { @@ -55,12 +54,10 @@ public ParameterExtractingExpressionVisitor( { _evaluatableExpressionFindingExpressionVisitor = new EvaluatableExpressionFindingExpressionVisitor(evaluatableExpressionFilter, model); - _parameterValues = parameterValues; _logger = logger; _parameterize = parameterize; _generateContextAccessors = generateContextAccessors; - if (_generateContextAccessors) { _contextParameterReplacingExpressionVisitor = new ContextParameterReplacingExpressionVisitor(contextType); @@ -153,8 +150,6 @@ private bool PreserveConvertNode(Expression expression) /// protected override Expression VisitConditional(ConditionalExpression conditionalExpression) { - Check.NotNull(conditionalExpression, nameof(conditionalExpression)); - var newTestExpression = TryGetConstantValue(conditionalExpression.Test) ?? Visit(conditionalExpression.Test); if (newTestExpression is ConstantExpression constantTestExpression @@ -179,8 +174,6 @@ protected override Expression VisitConditional(ConditionalExpression conditional /// protected override Expression VisitBinary(BinaryExpression binaryExpression) { - Check.NotNull(binaryExpression, nameof(binaryExpression)); - switch (binaryExpression.NodeType) { case ExpressionType.Coalesce: @@ -251,8 +244,6 @@ private static bool ShortCircuitLogicalExpression(Expression expression, Express /// protected override Expression VisitConstant(ConstantExpression constantExpression) { - Check.NotNull(constantExpression, nameof(constantExpression)); - if (constantExpression.Value is IDetachableContext detachableContext) { var queryProvider = ((IQueryable)constantExpression.Value).Provider; @@ -522,8 +513,6 @@ public override Expression Visit(Expression expression) protected override Expression VisitLambda(Expression lambdaExpression) { - Check.NotNull(lambdaExpression, nameof(lambdaExpression)); - var oldInLambda = _inLambda; _inLambda = true; @@ -537,8 +526,6 @@ protected override Expression VisitLambda(Expression lambdaExpression) protected override Expression VisitMemberInit(MemberInitExpression memberInitExpression) { - Check.NotNull(memberInitExpression, nameof(memberInitExpression)); - Visit(memberInitExpression.Bindings, VisitMemberBinding); // Cannot make parameter for NewExpression if Bindings cannot be evaluated @@ -552,8 +539,6 @@ protected override Expression VisitMemberInit(MemberInitExpression memberInitExp protected override Expression VisitListInit(ListInitExpression listInitExpression) { - Check.NotNull(listInitExpression, nameof(listInitExpression)); - Visit(listInitExpression.Initializers, VisitElementInit); // Cannot make parameter for NewExpression if Initializers cannot be evaluated @@ -567,7 +552,11 @@ protected override Expression VisitListInit(ListInitExpression listInitExpressio protected override Expression VisitMethodCall(MethodCallExpression methodCallExpression) { - Check.NotNull(methodCallExpression, nameof(methodCallExpression)); + if (_evaluatableExpressionFilter.IsQueryableFunction(methodCallExpression, _model) + && !_inLambda) + { + _evaluatable = false; + } Visit(methodCallExpression.Object); var parameterInfos = methodCallExpression.Method.GetParameters(); @@ -608,18 +597,13 @@ protected override Expression VisitMethodCall(MethodCallExpression methodCallExp protected override Expression VisitMember(MemberExpression memberExpression) { - Check.NotNull(memberExpression, nameof(memberExpression)); - _containsClosure = memberExpression.Expression != null || !(memberExpression.Member is FieldInfo fieldInfo && fieldInfo.IsInitOnly); - return base.VisitMember(memberExpression); } protected override Expression VisitParameter(ParameterExpression parameterExpression) { - Check.NotNull(parameterExpression, nameof(parameterExpression)); - _evaluatable = _allowedParameters.Contains(parameterExpression); return base.VisitParameter(parameterExpression); @@ -627,8 +611,6 @@ protected override Expression VisitParameter(ParameterExpression parameterExpres protected override Expression VisitConstant(ConstantExpression constantExpression) { - Check.NotNull(constantExpression, nameof(constantExpression)); - _evaluatable = !(constantExpression.Value is IDetachableContext) && !(constantExpression.Value is IQueryable); diff --git a/src/EFCore/Query/Internal/QueryTranslationPreprocessorFactory.cs b/src/EFCore/Query/Internal/QueryTranslationPreprocessorFactory.cs index 0ec15437e55..34ff2067b4a 100644 --- a/src/EFCore/Query/Internal/QueryTranslationPreprocessorFactory.cs +++ b/src/EFCore/Query/Internal/QueryTranslationPreprocessorFactory.cs @@ -20,17 +20,20 @@ namespace Microsoft.EntityFrameworkCore.Query.Internal public class QueryTranslationPreprocessorFactory : IQueryTranslationPreprocessorFactory { private readonly QueryTranslationPreprocessorDependencies _dependencies; + private readonly INavigationExpandingExpressionVisitorFactory _navigationExpandingExpressionVisitorFactory; - public QueryTranslationPreprocessorFactory([NotNull] QueryTranslationPreprocessorDependencies dependencies) + public QueryTranslationPreprocessorFactory([NotNull] QueryTranslationPreprocessorDependencies dependencies, + [NotNull] INavigationExpandingExpressionVisitorFactory navigationExpandingExpressionVisitorFactory) { _dependencies = dependencies; + _navigationExpandingExpressionVisitorFactory = navigationExpandingExpressionVisitorFactory; } public virtual QueryTranslationPreprocessor Create(QueryCompilationContext queryCompilationContext) { Check.NotNull(queryCompilationContext, nameof(queryCompilationContext)); - return new QueryTranslationPreprocessor(_dependencies, queryCompilationContext); + return new QueryTranslationPreprocessor(_dependencies, queryCompilationContext, _navigationExpandingExpressionVisitorFactory); } } } diff --git a/src/EFCore/Query/QueryTranslationPreprocessor.cs b/src/EFCore/Query/QueryTranslationPreprocessor.cs index fc284dc4d24..f827da997f9 100644 --- a/src/EFCore/Query/QueryTranslationPreprocessor.cs +++ b/src/EFCore/Query/QueryTranslationPreprocessor.cs @@ -11,16 +11,19 @@ namespace Microsoft.EntityFrameworkCore.Query public class QueryTranslationPreprocessor { private readonly QueryCompilationContext _queryCompilationContext; + private readonly INavigationExpandingExpressionVisitorFactory _navigationExpandingExpressionVisitorFactory; public QueryTranslationPreprocessor( [NotNull] QueryTranslationPreprocessorDependencies dependencies, - [NotNull] QueryCompilationContext queryCompilationContext) + [NotNull] QueryCompilationContext queryCompilationContext, + [NotNull] INavigationExpandingExpressionVisitorFactory navigationExpandingExpressionVisitorFactory) { Check.NotNull(dependencies, nameof(dependencies)); Check.NotNull(queryCompilationContext, nameof(queryCompilationContext)); Dependencies = dependencies; _queryCompilationContext = queryCompilationContext; + _navigationExpandingExpressionVisitorFactory = navigationExpandingExpressionVisitorFactory; } protected virtual QueryTranslationPreprocessorDependencies Dependencies { get; } @@ -37,8 +40,8 @@ public virtual Expression Process([NotNull] Expression query) query = new NullCheckRemovingExpressionVisitor().Visit(query); query = new EntityEqualityRewritingExpressionVisitor(_queryCompilationContext).Rewrite(query); query = new SubqueryMemberPushdownExpressionVisitor(_queryCompilationContext.Model).Visit(query); - query = new NavigationExpandingExpressionVisitor(_queryCompilationContext, Dependencies.EvaluatableExpressionFilter).Expand( - query); + query = _navigationExpandingExpressionVisitorFactory.Create(_queryCompilationContext, Dependencies.EvaluatableExpressionFilter) + .Expand(query); query = new FunctionPreprocessingExpressionVisitor().Visit(query); return query; diff --git a/src/EFCore/Query/ShapedQueryExpression.cs b/src/EFCore/Query/ShapedQueryExpression.cs index c851a881f48..ccca2d2a120 100644 --- a/src/EFCore/Query/ShapedQueryExpression.cs +++ b/src/EFCore/Query/ShapedQueryExpression.cs @@ -33,7 +33,7 @@ private ShapedQueryExpression( public virtual Expression ShaperExpression { get; } public override Type Type => ResultCardinality == ResultCardinality.Enumerable - ? typeof(IQueryable<>).MakeGenericType(ShaperExpression.Type) + ? typeof(IQueryable<>).MakeGenericType(ShaperExpression.Type) : ShaperExpression.Type; public sealed override ExpressionType NodeType => ExpressionType.Extension; diff --git a/test/EFCore.Relational.Specification.Tests/Query/UdfDbFunctionTestBase.cs b/test/EFCore.Relational.Specification.Tests/Query/UdfDbFunctionTestBase.cs index 4e089347708..d0d17d5dad7 100644 --- a/test/EFCore.Relational.Specification.Tests/Query/UdfDbFunctionTestBase.cs +++ b/test/EFCore.Relational.Specification.Tests/Query/UdfDbFunctionTestBase.cs @@ -4,6 +4,8 @@ using System; using System.Collections.Generic; using System.Linq; +using System.Linq.Expressions; +using System.Threading.Tasks; using Microsoft.EntityFrameworkCore.Diagnostics; using Microsoft.EntityFrameworkCore.Query.SqlExpressions; using Microsoft.EntityFrameworkCore.TestUtilities; @@ -28,14 +30,46 @@ public class Customer public string FirstName { get; set; } public string LastName { get; set; } public List Orders { get; set; } + public List
Addresses { get; set; } } public class Order { public int Id { get; set; } public string Name { get; set; } - public int ItemCount { get; set; } public DateTime OrderDate { get; set; } + + public int CustomerId { get; set; } + + public Customer Customer { get; set; } + public List Items { get; set; } + } + + public class LineItem + { + public int Id { get; set; } + public int OrderId { get; set; } + public int ProductId { get; set; } + public int Quantity { get; set; } + + public Order Order { get; set; } + public Product Product { get; set; } + } + + public class Product + { + public int Id { get; set; } + public string Name { get; set; } + } + + public class Address + { + public int Id { get; set; } + public string Street { get; set; } + public string City { get; set; } + public string State { get; set; } + + public int CustomerId { get; set; } public Customer Customer { get; set; } } @@ -45,6 +79,8 @@ protected class UDFSqlContext : PoolableDbContext public DbSet Customers { get; set; } public DbSet Orders { get; set; } + public DbSet Products { get; set; } + public DbSet
Addresses { get; set; } #endregion @@ -103,6 +139,77 @@ public enum ReportingPeriod [DbFunction(Schema = "dbo")] public static string IdentityString(string s) => throw new Exception(); + public int AddValues(int a, int b) + { + throw new NotImplementedException(); + } + + public int AddValues(Expression> a, int b) + { + throw new NotImplementedException(); + } + + #region Querable Functions + + public class OrderByYear + { + public int? CustomerId { get; set; } + public int? Count { get; set; } + public int? Year { get; set; } + } + + public class MultProductOrders + { + //public Order Order { get; set; } + public int OrderId { get; set; } + + public Customer Customer { get; set; } + public int CustomerId { get; set; } + + public DateTime OrderDate { get; set; } + } + + public IQueryable GetCustomerOrderCountByYear(int customerId) + { + return CreateQuery(() => GetCustomerOrderCountByYear(customerId)); + } + + public IQueryable GetCustomerOrderCountByYear(Expression> customerId2) + { + return CreateQuery(() => GetCustomerOrderCountByYear(customerId2)); + } + + public class TopSellingProduct + { + public Product Product { get; set; } + public int? ProductId { get; set; } + + public int? AmountSold { get; set; } + } + + public IQueryable GetTopTwoSellingProducts() + { + return CreateQuery(() => GetTopTwoSellingProducts()); + } + + public IQueryable GetTopSellingProductsForCustomer(int customerId) + { + return CreateQuery(() => GetTopSellingProductsForCustomer(customerId)); + } + + public IQueryable GetTopTwoSellingProductsCustomTranslation() + { + return CreateQuery(() => GetTopTwoSellingProductsCustomTranslation()); + } + + public IQueryable GetOrdersWithMultipleProducts(int customerId) + { + return CreateQuery(() => GetOrdersWithMultipleProducts(customerId)); + } + + + #endregion + #endregion public UDFSqlContext(DbContextOptions options) @@ -133,6 +240,8 @@ protected override void OnModelCreating(ModelBuilder modelBuilder) modelBuilder.HasDbFunction(methodInfo) .HasTranslation(args => SqlFunctionExpression.Create("len", args, methodInfo.ReturnType, null)); + modelBuilder.HasDbFunction(typeof(UDFSqlContext).GetMethod(nameof(AddValues), new[] { typeof(int), typeof(int) })); + //Instance modelBuilder.HasDbFunction(typeof(UDFSqlContext).GetMethod(nameof(CustomerOrderCountInstance))) .HasName("CustomerOrderCount"); @@ -154,6 +263,23 @@ protected override void OnModelCreating(ModelBuilder modelBuilder) modelBuilder.HasDbFunction(methodInfo2) .HasTranslation(args => SqlFunctionExpression.Create("len", args, methodInfo2.ReturnType, null)); + + modelBuilder.Entity().ToQueryableFunctionResultType().HasKey(mpo => mpo.OrderId); + //modelBuilder.Entity().HasOne(mpo => mpo.Order).WithOne().HasForeignKey(o => o.Id); + + modelBuilder.Entity().ToQueryableFunctionResultType().HasNoKey(); + modelBuilder.Entity().ToQueryableFunctionResultType().HasNoKey(); + + //Table + modelBuilder.HasDbFunction(typeof(UDFSqlContext).GetMethod(nameof(GetCustomerOrderCountByYear), new[] { typeof(int) })); + modelBuilder.HasDbFunction(typeof(UDFSqlContext).GetMethod(nameof(GetCustomerOrderCountByYear), new[] { typeof(Expression>) })); + modelBuilder.HasDbFunction(typeof(UDFSqlContext).GetMethod(nameof(GetTopTwoSellingProducts))); + modelBuilder.HasDbFunction(typeof(UDFSqlContext).GetMethod(nameof(GetTopSellingProductsForCustomer))); + + modelBuilder.HasDbFunction(typeof(UDFSqlContext).GetMethod(nameof(GetOrdersWithMultipleProducts))); + + modelBuilder.HasDbFunction(typeof(UDFSqlContext).GetMethod(nameof(GetTopTwoSellingProductsCustomTranslation))) + .HasTranslation(args => SqlFunctionExpression.Create("dbo", "GetTopTwoSellingProducts", args, typeof(TopSellingProduct), null)); } } @@ -170,68 +296,105 @@ protected override void Seed(DbContext context) { context.Database.EnsureCreatedResiliently(); + var product1 = new Product { Name = "Product1" }; + var product2 = new Product { Name = "Product2" }; + var product3 = new Product { Name = "Product3" }; + var product4 = new Product { Name = "Product4" }; + var product5 = new Product { Name = "Product5" }; + var order11 = new Order { - Name = "Order11", - ItemCount = 4, - OrderDate = new DateTime(2000, 1, 20) + Name = "Order11", OrderDate = new DateTime(2000, 1, 20), + Items = new List + { + new LineItem { Quantity = 5, Product = product1}, + new LineItem { Quantity = 15, Product = product3} + } }; - var order12 = new Order - { - Name = "Order12", - ItemCount = 8, - OrderDate = new DateTime(2000, 2, 21) + + var order12 = new Order { Name = "Order12", OrderDate = new DateTime(2000, 2, 21), + Items = new List + { + new LineItem { Quantity = 1, Product = product1}, + new LineItem { Quantity = 6, Product = product2}, + new LineItem { Quantity = 200, Product = product3} + } }; - var order13 = new Order - { - Name = "Order13", - ItemCount = 15, - OrderDate = new DateTime(2000, 3, 20) + + var order13 = new Order { Name = "Order13", OrderDate = new DateTime(2001, 3, 20), + Items = new List + { + new LineItem { Quantity = 50, Product = product4}, + } }; - var order21 = new Order - { - Name = "Order21", - ItemCount = 16, - OrderDate = new DateTime(2000, 4, 21) + + var order21 = new Order { Name = "Order21", OrderDate = new DateTime(2000, 4, 21), + Items = new List + { + new LineItem { Quantity = 1, Product = product1}, + new LineItem { Quantity = 34, Product = product4}, + new LineItem { Quantity = 100, Product = product5} + } }; - var order22 = new Order - { - Name = "Order22", - ItemCount = 23, - OrderDate = new DateTime(2000, 5, 20) + + var order22 = new Order { Name = "Order22", OrderDate = new DateTime(2000, 5, 20), + Items = new List + { + new LineItem { Quantity = 34, Product = product3}, + new LineItem { Quantity = 100, Product = product4} + } }; - var order31 = new Order - { - Name = "Order31", - ItemCount = 42, - OrderDate = new DateTime(2000, 6, 21) + + var order31 = new Order { Name = "Order31", OrderDate = new DateTime(2001, 6, 21), + Items = new List + { + new LineItem { Quantity = 5, Product = product5} + } }; + var address11 = new Address { Street = "1600 Pennsylvania Avenue", City = "Washington", State = "DC" }; + var address12 = new Address { Street = "742 Evergreen Terrace", City = "SpringField", State = "" }; + var address21 = new Address { Street = "Apartment 5A, 129 West 81st Street", City = "New York", State = "NY" }; + var address31 = new Address { Street = "425 Grove Street, Apt 20", City = "New York", State = "NY" }; + var address32 = new Address { Street = "342 GravelPit Terrace", City = "BedRock", State = "" }; + var address41 = new Address { Street = "4222 Clinton Way", City = "Los Angles", State = "CA" }; + var address42 = new Address { Street = "1060 West Addison Street", City = "Chicago", State = "IL" }; + var address43 = new Address { Street = "112 ½ Beacon Street", City = "Boston", State = "MA" }; + var customer1 = new Customer { FirstName = "Customer", LastName = "One", - Orders = new List - { - order11, - order12, - order13 - } + Orders = new List { order11, order12, order13 }, + Addresses = new List
{ address11, address12 } }; + var customer2 = new Customer { FirstName = "Customer", LastName = "Two", - Orders = new List { order21, order22 } + Orders = new List { order21, order22 }, + Addresses = new List
{ address21 } }; + var customer3 = new Customer { FirstName = "Customer", LastName = "Three", - Orders = new List { order31 } + Orders = new List { order31 }, + Addresses = new List
{ address31, address32 } + }; + + var customer4 = new Customer + { + FirstName = "Customer", + LastName = "Four", + Addresses = new List
{ address41, address42, address43 } }; - ((UDFSqlContext)context).Customers.AddRange(customer1, customer2, customer3); + ((UDFSqlContext)context).Products.AddRange(product1, product2, product3, product4, product5); + ((UDFSqlContext)context).Addresses.AddRange(address11, address12, address21, address31, address32, address41, address42, address43); + ((UDFSqlContext)context).Customers.AddRange(customer1, customer2, customer3, customer4); ((UDFSqlContext)context).Orders.AddRange(order11, order12, order13, order21, order22, order31); } } @@ -249,7 +412,7 @@ public virtual void Scalar_Function_Extension_Method_Static() var len = context.Customers.Count(c => UDFSqlContext.IsDateStatic(c.FirstName) == false); - Assert.Equal(3, len); + Assert.Equal(4, len); } [ConditionalFact] @@ -286,7 +449,7 @@ public virtual void Scalar_Function_Constant_Parameter_Static() var custs = context.Customers.Select(c => UDFSqlContext.CustomerOrderCountStatic(customerId)).ToList(); - Assert.Equal(3, custs.Count); + Assert.Equal(4, custs.Count); } [ConditionalFact] @@ -494,8 +657,8 @@ public virtual void Scalar_Nested_Function_Unwind_Client_Eval_Select_Static() orderby c.Id select UDFSqlContext.AddOneStatic(c.Id)).ToList(); - Assert.Equal(3, results.Count); - Assert.True(results.SequenceEqual(Enumerable.Range(2, 3))); + Assert.Equal(4, results.Count); + Assert.True(results.SequenceEqual(Enumerable.Range(2, 4))); } [ConditionalFact] @@ -679,7 +842,7 @@ public virtual void Scalar_Function_Extension_Method_Instance() var len = context.Customers.Count(c => context.IsDateInstance(c.FirstName) == false); - Assert.Equal(3, len); + Assert.Equal(4, len); } [ConditionalFact] @@ -714,7 +877,7 @@ public virtual void Scalar_Function_Constant_Parameter_Instance() var custs = context.Customers.Select(c => context.CustomerOrderCountInstance(customerId)).ToList(); - Assert.Equal(3, custs.Count); + Assert.Equal(4, custs.Count); } [ConditionalFact] @@ -921,8 +1084,8 @@ public virtual void Scalar_Nested_Function_Unwind_Client_Eval_Select_Instance() orderby c.Id select context.AddOneInstance(c.Id)).ToList(); - Assert.Equal(3, results.Count); - Assert.True(results.SequenceEqual(Enumerable.Range(2, 3))); + Assert.Equal(4, results.Count); + Assert.True(results.SequenceEqual(Enumerable.Range(2, 4))); } [ConditionalFact] @@ -1070,9 +1233,706 @@ public virtual void Scalar_Nested_Function_UDF_BCL_Instance() #endregion - private void AssertTranslationFailed(Action testCode) - => Assert.Contains( - CoreStrings.TranslationFailed("").Substring(21), - Assert.Throws(testCode).Message); + #region QueryableFunction + + [Fact] + public virtual void QF_Anonymous_Collection_No_PK_Throws() + { + using (var context = CreateContext()) + { + var query = from c in context.Customers + select new { c.Id, products = context.GetTopSellingProductsForCustomer(c.Id).ToList() }; + + Assert.Contains( + RelationalStrings.DbFunctionProjectedCollectionMustHavePK("GetTopSellingProductsForCustomer"), + Assert.Throws(() => query.ToList()).Message); + } + } + + [Fact] + public virtual void QF_Anonymous_Collection_No_IQueryable_In_Projection_Throws() + { + using (var context = CreateContext()) + { + var query = (from c in context.Customers + select new { c.Id, orders = context.GetCustomerOrderCountByYear(c.Id) }); + + Assert.Contains( + RelationalStrings.DbFunctionCantProjectIQueryable(), + Assert.Throws(() => query.ToList()).Message); + } + } + + [Fact] + public virtual void QF_Stand_Alone() + { + using (var context = CreateContext()) + { + var products = (from t in context.GetTopTwoSellingProducts() + orderby t.ProductId + select t).ToList(); + + Assert.Equal(2, products.Count); + Assert.Equal(3, products[0].ProductId); + Assert.Equal(249, products[0].AmountSold); + Assert.Equal(4, products[1].ProductId); + Assert.Equal(184, products[1].AmountSold); + } + } + + [Fact] + public virtual void QF_Stand_Alone_With_Translation() + { + using (var context = CreateContext()) + { + var products = (from t in context.GetTopTwoSellingProductsCustomTranslation() + orderby t.ProductId + select t).ToList(); + + Assert.Equal(2, products.Count); + Assert.Equal(3, products[0].ProductId); + Assert.Equal(249, products[0].AmountSold); + Assert.Equal(4, products[1].ProductId); + Assert.Equal(184, products[1].AmountSold); + } + } + + [Fact] + public virtual void QF_Stand_Alone_Parameter() + { + using (var context = CreateContext()) + { + var orders = (from c in context.GetCustomerOrderCountByYear(1) + orderby c.Count descending + select c).ToList(); + + Assert.Equal(2, orders.Count); + Assert.Equal(2, orders[0].Count); + Assert.Equal(2000, orders[0].Year); + Assert.Equal(1, orders[1].Count); + Assert.Equal(2001, orders[1].Year); + } + } + + [Fact] + public virtual void QF_Stand_Alone_Nested() + { + using (var context = CreateContext()) + { + var orders = (from r in context.GetCustomerOrderCountByYear(() => context.AddValues(-2, 3)) + orderby r.Count descending + select r).ToList(); + + Assert.Equal(2, orders.Count); + Assert.Equal(2, orders[0].Count); + Assert.Equal(2000, orders[0].Year); + Assert.Equal(1, orders[1].Count); + Assert.Equal(2001, orders[1].Year); + } + } + + [Fact] + public virtual void QF_CrossApply_Correlated_Select_Anonymous() + { + using (var context = CreateContext()) + { + var orders = (from c in context.Customers + from r in context.GetCustomerOrderCountByYear(c.Id) + orderby c.Id, r.Year + select new + { + c.Id, + c.LastName, + r.Year, + r.Count + }).ToList(); + + Assert.Equal(4, orders.Count); + Assert.Equal(2, orders[0].Count); + Assert.Equal(1, orders[1].Count); + Assert.Equal(2, orders[2].Count); + Assert.Equal(1, orders[3].Count); + Assert.Equal(2000, orders[0].Year); + Assert.Equal(2001, orders[1].Year); + Assert.Equal(2000, orders[2].Year); + Assert.Equal(2001, orders[3].Year); + Assert.Equal(1, orders[0].Id); + Assert.Equal(1, orders[1].Id); + Assert.Equal(2, orders[2].Id); + Assert.Equal(3, orders[3].Id); + } + } + + [Fact] + public virtual void QF_Select_Direct_In_Anonymous() + { + using (var context = CreateContext()) + { + var results = (from c in context.Customers + select new + { + c.Id, + Prods = context.GetTopTwoSellingProducts().ToList(), + }).ToList(); + + Assert.Equal(4, results.Count); + Assert.Equal(2, results[0].Prods.Count); + Assert.Equal(2, results[1].Prods.Count); + Assert.Equal(2, results[2].Prods.Count); + Assert.Equal(2, results[3].Prods.Count); + } + } + + [Fact] + public virtual void QF_Select_Correlated_Direct_With_Function_Query_Parameter_Correlated_In_Anonymous() + { + using (var context = CreateContext()) + { + var cust = (from c in context.Customers + where c.Id == 1 + select new + { + c.Id, + Orders = context.GetOrdersWithMultipleProducts(context.AddValues(c.Id, 1)).ToList() + }).ToList(); + + Assert.Single(cust); + + Assert.Equal(1, cust[0].Id); + Assert.Equal(4, cust[0].Orders[0].OrderId); + Assert.Equal(5, cust[0].Orders[1].OrderId); + Assert.Equal(new DateTime(2000, 4, 21), cust[0].Orders[0].OrderDate); + Assert.Equal(new DateTime(2000, 5, 20), cust[0].Orders[1].OrderDate); + } + } + + [Fact] + public virtual void QF_Select_Correlated_Subquery_In_Anonymous() + { + using (var context = CreateContext()) + { + var results = (from c in context.Customers + select new + { + c.Id, + OrderCountYear = context.GetOrdersWithMultipleProducts(c.Id).Where(o => o.OrderDate.Day == 21).ToList() + }).ToList(); + + Assert.Equal(4, results.Count); + Assert.Equal(1, results[0].Id); + Assert.Equal(2, results[1].Id); + Assert.Equal(3, results[2].Id); + Assert.Equal(4, results[3].Id); + Assert.Single(results[0].OrderCountYear); + Assert.Single(results[1].OrderCountYear); + Assert.Empty(results[2].OrderCountYear); + Assert.Empty(results[3].OrderCountYear); + } + } + + [Fact] + public virtual void QF_Select_Correlated_Subquery_In_Anonymous_Nested_With_QF() + { + using (var context = CreateContext()) + { + var results = (from o in context.Orders + join osub in (from c in context.Customers + from a in context.GetOrdersWithMultipleProducts(c.Id) + select a.OrderId + ) on o.Id equals osub + select new { o.CustomerId, o.OrderDate }).ToList(); + + Assert.Equal(4, results.Count); + + Assert.Equal(1, results[0].CustomerId); + Assert.Equal(new DateTime(2000, 1, 20), results[0].OrderDate); + + Assert.Equal(1, results[1].CustomerId); + Assert.Equal(new DateTime(2000, 2, 21), results[1].OrderDate); + + Assert.Equal(2, results[2].CustomerId); + Assert.Equal(new DateTime(2000, 4, 21), results[2].OrderDate); + + Assert.Equal(2, results[3].CustomerId); + Assert.Equal(new DateTime(2000, 5, 20), results[3].OrderDate); + } + } + + [Fact] + public virtual void QF_Select_Correlated_Subquery_In_Anonymous_Nested() + { + using (var context = CreateContext()) + { + var results = (from c in context.Customers + select new + { + c.Id, + OrderCountYear = context.GetOrdersWithMultipleProducts(c.Id).Where(o => o.OrderDate.Day == 21).Select(o => new + { + OrderCountYearNested = context.GetOrdersWithMultipleProducts(o.CustomerId).ToList(), + Prods = context.GetTopTwoSellingProducts().ToList(), + }).ToList() + }).ToList(); + + Assert.Equal(4, results.Count); + + Assert.Single(results[0].OrderCountYear); + Assert.Equal(2, results[0].OrderCountYear[0].Prods.Count); + Assert.Equal(2, results[0].OrderCountYear[0].OrderCountYearNested.Count); + + Assert.Single(results[1].OrderCountYear); + Assert.Equal(2, results[1].OrderCountYear[0].Prods.Count); + Assert.Equal(2, results[1].OrderCountYear[0].OrderCountYearNested.Count); + + Assert.Empty(results[2].OrderCountYear); + + Assert.Empty(results[3].OrderCountYear); + } + } + + [Fact] + public virtual void QF_Select_Correlated_Subquery_In_Anonymous_MultipleCollections() + { + using (var context = CreateContext()) + { + var results = (from c in context.Customers + select new + { + c.Id, + Prods = context.GetTopTwoSellingProducts().Where(p => p.AmountSold == 249).Select(p => p.ProductId).ToList(), + Addresses = c.Addresses.Where(a => a.State == "NY").ToList() + }).ToList(); + + Assert.Equal(4, results.Count); + Assert.Equal(3, results[0].Prods[0]); + Assert.Equal(3, results[1].Prods[0]); + Assert.Equal(3, results[2].Prods[0]); + Assert.Equal(3, results[3].Prods[0]); + + Assert.Empty(results[0].Addresses); + Assert.Equal("Apartment 5A, 129 West 81st Street", results[1].Addresses[0].Street); + Assert.Equal("425 Grove Street, Apt 20", results[2].Addresses[0].Street); + Assert.Empty(results[3].Addresses); + } + } + + [Fact] + public virtual void QF_Select_NonCorrelated_Subquery_In_Anonymous() + { + using (var context = CreateContext()) + { + var results = (from c in context.Customers + select new + { + c.Id, + Prods = context.GetTopTwoSellingProducts().Where(p => p.AmountSold == 249).Select(p => p.ProductId).ToList(), + }).ToList(); + + Assert.Equal(4, results.Count); + Assert.Equal(3, results[0].Prods[0]); + Assert.Equal(3, results[1].Prods[0]); + Assert.Equal(3, results[2].Prods[0]); + Assert.Equal(3, results[3].Prods[0]); + } + } + + [Fact] + public virtual void QF_Select_NonCorrelated_Subquery_In_Anonymous_Parameter() + { + using (var context = CreateContext()) + { + var amount = 27; + + var results = (from c in context.Customers + select new + { + c.Id, + Prods = context.GetTopTwoSellingProducts().Where(p => p.AmountSold == amount).Select(p => p.ProductId).ToList(), + }).ToList(); + + Assert.Equal(4, results.Count); + Assert.Single(results[0].Prods); + Assert.Single(results[1].Prods); + Assert.Single(results[2].Prods); + Assert.Single(results[3].Prods); + } + } + + [Fact] + public virtual void QF_Correlated_Select_In_Anonymous() + { + using (var context = CreateContext()) + { + var cust = (from c in context.Customers + orderby c.Id + select new + { + c.Id, + c.LastName, + Orders = context.GetOrdersWithMultipleProducts(c.Id).ToList() + }).ToList(); + + Assert.Equal(4, cust.Count); + + Assert.Equal(1, cust[0].Id); + Assert.Equal(2, cust[0].Orders.Count); + Assert.Equal(1, cust[0].Orders[0].OrderId); + Assert.Equal(2, cust[0].Orders[1].OrderId); + Assert.Equal(new DateTime(2000, 1, 20), cust[0].Orders[0].OrderDate); + Assert.Equal(new DateTime(2000, 2, 21), cust[0].Orders[1].OrderDate); + + Assert.Equal(2, cust[1].Id); + Assert.Equal(2, cust[1].Orders.Count); + Assert.Equal(4, cust[1].Orders[0].OrderId); + Assert.Equal(5, cust[1].Orders[1].OrderId); + Assert.Equal(new DateTime(2000, 4, 21), cust[1].Orders[0].OrderDate); + Assert.Equal(new DateTime(2000, 5, 20), cust[1].Orders[1].OrderDate); + + Assert.Equal(3, cust[2].Id); + Assert.Empty(cust[2].Orders); + + Assert.Equal(4, cust[3].Id); + Assert.Empty(cust[3].Orders); + } + } + + [Fact] + public virtual void QF_CrossApply_Correlated_Select_Result() + { + using (var context = CreateContext()) + { + var orders = (from c in context.Customers + from r in context.GetCustomerOrderCountByYear(c.Id) + orderby r.Count descending, r.Year descending + select r).ToList(); + + Assert.Equal(4, orders.Count); + + Assert.Equal(4, orders.Count); + Assert.Equal(2, orders[0].Count); + Assert.Equal(2, orders[1].Count); + Assert.Equal(1, orders[2].Count); + Assert.Equal(1, orders[3].Count); + Assert.Equal(2000, orders[0].Year); + Assert.Equal(2000, orders[1].Year); + Assert.Equal(2001, orders[2].Year); + Assert.Equal(2001, orders[3].Year); + } + } + + [Fact] + public virtual void QF_CrossJoin_Not_Correlated() + { + using (var context = CreateContext()) + { + var orders = (from c in context.Customers + from r in context.GetCustomerOrderCountByYear(2) + where c.Id == 2 + orderby r.Count + select new + { + c.Id, + c.LastName, + r.Year, + r.Count + }).ToList(); + + Assert.Single(orders); + + Assert.Equal(2, orders[0].Count); + Assert.Equal(2000, orders[0].Year); + } + } + + [Fact] + public virtual void QF_CrossJoin_Parameter() + { + using (var context = CreateContext()) + { + var custId = 2; + + var orders = (from c in context.Customers + from r in context.GetCustomerOrderCountByYear(custId) + where c.Id == custId + orderby r.Count + select new + { + c.Id, + c.LastName, + r.Year, + r.Count + }).ToList(); + + Assert.Single(orders); + + Assert.Equal(2, orders[0].Count); + Assert.Equal(2000, orders[0].Year); + } + } + + [Fact] + public virtual void QF_Join() + { + using (var context = CreateContext()) + { + var products = (from p in context.Products + join r in context.GetTopTwoSellingProducts() on p.Id equals r.ProductId + select new + { + p.Id, + p.Name, + r.AmountSold + }).OrderBy(p => p.Id).ToList(); + + Assert.Equal(2, products.Count); + Assert.Equal(3, products[0].Id); + Assert.Equal("Product3", products[0].Name); + Assert.Equal(249, products[0].AmountSold); + Assert.Equal(4, products[1].Id); + Assert.Equal("Product4", products[1].Name); + Assert.Equal(184, products[1].AmountSold); + } + } + + [Fact] + public virtual void QF_LeftJoin_Select_Anonymous() + { + using (var context = CreateContext()) + { + var products = (from p in context.Products + join r in context.GetTopTwoSellingProducts() on p.Id equals r.ProductId into joinTable + from j in joinTable.DefaultIfEmpty() + orderby p.Id descending + select new + { + p.Id, + p.Name, + j.AmountSold + }).ToList(); + + Assert.Equal(5, products.Count); + Assert.Equal(5, products[0].Id); + Assert.Equal("Product5", products[0].Name); + Assert.Null(products[0].AmountSold); + + Assert.Equal(4, products[1].Id); + Assert.Equal("Product4", products[1].Name); + Assert.Equal(184, products[1].AmountSold); + + Assert.Equal(3, products[2].Id); + Assert.Equal("Product3", products[2].Name); + Assert.Equal(249, products[2].AmountSold); + + Assert.Equal(2, products[3].Id); + Assert.Equal("Product2", products[3].Name); + Assert.Null(products[3].AmountSold); + + Assert.Equal(1, products[4].Id); + Assert.Equal("Product1", products[4].Name); + Assert.Null(products[4].AmountSold); + } + } + + [Fact] + public virtual void QF_LeftJoin_Select_Result() + { + using (var context = CreateContext()) + { + var products = (from p in context.Products + join r in context.GetTopTwoSellingProducts() on p.Id equals r.ProductId into joinTable + from j in joinTable.DefaultIfEmpty() + orderby p.Id descending + select j).ToList(); + + Assert.Equal(5, products.Count); + Assert.Null(products[0]); + Assert.Equal(4, products[1].ProductId); + Assert.Equal(184, products[1].AmountSold); + Assert.Equal(3, products[2].ProductId); + Assert.Equal(249, products[2].AmountSold); + Assert.Null(products[3]); + Assert.Null(products[4]); + } + } + + [Fact] + public virtual void QF_OuterApply_Correlated_Select_TVF() + { + using (var context = CreateContext()) + { + var orders = (from c in context.Customers + from r in context.GetCustomerOrderCountByYear(c.Id).DefaultIfEmpty() + orderby c.Id, r.Year + select r).ToList(); + + Assert.Equal(5, orders.Count); + + Assert.Equal(2, orders[0].Count); + Assert.Equal(1, orders[1].Count); + Assert.Equal(2, orders[2].Count); + Assert.Equal(1, orders[3].Count); + Assert.Null(orders[4]); + Assert.Equal(2000, orders[0].Year); + Assert.Equal(2001, orders[1].Year); + Assert.Equal(2000, orders[2].Year); + Assert.Equal(2001, orders[3].Year); + Assert.Null(orders[4]); + Assert.Equal(1, orders[0].CustomerId); + Assert.Equal(1, orders[1].CustomerId); + Assert.Equal(2, orders[2].CustomerId); + Assert.Equal(3, orders[3].CustomerId); + Assert.Null(orders[4]); + } + } + + [Fact] + public virtual void QF_OuterApply_Correlated_Select_DbSet() + { + using (var context = CreateContext()) + { + var custs = (from c in context.Customers + from r in context.GetCustomerOrderCountByYear(c.Id).DefaultIfEmpty() + orderby c.Id, r.Year + select c).ToList(); + + Assert.Equal(5, custs.Count); + + Assert.Equal(1, custs[0].Id); + Assert.Equal(1, custs[1].Id); + Assert.Equal(2, custs[2].Id); + Assert.Equal(3, custs[3].Id); + Assert.Equal(4, custs[4].Id); + Assert.Equal("One", custs[0].LastName); + Assert.Equal("One", custs[1].LastName); + Assert.Equal("Two", custs[2].LastName); + Assert.Equal("Three", custs[3].LastName); + Assert.Equal("Four", custs[4].LastName); + } + } + + [Fact] + public virtual void QF_OuterApply_Correlated_Select_Anonymous() + { + using (var context = CreateContext()) + { + var orders = (from c in context.Customers + from r in context.GetCustomerOrderCountByYear(c.Id).DefaultIfEmpty() + orderby c.Id, r.Year + select new + { + c.Id, + c.LastName, + r.Year, + r.Count + }).ToList(); + + Assert.Equal(5, orders.Count); + + Assert.Equal(1, orders[0].Id); + Assert.Equal(1, orders[1].Id); + Assert.Equal(2, orders[2].Id); + Assert.Equal(3, orders[3].Id); + Assert.Equal(4, orders[4].Id); + Assert.Equal("One", orders[0].LastName); + Assert.Equal("One", orders[1].LastName); + Assert.Equal("Two", orders[2].LastName); + Assert.Equal("Three", orders[3].LastName); + Assert.Equal("Four", orders[4].LastName); + Assert.Equal(2, orders[0].Count); + Assert.Equal(1, orders[1].Count); + Assert.Equal(2, orders[2].Count); + Assert.Equal(1, orders[3].Count); + Assert.Null(orders[4].Count); + Assert.Equal(2000, orders[0].Year); + Assert.Equal(2001, orders[1].Year); + Assert.Equal(2000, orders[2].Year); + Assert.Equal(2001, orders[3].Year); + } + } + + [Fact] + public virtual void QF_Nested() + { + using (var context = CreateContext()) + { + var custId = 2; + + var orders = (from c in context.Customers + from r in context.GetCustomerOrderCountByYear(context.AddValues(1, 1)) + where c.Id == custId + orderby r.Year + select new + { + c.Id, + c.LastName, + r.Year, + r.Count + }).ToList(); + + Assert.Single(orders); + + Assert.Equal(2, orders[0].Count); + Assert.Equal(2000, orders[0].Year); + } + } + + + [Fact] + public virtual void QF_Correlated_Nested_Func_Call() + { + var custId = 2; + + using (var context = CreateContext()) + { + var orders = (from c in context.Customers + from r in context.GetCustomerOrderCountByYear(context.AddValues(c.Id, 1)) + where c.Id == custId + select new + { + c.Id, + r.Count, + r.Year + }).ToList(); + + Assert.Single(orders); + + Assert.Equal(1, orders[0].Count); + Assert.Equal(2001, orders[0].Year); + } + } + + [Fact] + public virtual void QF_Correlated_Func_Call_With_Navigation() + { + using (var context = CreateContext()) + { + var cust = (from c in context.Customers + orderby c.Id + select new + { + c.Id, + Orders = context.GetOrdersWithMultipleProducts(c.Id).Select(mpo => new + { + //how to I setup the PK/FK combo properly for this? Is it even possible? + //OrderName = mpo.Order.Name, + CustomerName = mpo.Customer.LastName + }).ToList() + }).ToList(); + + Assert.Equal(4, cust.Count); + Assert.Equal(2, cust[0].Orders.Count); + Assert.Equal("One", cust[0].Orders[0].CustomerName); + Assert.Equal(2, cust[1].Orders.Count); + Assert.Equal("Two", cust[1].Orders[0].CustomerName); + } + } + + #endregion + + + private void AssertTranslationFailed(Action testCode) + => Assert.Contains( + CoreStrings.TranslationFailed("").Substring(21), + Assert.Throws(testCode).Message); } } diff --git a/test/EFCore.Relational.Tests/Metadata/DbFunctionMetadataTests.cs b/test/EFCore.Relational.Tests/Metadata/DbFunctionMetadataTests.cs index ed44eec80c2..1ab59196ab5 100644 --- a/test/EFCore.Relational.Tests/Metadata/DbFunctionMetadataTests.cs +++ b/test/EFCore.Relational.Tests/Metadata/DbFunctionMetadataTests.cs @@ -2,6 +2,7 @@ // Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information. using System; +using System.Linq; using System.Linq.Expressions; using System.Reflection; using Microsoft.EntityFrameworkCore.Diagnostics; @@ -20,6 +21,12 @@ namespace Microsoft.EntityFrameworkCore.Metadata { public class DbFunctionMetadataTests { + public class Foo + { + public int I { get; set; } + public int J { get; set; } + } + public class MyNonDbContext { public int NonStatic() @@ -160,6 +167,21 @@ public static int DuplicateNameTest() [DbFunction] public override int VirtualBase() => throw new Exception(); + + [DbFunction] + public IQueryable QueryableNoParams() => throw new Exception(); + + [DbFunction] + public IQueryable QueryableSingleParam(int i) => throw new Exception(); + + public IQueryable QueryableSingleParam(Expression> i) => throw new Exception(); + + [DbFunction] + public IQueryable QueryableMultiParam(int i, double j) => throw new Exception(); + + public IQueryable QueryableMultiParam(Expression> i, double j) => throw new Exception(); + + public IQueryable QueryableMultiParam(Expression> i, Expression> j) => throw new Exception(); } public static MethodInfo MethodAmi = typeof(TestMethods).GetRuntimeMethod( @@ -173,6 +195,8 @@ public static int DuplicateNameTest() public static MethodInfo MethodHmi = typeof(TestMethods).GetTypeInfo().GetDeclaredMethod(nameof(TestMethods.MethodH)); + public static MethodInfo MethodJmi = typeof(TestMethods).GetTypeInfo().GetDeclaredMethod(nameof(TestMethods.MethodJ)); + public class TestMethods { public static int Foo => 1; @@ -211,6 +235,11 @@ public static int MethodI() { throw new Exception(); } + + public static IQueryable MethodJ() + { + throw new Exception(); + } } public static class OuterA @@ -640,6 +669,37 @@ public void DbFunction_Annotation_FullName() Assert.NotEqual(funcA.Metadata.Name, funcB.Metadata.Name); } + [ConditionalFact] + public void Find_Queryable_Single_Expression_Overload() + { + var modelBuilder = GetModelBuilder(); + + var funcA = modelBuilder.HasDbFunction(typeof(MyDerivedContext).GetMethod(nameof(MyDerivedContext.QueryableSingleParam), new Type[] { typeof(int) })); + var funcB = modelBuilder.HasDbFunction(typeof(MyDerivedContext).GetMethod(nameof(MyDerivedContext.QueryableSingleParam), new Type[] { typeof(Expression>) })); + + Assert.Equal("QueryableSingleParam", funcA.Metadata.Name); + Assert.Equal("QueryableSingleParam", funcB.Metadata.Name); + Assert.Equal(funcA.Metadata, funcB.Metadata); + } + + [ConditionalFact] + public void Find_Queryable_Multiple_Expression_Overload() + { + var modelBuilder = GetModelBuilder(); + + var funcA = modelBuilder.HasDbFunction(typeof(MyDerivedContext).GetMethod(nameof(MyDerivedContext.QueryableMultiParam), new Type[] { typeof(int), typeof(double) })); + var funcB = modelBuilder.HasDbFunction(typeof(MyDerivedContext).GetMethod(nameof(MyDerivedContext.QueryableMultiParam), new Type[] { typeof(Expression>), typeof(double) })); + var funcC = modelBuilder.HasDbFunction(typeof(MyDerivedContext).GetMethod(nameof(MyDerivedContext.QueryableMultiParam), new Type[] { typeof(Expression>), typeof(Expression>) })); + + Assert.Equal("QueryableMultiParam", funcA.Metadata.Name); + Assert.Equal("QueryableMultiParam", funcB.Metadata.Name); + Assert.Equal("QueryableMultiParam", funcC.Metadata.Name); + + Assert.Equal(funcA.Metadata, funcB.Metadata); + Assert.Equal(funcA.Metadata, funcC.Metadata); + Assert.Equal(funcB.Metadata, funcC.Metadata); + } + private ModelBuilder GetModelBuilder(DbContext dbContext = null) { var conventionSet = new ConventionSet(); diff --git a/test/EFCore.Relational.Tests/Migrations/Internal/MigrationsModelDifferTest.cs b/test/EFCore.Relational.Tests/Migrations/Internal/MigrationsModelDifferTest.cs index 89b06e8e5bf..cbcd44c75c4 100644 --- a/test/EFCore.Relational.Tests/Migrations/Internal/MigrationsModelDifferTest.cs +++ b/test/EFCore.Relational.Tests/Migrations/Internal/MigrationsModelDifferTest.cs @@ -25,6 +25,15 @@ private class TestQueryType public string Something { get; set; } } + [ConditionalFact] + public void Model_differ_does_not_detect_queryable_function_result_type() + { + Execute( + _ => { }, + modelBuilder => modelBuilder.Entity().ToQueryableFunctionResultType(), + result => Assert.Equal(0, result.Count)); + } + [ConditionalFact] public void Model_differ_does_not_detect_views() { diff --git a/test/EFCore.Specification.Tests/Query/NorthwindMiscellaneousQueryTestBase.cs b/test/EFCore.Specification.Tests/Query/NorthwindMiscellaneousQueryTestBase.cs index 63bfed90b3d..80838cc765c 100644 --- a/test/EFCore.Specification.Tests/Query/NorthwindMiscellaneousQueryTestBase.cs +++ b/test/EFCore.Specification.Tests/Query/NorthwindMiscellaneousQueryTestBase.cs @@ -1996,19 +1996,19 @@ public virtual Task OrderBy_scalar_primitive(bool async) public virtual Task SelectMany_mixed(bool async) { return AssertTranslationFailed( - () => AssertQuery( - async, - ss => from e1 in ss.Set().OrderBy(e => e.EmployeeID).Take(2) - from s in new[] { "a", "b" } - from c in ss.Set().OrderBy(c => c.CustomerID).Take(2) - select new - { - e1, - s, - c - }, - e => (e.e1.EmployeeID, e.c.CustomerID), - entryCount: 4)); + () => AssertQuery( + async, + ss => from e1 in ss.Set().OrderBy(e => e.EmployeeID).Take(2) + from s in new[] { "a", "b" } + from c in ss.Set().OrderBy(c => c.CustomerID).Take(2) + select new + { + e1, + s, + c + }, + e => (e.e1.EmployeeID, e.c.CustomerID), + entryCount: 4)); } [ConditionalTheory] diff --git a/test/EFCore.SqlServer.FunctionalTests/Query/UdfDbFunctionSqlServerTests.cs b/test/EFCore.SqlServer.FunctionalTests/Query/UdfDbFunctionSqlServerTests.cs index 1533e0164e0..e8bf90bd222 100644 --- a/test/EFCore.SqlServer.FunctionalTests/Query/UdfDbFunctionSqlServerTests.cs +++ b/test/EFCore.SqlServer.FunctionalTests/Query/UdfDbFunctionSqlServerTests.cs @@ -215,7 +215,7 @@ public override void Nullable_navigation_property_access_preserves_schema_for_sq AssertSql( @"SELECT TOP(1) [dbo].[IdentityString]([c].[FirstName]) FROM [Orders] AS [o] -LEFT JOIN [Customers] AS [c] ON [o].[CustomerId] = [c].[Id] +INNER JOIN [Customers] AS [c] ON [o].[CustomerId] = [c].[Id] ORDER BY [o].[Id]"); } @@ -444,6 +444,322 @@ FROM [Customers] AS [c] #endregion + #region Queryable Function Tests + + public override void QF_Stand_Alone() + { + base.QF_Stand_Alone(); + + AssertSql(@"SELECT [t].[AmountSold], [t].[ProductId] +FROM [dbo].[GetTopTwoSellingProducts]() AS [t] +ORDER BY [t].[ProductId]"); + } + + public override void QF_Stand_Alone_With_Translation() + { + base.QF_Stand_Alone_With_Translation(); + + AssertSql(@"SELECT [t].[AmountSold], [t].[ProductId] +FROM [dbo].[GetTopTwoSellingProducts]() AS [t] +ORDER BY [t].[ProductId]"); + } + + public override void QF_Stand_Alone_Parameter() + { + base.QF_Stand_Alone_Parameter(); + + AssertSql(@"@__customerId_0='1' + +SELECT [o].[Count], [o].[CustomerId], [o].[Year] +FROM [dbo].[GetCustomerOrderCountByYear](@__customerId_0) AS [o] +ORDER BY [o].[Count] DESC"); + } + + public override void QF_Stand_Alone_Nested() + { + base.QF_Stand_Alone_Nested(); + + AssertSql(@"SELECT [o].[Count], [o].[CustomerId], [o].[Year] +FROM [dbo].[GetCustomerOrderCountByYear]([dbo].[AddValues](-2, 3)) AS [o] +ORDER BY [o].[Count] DESC"); + } + + public override void QF_CrossApply_Correlated_Select_Anonymous() + { + base.QF_CrossApply_Correlated_Select_Anonymous(); + + AssertSql(@"SELECT [c].[Id], [c].[LastName], [o].[Year], [o].[Count] +FROM [Customers] AS [c] +CROSS APPLY [dbo].[GetCustomerOrderCountByYear]([c].[Id]) AS [o] +ORDER BY [c].[Id], [o].[Year]"); + } + + + public override void QF_Select_Direct_In_Anonymous() + { + base.QF_Select_Direct_In_Anonymous(); + + AssertSql(@"SELECT [t].[AmountSold], [t].[ProductId] +FROM [dbo].[GetTopTwoSellingProducts]() AS [t]", + +@"SELECT [c].[Id] +FROM [Customers] AS [c]"); + } + + public override void QF_Select_Correlated_Direct_With_Function_Query_Parameter_Correlated_In_Anonymous() + { + base.QF_Select_Correlated_Direct_With_Function_Query_Parameter_Correlated_In_Anonymous(); + + AssertSql(@"SELECT [c].[Id], [m].[OrderId], [m].[CustomerId], [m].[OrderDate] +FROM [Customers] AS [c] +OUTER APPLY [dbo].[GetOrdersWithMultipleProducts]([dbo].[AddValues]([c].[Id], 1)) AS [m] +WHERE [c].[Id] = 1 +ORDER BY [c].[Id], [m].[OrderId]"); + } + + public override void QF_Select_Correlated_Subquery_In_Anonymous() + { + base.QF_Select_Correlated_Subquery_In_Anonymous(); + + AssertSql(@"SELECT [c].[Id], [t].[OrderId], [t].[CustomerId], [t].[OrderDate] +FROM [Customers] AS [c] +OUTER APPLY ( + SELECT [m].[OrderId], [m].[CustomerId], [m].[OrderDate] + FROM [dbo].[GetOrdersWithMultipleProducts]([c].[Id]) AS [m] + WHERE DATEPART(day, [m].[OrderDate]) = 21 +) AS [t] +ORDER BY [c].[Id], [t].[OrderId]"); + } + + public override void QF_Select_Correlated_Subquery_In_Anonymous_Nested_With_QF() + { + base.QF_Select_Correlated_Subquery_In_Anonymous_Nested_With_QF(); + + AssertSql(@"SELECT [o].[CustomerId], [o].[OrderDate] +FROM [Orders] AS [o] +INNER JOIN ( + SELECT [c].[Id], [c].[FirstName], [c].[LastName], [m].[OrderId], [m].[CustomerId], [m].[OrderDate] + FROM [Customers] AS [c] + CROSS APPLY [dbo].[GetOrdersWithMultipleProducts]([c].[Id]) AS [m] +) AS [t] ON [o].[Id] = [t].[OrderId]"); + } + + public override void QF_Select_Correlated_Subquery_In_Anonymous_Nested() + { + base.QF_Select_Correlated_Subquery_In_Anonymous_Nested(); + + AssertSql(@"SELECT [t].[AmountSold], [t].[ProductId] +FROM [dbo].[GetTopTwoSellingProducts]() AS [t]", + + @"SELECT [c].[Id], [t].[OrderId], [t].[OrderId0], [t].[CustomerId], [t].[OrderDate] +FROM [Customers] AS [c] +OUTER APPLY ( + SELECT [m].[OrderId], [m0].[OrderId] AS [OrderId0], [m0].[CustomerId], [m0].[OrderDate] + FROM [dbo].[GetOrdersWithMultipleProducts]([c].[Id]) AS [m] + OUTER APPLY [dbo].[GetOrdersWithMultipleProducts]([m].[CustomerId]) AS [m0] + WHERE DATEPART(day, [m].[OrderDate]) = 21 +) AS [t] +ORDER BY [c].[Id], [t].[OrderId], [t].[OrderId0]"); + } + + public override void QF_Select_Correlated_Subquery_In_Anonymous_MultipleCollections() + { + base.QF_Select_Correlated_Subquery_In_Anonymous_MultipleCollections(); + + AssertSql(@"SELECT [c].[Id], [t0].[ProductId], [t1].[Id], [t1].[City], [t1].[CustomerId], [t1].[State], [t1].[Street] +FROM [Customers] AS [c] +OUTER APPLY ( + SELECT [t].[ProductId] + FROM [dbo].[GetTopTwoSellingProducts]() AS [t] + WHERE [t].[AmountSold] = 249 +) AS [t0] +LEFT JOIN ( + SELECT [a].[Id], [a].[City], [a].[CustomerId], [a].[State], [a].[Street] + FROM [Addresses] AS [a] + WHERE [a].[State] = N'NY' +) AS [t1] ON [c].[Id] = [t1].[CustomerId] +ORDER BY [c].[Id], [t1].[Id]"); + } + + public override void QF_Select_NonCorrelated_Subquery_In_Anonymous() + { + base.QF_Select_NonCorrelated_Subquery_In_Anonymous(); + + AssertSql(@"SELECT [c].[Id], [t0].[ProductId] +FROM [Customers] AS [c] +OUTER APPLY ( + SELECT [t].[ProductId] + FROM [dbo].[GetTopTwoSellingProducts]() AS [t] + WHERE [t].[AmountSold] = 249 +) AS [t0] +ORDER BY [c].[Id]"); + } + + public override void QF_Select_NonCorrelated_Subquery_In_Anonymous_Parameter() + { + base.QF_Select_NonCorrelated_Subquery_In_Anonymous_Parameter(); + + AssertSql( + @"@__amount_0='27' (Nullable = true) + +SELECT [c].[Id], [t0].[ProductId] +FROM [Customers] AS [c] +OUTER APPLY ( + SELECT [t].[ProductId] + FROM [dbo].[GetTopTwoSellingProducts]() AS [t] + WHERE [t].[AmountSold] = @__amount_0 +) AS [t0] +ORDER BY [c].[Id]"); + } + + public override void QF_Correlated_Select_In_Anonymous() + { + base.QF_Correlated_Select_In_Anonymous(); + + AssertSql(@"SELECT [c].[Id], [c].[LastName], [m].[OrderId], [m].[CustomerId], [m].[OrderDate] +FROM [Customers] AS [c] +OUTER APPLY [dbo].[GetOrdersWithMultipleProducts]([c].[Id]) AS [m] +ORDER BY [c].[Id], [m].[OrderId]"); + } + + public override void QF_CrossApply_Correlated_Select_Result() + { + base.QF_CrossApply_Correlated_Select_Result(); + + AssertSql(@"SELECT [o].[Count], [o].[CustomerId], [o].[Year] +FROM [Customers] AS [c] +CROSS APPLY [dbo].[GetCustomerOrderCountByYear]([c].[Id]) AS [o] +ORDER BY [o].[Count] DESC, [o].[Year] DESC"); + } + + public override void QF_CrossJoin_Not_Correlated() + { + base.QF_CrossJoin_Not_Correlated(); + + AssertSql(@"@__customerId_0='2' + +SELECT [c].[Id], [c].[LastName], [o].[Year], [o].[Count] +FROM [Customers] AS [c] +CROSS JOIN [dbo].[GetCustomerOrderCountByYear](@__customerId_0) AS [o] +WHERE [c].[Id] = 2 +ORDER BY [o].[Count]"); + } + + public override void QF_CrossJoin_Parameter() + { + base.QF_CrossJoin_Parameter(); + + AssertSql(@"@__customerId_0='2' +@__custId_1='2' + +SELECT [c].[Id], [c].[LastName], [o].[Year], [o].[Count] +FROM [Customers] AS [c] +CROSS JOIN [dbo].[GetCustomerOrderCountByYear](@__customerId_0) AS [o] +WHERE [c].[Id] = @__custId_1 +ORDER BY [o].[Count]"); + } + + public override void QF_Join() + { + base.QF_Join(); + + AssertSql(@"SELECT [p].[Id], [p].[Name], [t].[AmountSold] +FROM [Products] AS [p] +INNER JOIN [dbo].[GetTopTwoSellingProducts]() AS [t] ON [p].[Id] = [t].[ProductId] +ORDER BY [p].[Id]"); + } + + public override void QF_LeftJoin_Select_Anonymous() + { + base.QF_LeftJoin_Select_Anonymous(); + + AssertSql(@"SELECT [p].[Id], [p].[Name], [t].[AmountSold] +FROM [Products] AS [p] +LEFT JOIN [dbo].[GetTopTwoSellingProducts]() AS [t] ON [p].[Id] = [t].[ProductId] +ORDER BY [p].[Id] DESC"); + } + + public override void QF_LeftJoin_Select_Result() + { + base.QF_LeftJoin_Select_Result(); + + AssertSql(@"SELECT [t].[AmountSold], [t].[ProductId] +FROM [Products] AS [p] +LEFT JOIN [dbo].[GetTopTwoSellingProducts]() AS [t] ON [p].[Id] = [t].[ProductId] +ORDER BY [p].[Id] DESC"); + } + + public override void QF_OuterApply_Correlated_Select_TVF() + { + base.QF_OuterApply_Correlated_Select_TVF(); + + AssertSql(@"SELECT [o].[Count], [o].[CustomerId], [o].[Year] +FROM [Customers] AS [c] +OUTER APPLY [dbo].[GetCustomerOrderCountByYear]([c].[Id]) AS [o] +ORDER BY [c].[Id], [o].[Year]"); + } + + public override void QF_OuterApply_Correlated_Select_DbSet() + { + base.QF_OuterApply_Correlated_Select_DbSet(); + + AssertSql(@"SELECT [c].[Id], [c].[FirstName], [c].[LastName] +FROM [Customers] AS [c] +OUTER APPLY [dbo].[GetCustomerOrderCountByYear]([c].[Id]) AS [o] +ORDER BY [c].[Id], [o].[Year]"); + } + + public override void QF_OuterApply_Correlated_Select_Anonymous() + { + base.QF_OuterApply_Correlated_Select_Anonymous(); + + AssertSql(@"SELECT [c].[Id], [c].[LastName], [o].[Year], [o].[Count] +FROM [Customers] AS [c] +OUTER APPLY [dbo].[GetCustomerOrderCountByYear]([c].[Id]) AS [o] +ORDER BY [c].[Id], [o].[Year]"); + } + + public override void QF_Nested() + { + base.QF_Nested(); + + AssertSql(@"@__custId_1='2' + +SELECT [c].[Id], [c].[LastName], [o].[Year], [o].[Count] +FROM [Customers] AS [c] +CROSS JOIN [dbo].[GetCustomerOrderCountByYear]([dbo].[AddValues](1, 1)) AS [o] +WHERE [c].[Id] = @__custId_1 +ORDER BY [o].[Year]"); + } + + public override void QF_Correlated_Nested_Func_Call() + { + base.QF_Correlated_Nested_Func_Call(); + + AssertSql(@"@__custId_1='2' + +SELECT [c].[Id], [o].[Count], [o].[Year] +FROM [Customers] AS [c] +CROSS APPLY [dbo].[GetCustomerOrderCountByYear]([dbo].[AddValues]([c].[Id], 1)) AS [o] +WHERE [c].[Id] = @__custId_1"); + } + + public override void QF_Correlated_Func_Call_With_Navigation() + { + base.QF_Correlated_Func_Call_With_Navigation(); + + AssertSql(@"SELECT [c].[Id], [t].[LastName], [t].[OrderId], [t].[Id] +FROM [Customers] AS [c] +OUTER APPLY ( + SELECT [c0].[LastName], [m].[OrderId], [c0].[Id] + FROM [dbo].[GetOrdersWithMultipleProducts]([c].[Id]) AS [m] + INNER JOIN [Customers] AS [c0] ON [m].[CustomerId] = [c0].[Id] +) AS [t] +ORDER BY [c].[Id], [t].[OrderId], [t].[Id]"); + } + + #endregion + public class SqlServer : UdfFixtureBase { protected override string StoreName { get; } = "UDFDbFunctionSqlServerTests"; @@ -516,6 +832,98 @@ returns nvarchar(max) return @customerName; end"); + context.Database.ExecuteSqlRaw( + @"create function [dbo].GetCustomerOrderCountByYear(@customerId int) + returns @reports table + ( + CustomerId int not null, + Count int not null, + Year int not null + ) + as + begin + + insert into @reports + select @customerId, count(id), year(orderDate) + from orders + where customerId = @customerId + group by customerId, year(orderDate) + order by year(orderDate) + + return + end"); + + context.Database.ExecuteSqlRaw( + @"create function [dbo].GetTopTwoSellingProducts() + returns @products table + ( + ProductId int not null, + AmountSold int + ) + as + begin + + insert into @products + select top 2 ProductID, sum(Quantity) as totalSold + from lineItem + group by ProductID + order by totalSold desc + return + end"); + + context.Database.ExecuteSqlRaw( + @"create function [dbo].GetTopSellingProductsForCustomer(@customerId int) + returns @products table + ( + ProductId int not null, + AmountSold int + ) + as + begin + + insert into @products + select ProductID, sum(Quantity) as totalSold + from lineItem li + join orders o on o.id = li.orderId + where o.customerId = @customerId + group by ProductID + + return + end"); + + context.Database.ExecuteSqlRaw( + @"create function [dbo].GetOrdersWithMultipleProducts(@customerId int) + returns @orders table + ( + OrderId int not null, + CustomerId int not null, + OrderDate dateTime2 + ) + as + begin + + insert into @orders + select o.id, @customerId, OrderDate + from orders o + join lineItem li on o.id = li.orderId + where o.customerId = @customerId + group by o.id, OrderDate + having count(productId) > 1 + return + end"); + + context.Database.ExecuteSqlRaw( + @"create function [dbo].[AddValues] (@a int, @b int) + returns int + as + begin + return @a + @b; + end"); + + + context.Database.ExecuteSqlRaw( + @"create view [dbo].[vOrderQuery] as select * from orders"); + context.SaveChanges(); } }