Skip to content

Commit

Permalink
Query: Use function mapping when querying for DbSet (#21507)
Browse files Browse the repository at this point in the history
Resolves #20051
  • Loading branch information
smitpatel committed Jul 3, 2020
1 parent 20d587c commit dff6a1a
Show file tree
Hide file tree
Showing 5 changed files with 120 additions and 39 deletions.
2 changes: 1 addition & 1 deletion src/EFCore.Relational/Metadata/Internal/StoreFunction.cs
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,7 @@ public StoreFunction([NotNull] DbFunction dbFunction, [NotNull] RelationalModel
/// any release. You should only use it directly in your code with extreme caution and knowing that
/// doing so can result in application failures when updating to a new Entity Framework Core release.
/// </summary>
public virtual SortedDictionary<string, DbFunction> DbFunctions { get; }
public virtual SortedDictionary<string, DbFunction> DbFunctions { get; }

/// <inheritdoc />
public virtual bool IsBuiltIn { get; }
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -129,8 +129,7 @@ protected override Expression VisitExtension(Expression extensionExpression)
}

var entityType = tableValuedFunctionQueryRootExpression.EntityType;
var alias = (entityType.GetViewOrTableMappings().SingleOrDefault()?.Table.Name
?? entityType.ShortName()).Substring(0, 1).ToLower();
var alias = entityType.ShortName().Substring(0, 1).ToLower();

var translation = new TableValuedFunctionExpression(alias, function.Schema, function.Name, arguments);
var queryExpression = _sqlExpressionFactory.Select(entityType, translation);
Expand Down
37 changes: 28 additions & 9 deletions src/EFCore.Relational/Query/SqlExpressions/SelectExpression.cs
Original file line number Diff line number Diff line change
Expand Up @@ -124,9 +124,23 @@ internal SelectExpression(IEntityType entityType, ISqlExpressionFactory sqlExpre
if ((entityType.BaseType == null && !entityType.GetDirectlyDerivedTypes().Any())
|| entityType.GetDiscriminatorProperty() != null)
{
// Key-less entities or TPH
var table = entityType.GetViewOrTableMappings().Single().Table;
var tableExpression = new TableExpression(table);
ITableBase table;
TableExpressionBase tableExpression;
if (entityType.GetFunctionMappings().SingleOrDefault(e => e.IsDefaultFunctionMapping) is IFunctionMapping functionMapping)
{
var storeFunction = functionMapping.Table;
var alias = entityType.ShortName().Substring(0, 1).ToLower();

table = storeFunction;
tableExpression = new TableValuedFunctionExpression(
alias, storeFunction.Schema, storeFunction.Name, Array.Empty<SqlExpression>());
}
else
{
table = entityType.GetViewOrTableMappings().Single().Table;
tableExpression = new TableExpression(entityType.GetViewOrTableMappings().Single().Table);
}

_tables.Add(tableExpression);

var propertyExpressions = new Dictionary<IProperty, ColumnExpression>();
Expand Down Expand Up @@ -218,13 +232,18 @@ internal SelectExpression(IEntityType entityType, ISqlExpressionFactory sqlExpre
}

static ColumnExpression GetColumn(
IProperty property, IEntityType currentEntityType, ITableBase table, TableExpression tableExpression, bool nullable)
IProperty property, IEntityType currentEntityType, ITableBase table, TableExpressionBase tableExpression, bool nullable)
{
var column = table is ITable
? (IColumnBase)property.GetTableColumnMappings().Single(cm => cm.TableMapping.Table == table
&& cm.TableMapping.EntityType == currentEntityType).Column
: property.GetViewColumnMappings().Single(cm => cm.TableMapping.Table == table
&& cm.TableMapping.EntityType == currentEntityType).Column;
var columnMappings = table switch
{
IStoreFunction _ => property.GetFunctionColumnMappings().Cast<IColumnMappingBase>(),
IView _ => property.GetViewColumnMappings().Cast<IColumnMappingBase>(),
_ => property.GetTableColumnMappings().Cast<IColumnMappingBase>()
};

var column = columnMappings
.Single(cm => cm.TableMapping.Table == table && cm.TableMapping.EntityType == currentEntityType)
.Column;

return new ColumnExpression(property, column, tableExpression, nullable);
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -73,6 +73,38 @@ public class Address
public Customer Customer { get; set; }
}

public class OrderByYear
{
public int? CustomerId { get; set; }
public int? Count { get; set; }
public int? Year { get; set; }
}

public class MultProductOrders
{
public int OrderId { get; set; }

public Customer Customer { get; set; }
public int CustomerId { get; set; }

public DateTime OrderDate { get; set; }
}

public class TopSellingProduct
{
public Product Product { get; set; }
public int? ProductId { get; set; }

public int? AmountSold { get; set; }
}

public class CustomerData
{
public int Id { get; set; }
public string FirstName { get; set; }
public string LastName { get; set; }
}

protected class UDFSqlContext : PoolableDbContext
{
#region DbSets
Expand Down Expand Up @@ -152,36 +184,11 @@ public int AddValues(Expression<Func<int>> a, int b)

#region Queryable Functions

public class OrderByYear
{
public int? CustomerId { get; set; }
public int? Count { get; set; }
public int? Year { get; set; }
}

public class MultProductOrders
{
public int OrderId { get; set; }

public Customer Customer { get; set; }
public int CustomerId { get; set; }

public DateTime OrderDate { get; set; }
}

public IQueryable<OrderByYear> GetCustomerOrderCountByYear(int customerId)
{
return FromExpression(() => GetCustomerOrderCountByYear(customerId));
}

public class TopSellingProduct
{
public Product Product { get; set; }
public int? ProductId { get; set; }

public int? AmountSold { get; set; }
}

public IQueryable<TopSellingProduct> GetTopTwoSellingProducts()
{
return FromExpression(() => GetTopTwoSellingProducts());
Expand All @@ -197,6 +204,11 @@ public IQueryable<MultProductOrders> GetOrdersWithMultipleProducts(int customerI
return FromExpression(() => GetOrdersWithMultipleProducts(customerId));
}

public IQueryable<CustomerData> GetCustomerData(int customerId)
{
return FromExpression(() => GetCustomerData(customerId));
}

#endregion

#endregion
Expand Down Expand Up @@ -251,11 +263,12 @@ protected override void OnModelCreating(ModelBuilder modelBuilder)
modelBuilder.HasDbFunction(typeof(UDFSqlContext).GetMethod(nameof(GetCustomerOrderCountByYear), new[] { typeof(int) }));
modelBuilder.HasDbFunction(typeof(UDFSqlContext).GetMethod(nameof(GetTopTwoSellingProducts)));
modelBuilder.HasDbFunction(typeof(UDFSqlContext).GetMethod(nameof(GetTopSellingProductsForCustomer)));

modelBuilder.HasDbFunction(typeof(UDFSqlContext).GetMethod(nameof(GetOrdersWithMultipleProducts)));
modelBuilder.HasDbFunction(typeof(UDFSqlContext).GetMethod(nameof(GetCustomerData)));

modelBuilder.Entity<OrderByYear>().HasNoKey();
modelBuilder.Entity<TopSellingProduct>().HasNoKey();
modelBuilder.Entity<TopSellingProduct>().HasNoKey().ToFunction("GetTopTwoSellingProducts");
modelBuilder.Entity<CustomerData>().ToView("Customers");
}
}

Expand Down Expand Up @@ -1905,6 +1918,36 @@ orderby c.Id
}
}

[ConditionalFact]
public virtual void DbSet_mapped_to_function()
{
using (var context = CreateContext())
{
var products = (from t in context.Set<TopSellingProduct>()
orderby t.ProductId
select t).ToList();

Assert.Equal(2, products.Count);
Assert.Equal(3, products[0].ProductId);
Assert.Equal(249, products[0].AmountSold);
Assert.Equal(4, products[1].ProductId);
Assert.Equal(184, products[1].AmountSold);
}
}

[ConditionalFact]
public virtual void TVF_backing_entity_type_mapped_to_view()
{
using (var context = CreateContext())
{
var customers = (from t in context.Set<CustomerData>()
orderby t.FirstName
select t).ToList();

Assert.Equal(4, customers.Count);
}
}

#endregion

private void AssertTranslationFailed(Action testCode)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -753,6 +753,26 @@ FROM [dbo].[GetOrdersWithMultipleProducts]([c].[Id]) AS [m]
ORDER BY [c].[Id], [t].[OrderId], [t].[Id]");
}

public override void DbSet_mapped_to_function()
{
base.DbSet_mapped_to_function();

AssertSql(
@"SELECT [t].[AmountSold], [t].[ProductId]
FROM [dbo].[GetTopTwoSellingProducts]() AS [t]
ORDER BY [t].[ProductId]");
}

public override void TVF_backing_entity_type_mapped_to_view()
{
base.TVF_backing_entity_type_mapped_to_view();

AssertSql(
@"SELECT [c].[Id], [c].[FirstName], [c].[LastName]
FROM [Customers] AS [c]
ORDER BY [c].[FirstName]");
}

#endregion

public class SqlServer : UdfFixtureBase
Expand Down

0 comments on commit dff6a1a

Please sign in to comment.