diff --git a/src/EFCore.Relational/Extensions/RelationalEntityTypeBuilderExtensions.cs b/src/EFCore.Relational/Extensions/RelationalEntityTypeBuilderExtensions.cs
index c7524b9a479..d6223ac8c20 100644
--- a/src/EFCore.Relational/Extensions/RelationalEntityTypeBuilderExtensions.cs
+++ b/src/EFCore.Relational/Extensions/RelationalEntityTypeBuilderExtensions.cs
@@ -264,6 +264,37 @@ public static bool CanSetSchema(
return entityTypeBuilder.CanSetAnnotation(RelationalAnnotationNames.Schema, schema, fromDataAnnotation);
}
+ ///
+ /// Configures the entity as a result of a queryable function. Prevents a table from being created for this entity.
+ ///
+ /// The builder for the entity type being configured.
+ /// The same builder instance so that multiple calls can be chained.
+ public static EntityTypeBuilder ToQueryableFunctionResultType(
+ [NotNull] this EntityTypeBuilder entityTypeBuilder)
+ {
+ Check.NotNull(entityTypeBuilder, nameof(entityTypeBuilder));
+
+ entityTypeBuilder.Metadata.SetAnnotation(RelationalAnnotationNames.QueryableFunctionResultType, null);
+
+ return entityTypeBuilder;
+ }
+
+ ///
+ /// Configures the entity as a result of a queryable function. Prevents a table from being created for this entity.
+ ///
+ /// The builder for the entity type being configured.
+ /// The same builder instance so that multiple calls can be chained.
+ public static EntityTypeBuilder ToQueryableFunctionResultType(
+ [NotNull] this EntityTypeBuilder entityTypeBuilder)
+ where TEntity : class
+ {
+ Check.NotNull(entityTypeBuilder, nameof(entityTypeBuilder));
+
+ entityTypeBuilder.Metadata.SetAnnotation(RelationalAnnotationNames.QueryableFunctionResultType, null);
+
+ return entityTypeBuilder;
+ }
+
///
/// Configures the view that the entity type maps to when targeting a relational database.
///
diff --git a/src/EFCore.Relational/Extensions/RelationalEntityTypeExtensions.cs b/src/EFCore.Relational/Extensions/RelationalEntityTypeExtensions.cs
index 45c32cf6d4c..a9300130df3 100644
--- a/src/EFCore.Relational/Extensions/RelationalEntityTypeExtensions.cs
+++ b/src/EFCore.Relational/Extensions/RelationalEntityTypeExtensions.cs
@@ -292,6 +292,9 @@ public static bool IsIgnoredByMigrations([NotNull] this IEntityType entityType)
return true;
}
+ if (entityType.FindAnnotation(RelationalAnnotationNames.QueryableFunctionResultType) != null)
+ return true;
+
var viewDefinition = entityType.FindAnnotation(RelationalAnnotationNames.ViewDefinition);
if (viewDefinition == null)
{
diff --git a/src/EFCore.Relational/Infrastructure/EntityFrameworkRelationalServicesBuilder.cs b/src/EFCore.Relational/Infrastructure/EntityFrameworkRelationalServicesBuilder.cs
index e70312a710f..132edc53ae8 100644
--- a/src/EFCore.Relational/Infrastructure/EntityFrameworkRelationalServicesBuilder.cs
+++ b/src/EFCore.Relational/Infrastructure/EntityFrameworkRelationalServicesBuilder.cs
@@ -169,6 +169,7 @@ public override EntityFrameworkServicesBuilder TryAddCoreServices()
TryAdd();
TryAdd();
TryAdd();
+ TryAdd();
ServiceCollectionMap.GetInfrastructure()
.AddDependencySingleton()
diff --git a/src/EFCore.Relational/Infrastructure/RelationalModelValidator.cs b/src/EFCore.Relational/Infrastructure/RelationalModelValidator.cs
index d00c17a598b..0bee6512073 100644
--- a/src/EFCore.Relational/Infrastructure/RelationalModelValidator.cs
+++ b/src/EFCore.Relational/Infrastructure/RelationalModelValidator.cs
@@ -80,7 +80,8 @@ protected virtual void ValidateDbFunctions(
RelationalStrings.DbFunctionNameEmpty(methodInfo.DisplayName()));
}
- if (dbFunction.TypeMapping == null)
+ if (dbFunction.TypeMapping == null &&
+ !(dbFunction.IsIQueryable && model.FindEntityType(dbFunction.MethodInfo.ReturnType.GetGenericArguments()[0]) != null))
{
throw new InvalidOperationException(
RelationalStrings.DbFunctionInvalidReturnType(
diff --git a/src/EFCore.Relational/Metadata/IDbFunction.cs b/src/EFCore.Relational/Metadata/IDbFunction.cs
index af3d76d90d2..8b129911289 100644
--- a/src/EFCore.Relational/Metadata/IDbFunction.cs
+++ b/src/EFCore.Relational/Metadata/IDbFunction.cs
@@ -34,6 +34,11 @@ public interface IDbFunction
///
MethodInfo MethodInfo { get; }
+ ///
+ /// Whether this method returns IQueryable
+ ///
+ bool IsIQueryable { get; }
+
///
/// The configured store type string
///
diff --git a/src/EFCore.Relational/Metadata/Internal/DbFunction.cs b/src/EFCore.Relational/Metadata/Internal/DbFunction.cs
index 6818016b6df..9a503831b20 100644
--- a/src/EFCore.Relational/Metadata/Internal/DbFunction.cs
+++ b/src/EFCore.Relational/Metadata/Internal/DbFunction.cs
@@ -4,6 +4,7 @@
using System;
using System.Collections.Generic;
using System.Linq;
+using System.Linq.Expressions;
using System.Reflection;
using JetBrains.Annotations;
using Microsoft.EntityFrameworkCore.Diagnostics;
@@ -75,6 +76,19 @@ public DbFunction(
methodInfo.DisplayName(), methodInfo.ReturnType.ShortDisplayName()));
}
+ if (methodInfo.ReturnType.IsGenericType
+ && methodInfo.ReturnType.GetGenericTypeDefinition() == typeof(IQueryable<>))
+ {
+ IsIQueryable = true;
+
+ //todo - if the generic argument is not usuable as an entitytype should we throw here? IE IQueryable
+ //the built in entitytype will throw is the type is not a class
+ if (model.FindEntityType(methodInfo.ReturnType.GetGenericArguments()[0]) == null)
+ {
+ model.AddEntityType(methodInfo.ReturnType.GetGenericArguments()[0]).SetAnnotation(RelationalAnnotationNames.QueryableFunctionResultType, null);
+ }
+ }
+
MethodInfo = methodInfo;
_model = model;
@@ -310,6 +324,14 @@ public virtual Func, SqlExpression> Translati
set => SetTranslation(value, ConfigurationSource.Explicit);
}
+ ///
+ /// This is an internal API that supports the Entity Framework Core infrastructure and not subject to
+ /// the same compatibility standards as public APIs. It may be changed or removed without notice in
+ /// any release. You should only use it directly in your code with extreme caution and knowing that
+ /// doing so can result in application failures when updating to a new Entity Framework Core release.
+ ///
+ public virtual bool IsIQueryable { get; }
+
///
/// This is an internal API that supports the Entity Framework Core infrastructure and not subject to
/// the same compatibility standards as public APIs. It may be changed or removed without notice in
@@ -345,7 +367,27 @@ private void UpdateTranslationConfigurationSource(ConfigurationSource configurat
public static DbFunction FindDbFunction(
[NotNull] IModel model,
[NotNull] MethodInfo methodInfo)
- => model[BuildAnnotationName(methodInfo)] as DbFunction;
+ {
+ var dbFunction = model[BuildAnnotationName(methodInfo)] as DbFunction;
+
+ if (dbFunction == null
+ && methodInfo.GetParameters().Any(p => p.ParameterType.IsGenericType && p.ParameterType.GetGenericTypeDefinition() == typeof(Expression<>)))
+ {
+ var parameters = methodInfo.GetParameters().Select(p => p.ParameterType.IsGenericType
+ && p.ParameterType.GetGenericTypeDefinition() == typeof(Expression<>)
+ && p.ParameterType.GetGenericArguments()[0].GetGenericTypeDefinition() == typeof(Func<>)
+ ? p.ParameterType.GetGenericArguments()[0].GetGenericArguments()[0]
+ : p.ParameterType).ToArray();
+
+ var nonExpressionMethod = methodInfo.DeclaringType.GetMethod(methodInfo.Name, parameters);
+
+ dbFunction = nonExpressionMethod != null
+ ? model[BuildAnnotationName(nonExpressionMethod)] as DbFunction
+ : null;
+ }
+
+ return dbFunction;
+ }
///
/// This is an internal API that supports the Entity Framework Core infrastructure and not subject to
diff --git a/src/EFCore.Relational/Metadata/RelationalAnnotationNames.cs b/src/EFCore.Relational/Metadata/RelationalAnnotationNames.cs
index fafdb4f054c..4d712d842a4 100644
--- a/src/EFCore.Relational/Metadata/RelationalAnnotationNames.cs
+++ b/src/EFCore.Relational/Metadata/RelationalAnnotationNames.cs
@@ -98,5 +98,10 @@ public static class RelationalAnnotationNames
/// The definition of a database view.
///
public const string ViewDefinition = Prefix + "ViewDefinition";
+
+ ///
+ /// The definition of a Queryable Function Result Type.
+ ///
+ public const string QueryableFunctionResultType = Prefix + "QueryableFunctionResultType";
}
}
diff --git a/src/EFCore.Relational/Properties/RelationalStrings.Designer.cs b/src/EFCore.Relational/Properties/RelationalStrings.Designer.cs
index 316be0dac26..18e29b83574 100644
--- a/src/EFCore.Relational/Properties/RelationalStrings.Designer.cs
+++ b/src/EFCore.Relational/Properties/RelationalStrings.Designer.cs
@@ -488,6 +488,20 @@ public static string DbFunctionInvalidInstanceType([CanBeNull] object function,
GetString("DbFunctionInvalidInstanceType", nameof(function), nameof(type)),
function, type);
+ ///
+ /// Queryable Db Functions used in projections cannot return IQueryable. IQueryable must be converted to a collection type such as List or Array.
+ ///
+ public static string DbFunctionCantProjectIQueryable()
+ => GetString("DbFunctionCantProjectIQueryable");
+
+ ///
+ /// Return type of a queryable function '{functionName}' which is used in a projected collection must define a primary key.
+ ///
+ public static string DbFunctionProjectedCollectionMustHavePK([CanBeNull] string functionName)
+ => string.Format(
+ GetString("DbFunctionProjectedCollectionMustHavePK", nameof(functionName)),
+ functionName);
+
///
/// An ambient transaction has been detected. The ambient transaction needs to be completed before beginning a transaction on this connection.
///
diff --git a/src/EFCore.Relational/Properties/RelationalStrings.resx b/src/EFCore.Relational/Properties/RelationalStrings.resx
index df2095d5dc5..dfe238ab682 100644
--- a/src/EFCore.Relational/Properties/RelationalStrings.resx
+++ b/src/EFCore.Relational/Properties/RelationalStrings.resx
@@ -435,6 +435,12 @@
The DbFunction '{function}' defined on type '{type}' must be either a static method or an instance method defined on a DbContext subclass. Instance methods on other types are not supported.
+
+ Queryable Db Functions used in projections cannot return IQueryable. IQueryable must be converted to a collection type such as List or Array.
+
+
+ Return type of a queryable function '{functionName}' which is used in a projected collection must define a primary key.
+
An ambient transaction has been detected. The ambient transaction needs to be completed before beginning a transaction on this connection.
diff --git a/src/EFCore.Relational/Query/ISqlExpressionFactory.cs b/src/EFCore.Relational/Query/ISqlExpressionFactory.cs
index 955e4857bbe..3f1c25718b5 100644
--- a/src/EFCore.Relational/Query/ISqlExpressionFactory.cs
+++ b/src/EFCore.Relational/Query/ISqlExpressionFactory.cs
@@ -141,5 +141,6 @@ SqlFunctionExpression Function(
SelectExpression Select([CanBeNull] SqlExpression projection);
SelectExpression Select([NotNull] IEntityType entityType);
SelectExpression Select([NotNull] IEntityType entityType, [NotNull] string sql, [NotNull] Expression sqlArguments);
+ SelectExpression Select([NotNull] IEntityType entityType, [NotNull] SqlFunctionExpression expression);
}
}
diff --git a/src/EFCore.Relational/Query/Internal/NullSemanticsRewritingExpressionVisitor.cs b/src/EFCore.Relational/Query/Internal/NullSemanticsRewritingExpressionVisitor.cs
index 69bf5410321..30947b4e134 100644
--- a/src/EFCore.Relational/Query/Internal/NullSemanticsRewritingExpressionVisitor.cs
+++ b/src/EFCore.Relational/Query/Internal/NullSemanticsRewritingExpressionVisitor.cs
@@ -656,6 +656,13 @@ protected override Expression VisitSqlFunction(SqlFunctionExpression sqlFunction
return sqlFunctionExpression.Update(newInstance, newArguments);
}
+ protected override Expression VisitQueryableSqlFunctionExpression(QuerableSqlFunctionExpression queryableFunctionExpression)
+ {
+ Check.NotNull(queryableFunctionExpression, nameof(queryableFunctionExpression));
+
+ return queryableFunctionExpression;
+ }
+
protected override Expression VisitSqlParameter(SqlParameterExpression sqlParameterExpression)
{
Check.NotNull(sqlParameterExpression, nameof(sqlParameterExpression));
diff --git a/src/EFCore.Relational/Query/Internal/RelationalNavigationExpandingExpressionVisitor.cs b/src/EFCore.Relational/Query/Internal/RelationalNavigationExpandingExpressionVisitor.cs
new file mode 100644
index 00000000000..25cd5833908
--- /dev/null
+++ b/src/EFCore.Relational/Query/Internal/RelationalNavigationExpandingExpressionVisitor.cs
@@ -0,0 +1,27 @@
+// Copyright (c) .NET Foundation. All rights reserved.
+// Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information.
+
+using System.Diagnostics.CodeAnalysis;
+using System.Linq.Expressions;
+
+namespace Microsoft.EntityFrameworkCore.Query.Internal
+{
+ public class RelationalNavigationExpandingExpressionVisitor : NavigationExpandingExpressionVisitor
+ {
+ public RelationalNavigationExpandingExpressionVisitor(
+ [NotNull] QueryCompilationContext queryCompilationContext,
+ [NotNull] IEvaluatableExpressionFilter evaluatableExpressionFilter)
+ : base(queryCompilationContext, evaluatableExpressionFilter)
+ {
+ }
+
+ protected override Expression VisitMethodCall(MethodCallExpression methodCallExpression)
+ {
+ var dbFunction = QueryCompilationContext.Model.FindDbFunction(methodCallExpression.Method);
+
+ return dbFunction?.IsIQueryable == true
+ ? CreateNavigationExpansionExpression(methodCallExpression, QueryCompilationContext.Model.FindEntityType(dbFunction.MethodInfo.ReturnType.GetGenericArguments()[0]))
+ : base.VisitMethodCall(methodCallExpression);
+ }
+ }
+}
diff --git a/src/EFCore.Relational/Query/Internal/RelationalNavigationExpandingExpressionVisitorFactory.cs b/src/EFCore.Relational/Query/Internal/RelationalNavigationExpandingExpressionVisitorFactory.cs
new file mode 100644
index 00000000000..6f2ffd343e9
--- /dev/null
+++ b/src/EFCore.Relational/Query/Internal/RelationalNavigationExpandingExpressionVisitorFactory.cs
@@ -0,0 +1,14 @@
+// Copyright (c) .NET Foundation. All rights reserved.
+// Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information.
+
+namespace Microsoft.EntityFrameworkCore.Query.Internal
+{
+ public class RelationalNavigationExpandingExpressionVisitorFactory : INavigationExpandingExpressionVisitorFactory
+ {
+ public virtual NavigationExpandingExpressionVisitor Create(
+ QueryCompilationContext queryCompilationContext, IEvaluatableExpressionFilter evaluatableExpressionFilter)
+ {
+ return new RelationalNavigationExpandingExpressionVisitor(queryCompilationContext, evaluatableExpressionFilter);
+ }
+ }
+}
diff --git a/src/EFCore.Relational/Query/Internal/RelationalProjectionBindingExpressionVisitor.cs b/src/EFCore.Relational/Query/Internal/RelationalProjectionBindingExpressionVisitor.cs
index 694ea4513c1..560ff880095 100644
--- a/src/EFCore.Relational/Query/Internal/RelationalProjectionBindingExpressionVisitor.cs
+++ b/src/EFCore.Relational/Query/Internal/RelationalProjectionBindingExpressionVisitor.cs
@@ -9,6 +9,7 @@
using JetBrains.Annotations;
using Microsoft.EntityFrameworkCore.Diagnostics;
using Microsoft.EntityFrameworkCore.Infrastructure;
+using Microsoft.EntityFrameworkCore.Metadata;
using Microsoft.EntityFrameworkCore.Query.SqlExpressions;
using Microsoft.EntityFrameworkCore.Storage;
using Microsoft.EntityFrameworkCore.Utilities;
@@ -22,6 +23,7 @@ public class RelationalProjectionBindingExpressionVisitor : ExpressionVisitor
private SelectExpression _selectExpression;
private bool _clientEval;
+ private readonly IModel _model;
private readonly IDictionary _projectionMapping
= new Dictionary();
@@ -30,10 +32,12 @@ private readonly IDictionary _projectionMapping
public RelationalProjectionBindingExpressionVisitor(
[NotNull] RelationalQueryableMethodTranslatingExpressionVisitor queryableMethodTranslatingExpressionVisitor,
- [NotNull] RelationalSqlTranslatingExpressionVisitor sqlTranslatingExpressionVisitor)
+ [NotNull] RelationalSqlTranslatingExpressionVisitor sqlTranslatingExpressionVisitor,
+ [NotNull] IModel model)
{
_queryableMethodTranslatingExpressionVisitor = queryableMethodTranslatingExpressionVisitor;
_sqlTranslator = sqlTranslatingExpressionVisitor;
+ _model = model;
}
public virtual Expression Translate([NotNull] SelectExpression selectExpression, [NotNull] Expression expression)
@@ -242,6 +246,11 @@ protected override Expression VisitNew(NewExpression newExpression)
return null;
}
+ if (newExpression.Arguments.Any(arg => arg is MethodCallExpression methodCallExp && _model.FindDbFunction(methodCallExp.Method)?.IsIQueryable == true))
+ {
+ throw new InvalidOperationException(RelationalStrings.DbFunctionCantProjectIQueryable());
+ }
+
var newArguments = new Expression[newExpression.Arguments.Count];
for (var i = 0; i < newArguments.Length; i++)
{
diff --git a/src/EFCore.Relational/Query/Internal/RelationalQueryTranslationPreprocessorFactory.cs b/src/EFCore.Relational/Query/Internal/RelationalQueryTranslationPreprocessorFactory.cs
index 8ba74997331..a6e026532ab 100644
--- a/src/EFCore.Relational/Query/Internal/RelationalQueryTranslationPreprocessorFactory.cs
+++ b/src/EFCore.Relational/Query/Internal/RelationalQueryTranslationPreprocessorFactory.cs
@@ -21,20 +21,23 @@ public class RelationalQueryTranslationPreprocessorFactory : IQueryTranslationPr
{
private readonly QueryTranslationPreprocessorDependencies _dependencies;
private readonly RelationalQueryTranslationPreprocessorDependencies _relationalDependencies;
+ private readonly INavigationExpandingExpressionVisitorFactory _navigationExpandingExpressionVisitorFactory;
public RelationalQueryTranslationPreprocessorFactory(
[NotNull] QueryTranslationPreprocessorDependencies dependencies,
- [NotNull] RelationalQueryTranslationPreprocessorDependencies relationalDependencies)
+ [NotNull] RelationalQueryTranslationPreprocessorDependencies relationalDependencies,
+ [NotNull] INavigationExpandingExpressionVisitorFactory navigationExpandingExpressionVisitorFactory)
{
_dependencies = dependencies;
_relationalDependencies = relationalDependencies;
+ _navigationExpandingExpressionVisitorFactory = navigationExpandingExpressionVisitorFactory;
}
public virtual QueryTranslationPreprocessor Create(QueryCompilationContext queryCompilationContext)
{
Check.NotNull(queryCompilationContext, nameof(queryCompilationContext));
- return new RelationalQueryTranslationPreprocessor(_dependencies, _relationalDependencies, queryCompilationContext);
+ return new RelationalQueryTranslationPreprocessor(_dependencies, _relationalDependencies, queryCompilationContext, _navigationExpandingExpressionVisitorFactory);
}
}
}
diff --git a/src/EFCore.Relational/Query/NullabilityBasedSqlProcessingExpressionVisitor.cs b/src/EFCore.Relational/Query/NullabilityBasedSqlProcessingExpressionVisitor.cs
index 0c5e3a5aa42..6d5bec985d1 100644
--- a/src/EFCore.Relational/Query/NullabilityBasedSqlProcessingExpressionVisitor.cs
+++ b/src/EFCore.Relational/Query/NullabilityBasedSqlProcessingExpressionVisitor.cs
@@ -371,6 +371,9 @@ protected override Expression VisitProjection(ProjectionExpression projectionExp
VisitInternal(projectionExpression.Expression).ResultExpression);
}
+ protected override Expression VisitQueryableSqlFunctionExpression(QuerableSqlFunctionExpression queryableFunctionExpression)
+ => Check.NotNull(queryableFunctionExpression, nameof(queryableFunctionExpression));
+
protected override Expression VisitRowNumber(RowNumberExpression rowNumberExpression)
{
Check.NotNull(rowNumberExpression, nameof(rowNumberExpression));
diff --git a/src/EFCore.Relational/Query/QuerySqlGenerator.cs b/src/EFCore.Relational/Query/QuerySqlGenerator.cs
index 5e9f7314ffe..04e9013e0f3 100644
--- a/src/EFCore.Relational/Query/QuerySqlGenerator.cs
+++ b/src/EFCore.Relational/Query/QuerySqlGenerator.cs
@@ -270,6 +270,17 @@ protected override Expression VisitSqlFunction(SqlFunctionExpression sqlFunction
return sqlFunctionExpression;
}
+ protected override Expression VisitQueryableSqlFunctionExpression(QuerableSqlFunctionExpression queryableFunctionExpression)
+ {
+ Visit(queryableFunctionExpression.SqlFunctionExpression);
+
+ _relationalCommandBuilder
+ .Append(AliasSeparator)
+ .Append(_sqlGenerationHelper.DelimitIdentifier(queryableFunctionExpression.Alias));
+
+ return queryableFunctionExpression;
+ }
+
protected override Expression VisitColumn(ColumnExpression columnExpression)
{
Check.NotNull(columnExpression, nameof(columnExpression));
diff --git a/src/EFCore.Relational/Query/RelationalEvaluatableExpressionFilter.cs b/src/EFCore.Relational/Query/RelationalEvaluatableExpressionFilter.cs
index 786419d6ab5..2aa9b4f7b42 100644
--- a/src/EFCore.Relational/Query/RelationalEvaluatableExpressionFilter.cs
+++ b/src/EFCore.Relational/Query/RelationalEvaluatableExpressionFilter.cs
@@ -47,11 +47,11 @@ public RelationalEvaluatableExpressionFilter(
///
protected virtual RelationalEvaluatableExpressionFilterDependencies RelationalDependencies { get; }
- ///
- /// Checks whether the given expression can be evaluated.
- ///
- /// The expression.
- /// The model.
+ ///
+ /// Checks whether the given expression can be evaluated.
+ ///
+ /// The expression.
+ /// The model.
/// True if the expression can be evaluated; false otherwise.
public override bool IsEvaluatableExpression(Expression expression, IModel model)
{
@@ -59,12 +59,16 @@ public override bool IsEvaluatableExpression(Expression expression, IModel model
Check.NotNull(model, nameof(model));
if (expression is MethodCallExpression methodCallExpression
- && model.FindDbFunction(methodCallExpression.Method) != null)
+ && model.FindDbFunction(methodCallExpression.Method) is IDbFunction func)
{
- return false;
+ return func?.IsIQueryable ?? true;
}
return base.IsEvaluatableExpression(expression, model);
}
+
+ public override bool IsQueryableFunction(Expression expression, IModel model) =>
+ expression is MethodCallExpression methodCallExpression
+ && model.FindDbFunction(methodCallExpression.Method)?.IsIQueryable == true;
}
}
diff --git a/src/EFCore.Relational/Query/RelationalQueryTranslationPreprocessor.cs b/src/EFCore.Relational/Query/RelationalQueryTranslationPreprocessor.cs
index 721660089da..f0f441f5fcf 100644
--- a/src/EFCore.Relational/Query/RelationalQueryTranslationPreprocessor.cs
+++ b/src/EFCore.Relational/Query/RelationalQueryTranslationPreprocessor.cs
@@ -11,8 +11,9 @@ public class RelationalQueryTranslationPreprocessor : QueryTranslationPreprocess
public RelationalQueryTranslationPreprocessor(
[NotNull] QueryTranslationPreprocessorDependencies dependencies,
[NotNull] RelationalQueryTranslationPreprocessorDependencies relationalDependencies,
- [NotNull] QueryCompilationContext queryCompilationContext)
- : base(dependencies, queryCompilationContext)
+ [NotNull] QueryCompilationContext queryCompilationContext,
+ [NotNull] INavigationExpandingExpressionVisitorFactory navigationExpandingExpressionVisitorFactory)
+ : base(dependencies, queryCompilationContext, navigationExpandingExpressionVisitorFactory)
{
Check.NotNull(relationalDependencies, nameof(relationalDependencies));
diff --git a/src/EFCore.Relational/Query/RelationalQueryableMethodTranslatingExpressionVisitor.cs b/src/EFCore.Relational/Query/RelationalQueryableMethodTranslatingExpressionVisitor.cs
index 04e2820db79..f6435f98779 100644
--- a/src/EFCore.Relational/Query/RelationalQueryableMethodTranslatingExpressionVisitor.cs
+++ b/src/EFCore.Relational/Query/RelationalQueryableMethodTranslatingExpressionVisitor.cs
@@ -42,7 +42,7 @@ public RelationalQueryableMethodTranslatingExpressionVisitor(
var sqlExpressionFactory = relationalDependencies.SqlExpressionFactory;
_sqlTranslator = relationalDependencies.RelationalSqlTranslatingExpressionVisitorFactory.Create(model, this);
_weakEntityExpandingExpressionVisitor = new WeakEntityExpandingExpressionVisitor(_sqlTranslator, sqlExpressionFactory);
- _projectionBindingExpressionVisitor = new RelationalProjectionBindingExpressionVisitor(this, _sqlTranslator);
+ _projectionBindingExpressionVisitor = new RelationalProjectionBindingExpressionVisitor(this, _sqlTranslator, model);
_model = model;
_sqlExpressionFactory = sqlExpressionFactory;
_subquery = false;
@@ -58,7 +58,7 @@ protected RelationalQueryableMethodTranslatingExpressionVisitor(
_model = parentVisitor._model;
_sqlTranslator = parentVisitor._sqlTranslator;
_weakEntityExpandingExpressionVisitor = parentVisitor._weakEntityExpandingExpressionVisitor;
- _projectionBindingExpressionVisitor = new RelationalProjectionBindingExpressionVisitor(this, _sqlTranslator);
+ _projectionBindingExpressionVisitor = new RelationalProjectionBindingExpressionVisitor(this, _sqlTranslator, _model);
_sqlExpressionFactory = parentVisitor._sqlExpressionFactory;
_subquery = true;
}
@@ -75,12 +75,29 @@ protected override Expression VisitMethodCall(MethodCallExpression methodCallExp
return CreateShapedQueryExpression(queryable.ElementType, sql, methodCallExpression.Arguments[2]);
}
+ var dbFunction = this._model.FindDbFunction(methodCallExpression.Method);
+ if (dbFunction != null && dbFunction.IsIQueryable)
+ {
+ return CreateShapedQueryExpression(methodCallExpression);
+ }
+
return base.VisitMethodCall(methodCallExpression);
}
protected override QueryableMethodTranslatingExpressionVisitor CreateSubqueryVisitor()
=> new RelationalQueryableMethodTranslatingExpressionVisitor(this);
+ protected virtual ShapedQueryExpression CreateShapedQueryExpression([NotNull] MethodCallExpression methodCallExpression)
+ {
+ var sqlFuncExpression = _sqlTranslator.TranslateMethodCall(methodCallExpression) as SqlFunctionExpression;
+
+ var elementType = methodCallExpression.Method.ReturnType.GetGenericArguments()[0];
+ var entityType =_model.FindEntityType(elementType);
+ var queryExpression = _sqlExpressionFactory.Select(entityType, sqlFuncExpression);
+
+ return CreateShapedQueryExpression(entityType, queryExpression);
+ }
+
protected override ShapedQueryExpression CreateShapedQueryExpression(Type elementType)
{
Check.NotNull(elementType, nameof(elementType));
diff --git a/src/EFCore.Relational/Query/RelationalSqlTranslatingExpressionVisitor.cs b/src/EFCore.Relational/Query/RelationalSqlTranslatingExpressionVisitor.cs
index 1bc21038d94..50a0d0391bf 100644
--- a/src/EFCore.Relational/Query/RelationalSqlTranslatingExpressionVisitor.cs
+++ b/src/EFCore.Relational/Query/RelationalSqlTranslatingExpressionVisitor.cs
@@ -9,6 +9,7 @@
using JetBrains.Annotations;
using Microsoft.EntityFrameworkCore.Diagnostics;
using Microsoft.EntityFrameworkCore.Infrastructure;
+using Microsoft.EntityFrameworkCore.Internal;
using Microsoft.EntityFrameworkCore.Metadata;
using Microsoft.EntityFrameworkCore.Query.Internal;
using Microsoft.EntityFrameworkCore.Query.SqlExpressions;
@@ -183,6 +184,30 @@ public virtual SqlExpression TranslateSum([NotNull] Expression expression)
"SUM", new[] { sqlExpression }, inputType, sqlExpression.TypeMapping);
}
+ public virtual Expression TranslateMethodCall([NotNull] MethodCallExpression methodCallExpression)
+ {
+ Check.NotNull(methodCallExpression, nameof(methodCallExpression));
+
+ if (TranslationFailed(methodCallExpression.Object, Visit(methodCallExpression.Object), out var sqlObject))
+ {
+ return null;
+ }
+
+ var arguments = new SqlExpression[methodCallExpression.Arguments.Count];
+ for (var i = 0; i < arguments.Length; i++)
+ {
+ var argument = methodCallExpression.Arguments[i];
+ if (TranslationFailed(argument, Visit(argument), out var sqlArgument))
+ {
+ return null;
+ }
+
+ arguments[i] = sqlArgument;
+ }
+
+ return Dependencies.MethodCallTranslatorProvider.Translate(_model, sqlObject, methodCallExpression.Method, arguments);
+ }
+
private sealed class SqlTypeMappingVerifyingExpressionVisitor : ExpressionVisitor
{
protected override Expression VisitExtension(Expression node)
@@ -190,7 +215,9 @@ protected override Expression VisitExtension(Expression node)
Check.NotNull(node, nameof(node));
if (node is SqlExpression sqlExpression
- && !(node is SqlFragmentExpression))
+ && !(node is SqlFragmentExpression)
+ && !(node is SqlFunctionExpression sqlFunctionExpression
+ && sqlFunctionExpression.Type.IsQueryableType()))
{
if (sqlExpression.TypeMapping == null)
{
@@ -417,24 +444,7 @@ static bool IsAggregateResultWithCustomShaper(MethodInfo method)
}
// MethodCall translators
- if (TranslationFailed(methodCallExpression.Object, Visit(methodCallExpression.Object), out var sqlObject))
- {
- return null;
- }
-
- var arguments = new SqlExpression[methodCallExpression.Arguments.Count];
- for (var i = 0; i < arguments.Length; i++)
- {
- var argument = methodCallExpression.Arguments[i];
- if (TranslationFailed(argument, Visit(argument), out var sqlArgument))
- {
- return null;
- }
-
- arguments[i] = sqlArgument;
- }
-
- return Dependencies.MethodCallTranslatorProvider.Translate(_model, sqlObject, methodCallExpression.Method, arguments);
+ return TranslateMethodCall(methodCallExpression);
}
private static Expression TryRemoveImplicitConvert(Expression expression)
@@ -597,7 +607,7 @@ protected override Expression VisitLambda(Expression node)
{
Check.NotNull(node, nameof(node));
- return null;
+ return node.Body != null ? Visit(node.Body) : null;
}
protected override Expression VisitConstant(ConstantExpression constantExpression)
@@ -693,6 +703,9 @@ protected override Expression VisitUnary(UnaryExpression unaryExpression)
}
break;
+
+ case ExpressionType.Quote:
+ return operand;
}
return null;
diff --git a/src/EFCore.Relational/Query/SqlExpressionFactory.cs b/src/EFCore.Relational/Query/SqlExpressionFactory.cs
index 9703701c969..bc1cf9d9c5a 100644
--- a/src/EFCore.Relational/Query/SqlExpressionFactory.cs
+++ b/src/EFCore.Relational/Query/SqlExpressionFactory.cs
@@ -670,6 +670,14 @@ public virtual SelectExpression Select(IEntityType entityType, string sql, Expre
return selectExpression;
}
+ public virtual SelectExpression Select(IEntityType entityType, SqlFunctionExpression expression)
+ {
+ var selectExpression = new SelectExpression(entityType, expression);
+ AddConditions(selectExpression, entityType);
+
+ return selectExpression;
+ }
+
private void AddConditions(
SelectExpression selectExpression,
IEntityType entityType,
diff --git a/src/EFCore.Relational/Query/SqlExpressionVisitor.cs b/src/EFCore.Relational/Query/SqlExpressionVisitor.cs
index 36b26110980..20226c66f15 100644
--- a/src/EFCore.Relational/Query/SqlExpressionVisitor.cs
+++ b/src/EFCore.Relational/Query/SqlExpressionVisitor.cs
@@ -65,6 +65,9 @@ protected override Expression VisitExtension(Expression extensionExpression)
case ProjectionExpression projectionExpression:
return VisitProjection(projectionExpression);
+ case QuerableSqlFunctionExpression queryableFunctionExpression:
+ return VisitQueryableSqlFunctionExpression(queryableFunctionExpression);
+
case RowNumberExpression rowNumberExpression:
return VisitRowNumber(rowNumberExpression);
@@ -117,6 +120,7 @@ protected override Expression VisitExtension(Expression extensionExpression)
protected abstract Expression VisitOrdering([NotNull] OrderingExpression orderingExpression);
protected abstract Expression VisitOuterApply([NotNull] OuterApplyExpression outerApplyExpression);
protected abstract Expression VisitProjection([NotNull] ProjectionExpression projectionExpression);
+ protected abstract Expression VisitQueryableSqlFunctionExpression([NotNull] QuerableSqlFunctionExpression queryableFunctionExpression);
protected abstract Expression VisitRowNumber([NotNull] RowNumberExpression rowNumberExpression);
protected abstract Expression VisitScalarSubquery([NotNull] ScalarSubqueryExpression scalarSubqueryExpression);
protected abstract Expression VisitSelect([NotNull] SelectExpression selectExpression);
diff --git a/src/EFCore.Relational/Query/SqlExpressions/QuerableSqlFunctionExpression.cs b/src/EFCore.Relational/Query/SqlExpressions/QuerableSqlFunctionExpression.cs
new file mode 100644
index 00000000000..5ae37ce65fe
--- /dev/null
+++ b/src/EFCore.Relational/Query/SqlExpressions/QuerableSqlFunctionExpression.cs
@@ -0,0 +1,45 @@
+// Copyright (c) .NET Foundation. All rights reserved.
+// Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information.
+
+using System;
+using JetBrains.Annotations;
+using Microsoft.EntityFrameworkCore.Utilities;
+
+namespace Microsoft.EntityFrameworkCore.Query.SqlExpressions
+{
+ ///
+ /// Represents a SQL Table Valued Fuction in the sql generation tree.
+ ///
+ public class QuerableSqlFunctionExpression : TableExpressionBase
+ {
+ public QuerableSqlFunctionExpression([NotNull] SqlFunctionExpression expression, [CanBeNull] string alias)
+ : base(alias)
+ {
+ Check.NotNull(expression, nameof(expression));
+
+ SqlFunctionExpression = expression;
+ }
+
+ public virtual SqlFunctionExpression SqlFunctionExpression { get; }
+
+ public override void Print(ExpressionPrinter expressionPrinter)
+ {
+ expressionPrinter.Append("(");
+ expressionPrinter.Visit(SqlFunctionExpression);
+ expressionPrinter.AppendLine()
+ .AppendLine($") AS {Alias}");
+ }
+
+ public override bool Equals(object obj)
+ => obj != null
+ && (ReferenceEquals(this, obj)
+ || obj is QuerableSqlFunctionExpression queryableExpression
+ && Equals(queryableExpression));
+
+ private bool Equals(QuerableSqlFunctionExpression queryableExpression)
+ => base.Equals(queryableExpression)
+ && SqlFunctionExpression.Equals(queryableExpression.SqlFunctionExpression);
+
+ public override int GetHashCode() => HashCode.Combine(base.GetHashCode(), SqlFunctionExpression);
+ }
+}
diff --git a/src/EFCore.Relational/Query/SqlExpressions/SelectExpression.cs b/src/EFCore.Relational/Query/SqlExpressions/SelectExpression.cs
index 70497e09b13..ff5f6addc1a 100644
--- a/src/EFCore.Relational/Query/SqlExpressions/SelectExpression.cs
+++ b/src/EFCore.Relational/Query/SqlExpressions/SelectExpression.cs
@@ -7,6 +7,7 @@
using System.Linq.Expressions;
using System.Reflection;
using JetBrains.Annotations;
+using Microsoft.EntityFrameworkCore.Diagnostics;
using Microsoft.EntityFrameworkCore.Metadata;
using Microsoft.EntityFrameworkCore.Metadata.Internal;
using Microsoft.EntityFrameworkCore.Utilities;
@@ -79,6 +80,13 @@ internal SelectExpression(IEntityType entityType, string sql, Expression argumen
{
}
+ internal SelectExpression(IEntityType entityType, SqlFunctionExpression expression)
+ : this(
+ entityType, new QuerableSqlFunctionExpression(expression,
+ entityType.GetTableName().ToLower().Substring(0, 1)))
+ {
+ }
+
private SelectExpression(IEntityType entityType, TableExpressionBase tableExpression)
: base(null)
{
@@ -927,6 +935,13 @@ public Expression ApplyCollectionJoin(
var parentIdentifier = GetIdentifierAccessor(_identifier);
var outerIdentifier = GetIdentifierAccessor(_identifier.Concat(_childIdentifiers));
innerSelectExpression.ApplyProjection();
+
+ if (innerSelectExpression._identifier.Count == 0 && innerSelectExpression.Tables.FirstOrDefault(
+ t => t is QuerableSqlFunctionExpression expression && expression.SqlFunctionExpression.Arguments.Count != 0) is QuerableSqlFunctionExpression queryableFunctionExpression)
+ {
+ throw new InvalidOperationException(RelationalStrings.DbFunctionProjectedCollectionMustHavePK(queryableFunctionExpression.SqlFunctionExpression.Name));
+ }
+
var selfIdentifier = innerSelectExpression.GetIdentifierAccessor(innerSelectExpression._identifier);
if (collectionIndex == 0)
diff --git a/src/EFCore.SqlServer/Query/Internal/SearchConditionConvertingExpressionVisitor.cs b/src/EFCore.SqlServer/Query/Internal/SearchConditionConvertingExpressionVisitor.cs
index a90cf2b6d58..8feaf62b939 100644
--- a/src/EFCore.SqlServer/Query/Internal/SearchConditionConvertingExpressionVisitor.cs
+++ b/src/EFCore.SqlServer/Query/Internal/SearchConditionConvertingExpressionVisitor.cs
@@ -341,6 +341,13 @@ protected override Expression VisitSqlFunction(SqlFunctionExpression sqlFunction
return ApplyConversion(newFunction, condition);
}
+ protected override Expression VisitQueryableSqlFunctionExpression(QuerableSqlFunctionExpression queryableFunctionExpression)
+ {
+ Check.NotNull(queryableFunctionExpression, nameof(queryableFunctionExpression));
+
+ return queryableFunctionExpression;
+ }
+
protected override Expression VisitSqlParameter(SqlParameterExpression sqlParameterExpression)
{
Check.NotNull(sqlParameterExpression, nameof(sqlParameterExpression));
diff --git a/src/EFCore/DbContext.cs b/src/EFCore/DbContext.cs
index 722ec554f45..5ec49ed8aed 100644
--- a/src/EFCore/DbContext.cs
+++ b/src/EFCore/DbContext.cs
@@ -6,6 +6,7 @@
using System.ComponentModel;
using System.Diagnostics;
using System.Linq;
+using System.Linq.Expressions;
using System.Threading;
using System.Threading.Tasks;
using JetBrains.Annotations;
@@ -1656,6 +1657,22 @@ public virtual ValueTask FindAsync([CanBeNull] object[] keyVal
///
IServiceProvider IInfrastructure.Instance => InternalServiceProvider;
+ ///
+ /// Creates a query expression, which represents a function call, against the query store.
+ ///
+ /// The result type of the query expression
+ /// The query expression to create.
+ /// An IQueryable representing the query.
+ protected virtual IQueryable CreateQuery([NotNull] Expression>> expression)
+ {
+ //should we add this method as an extension in relational? That would require making DbContextDependencies public.
+ //Is there a 3rd way?
+
+ Check.NotNull(expression, nameof(expression));
+
+ return DbContextDependencies.QueryProvider.CreateQuery(expression.Body);
+ }
+
#region Hidden System.Object members
///
diff --git a/src/EFCore/Infrastructure/EntityFrameworkServicesBuilder.cs b/src/EFCore/Infrastructure/EntityFrameworkServicesBuilder.cs
index 7f4db6e8fb9..107ad57a62c 100644
--- a/src/EFCore/Infrastructure/EntityFrameworkServicesBuilder.cs
+++ b/src/EFCore/Infrastructure/EntityFrameworkServicesBuilder.cs
@@ -140,7 +140,9 @@ public static readonly IDictionary CoreServices
},
{ typeof(ISingletonOptions), new ServiceCharacteristics(ServiceLifetime.Singleton, multipleRegistrations: true) },
{ typeof(IConventionSetPlugin), new ServiceCharacteristics(ServiceLifetime.Scoped, multipleRegistrations: true) },
- { typeof(IResettableService), new ServiceCharacteristics(ServiceLifetime.Scoped, multipleRegistrations: true) }
+ { typeof(IResettableService), new ServiceCharacteristics(ServiceLifetime.Scoped, multipleRegistrations: true) },
+ { typeof(INavigationExpandingExpressionVisitorFactory), new ServiceCharacteristics(ServiceLifetime.Singleton) },
+
};
///
@@ -265,6 +267,7 @@ public virtual EntityFrameworkServicesBuilder TryAddCoreServices()
TryAdd();
TryAdd();
TryAdd();
+ TryAdd();
TryAdd(p => p.GetService()?.FindExtension()?.DbContextLogger ?? new NullDbContextLogger());
diff --git a/src/EFCore/Query/EvaluatableExpressionFilter.cs b/src/EFCore/Query/EvaluatableExpressionFilter.cs
index 395610dbdcb..51a2d66f860 100644
--- a/src/EFCore/Query/EvaluatableExpressionFilter.cs
+++ b/src/EFCore/Query/EvaluatableExpressionFilter.cs
@@ -118,5 +118,7 @@ public virtual bool IsEvaluatableExpression(Expression expression, IModel model)
return true;
}
+
+ public virtual bool IsQueryableFunction(Expression expression, IModel model) => false;
}
}
diff --git a/src/EFCore/Query/IEvaluatableExpressionFilter.cs b/src/EFCore/Query/IEvaluatableExpressionFilter.cs
index d461a93e17e..e39a4146e05 100644
--- a/src/EFCore/Query/IEvaluatableExpressionFilter.cs
+++ b/src/EFCore/Query/IEvaluatableExpressionFilter.cs
@@ -27,5 +27,13 @@ public interface IEvaluatableExpressionFilter
/// The model.
/// True if the expression can be evaluated; false otherwise.
bool IsEvaluatableExpression([NotNull] Expression expression, [NotNull] IModel model);
+
+ ///
+ /// Checks whether the given expression is a queryable function
+ ///
+ /// The expression.
+ /// The model.
+ /// True if the expression is a queryable function
+ bool IsQueryableFunction([NotNull] Expression expression, [NotNull] IModel model);
}
}
diff --git a/src/EFCore/Query/INavigationExpandingExpressionVisitorFactory.cs b/src/EFCore/Query/INavigationExpandingExpressionVisitorFactory.cs
new file mode 100644
index 00000000000..ce22ad612bd
--- /dev/null
+++ b/src/EFCore/Query/INavigationExpandingExpressionVisitorFactory.cs
@@ -0,0 +1,14 @@
+// Copyright (c) .NET Foundation. All rights reserved.
+// Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information.
+
+using System.Diagnostics.CodeAnalysis;
+using Microsoft.EntityFrameworkCore.Query.Internal;
+
+namespace Microsoft.EntityFrameworkCore.Query
+{
+ public interface INavigationExpandingExpressionVisitorFactory
+ {
+ NavigationExpandingExpressionVisitor Create([NotNull] QueryCompilationContext queryCompilationContext,
+ [NotNull] IEvaluatableExpressionFilter evaluatableExpressionFilter);
+ }
+}
diff --git a/src/EFCore/Query/Internal/NavigationExpandingExpressionVisitor.cs b/src/EFCore/Query/Internal/NavigationExpandingExpressionVisitor.cs
index f7a34f72d00..4fd77759ee5 100644
--- a/src/EFCore/Query/Internal/NavigationExpandingExpressionVisitor.cs
+++ b/src/EFCore/Query/Internal/NavigationExpandingExpressionVisitor.cs
@@ -48,6 +48,8 @@ private readonly Dictionary _parameterizedQueryFi
private readonly Parameters _parameters = new Parameters();
+ protected virtual QueryCompilationContext QueryCompilationContext => _queryCompilationContext;
+
public NavigationExpandingExpressionVisitor(
[NotNull] QueryCompilationContext queryCompilationContext,
[NotNull] IEvaluatableExpressionFilter evaluatableExpressionFilter)
@@ -1356,7 +1358,7 @@ static bool IsNumericType(Type type)
}
}
- private NavigationExpansionExpression CreateNavigationExpansionExpression(Expression sourceExpression, IEntityType entityType)
+ protected virtual NavigationExpansionExpression CreateNavigationExpansionExpression([NotNull] Expression sourceExpression, [NotNull] IEntityType entityType)
{
var entityReference = new EntityReference(entityType);
PopulateEagerLoadedNavigations(entityReference.IncludePaths);
diff --git a/src/EFCore/Query/Internal/NavigationExpandingExpressionVisitorFactory.cs b/src/EFCore/Query/Internal/NavigationExpandingExpressionVisitorFactory.cs
new file mode 100644
index 00000000000..39774cf1d1b
--- /dev/null
+++ b/src/EFCore/Query/Internal/NavigationExpandingExpressionVisitorFactory.cs
@@ -0,0 +1,18 @@
+// Copyright (c) .NET Foundation. All rights reserved.
+// Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information.
+
+using System;
+using System.Collections.Generic;
+using System.Text;
+
+namespace Microsoft.EntityFrameworkCore.Query.Internal
+{
+ public class NavigationExpandingExpressionVisitorFactory : INavigationExpandingExpressionVisitorFactory
+ {
+ public virtual NavigationExpandingExpressionVisitor Create(
+ QueryCompilationContext queryCompilationContext, IEvaluatableExpressionFilter evaluatableExpressionFilter)
+ {
+ return new NavigationExpandingExpressionVisitor(queryCompilationContext, evaluatableExpressionFilter);
+ }
+ }
+}
diff --git a/src/EFCore/Query/Internal/ParameterExtractingExpressionVisitor.cs b/src/EFCore/Query/Internal/ParameterExtractingExpressionVisitor.cs
index 2b2ecf1c48c..582489c1a2c 100644
--- a/src/EFCore/Query/Internal/ParameterExtractingExpressionVisitor.cs
+++ b/src/EFCore/Query/Internal/ParameterExtractingExpressionVisitor.cs
@@ -3,15 +3,14 @@
using System;
using System.Collections.Generic;
+using System.Diagnostics.CodeAnalysis;
using System.Linq;
using System.Linq.Expressions;
using System.Reflection;
using System.Runtime.CompilerServices;
-using JetBrains.Annotations;
using Microsoft.EntityFrameworkCore.Diagnostics;
using Microsoft.EntityFrameworkCore.Metadata;
using Microsoft.EntityFrameworkCore.Storage;
-using Microsoft.EntityFrameworkCore.Utilities;
namespace Microsoft.EntityFrameworkCore.Query.Internal
{
@@ -55,12 +54,10 @@ public ParameterExtractingExpressionVisitor(
{
_evaluatableExpressionFindingExpressionVisitor
= new EvaluatableExpressionFindingExpressionVisitor(evaluatableExpressionFilter, model);
-
_parameterValues = parameterValues;
_logger = logger;
_parameterize = parameterize;
_generateContextAccessors = generateContextAccessors;
-
if (_generateContextAccessors)
{
_contextParameterReplacingExpressionVisitor = new ContextParameterReplacingExpressionVisitor(contextType);
@@ -153,8 +150,6 @@ private bool PreserveConvertNode(Expression expression)
///
protected override Expression VisitConditional(ConditionalExpression conditionalExpression)
{
- Check.NotNull(conditionalExpression, nameof(conditionalExpression));
-
var newTestExpression = TryGetConstantValue(conditionalExpression.Test) ?? Visit(conditionalExpression.Test);
if (newTestExpression is ConstantExpression constantTestExpression
@@ -179,8 +174,6 @@ protected override Expression VisitConditional(ConditionalExpression conditional
///
protected override Expression VisitBinary(BinaryExpression binaryExpression)
{
- Check.NotNull(binaryExpression, nameof(binaryExpression));
-
switch (binaryExpression.NodeType)
{
case ExpressionType.Coalesce:
@@ -251,8 +244,6 @@ private static bool ShortCircuitLogicalExpression(Expression expression, Express
///
protected override Expression VisitConstant(ConstantExpression constantExpression)
{
- Check.NotNull(constantExpression, nameof(constantExpression));
-
if (constantExpression.Value is IDetachableContext detachableContext)
{
var queryProvider = ((IQueryable)constantExpression.Value).Provider;
@@ -522,8 +513,6 @@ public override Expression Visit(Expression expression)
protected override Expression VisitLambda(Expression lambdaExpression)
{
- Check.NotNull(lambdaExpression, nameof(lambdaExpression));
-
var oldInLambda = _inLambda;
_inLambda = true;
@@ -537,8 +526,6 @@ protected override Expression VisitLambda(Expression lambdaExpression)
protected override Expression VisitMemberInit(MemberInitExpression memberInitExpression)
{
- Check.NotNull(memberInitExpression, nameof(memberInitExpression));
-
Visit(memberInitExpression.Bindings, VisitMemberBinding);
// Cannot make parameter for NewExpression if Bindings cannot be evaluated
@@ -552,8 +539,6 @@ protected override Expression VisitMemberInit(MemberInitExpression memberInitExp
protected override Expression VisitListInit(ListInitExpression listInitExpression)
{
- Check.NotNull(listInitExpression, nameof(listInitExpression));
-
Visit(listInitExpression.Initializers, VisitElementInit);
// Cannot make parameter for NewExpression if Initializers cannot be evaluated
@@ -567,7 +552,11 @@ protected override Expression VisitListInit(ListInitExpression listInitExpressio
protected override Expression VisitMethodCall(MethodCallExpression methodCallExpression)
{
- Check.NotNull(methodCallExpression, nameof(methodCallExpression));
+ if (_evaluatableExpressionFilter.IsQueryableFunction(methodCallExpression, _model)
+ && !_inLambda)
+ {
+ _evaluatable = false;
+ }
Visit(methodCallExpression.Object);
var parameterInfos = methodCallExpression.Method.GetParameters();
@@ -608,18 +597,13 @@ protected override Expression VisitMethodCall(MethodCallExpression methodCallExp
protected override Expression VisitMember(MemberExpression memberExpression)
{
- Check.NotNull(memberExpression, nameof(memberExpression));
-
_containsClosure = memberExpression.Expression != null
|| !(memberExpression.Member is FieldInfo fieldInfo && fieldInfo.IsInitOnly);
-
return base.VisitMember(memberExpression);
}
protected override Expression VisitParameter(ParameterExpression parameterExpression)
{
- Check.NotNull(parameterExpression, nameof(parameterExpression));
-
_evaluatable = _allowedParameters.Contains(parameterExpression);
return base.VisitParameter(parameterExpression);
@@ -627,8 +611,6 @@ protected override Expression VisitParameter(ParameterExpression parameterExpres
protected override Expression VisitConstant(ConstantExpression constantExpression)
{
- Check.NotNull(constantExpression, nameof(constantExpression));
-
_evaluatable = !(constantExpression.Value is IDetachableContext)
&& !(constantExpression.Value is IQueryable);
diff --git a/src/EFCore/Query/Internal/QueryTranslationPreprocessorFactory.cs b/src/EFCore/Query/Internal/QueryTranslationPreprocessorFactory.cs
index 0ec15437e55..34ff2067b4a 100644
--- a/src/EFCore/Query/Internal/QueryTranslationPreprocessorFactory.cs
+++ b/src/EFCore/Query/Internal/QueryTranslationPreprocessorFactory.cs
@@ -20,17 +20,20 @@ namespace Microsoft.EntityFrameworkCore.Query.Internal
public class QueryTranslationPreprocessorFactory : IQueryTranslationPreprocessorFactory
{
private readonly QueryTranslationPreprocessorDependencies _dependencies;
+ private readonly INavigationExpandingExpressionVisitorFactory _navigationExpandingExpressionVisitorFactory;
- public QueryTranslationPreprocessorFactory([NotNull] QueryTranslationPreprocessorDependencies dependencies)
+ public QueryTranslationPreprocessorFactory([NotNull] QueryTranslationPreprocessorDependencies dependencies,
+ [NotNull] INavigationExpandingExpressionVisitorFactory navigationExpandingExpressionVisitorFactory)
{
_dependencies = dependencies;
+ _navigationExpandingExpressionVisitorFactory = navigationExpandingExpressionVisitorFactory;
}
public virtual QueryTranslationPreprocessor Create(QueryCompilationContext queryCompilationContext)
{
Check.NotNull(queryCompilationContext, nameof(queryCompilationContext));
- return new QueryTranslationPreprocessor(_dependencies, queryCompilationContext);
+ return new QueryTranslationPreprocessor(_dependencies, queryCompilationContext, _navigationExpandingExpressionVisitorFactory);
}
}
}
diff --git a/src/EFCore/Query/QueryTranslationPreprocessor.cs b/src/EFCore/Query/QueryTranslationPreprocessor.cs
index fc284dc4d24..f827da997f9 100644
--- a/src/EFCore/Query/QueryTranslationPreprocessor.cs
+++ b/src/EFCore/Query/QueryTranslationPreprocessor.cs
@@ -11,16 +11,19 @@ namespace Microsoft.EntityFrameworkCore.Query
public class QueryTranslationPreprocessor
{
private readonly QueryCompilationContext _queryCompilationContext;
+ private readonly INavigationExpandingExpressionVisitorFactory _navigationExpandingExpressionVisitorFactory;
public QueryTranslationPreprocessor(
[NotNull] QueryTranslationPreprocessorDependencies dependencies,
- [NotNull] QueryCompilationContext queryCompilationContext)
+ [NotNull] QueryCompilationContext queryCompilationContext,
+ [NotNull] INavigationExpandingExpressionVisitorFactory navigationExpandingExpressionVisitorFactory)
{
Check.NotNull(dependencies, nameof(dependencies));
Check.NotNull(queryCompilationContext, nameof(queryCompilationContext));
Dependencies = dependencies;
_queryCompilationContext = queryCompilationContext;
+ _navigationExpandingExpressionVisitorFactory = navigationExpandingExpressionVisitorFactory;
}
protected virtual QueryTranslationPreprocessorDependencies Dependencies { get; }
@@ -37,8 +40,8 @@ public virtual Expression Process([NotNull] Expression query)
query = new NullCheckRemovingExpressionVisitor().Visit(query);
query = new EntityEqualityRewritingExpressionVisitor(_queryCompilationContext).Rewrite(query);
query = new SubqueryMemberPushdownExpressionVisitor(_queryCompilationContext.Model).Visit(query);
- query = new NavigationExpandingExpressionVisitor(_queryCompilationContext, Dependencies.EvaluatableExpressionFilter).Expand(
- query);
+ query = _navigationExpandingExpressionVisitorFactory.Create(_queryCompilationContext, Dependencies.EvaluatableExpressionFilter)
+ .Expand(query);
query = new FunctionPreprocessingExpressionVisitor().Visit(query);
return query;
diff --git a/src/EFCore/Query/ShapedQueryExpression.cs b/src/EFCore/Query/ShapedQueryExpression.cs
index c851a881f48..ccca2d2a120 100644
--- a/src/EFCore/Query/ShapedQueryExpression.cs
+++ b/src/EFCore/Query/ShapedQueryExpression.cs
@@ -33,7 +33,7 @@ private ShapedQueryExpression(
public virtual Expression ShaperExpression { get; }
public override Type Type => ResultCardinality == ResultCardinality.Enumerable
- ? typeof(IQueryable<>).MakeGenericType(ShaperExpression.Type)
+ ? typeof(IQueryable<>).MakeGenericType(ShaperExpression.Type)
: ShaperExpression.Type;
public sealed override ExpressionType NodeType => ExpressionType.Extension;
diff --git a/test/EFCore.Relational.Specification.Tests/Query/UdfDbFunctionTestBase.cs b/test/EFCore.Relational.Specification.Tests/Query/UdfDbFunctionTestBase.cs
index 4e089347708..d0d17d5dad7 100644
--- a/test/EFCore.Relational.Specification.Tests/Query/UdfDbFunctionTestBase.cs
+++ b/test/EFCore.Relational.Specification.Tests/Query/UdfDbFunctionTestBase.cs
@@ -4,6 +4,8 @@
using System;
using System.Collections.Generic;
using System.Linq;
+using System.Linq.Expressions;
+using System.Threading.Tasks;
using Microsoft.EntityFrameworkCore.Diagnostics;
using Microsoft.EntityFrameworkCore.Query.SqlExpressions;
using Microsoft.EntityFrameworkCore.TestUtilities;
@@ -28,14 +30,46 @@ public class Customer
public string FirstName { get; set; }
public string LastName { get; set; }
public List Orders { get; set; }
+ public List Addresses { get; set; }
}
public class Order
{
public int Id { get; set; }
public string Name { get; set; }
- public int ItemCount { get; set; }
public DateTime OrderDate { get; set; }
+
+ public int CustomerId { get; set; }
+
+ public Customer Customer { get; set; }
+ public List Items { get; set; }
+ }
+
+ public class LineItem
+ {
+ public int Id { get; set; }
+ public int OrderId { get; set; }
+ public int ProductId { get; set; }
+ public int Quantity { get; set; }
+
+ public Order Order { get; set; }
+ public Product Product { get; set; }
+ }
+
+ public class Product
+ {
+ public int Id { get; set; }
+ public string Name { get; set; }
+ }
+
+ public class Address
+ {
+ public int Id { get; set; }
+ public string Street { get; set; }
+ public string City { get; set; }
+ public string State { get; set; }
+
+ public int CustomerId { get; set; }
public Customer Customer { get; set; }
}
@@ -45,6 +79,8 @@ protected class UDFSqlContext : PoolableDbContext
public DbSet Customers { get; set; }
public DbSet Orders { get; set; }
+ public DbSet Products { get; set; }
+ public DbSet Addresses { get; set; }
#endregion
@@ -103,6 +139,77 @@ public enum ReportingPeriod
[DbFunction(Schema = "dbo")]
public static string IdentityString(string s) => throw new Exception();
+ public int AddValues(int a, int b)
+ {
+ throw new NotImplementedException();
+ }
+
+ public int AddValues(Expression> a, int b)
+ {
+ throw new NotImplementedException();
+ }
+
+ #region Querable Functions
+
+ public class OrderByYear
+ {
+ public int? CustomerId { get; set; }
+ public int? Count { get; set; }
+ public int? Year { get; set; }
+ }
+
+ public class MultProductOrders
+ {
+ //public Order Order { get; set; }
+ public int OrderId { get; set; }
+
+ public Customer Customer { get; set; }
+ public int CustomerId { get; set; }
+
+ public DateTime OrderDate { get; set; }
+ }
+
+ public IQueryable GetCustomerOrderCountByYear(int customerId)
+ {
+ return CreateQuery(() => GetCustomerOrderCountByYear(customerId));
+ }
+
+ public IQueryable GetCustomerOrderCountByYear(Expression> customerId2)
+ {
+ return CreateQuery(() => GetCustomerOrderCountByYear(customerId2));
+ }
+
+ public class TopSellingProduct
+ {
+ public Product Product { get; set; }
+ public int? ProductId { get; set; }
+
+ public int? AmountSold { get; set; }
+ }
+
+ public IQueryable GetTopTwoSellingProducts()
+ {
+ return CreateQuery(() => GetTopTwoSellingProducts());
+ }
+
+ public IQueryable GetTopSellingProductsForCustomer(int customerId)
+ {
+ return CreateQuery(() => GetTopSellingProductsForCustomer(customerId));
+ }
+
+ public IQueryable GetTopTwoSellingProductsCustomTranslation()
+ {
+ return CreateQuery(() => GetTopTwoSellingProductsCustomTranslation());
+ }
+
+ public IQueryable GetOrdersWithMultipleProducts(int customerId)
+ {
+ return CreateQuery(() => GetOrdersWithMultipleProducts(customerId));
+ }
+
+
+ #endregion
+
#endregion
public UDFSqlContext(DbContextOptions options)
@@ -133,6 +240,8 @@ protected override void OnModelCreating(ModelBuilder modelBuilder)
modelBuilder.HasDbFunction(methodInfo)
.HasTranslation(args => SqlFunctionExpression.Create("len", args, methodInfo.ReturnType, null));
+ modelBuilder.HasDbFunction(typeof(UDFSqlContext).GetMethod(nameof(AddValues), new[] { typeof(int), typeof(int) }));
+
//Instance
modelBuilder.HasDbFunction(typeof(UDFSqlContext).GetMethod(nameof(CustomerOrderCountInstance)))
.HasName("CustomerOrderCount");
@@ -154,6 +263,23 @@ protected override void OnModelCreating(ModelBuilder modelBuilder)
modelBuilder.HasDbFunction(methodInfo2)
.HasTranslation(args => SqlFunctionExpression.Create("len", args, methodInfo2.ReturnType, null));
+
+ modelBuilder.Entity().ToQueryableFunctionResultType().HasKey(mpo => mpo.OrderId);
+ //modelBuilder.Entity().HasOne(mpo => mpo.Order).WithOne().HasForeignKey(o => o.Id);
+
+ modelBuilder.Entity().ToQueryableFunctionResultType().HasNoKey();
+ modelBuilder.Entity().ToQueryableFunctionResultType().HasNoKey();
+
+ //Table
+ modelBuilder.HasDbFunction(typeof(UDFSqlContext).GetMethod(nameof(GetCustomerOrderCountByYear), new[] { typeof(int) }));
+ modelBuilder.HasDbFunction(typeof(UDFSqlContext).GetMethod(nameof(GetCustomerOrderCountByYear), new[] { typeof(Expression>) }));
+ modelBuilder.HasDbFunction(typeof(UDFSqlContext).GetMethod(nameof(GetTopTwoSellingProducts)));
+ modelBuilder.HasDbFunction(typeof(UDFSqlContext).GetMethod(nameof(GetTopSellingProductsForCustomer)));
+
+ modelBuilder.HasDbFunction(typeof(UDFSqlContext).GetMethod(nameof(GetOrdersWithMultipleProducts)));
+
+ modelBuilder.HasDbFunction(typeof(UDFSqlContext).GetMethod(nameof(GetTopTwoSellingProductsCustomTranslation)))
+ .HasTranslation(args => SqlFunctionExpression.Create("dbo", "GetTopTwoSellingProducts", args, typeof(TopSellingProduct), null));
}
}
@@ -170,68 +296,105 @@ protected override void Seed(DbContext context)
{
context.Database.EnsureCreatedResiliently();
+ var product1 = new Product { Name = "Product1" };
+ var product2 = new Product { Name = "Product2" };
+ var product3 = new Product { Name = "Product3" };
+ var product4 = new Product { Name = "Product4" };
+ var product5 = new Product { Name = "Product5" };
+
var order11 = new Order
{
- Name = "Order11",
- ItemCount = 4,
- OrderDate = new DateTime(2000, 1, 20)
+ Name = "Order11", OrderDate = new DateTime(2000, 1, 20),
+ Items = new List
+ {
+ new LineItem { Quantity = 5, Product = product1},
+ new LineItem { Quantity = 15, Product = product3}
+ }
};
- var order12 = new Order
- {
- Name = "Order12",
- ItemCount = 8,
- OrderDate = new DateTime(2000, 2, 21)
+
+ var order12 = new Order { Name = "Order12", OrderDate = new DateTime(2000, 2, 21),
+ Items = new List
+ {
+ new LineItem { Quantity = 1, Product = product1},
+ new LineItem { Quantity = 6, Product = product2},
+ new LineItem { Quantity = 200, Product = product3}
+ }
};
- var order13 = new Order
- {
- Name = "Order13",
- ItemCount = 15,
- OrderDate = new DateTime(2000, 3, 20)
+
+ var order13 = new Order { Name = "Order13", OrderDate = new DateTime(2001, 3, 20),
+ Items = new List
+ {
+ new LineItem { Quantity = 50, Product = product4},
+ }
};
- var order21 = new Order
- {
- Name = "Order21",
- ItemCount = 16,
- OrderDate = new DateTime(2000, 4, 21)
+
+ var order21 = new Order { Name = "Order21", OrderDate = new DateTime(2000, 4, 21),
+ Items = new List
+ {
+ new LineItem { Quantity = 1, Product = product1},
+ new LineItem { Quantity = 34, Product = product4},
+ new LineItem { Quantity = 100, Product = product5}
+ }
};
- var order22 = new Order
- {
- Name = "Order22",
- ItemCount = 23,
- OrderDate = new DateTime(2000, 5, 20)
+
+ var order22 = new Order { Name = "Order22", OrderDate = new DateTime(2000, 5, 20),
+ Items = new List
+ {
+ new LineItem { Quantity = 34, Product = product3},
+ new LineItem { Quantity = 100, Product = product4}
+ }
};
- var order31 = new Order
- {
- Name = "Order31",
- ItemCount = 42,
- OrderDate = new DateTime(2000, 6, 21)
+
+ var order31 = new Order { Name = "Order31", OrderDate = new DateTime(2001, 6, 21),
+ Items = new List
+ {
+ new LineItem { Quantity = 5, Product = product5}
+ }
};
+ var address11 = new Address { Street = "1600 Pennsylvania Avenue", City = "Washington", State = "DC" };
+ var address12 = new Address { Street = "742 Evergreen Terrace", City = "SpringField", State = "" };
+ var address21 = new Address { Street = "Apartment 5A, 129 West 81st Street", City = "New York", State = "NY" };
+ var address31 = new Address { Street = "425 Grove Street, Apt 20", City = "New York", State = "NY" };
+ var address32 = new Address { Street = "342 GravelPit Terrace", City = "BedRock", State = "" };
+ var address41 = new Address { Street = "4222 Clinton Way", City = "Los Angles", State = "CA" };
+ var address42 = new Address { Street = "1060 West Addison Street", City = "Chicago", State = "IL" };
+ var address43 = new Address { Street = "112 ½ Beacon Street", City = "Boston", State = "MA" };
+
var customer1 = new Customer
{
FirstName = "Customer",
LastName = "One",
- Orders = new List
- {
- order11,
- order12,
- order13
- }
+ Orders = new List { order11, order12, order13 },
+ Addresses = new List { address11, address12 }
};
+
var customer2 = new Customer
{
FirstName = "Customer",
LastName = "Two",
- Orders = new List { order21, order22 }
+ Orders = new List { order21, order22 },
+ Addresses = new List { address21 }
};
+
var customer3 = new Customer
{
FirstName = "Customer",
LastName = "Three",
- Orders = new List { order31 }
+ Orders = new List { order31 },
+ Addresses = new List { address31, address32 }
+ };
+
+ var customer4 = new Customer
+ {
+ FirstName = "Customer",
+ LastName = "Four",
+ Addresses = new List { address41, address42, address43 }
};
- ((UDFSqlContext)context).Customers.AddRange(customer1, customer2, customer3);
+ ((UDFSqlContext)context).Products.AddRange(product1, product2, product3, product4, product5);
+ ((UDFSqlContext)context).Addresses.AddRange(address11, address12, address21, address31, address32, address41, address42, address43);
+ ((UDFSqlContext)context).Customers.AddRange(customer1, customer2, customer3, customer4);
((UDFSqlContext)context).Orders.AddRange(order11, order12, order13, order21, order22, order31);
}
}
@@ -249,7 +412,7 @@ public virtual void Scalar_Function_Extension_Method_Static()
var len = context.Customers.Count(c => UDFSqlContext.IsDateStatic(c.FirstName) == false);
- Assert.Equal(3, len);
+ Assert.Equal(4, len);
}
[ConditionalFact]
@@ -286,7 +449,7 @@ public virtual void Scalar_Function_Constant_Parameter_Static()
var custs = context.Customers.Select(c => UDFSqlContext.CustomerOrderCountStatic(customerId)).ToList();
- Assert.Equal(3, custs.Count);
+ Assert.Equal(4, custs.Count);
}
[ConditionalFact]
@@ -494,8 +657,8 @@ public virtual void Scalar_Nested_Function_Unwind_Client_Eval_Select_Static()
orderby c.Id
select UDFSqlContext.AddOneStatic(c.Id)).ToList();
- Assert.Equal(3, results.Count);
- Assert.True(results.SequenceEqual(Enumerable.Range(2, 3)));
+ Assert.Equal(4, results.Count);
+ Assert.True(results.SequenceEqual(Enumerable.Range(2, 4)));
}
[ConditionalFact]
@@ -679,7 +842,7 @@ public virtual void Scalar_Function_Extension_Method_Instance()
var len = context.Customers.Count(c => context.IsDateInstance(c.FirstName) == false);
- Assert.Equal(3, len);
+ Assert.Equal(4, len);
}
[ConditionalFact]
@@ -714,7 +877,7 @@ public virtual void Scalar_Function_Constant_Parameter_Instance()
var custs = context.Customers.Select(c => context.CustomerOrderCountInstance(customerId)).ToList();
- Assert.Equal(3, custs.Count);
+ Assert.Equal(4, custs.Count);
}
[ConditionalFact]
@@ -921,8 +1084,8 @@ public virtual void Scalar_Nested_Function_Unwind_Client_Eval_Select_Instance()
orderby c.Id
select context.AddOneInstance(c.Id)).ToList();
- Assert.Equal(3, results.Count);
- Assert.True(results.SequenceEqual(Enumerable.Range(2, 3)));
+ Assert.Equal(4, results.Count);
+ Assert.True(results.SequenceEqual(Enumerable.Range(2, 4)));
}
[ConditionalFact]
@@ -1070,9 +1233,706 @@ public virtual void Scalar_Nested_Function_UDF_BCL_Instance()
#endregion
- private void AssertTranslationFailed(Action testCode)
- => Assert.Contains(
- CoreStrings.TranslationFailed("").Substring(21),
- Assert.Throws(testCode).Message);
+ #region QueryableFunction
+
+ [Fact]
+ public virtual void QF_Anonymous_Collection_No_PK_Throws()
+ {
+ using (var context = CreateContext())
+ {
+ var query = from c in context.Customers
+ select new { c.Id, products = context.GetTopSellingProductsForCustomer(c.Id).ToList() };
+
+ Assert.Contains(
+ RelationalStrings.DbFunctionProjectedCollectionMustHavePK("GetTopSellingProductsForCustomer"),
+ Assert.Throws(() => query.ToList()).Message);
+ }
+ }
+
+ [Fact]
+ public virtual void QF_Anonymous_Collection_No_IQueryable_In_Projection_Throws()
+ {
+ using (var context = CreateContext())
+ {
+ var query = (from c in context.Customers
+ select new { c.Id, orders = context.GetCustomerOrderCountByYear(c.Id) });
+
+ Assert.Contains(
+ RelationalStrings.DbFunctionCantProjectIQueryable(),
+ Assert.Throws(() => query.ToList()).Message);
+ }
+ }
+
+ [Fact]
+ public virtual void QF_Stand_Alone()
+ {
+ using (var context = CreateContext())
+ {
+ var products = (from t in context.GetTopTwoSellingProducts()
+ orderby t.ProductId
+ select t).ToList();
+
+ Assert.Equal(2, products.Count);
+ Assert.Equal(3, products[0].ProductId);
+ Assert.Equal(249, products[0].AmountSold);
+ Assert.Equal(4, products[1].ProductId);
+ Assert.Equal(184, products[1].AmountSold);
+ }
+ }
+
+ [Fact]
+ public virtual void QF_Stand_Alone_With_Translation()
+ {
+ using (var context = CreateContext())
+ {
+ var products = (from t in context.GetTopTwoSellingProductsCustomTranslation()
+ orderby t.ProductId
+ select t).ToList();
+
+ Assert.Equal(2, products.Count);
+ Assert.Equal(3, products[0].ProductId);
+ Assert.Equal(249, products[0].AmountSold);
+ Assert.Equal(4, products[1].ProductId);
+ Assert.Equal(184, products[1].AmountSold);
+ }
+ }
+
+ [Fact]
+ public virtual void QF_Stand_Alone_Parameter()
+ {
+ using (var context = CreateContext())
+ {
+ var orders = (from c in context.GetCustomerOrderCountByYear(1)
+ orderby c.Count descending
+ select c).ToList();
+
+ Assert.Equal(2, orders.Count);
+ Assert.Equal(2, orders[0].Count);
+ Assert.Equal(2000, orders[0].Year);
+ Assert.Equal(1, orders[1].Count);
+ Assert.Equal(2001, orders[1].Year);
+ }
+ }
+
+ [Fact]
+ public virtual void QF_Stand_Alone_Nested()
+ {
+ using (var context = CreateContext())
+ {
+ var orders = (from r in context.GetCustomerOrderCountByYear(() => context.AddValues(-2, 3))
+ orderby r.Count descending
+ select r).ToList();
+
+ Assert.Equal(2, orders.Count);
+ Assert.Equal(2, orders[0].Count);
+ Assert.Equal(2000, orders[0].Year);
+ Assert.Equal(1, orders[1].Count);
+ Assert.Equal(2001, orders[1].Year);
+ }
+ }
+
+ [Fact]
+ public virtual void QF_CrossApply_Correlated_Select_Anonymous()
+ {
+ using (var context = CreateContext())
+ {
+ var orders = (from c in context.Customers
+ from r in context.GetCustomerOrderCountByYear(c.Id)
+ orderby c.Id, r.Year
+ select new
+ {
+ c.Id,
+ c.LastName,
+ r.Year,
+ r.Count
+ }).ToList();
+
+ Assert.Equal(4, orders.Count);
+ Assert.Equal(2, orders[0].Count);
+ Assert.Equal(1, orders[1].Count);
+ Assert.Equal(2, orders[2].Count);
+ Assert.Equal(1, orders[3].Count);
+ Assert.Equal(2000, orders[0].Year);
+ Assert.Equal(2001, orders[1].Year);
+ Assert.Equal(2000, orders[2].Year);
+ Assert.Equal(2001, orders[3].Year);
+ Assert.Equal(1, orders[0].Id);
+ Assert.Equal(1, orders[1].Id);
+ Assert.Equal(2, orders[2].Id);
+ Assert.Equal(3, orders[3].Id);
+ }
+ }
+
+ [Fact]
+ public virtual void QF_Select_Direct_In_Anonymous()
+ {
+ using (var context = CreateContext())
+ {
+ var results = (from c in context.Customers
+ select new
+ {
+ c.Id,
+ Prods = context.GetTopTwoSellingProducts().ToList(),
+ }).ToList();
+
+ Assert.Equal(4, results.Count);
+ Assert.Equal(2, results[0].Prods.Count);
+ Assert.Equal(2, results[1].Prods.Count);
+ Assert.Equal(2, results[2].Prods.Count);
+ Assert.Equal(2, results[3].Prods.Count);
+ }
+ }
+
+ [Fact]
+ public virtual void QF_Select_Correlated_Direct_With_Function_Query_Parameter_Correlated_In_Anonymous()
+ {
+ using (var context = CreateContext())
+ {
+ var cust = (from c in context.Customers
+ where c.Id == 1
+ select new
+ {
+ c.Id,
+ Orders = context.GetOrdersWithMultipleProducts(context.AddValues(c.Id, 1)).ToList()
+ }).ToList();
+
+ Assert.Single(cust);
+
+ Assert.Equal(1, cust[0].Id);
+ Assert.Equal(4, cust[0].Orders[0].OrderId);
+ Assert.Equal(5, cust[0].Orders[1].OrderId);
+ Assert.Equal(new DateTime(2000, 4, 21), cust[0].Orders[0].OrderDate);
+ Assert.Equal(new DateTime(2000, 5, 20), cust[0].Orders[1].OrderDate);
+ }
+ }
+
+ [Fact]
+ public virtual void QF_Select_Correlated_Subquery_In_Anonymous()
+ {
+ using (var context = CreateContext())
+ {
+ var results = (from c in context.Customers
+ select new
+ {
+ c.Id,
+ OrderCountYear = context.GetOrdersWithMultipleProducts(c.Id).Where(o => o.OrderDate.Day == 21).ToList()
+ }).ToList();
+
+ Assert.Equal(4, results.Count);
+ Assert.Equal(1, results[0].Id);
+ Assert.Equal(2, results[1].Id);
+ Assert.Equal(3, results[2].Id);
+ Assert.Equal(4, results[3].Id);
+ Assert.Single(results[0].OrderCountYear);
+ Assert.Single(results[1].OrderCountYear);
+ Assert.Empty(results[2].OrderCountYear);
+ Assert.Empty(results[3].OrderCountYear);
+ }
+ }
+
+ [Fact]
+ public virtual void QF_Select_Correlated_Subquery_In_Anonymous_Nested_With_QF()
+ {
+ using (var context = CreateContext())
+ {
+ var results = (from o in context.Orders
+ join osub in (from c in context.Customers
+ from a in context.GetOrdersWithMultipleProducts(c.Id)
+ select a.OrderId
+ ) on o.Id equals osub
+ select new { o.CustomerId, o.OrderDate }).ToList();
+
+ Assert.Equal(4, results.Count);
+
+ Assert.Equal(1, results[0].CustomerId);
+ Assert.Equal(new DateTime(2000, 1, 20), results[0].OrderDate);
+
+ Assert.Equal(1, results[1].CustomerId);
+ Assert.Equal(new DateTime(2000, 2, 21), results[1].OrderDate);
+
+ Assert.Equal(2, results[2].CustomerId);
+ Assert.Equal(new DateTime(2000, 4, 21), results[2].OrderDate);
+
+ Assert.Equal(2, results[3].CustomerId);
+ Assert.Equal(new DateTime(2000, 5, 20), results[3].OrderDate);
+ }
+ }
+
+ [Fact]
+ public virtual void QF_Select_Correlated_Subquery_In_Anonymous_Nested()
+ {
+ using (var context = CreateContext())
+ {
+ var results = (from c in context.Customers
+ select new
+ {
+ c.Id,
+ OrderCountYear = context.GetOrdersWithMultipleProducts(c.Id).Where(o => o.OrderDate.Day == 21).Select(o => new
+ {
+ OrderCountYearNested = context.GetOrdersWithMultipleProducts(o.CustomerId).ToList(),
+ Prods = context.GetTopTwoSellingProducts().ToList(),
+ }).ToList()
+ }).ToList();
+
+ Assert.Equal(4, results.Count);
+
+ Assert.Single(results[0].OrderCountYear);
+ Assert.Equal(2, results[0].OrderCountYear[0].Prods.Count);
+ Assert.Equal(2, results[0].OrderCountYear[0].OrderCountYearNested.Count);
+
+ Assert.Single(results[1].OrderCountYear);
+ Assert.Equal(2, results[1].OrderCountYear[0].Prods.Count);
+ Assert.Equal(2, results[1].OrderCountYear[0].OrderCountYearNested.Count);
+
+ Assert.Empty(results[2].OrderCountYear);
+
+ Assert.Empty(results[3].OrderCountYear);
+ }
+ }
+
+ [Fact]
+ public virtual void QF_Select_Correlated_Subquery_In_Anonymous_MultipleCollections()
+ {
+ using (var context = CreateContext())
+ {
+ var results = (from c in context.Customers
+ select new
+ {
+ c.Id,
+ Prods = context.GetTopTwoSellingProducts().Where(p => p.AmountSold == 249).Select(p => p.ProductId).ToList(),
+ Addresses = c.Addresses.Where(a => a.State == "NY").ToList()
+ }).ToList();
+
+ Assert.Equal(4, results.Count);
+ Assert.Equal(3, results[0].Prods[0]);
+ Assert.Equal(3, results[1].Prods[0]);
+ Assert.Equal(3, results[2].Prods[0]);
+ Assert.Equal(3, results[3].Prods[0]);
+
+ Assert.Empty(results[0].Addresses);
+ Assert.Equal("Apartment 5A, 129 West 81st Street", results[1].Addresses[0].Street);
+ Assert.Equal("425 Grove Street, Apt 20", results[2].Addresses[0].Street);
+ Assert.Empty(results[3].Addresses);
+ }
+ }
+
+ [Fact]
+ public virtual void QF_Select_NonCorrelated_Subquery_In_Anonymous()
+ {
+ using (var context = CreateContext())
+ {
+ var results = (from c in context.Customers
+ select new
+ {
+ c.Id,
+ Prods = context.GetTopTwoSellingProducts().Where(p => p.AmountSold == 249).Select(p => p.ProductId).ToList(),
+ }).ToList();
+
+ Assert.Equal(4, results.Count);
+ Assert.Equal(3, results[0].Prods[0]);
+ Assert.Equal(3, results[1].Prods[0]);
+ Assert.Equal(3, results[2].Prods[0]);
+ Assert.Equal(3, results[3].Prods[0]);
+ }
+ }
+
+ [Fact]
+ public virtual void QF_Select_NonCorrelated_Subquery_In_Anonymous_Parameter()
+ {
+ using (var context = CreateContext())
+ {
+ var amount = 27;
+
+ var results = (from c in context.Customers
+ select new
+ {
+ c.Id,
+ Prods = context.GetTopTwoSellingProducts().Where(p => p.AmountSold == amount).Select(p => p.ProductId).ToList(),
+ }).ToList();
+
+ Assert.Equal(4, results.Count);
+ Assert.Single(results[0].Prods);
+ Assert.Single(results[1].Prods);
+ Assert.Single(results[2].Prods);
+ Assert.Single(results[3].Prods);
+ }
+ }
+
+ [Fact]
+ public virtual void QF_Correlated_Select_In_Anonymous()
+ {
+ using (var context = CreateContext())
+ {
+ var cust = (from c in context.Customers
+ orderby c.Id
+ select new
+ {
+ c.Id,
+ c.LastName,
+ Orders = context.GetOrdersWithMultipleProducts(c.Id).ToList()
+ }).ToList();
+
+ Assert.Equal(4, cust.Count);
+
+ Assert.Equal(1, cust[0].Id);
+ Assert.Equal(2, cust[0].Orders.Count);
+ Assert.Equal(1, cust[0].Orders[0].OrderId);
+ Assert.Equal(2, cust[0].Orders[1].OrderId);
+ Assert.Equal(new DateTime(2000, 1, 20), cust[0].Orders[0].OrderDate);
+ Assert.Equal(new DateTime(2000, 2, 21), cust[0].Orders[1].OrderDate);
+
+ Assert.Equal(2, cust[1].Id);
+ Assert.Equal(2, cust[1].Orders.Count);
+ Assert.Equal(4, cust[1].Orders[0].OrderId);
+ Assert.Equal(5, cust[1].Orders[1].OrderId);
+ Assert.Equal(new DateTime(2000, 4, 21), cust[1].Orders[0].OrderDate);
+ Assert.Equal(new DateTime(2000, 5, 20), cust[1].Orders[1].OrderDate);
+
+ Assert.Equal(3, cust[2].Id);
+ Assert.Empty(cust[2].Orders);
+
+ Assert.Equal(4, cust[3].Id);
+ Assert.Empty(cust[3].Orders);
+ }
+ }
+
+ [Fact]
+ public virtual void QF_CrossApply_Correlated_Select_Result()
+ {
+ using (var context = CreateContext())
+ {
+ var orders = (from c in context.Customers
+ from r in context.GetCustomerOrderCountByYear(c.Id)
+ orderby r.Count descending, r.Year descending
+ select r).ToList();
+
+ Assert.Equal(4, orders.Count);
+
+ Assert.Equal(4, orders.Count);
+ Assert.Equal(2, orders[0].Count);
+ Assert.Equal(2, orders[1].Count);
+ Assert.Equal(1, orders[2].Count);
+ Assert.Equal(1, orders[3].Count);
+ Assert.Equal(2000, orders[0].Year);
+ Assert.Equal(2000, orders[1].Year);
+ Assert.Equal(2001, orders[2].Year);
+ Assert.Equal(2001, orders[3].Year);
+ }
+ }
+
+ [Fact]
+ public virtual void QF_CrossJoin_Not_Correlated()
+ {
+ using (var context = CreateContext())
+ {
+ var orders = (from c in context.Customers
+ from r in context.GetCustomerOrderCountByYear(2)
+ where c.Id == 2
+ orderby r.Count
+ select new
+ {
+ c.Id,
+ c.LastName,
+ r.Year,
+ r.Count
+ }).ToList();
+
+ Assert.Single(orders);
+
+ Assert.Equal(2, orders[0].Count);
+ Assert.Equal(2000, orders[0].Year);
+ }
+ }
+
+ [Fact]
+ public virtual void QF_CrossJoin_Parameter()
+ {
+ using (var context = CreateContext())
+ {
+ var custId = 2;
+
+ var orders = (from c in context.Customers
+ from r in context.GetCustomerOrderCountByYear(custId)
+ where c.Id == custId
+ orderby r.Count
+ select new
+ {
+ c.Id,
+ c.LastName,
+ r.Year,
+ r.Count
+ }).ToList();
+
+ Assert.Single(orders);
+
+ Assert.Equal(2, orders[0].Count);
+ Assert.Equal(2000, orders[0].Year);
+ }
+ }
+
+ [Fact]
+ public virtual void QF_Join()
+ {
+ using (var context = CreateContext())
+ {
+ var products = (from p in context.Products
+ join r in context.GetTopTwoSellingProducts() on p.Id equals r.ProductId
+ select new
+ {
+ p.Id,
+ p.Name,
+ r.AmountSold
+ }).OrderBy(p => p.Id).ToList();
+
+ Assert.Equal(2, products.Count);
+ Assert.Equal(3, products[0].Id);
+ Assert.Equal("Product3", products[0].Name);
+ Assert.Equal(249, products[0].AmountSold);
+ Assert.Equal(4, products[1].Id);
+ Assert.Equal("Product4", products[1].Name);
+ Assert.Equal(184, products[1].AmountSold);
+ }
+ }
+
+ [Fact]
+ public virtual void QF_LeftJoin_Select_Anonymous()
+ {
+ using (var context = CreateContext())
+ {
+ var products = (from p in context.Products
+ join r in context.GetTopTwoSellingProducts() on p.Id equals r.ProductId into joinTable
+ from j in joinTable.DefaultIfEmpty()
+ orderby p.Id descending
+ select new
+ {
+ p.Id,
+ p.Name,
+ j.AmountSold
+ }).ToList();
+
+ Assert.Equal(5, products.Count);
+ Assert.Equal(5, products[0].Id);
+ Assert.Equal("Product5", products[0].Name);
+ Assert.Null(products[0].AmountSold);
+
+ Assert.Equal(4, products[1].Id);
+ Assert.Equal("Product4", products[1].Name);
+ Assert.Equal(184, products[1].AmountSold);
+
+ Assert.Equal(3, products[2].Id);
+ Assert.Equal("Product3", products[2].Name);
+ Assert.Equal(249, products[2].AmountSold);
+
+ Assert.Equal(2, products[3].Id);
+ Assert.Equal("Product2", products[3].Name);
+ Assert.Null(products[3].AmountSold);
+
+ Assert.Equal(1, products[4].Id);
+ Assert.Equal("Product1", products[4].Name);
+ Assert.Null(products[4].AmountSold);
+ }
+ }
+
+ [Fact]
+ public virtual void QF_LeftJoin_Select_Result()
+ {
+ using (var context = CreateContext())
+ {
+ var products = (from p in context.Products
+ join r in context.GetTopTwoSellingProducts() on p.Id equals r.ProductId into joinTable
+ from j in joinTable.DefaultIfEmpty()
+ orderby p.Id descending
+ select j).ToList();
+
+ Assert.Equal(5, products.Count);
+ Assert.Null(products[0]);
+ Assert.Equal(4, products[1].ProductId);
+ Assert.Equal(184, products[1].AmountSold);
+ Assert.Equal(3, products[2].ProductId);
+ Assert.Equal(249, products[2].AmountSold);
+ Assert.Null(products[3]);
+ Assert.Null(products[4]);
+ }
+ }
+
+ [Fact]
+ public virtual void QF_OuterApply_Correlated_Select_TVF()
+ {
+ using (var context = CreateContext())
+ {
+ var orders = (from c in context.Customers
+ from r in context.GetCustomerOrderCountByYear(c.Id).DefaultIfEmpty()
+ orderby c.Id, r.Year
+ select r).ToList();
+
+ Assert.Equal(5, orders.Count);
+
+ Assert.Equal(2, orders[0].Count);
+ Assert.Equal(1, orders[1].Count);
+ Assert.Equal(2, orders[2].Count);
+ Assert.Equal(1, orders[3].Count);
+ Assert.Null(orders[4]);
+ Assert.Equal(2000, orders[0].Year);
+ Assert.Equal(2001, orders[1].Year);
+ Assert.Equal(2000, orders[2].Year);
+ Assert.Equal(2001, orders[3].Year);
+ Assert.Null(orders[4]);
+ Assert.Equal(1, orders[0].CustomerId);
+ Assert.Equal(1, orders[1].CustomerId);
+ Assert.Equal(2, orders[2].CustomerId);
+ Assert.Equal(3, orders[3].CustomerId);
+ Assert.Null(orders[4]);
+ }
+ }
+
+ [Fact]
+ public virtual void QF_OuterApply_Correlated_Select_DbSet()
+ {
+ using (var context = CreateContext())
+ {
+ var custs = (from c in context.Customers
+ from r in context.GetCustomerOrderCountByYear(c.Id).DefaultIfEmpty()
+ orderby c.Id, r.Year
+ select c).ToList();
+
+ Assert.Equal(5, custs.Count);
+
+ Assert.Equal(1, custs[0].Id);
+ Assert.Equal(1, custs[1].Id);
+ Assert.Equal(2, custs[2].Id);
+ Assert.Equal(3, custs[3].Id);
+ Assert.Equal(4, custs[4].Id);
+ Assert.Equal("One", custs[0].LastName);
+ Assert.Equal("One", custs[1].LastName);
+ Assert.Equal("Two", custs[2].LastName);
+ Assert.Equal("Three", custs[3].LastName);
+ Assert.Equal("Four", custs[4].LastName);
+ }
+ }
+
+ [Fact]
+ public virtual void QF_OuterApply_Correlated_Select_Anonymous()
+ {
+ using (var context = CreateContext())
+ {
+ var orders = (from c in context.Customers
+ from r in context.GetCustomerOrderCountByYear(c.Id).DefaultIfEmpty()
+ orderby c.Id, r.Year
+ select new
+ {
+ c.Id,
+ c.LastName,
+ r.Year,
+ r.Count
+ }).ToList();
+
+ Assert.Equal(5, orders.Count);
+
+ Assert.Equal(1, orders[0].Id);
+ Assert.Equal(1, orders[1].Id);
+ Assert.Equal(2, orders[2].Id);
+ Assert.Equal(3, orders[3].Id);
+ Assert.Equal(4, orders[4].Id);
+ Assert.Equal("One", orders[0].LastName);
+ Assert.Equal("One", orders[1].LastName);
+ Assert.Equal("Two", orders[2].LastName);
+ Assert.Equal("Three", orders[3].LastName);
+ Assert.Equal("Four", orders[4].LastName);
+ Assert.Equal(2, orders[0].Count);
+ Assert.Equal(1, orders[1].Count);
+ Assert.Equal(2, orders[2].Count);
+ Assert.Equal(1, orders[3].Count);
+ Assert.Null(orders[4].Count);
+ Assert.Equal(2000, orders[0].Year);
+ Assert.Equal(2001, orders[1].Year);
+ Assert.Equal(2000, orders[2].Year);
+ Assert.Equal(2001, orders[3].Year);
+ }
+ }
+
+ [Fact]
+ public virtual void QF_Nested()
+ {
+ using (var context = CreateContext())
+ {
+ var custId = 2;
+
+ var orders = (from c in context.Customers
+ from r in context.GetCustomerOrderCountByYear(context.AddValues(1, 1))
+ where c.Id == custId
+ orderby r.Year
+ select new
+ {
+ c.Id,
+ c.LastName,
+ r.Year,
+ r.Count
+ }).ToList();
+
+ Assert.Single(orders);
+
+ Assert.Equal(2, orders[0].Count);
+ Assert.Equal(2000, orders[0].Year);
+ }
+ }
+
+
+ [Fact]
+ public virtual void QF_Correlated_Nested_Func_Call()
+ {
+ var custId = 2;
+
+ using (var context = CreateContext())
+ {
+ var orders = (from c in context.Customers
+ from r in context.GetCustomerOrderCountByYear(context.AddValues(c.Id, 1))
+ where c.Id == custId
+ select new
+ {
+ c.Id,
+ r.Count,
+ r.Year
+ }).ToList();
+
+ Assert.Single(orders);
+
+ Assert.Equal(1, orders[0].Count);
+ Assert.Equal(2001, orders[0].Year);
+ }
+ }
+
+ [Fact]
+ public virtual void QF_Correlated_Func_Call_With_Navigation()
+ {
+ using (var context = CreateContext())
+ {
+ var cust = (from c in context.Customers
+ orderby c.Id
+ select new
+ {
+ c.Id,
+ Orders = context.GetOrdersWithMultipleProducts(c.Id).Select(mpo => new
+ {
+ //how to I setup the PK/FK combo properly for this? Is it even possible?
+ //OrderName = mpo.Order.Name,
+ CustomerName = mpo.Customer.LastName
+ }).ToList()
+ }).ToList();
+
+ Assert.Equal(4, cust.Count);
+ Assert.Equal(2, cust[0].Orders.Count);
+ Assert.Equal("One", cust[0].Orders[0].CustomerName);
+ Assert.Equal(2, cust[1].Orders.Count);
+ Assert.Equal("Two", cust[1].Orders[0].CustomerName);
+ }
+ }
+
+ #endregion
+
+
+ private void AssertTranslationFailed(Action testCode)
+ => Assert.Contains(
+ CoreStrings.TranslationFailed("").Substring(21),
+ Assert.Throws(testCode).Message);
}
}
diff --git a/test/EFCore.Relational.Tests/Metadata/DbFunctionMetadataTests.cs b/test/EFCore.Relational.Tests/Metadata/DbFunctionMetadataTests.cs
index ed44eec80c2..1ab59196ab5 100644
--- a/test/EFCore.Relational.Tests/Metadata/DbFunctionMetadataTests.cs
+++ b/test/EFCore.Relational.Tests/Metadata/DbFunctionMetadataTests.cs
@@ -2,6 +2,7 @@
// Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information.
using System;
+using System.Linq;
using System.Linq.Expressions;
using System.Reflection;
using Microsoft.EntityFrameworkCore.Diagnostics;
@@ -20,6 +21,12 @@ namespace Microsoft.EntityFrameworkCore.Metadata
{
public class DbFunctionMetadataTests
{
+ public class Foo
+ {
+ public int I { get; set; }
+ public int J { get; set; }
+ }
+
public class MyNonDbContext
{
public int NonStatic()
@@ -160,6 +167,21 @@ public static int DuplicateNameTest()
[DbFunction]
public override int VirtualBase() => throw new Exception();
+
+ [DbFunction]
+ public IQueryable QueryableNoParams() => throw new Exception();
+
+ [DbFunction]
+ public IQueryable QueryableSingleParam(int i) => throw new Exception();
+
+ public IQueryable QueryableSingleParam(Expression> i) => throw new Exception();
+
+ [DbFunction]
+ public IQueryable QueryableMultiParam(int i, double j) => throw new Exception();
+
+ public IQueryable QueryableMultiParam(Expression> i, double j) => throw new Exception();
+
+ public IQueryable QueryableMultiParam(Expression> i, Expression> j) => throw new Exception();
}
public static MethodInfo MethodAmi = typeof(TestMethods).GetRuntimeMethod(
@@ -173,6 +195,8 @@ public static int DuplicateNameTest()
public static MethodInfo MethodHmi = typeof(TestMethods).GetTypeInfo().GetDeclaredMethod(nameof(TestMethods.MethodH));
+ public static MethodInfo MethodJmi = typeof(TestMethods).GetTypeInfo().GetDeclaredMethod(nameof(TestMethods.MethodJ));
+
public class TestMethods
{
public static int Foo => 1;
@@ -211,6 +235,11 @@ public static int MethodI()
{
throw new Exception();
}
+
+ public static IQueryable MethodJ()
+ {
+ throw new Exception();
+ }
}
public static class OuterA
@@ -640,6 +669,37 @@ public void DbFunction_Annotation_FullName()
Assert.NotEqual(funcA.Metadata.Name, funcB.Metadata.Name);
}
+ [ConditionalFact]
+ public void Find_Queryable_Single_Expression_Overload()
+ {
+ var modelBuilder = GetModelBuilder();
+
+ var funcA = modelBuilder.HasDbFunction(typeof(MyDerivedContext).GetMethod(nameof(MyDerivedContext.QueryableSingleParam), new Type[] { typeof(int) }));
+ var funcB = modelBuilder.HasDbFunction(typeof(MyDerivedContext).GetMethod(nameof(MyDerivedContext.QueryableSingleParam), new Type[] { typeof(Expression>) }));
+
+ Assert.Equal("QueryableSingleParam", funcA.Metadata.Name);
+ Assert.Equal("QueryableSingleParam", funcB.Metadata.Name);
+ Assert.Equal(funcA.Metadata, funcB.Metadata);
+ }
+
+ [ConditionalFact]
+ public void Find_Queryable_Multiple_Expression_Overload()
+ {
+ var modelBuilder = GetModelBuilder();
+
+ var funcA = modelBuilder.HasDbFunction(typeof(MyDerivedContext).GetMethod(nameof(MyDerivedContext.QueryableMultiParam), new Type[] { typeof(int), typeof(double) }));
+ var funcB = modelBuilder.HasDbFunction(typeof(MyDerivedContext).GetMethod(nameof(MyDerivedContext.QueryableMultiParam), new Type[] { typeof(Expression>), typeof(double) }));
+ var funcC = modelBuilder.HasDbFunction(typeof(MyDerivedContext).GetMethod(nameof(MyDerivedContext.QueryableMultiParam), new Type[] { typeof(Expression>), typeof(Expression>) }));
+
+ Assert.Equal("QueryableMultiParam", funcA.Metadata.Name);
+ Assert.Equal("QueryableMultiParam", funcB.Metadata.Name);
+ Assert.Equal("QueryableMultiParam", funcC.Metadata.Name);
+
+ Assert.Equal(funcA.Metadata, funcB.Metadata);
+ Assert.Equal(funcA.Metadata, funcC.Metadata);
+ Assert.Equal(funcB.Metadata, funcC.Metadata);
+ }
+
private ModelBuilder GetModelBuilder(DbContext dbContext = null)
{
var conventionSet = new ConventionSet();
diff --git a/test/EFCore.Relational.Tests/Migrations/Internal/MigrationsModelDifferTest.cs b/test/EFCore.Relational.Tests/Migrations/Internal/MigrationsModelDifferTest.cs
index 89b06e8e5bf..cbcd44c75c4 100644
--- a/test/EFCore.Relational.Tests/Migrations/Internal/MigrationsModelDifferTest.cs
+++ b/test/EFCore.Relational.Tests/Migrations/Internal/MigrationsModelDifferTest.cs
@@ -25,6 +25,15 @@ private class TestQueryType
public string Something { get; set; }
}
+ [ConditionalFact]
+ public void Model_differ_does_not_detect_queryable_function_result_type()
+ {
+ Execute(
+ _ => { },
+ modelBuilder => modelBuilder.Entity().ToQueryableFunctionResultType(),
+ result => Assert.Equal(0, result.Count));
+ }
+
[ConditionalFact]
public void Model_differ_does_not_detect_views()
{
diff --git a/test/EFCore.Specification.Tests/Query/NorthwindMiscellaneousQueryTestBase.cs b/test/EFCore.Specification.Tests/Query/NorthwindMiscellaneousQueryTestBase.cs
index 63bfed90b3d..80838cc765c 100644
--- a/test/EFCore.Specification.Tests/Query/NorthwindMiscellaneousQueryTestBase.cs
+++ b/test/EFCore.Specification.Tests/Query/NorthwindMiscellaneousQueryTestBase.cs
@@ -1996,19 +1996,19 @@ public virtual Task OrderBy_scalar_primitive(bool async)
public virtual Task SelectMany_mixed(bool async)
{
return AssertTranslationFailed(
- () => AssertQuery(
- async,
- ss => from e1 in ss.Set().OrderBy(e => e.EmployeeID).Take(2)
- from s in new[] { "a", "b" }
- from c in ss.Set().OrderBy(c => c.CustomerID).Take(2)
- select new
- {
- e1,
- s,
- c
- },
- e => (e.e1.EmployeeID, e.c.CustomerID),
- entryCount: 4));
+ () => AssertQuery(
+ async,
+ ss => from e1 in ss.Set().OrderBy(e => e.EmployeeID).Take(2)
+ from s in new[] { "a", "b" }
+ from c in ss.Set().OrderBy(c => c.CustomerID).Take(2)
+ select new
+ {
+ e1,
+ s,
+ c
+ },
+ e => (e.e1.EmployeeID, e.c.CustomerID),
+ entryCount: 4));
}
[ConditionalTheory]
diff --git a/test/EFCore.SqlServer.FunctionalTests/Query/UdfDbFunctionSqlServerTests.cs b/test/EFCore.SqlServer.FunctionalTests/Query/UdfDbFunctionSqlServerTests.cs
index 1533e0164e0..e8bf90bd222 100644
--- a/test/EFCore.SqlServer.FunctionalTests/Query/UdfDbFunctionSqlServerTests.cs
+++ b/test/EFCore.SqlServer.FunctionalTests/Query/UdfDbFunctionSqlServerTests.cs
@@ -215,7 +215,7 @@ public override void Nullable_navigation_property_access_preserves_schema_for_sq
AssertSql(
@"SELECT TOP(1) [dbo].[IdentityString]([c].[FirstName])
FROM [Orders] AS [o]
-LEFT JOIN [Customers] AS [c] ON [o].[CustomerId] = [c].[Id]
+INNER JOIN [Customers] AS [c] ON [o].[CustomerId] = [c].[Id]
ORDER BY [o].[Id]");
}
@@ -444,6 +444,322 @@ FROM [Customers] AS [c]
#endregion
+ #region Queryable Function Tests
+
+ public override void QF_Stand_Alone()
+ {
+ base.QF_Stand_Alone();
+
+ AssertSql(@"SELECT [t].[AmountSold], [t].[ProductId]
+FROM [dbo].[GetTopTwoSellingProducts]() AS [t]
+ORDER BY [t].[ProductId]");
+ }
+
+ public override void QF_Stand_Alone_With_Translation()
+ {
+ base.QF_Stand_Alone_With_Translation();
+
+ AssertSql(@"SELECT [t].[AmountSold], [t].[ProductId]
+FROM [dbo].[GetTopTwoSellingProducts]() AS [t]
+ORDER BY [t].[ProductId]");
+ }
+
+ public override void QF_Stand_Alone_Parameter()
+ {
+ base.QF_Stand_Alone_Parameter();
+
+ AssertSql(@"@__customerId_0='1'
+
+SELECT [o].[Count], [o].[CustomerId], [o].[Year]
+FROM [dbo].[GetCustomerOrderCountByYear](@__customerId_0) AS [o]
+ORDER BY [o].[Count] DESC");
+ }
+
+ public override void QF_Stand_Alone_Nested()
+ {
+ base.QF_Stand_Alone_Nested();
+
+ AssertSql(@"SELECT [o].[Count], [o].[CustomerId], [o].[Year]
+FROM [dbo].[GetCustomerOrderCountByYear]([dbo].[AddValues](-2, 3)) AS [o]
+ORDER BY [o].[Count] DESC");
+ }
+
+ public override void QF_CrossApply_Correlated_Select_Anonymous()
+ {
+ base.QF_CrossApply_Correlated_Select_Anonymous();
+
+ AssertSql(@"SELECT [c].[Id], [c].[LastName], [o].[Year], [o].[Count]
+FROM [Customers] AS [c]
+CROSS APPLY [dbo].[GetCustomerOrderCountByYear]([c].[Id]) AS [o]
+ORDER BY [c].[Id], [o].[Year]");
+ }
+
+
+ public override void QF_Select_Direct_In_Anonymous()
+ {
+ base.QF_Select_Direct_In_Anonymous();
+
+ AssertSql(@"SELECT [t].[AmountSold], [t].[ProductId]
+FROM [dbo].[GetTopTwoSellingProducts]() AS [t]",
+
+@"SELECT [c].[Id]
+FROM [Customers] AS [c]");
+ }
+
+ public override void QF_Select_Correlated_Direct_With_Function_Query_Parameter_Correlated_In_Anonymous()
+ {
+ base.QF_Select_Correlated_Direct_With_Function_Query_Parameter_Correlated_In_Anonymous();
+
+ AssertSql(@"SELECT [c].[Id], [m].[OrderId], [m].[CustomerId], [m].[OrderDate]
+FROM [Customers] AS [c]
+OUTER APPLY [dbo].[GetOrdersWithMultipleProducts]([dbo].[AddValues]([c].[Id], 1)) AS [m]
+WHERE [c].[Id] = 1
+ORDER BY [c].[Id], [m].[OrderId]");
+ }
+
+ public override void QF_Select_Correlated_Subquery_In_Anonymous()
+ {
+ base.QF_Select_Correlated_Subquery_In_Anonymous();
+
+ AssertSql(@"SELECT [c].[Id], [t].[OrderId], [t].[CustomerId], [t].[OrderDate]
+FROM [Customers] AS [c]
+OUTER APPLY (
+ SELECT [m].[OrderId], [m].[CustomerId], [m].[OrderDate]
+ FROM [dbo].[GetOrdersWithMultipleProducts]([c].[Id]) AS [m]
+ WHERE DATEPART(day, [m].[OrderDate]) = 21
+) AS [t]
+ORDER BY [c].[Id], [t].[OrderId]");
+ }
+
+ public override void QF_Select_Correlated_Subquery_In_Anonymous_Nested_With_QF()
+ {
+ base.QF_Select_Correlated_Subquery_In_Anonymous_Nested_With_QF();
+
+ AssertSql(@"SELECT [o].[CustomerId], [o].[OrderDate]
+FROM [Orders] AS [o]
+INNER JOIN (
+ SELECT [c].[Id], [c].[FirstName], [c].[LastName], [m].[OrderId], [m].[CustomerId], [m].[OrderDate]
+ FROM [Customers] AS [c]
+ CROSS APPLY [dbo].[GetOrdersWithMultipleProducts]([c].[Id]) AS [m]
+) AS [t] ON [o].[Id] = [t].[OrderId]");
+ }
+
+ public override void QF_Select_Correlated_Subquery_In_Anonymous_Nested()
+ {
+ base.QF_Select_Correlated_Subquery_In_Anonymous_Nested();
+
+ AssertSql(@"SELECT [t].[AmountSold], [t].[ProductId]
+FROM [dbo].[GetTopTwoSellingProducts]() AS [t]",
+
+ @"SELECT [c].[Id], [t].[OrderId], [t].[OrderId0], [t].[CustomerId], [t].[OrderDate]
+FROM [Customers] AS [c]
+OUTER APPLY (
+ SELECT [m].[OrderId], [m0].[OrderId] AS [OrderId0], [m0].[CustomerId], [m0].[OrderDate]
+ FROM [dbo].[GetOrdersWithMultipleProducts]([c].[Id]) AS [m]
+ OUTER APPLY [dbo].[GetOrdersWithMultipleProducts]([m].[CustomerId]) AS [m0]
+ WHERE DATEPART(day, [m].[OrderDate]) = 21
+) AS [t]
+ORDER BY [c].[Id], [t].[OrderId], [t].[OrderId0]");
+ }
+
+ public override void QF_Select_Correlated_Subquery_In_Anonymous_MultipleCollections()
+ {
+ base.QF_Select_Correlated_Subquery_In_Anonymous_MultipleCollections();
+
+ AssertSql(@"SELECT [c].[Id], [t0].[ProductId], [t1].[Id], [t1].[City], [t1].[CustomerId], [t1].[State], [t1].[Street]
+FROM [Customers] AS [c]
+OUTER APPLY (
+ SELECT [t].[ProductId]
+ FROM [dbo].[GetTopTwoSellingProducts]() AS [t]
+ WHERE [t].[AmountSold] = 249
+) AS [t0]
+LEFT JOIN (
+ SELECT [a].[Id], [a].[City], [a].[CustomerId], [a].[State], [a].[Street]
+ FROM [Addresses] AS [a]
+ WHERE [a].[State] = N'NY'
+) AS [t1] ON [c].[Id] = [t1].[CustomerId]
+ORDER BY [c].[Id], [t1].[Id]");
+ }
+
+ public override void QF_Select_NonCorrelated_Subquery_In_Anonymous()
+ {
+ base.QF_Select_NonCorrelated_Subquery_In_Anonymous();
+
+ AssertSql(@"SELECT [c].[Id], [t0].[ProductId]
+FROM [Customers] AS [c]
+OUTER APPLY (
+ SELECT [t].[ProductId]
+ FROM [dbo].[GetTopTwoSellingProducts]() AS [t]
+ WHERE [t].[AmountSold] = 249
+) AS [t0]
+ORDER BY [c].[Id]");
+ }
+
+ public override void QF_Select_NonCorrelated_Subquery_In_Anonymous_Parameter()
+ {
+ base.QF_Select_NonCorrelated_Subquery_In_Anonymous_Parameter();
+
+ AssertSql(
+ @"@__amount_0='27' (Nullable = true)
+
+SELECT [c].[Id], [t0].[ProductId]
+FROM [Customers] AS [c]
+OUTER APPLY (
+ SELECT [t].[ProductId]
+ FROM [dbo].[GetTopTwoSellingProducts]() AS [t]
+ WHERE [t].[AmountSold] = @__amount_0
+) AS [t0]
+ORDER BY [c].[Id]");
+ }
+
+ public override void QF_Correlated_Select_In_Anonymous()
+ {
+ base.QF_Correlated_Select_In_Anonymous();
+
+ AssertSql(@"SELECT [c].[Id], [c].[LastName], [m].[OrderId], [m].[CustomerId], [m].[OrderDate]
+FROM [Customers] AS [c]
+OUTER APPLY [dbo].[GetOrdersWithMultipleProducts]([c].[Id]) AS [m]
+ORDER BY [c].[Id], [m].[OrderId]");
+ }
+
+ public override void QF_CrossApply_Correlated_Select_Result()
+ {
+ base.QF_CrossApply_Correlated_Select_Result();
+
+ AssertSql(@"SELECT [o].[Count], [o].[CustomerId], [o].[Year]
+FROM [Customers] AS [c]
+CROSS APPLY [dbo].[GetCustomerOrderCountByYear]([c].[Id]) AS [o]
+ORDER BY [o].[Count] DESC, [o].[Year] DESC");
+ }
+
+ public override void QF_CrossJoin_Not_Correlated()
+ {
+ base.QF_CrossJoin_Not_Correlated();
+
+ AssertSql(@"@__customerId_0='2'
+
+SELECT [c].[Id], [c].[LastName], [o].[Year], [o].[Count]
+FROM [Customers] AS [c]
+CROSS JOIN [dbo].[GetCustomerOrderCountByYear](@__customerId_0) AS [o]
+WHERE [c].[Id] = 2
+ORDER BY [o].[Count]");
+ }
+
+ public override void QF_CrossJoin_Parameter()
+ {
+ base.QF_CrossJoin_Parameter();
+
+ AssertSql(@"@__customerId_0='2'
+@__custId_1='2'
+
+SELECT [c].[Id], [c].[LastName], [o].[Year], [o].[Count]
+FROM [Customers] AS [c]
+CROSS JOIN [dbo].[GetCustomerOrderCountByYear](@__customerId_0) AS [o]
+WHERE [c].[Id] = @__custId_1
+ORDER BY [o].[Count]");
+ }
+
+ public override void QF_Join()
+ {
+ base.QF_Join();
+
+ AssertSql(@"SELECT [p].[Id], [p].[Name], [t].[AmountSold]
+FROM [Products] AS [p]
+INNER JOIN [dbo].[GetTopTwoSellingProducts]() AS [t] ON [p].[Id] = [t].[ProductId]
+ORDER BY [p].[Id]");
+ }
+
+ public override void QF_LeftJoin_Select_Anonymous()
+ {
+ base.QF_LeftJoin_Select_Anonymous();
+
+ AssertSql(@"SELECT [p].[Id], [p].[Name], [t].[AmountSold]
+FROM [Products] AS [p]
+LEFT JOIN [dbo].[GetTopTwoSellingProducts]() AS [t] ON [p].[Id] = [t].[ProductId]
+ORDER BY [p].[Id] DESC");
+ }
+
+ public override void QF_LeftJoin_Select_Result()
+ {
+ base.QF_LeftJoin_Select_Result();
+
+ AssertSql(@"SELECT [t].[AmountSold], [t].[ProductId]
+FROM [Products] AS [p]
+LEFT JOIN [dbo].[GetTopTwoSellingProducts]() AS [t] ON [p].[Id] = [t].[ProductId]
+ORDER BY [p].[Id] DESC");
+ }
+
+ public override void QF_OuterApply_Correlated_Select_TVF()
+ {
+ base.QF_OuterApply_Correlated_Select_TVF();
+
+ AssertSql(@"SELECT [o].[Count], [o].[CustomerId], [o].[Year]
+FROM [Customers] AS [c]
+OUTER APPLY [dbo].[GetCustomerOrderCountByYear]([c].[Id]) AS [o]
+ORDER BY [c].[Id], [o].[Year]");
+ }
+
+ public override void QF_OuterApply_Correlated_Select_DbSet()
+ {
+ base.QF_OuterApply_Correlated_Select_DbSet();
+
+ AssertSql(@"SELECT [c].[Id], [c].[FirstName], [c].[LastName]
+FROM [Customers] AS [c]
+OUTER APPLY [dbo].[GetCustomerOrderCountByYear]([c].[Id]) AS [o]
+ORDER BY [c].[Id], [o].[Year]");
+ }
+
+ public override void QF_OuterApply_Correlated_Select_Anonymous()
+ {
+ base.QF_OuterApply_Correlated_Select_Anonymous();
+
+ AssertSql(@"SELECT [c].[Id], [c].[LastName], [o].[Year], [o].[Count]
+FROM [Customers] AS [c]
+OUTER APPLY [dbo].[GetCustomerOrderCountByYear]([c].[Id]) AS [o]
+ORDER BY [c].[Id], [o].[Year]");
+ }
+
+ public override void QF_Nested()
+ {
+ base.QF_Nested();
+
+ AssertSql(@"@__custId_1='2'
+
+SELECT [c].[Id], [c].[LastName], [o].[Year], [o].[Count]
+FROM [Customers] AS [c]
+CROSS JOIN [dbo].[GetCustomerOrderCountByYear]([dbo].[AddValues](1, 1)) AS [o]
+WHERE [c].[Id] = @__custId_1
+ORDER BY [o].[Year]");
+ }
+
+ public override void QF_Correlated_Nested_Func_Call()
+ {
+ base.QF_Correlated_Nested_Func_Call();
+
+ AssertSql(@"@__custId_1='2'
+
+SELECT [c].[Id], [o].[Count], [o].[Year]
+FROM [Customers] AS [c]
+CROSS APPLY [dbo].[GetCustomerOrderCountByYear]([dbo].[AddValues]([c].[Id], 1)) AS [o]
+WHERE [c].[Id] = @__custId_1");
+ }
+
+ public override void QF_Correlated_Func_Call_With_Navigation()
+ {
+ base.QF_Correlated_Func_Call_With_Navigation();
+
+ AssertSql(@"SELECT [c].[Id], [t].[LastName], [t].[OrderId], [t].[Id]
+FROM [Customers] AS [c]
+OUTER APPLY (
+ SELECT [c0].[LastName], [m].[OrderId], [c0].[Id]
+ FROM [dbo].[GetOrdersWithMultipleProducts]([c].[Id]) AS [m]
+ INNER JOIN [Customers] AS [c0] ON [m].[CustomerId] = [c0].[Id]
+) AS [t]
+ORDER BY [c].[Id], [t].[OrderId], [t].[Id]");
+ }
+
+ #endregion
+
public class SqlServer : UdfFixtureBase
{
protected override string StoreName { get; } = "UDFDbFunctionSqlServerTests";
@@ -516,6 +832,98 @@ returns nvarchar(max)
return @customerName;
end");
+ context.Database.ExecuteSqlRaw(
+ @"create function [dbo].GetCustomerOrderCountByYear(@customerId int)
+ returns @reports table
+ (
+ CustomerId int not null,
+ Count int not null,
+ Year int not null
+ )
+ as
+ begin
+
+ insert into @reports
+ select @customerId, count(id), year(orderDate)
+ from orders
+ where customerId = @customerId
+ group by customerId, year(orderDate)
+ order by year(orderDate)
+
+ return
+ end");
+
+ context.Database.ExecuteSqlRaw(
+ @"create function [dbo].GetTopTwoSellingProducts()
+ returns @products table
+ (
+ ProductId int not null,
+ AmountSold int
+ )
+ as
+ begin
+
+ insert into @products
+ select top 2 ProductID, sum(Quantity) as totalSold
+ from lineItem
+ group by ProductID
+ order by totalSold desc
+ return
+ end");
+
+ context.Database.ExecuteSqlRaw(
+ @"create function [dbo].GetTopSellingProductsForCustomer(@customerId int)
+ returns @products table
+ (
+ ProductId int not null,
+ AmountSold int
+ )
+ as
+ begin
+
+ insert into @products
+ select ProductID, sum(Quantity) as totalSold
+ from lineItem li
+ join orders o on o.id = li.orderId
+ where o.customerId = @customerId
+ group by ProductID
+
+ return
+ end");
+
+ context.Database.ExecuteSqlRaw(
+ @"create function [dbo].GetOrdersWithMultipleProducts(@customerId int)
+ returns @orders table
+ (
+ OrderId int not null,
+ CustomerId int not null,
+ OrderDate dateTime2
+ )
+ as
+ begin
+
+ insert into @orders
+ select o.id, @customerId, OrderDate
+ from orders o
+ join lineItem li on o.id = li.orderId
+ where o.customerId = @customerId
+ group by o.id, OrderDate
+ having count(productId) > 1
+ return
+ end");
+
+ context.Database.ExecuteSqlRaw(
+ @"create function [dbo].[AddValues] (@a int, @b int)
+ returns int
+ as
+ begin
+ return @a + @b;
+ end");
+
+
+ context.Database.ExecuteSqlRaw(
+ @"create view [dbo].[vOrderQuery] as select * from orders");
+
context.SaveChanges();
}
}