Skip to content

Commit

Permalink
Query: Implement TPT for relational layer
Browse files Browse the repository at this point in the history
Resolves #2266
  • Loading branch information
smitpatel committed Jul 1, 2020
1 parent 4795eee commit 03cf99d
Show file tree
Hide file tree
Showing 20 changed files with 2,100 additions and 606 deletions.
109 changes: 50 additions & 59 deletions src/EFCore.Relational/Query/EntityProjectionExpression.cs
Original file line number Diff line number Diff line change
Expand Up @@ -23,49 +23,37 @@ namespace Microsoft.EntityFrameworkCore.Query
/// </summary>
public class EntityProjectionExpression : Expression
{
private readonly IDictionary<IProperty, ColumnExpression> _propertyExpressionsCache
= new Dictionary<IProperty, ColumnExpression>();

private readonly IDictionary<INavigation, EntityShaperExpression> _navigationExpressionsCache
private readonly IDictionary<IProperty, ColumnExpression> _propertyExpressions = new Dictionary<IProperty, ColumnExpression>();
private readonly IDictionary<INavigation, EntityShaperExpression> _navigationExpressions
= new Dictionary<INavigation, EntityShaperExpression>();

private readonly TableExpressionBase _innerTable;
private readonly bool _nullable;

/// <summary>
/// Creates a new instance of the <see cref="EntityProjectionExpression" /> class.
/// </summary>
/// <param name="entityType"> The entity type to shape. </param>
/// <param name="innerTable"> The table from which entity columns are being projected out. </param>
/// <param name="nullable"> A bool value indicating whether this entity instance can be null. </param>
public EntityProjectionExpression([NotNull] IEntityType entityType, [NotNull] TableExpressionBase innerTable, bool nullable)
{
Check.NotNull(entityType, nameof(entityType));
Check.NotNull(innerTable, nameof(innerTable));

EntityType = entityType;
_innerTable = innerTable;
_nullable = nullable;
}

/// <summary>
/// Creates a new instance of the <see cref="EntityProjectionExpression" /> class.
/// </summary>
/// <param name="entityType"> The entity type to shape. </param>
/// <param name="propertyExpressions"> A dictionary of column expressions corresponding to properties of the entity type. </param>
public EntityProjectionExpression([NotNull] IEntityType entityType, [NotNull] IDictionary<IProperty, ColumnExpression> propertyExpressions)
/// <param name="discriminatorExpressions"> A dictionary of <see cref="SqlExpression"/> to discriminator each entity type in hierarchy. </param>
public EntityProjectionExpression(
[NotNull] IEntityType entityType,
[NotNull] IDictionary<IProperty, ColumnExpression> propertyExpressions,
[CanBeNull] IReadOnlyDictionary<IEntityType, SqlExpression> discriminatorExpressions = null)
{
Check.NotNull(entityType, nameof(entityType));
Check.NotNull(propertyExpressions, nameof(propertyExpressions));

EntityType = entityType;
_propertyExpressionsCache = propertyExpressions;
_propertyExpressions = propertyExpressions;
DiscriminatorExpressions = discriminatorExpressions;
}

/// <summary>
/// The entity type being projected out.
/// </summary>
public virtual IEntityType EntityType { get; }
/// <summary>
/// Dictionary of discriminator expressions.
/// </summary>
public virtual IReadOnlyDictionary<IEntityType, SqlExpression> DiscriminatorExpressions { get; }
/// <inheritdoc />
public sealed override ExpressionType NodeType => ExpressionType.Extension;
/// <inheritdoc />
Expand All @@ -76,27 +64,31 @@ protected override Expression VisitChildren(ExpressionVisitor visitor)
{
Check.NotNull(visitor, nameof(visitor));

if (_innerTable != null)
{
var table = (TableExpressionBase)visitor.Visit(_innerTable);

return table != _innerTable
? new EntityProjectionExpression(EntityType, table, _nullable)
: this;
}

var changed = false;
var newCache = new Dictionary<IProperty, ColumnExpression>();
foreach (var expression in _propertyExpressionsCache)
foreach (var expression in _propertyExpressions)
{
var newExpression = (ColumnExpression)visitor.Visit(expression.Value);
changed |= newExpression != expression.Value;

newCache[expression.Key] = newExpression;
}

Dictionary<IEntityType, SqlExpression> newDiscriminators = null;
if (DiscriminatorExpressions != null)
{
newDiscriminators = new Dictionary<IEntityType, SqlExpression>();
foreach (var expression in DiscriminatorExpressions)
{
var newExpression = (SqlExpression)visitor.Visit(expression.Value);
changed |= newExpression != expression.Value;

newDiscriminators[expression.Key] = newExpression;
}
}

return changed
? new EntityProjectionExpression(EntityType, newCache)
? new EntityProjectionExpression(EntityType, newCache, newDiscriminators)
: this;
}

Expand All @@ -106,18 +98,13 @@ protected override Expression VisitChildren(ExpressionVisitor visitor)
/// <returns> A new entity projection expression which can project nullable entity. </returns>
public virtual EntityProjectionExpression MakeNullable()
{
if (_innerTable != null)
{
return new EntityProjectionExpression(EntityType, _innerTable, nullable: true);
}

var newCache = new Dictionary<IProperty, ColumnExpression>();
foreach (var expression in _propertyExpressionsCache)
foreach (var expression in _propertyExpressions)
{
newCache[expression.Key] = expression.Value.MakeNullable();
}

return new EntityProjectionExpression(EntityType, newCache);
return new EntityProjectionExpression(EntityType, newCache, DiscriminatorExpressions);
}

/// <summary>
Expand All @@ -129,13 +116,8 @@ public virtual EntityProjectionExpression UpdateEntityType([NotNull] IEntityType
{
Check.NotNull(derivedType, nameof(derivedType));

if (_innerTable != null)
{
return new EntityProjectionExpression(derivedType, _innerTable, _nullable);
}

var propertyExpressionCache = new Dictionary<IProperty, ColumnExpression>();
foreach (var kvp in _propertyExpressionsCache)
foreach (var kvp in _propertyExpressions)
{
var property = kvp.Key;
if (derivedType.IsAssignableFrom(property.DeclaringEntityType)
Expand All @@ -145,7 +127,22 @@ public virtual EntityProjectionExpression UpdateEntityType([NotNull] IEntityType
}
}

return new EntityProjectionExpression(derivedType, propertyExpressionCache);
Dictionary<IEntityType, SqlExpression> discriminatorExpressions = null;
if (DiscriminatorExpressions != null)
{
discriminatorExpressions = new Dictionary<IEntityType, SqlExpression>();
foreach (var kvp in DiscriminatorExpressions)
{
var entityType = kvp.Key;
if (derivedType.IsAssignableFrom(entityType)
|| entityType.IsAssignableFrom(derivedType))
{
discriminatorExpressions[entityType] = kvp.Value;
}
}
}

return new EntityProjectionExpression(derivedType, propertyExpressionCache, discriminatorExpressions);
}

/// <summary>
Expand All @@ -168,13 +165,7 @@ public virtual ColumnExpression BindProperty([NotNull] IProperty property)
property.Name));
}

if (!_propertyExpressionsCache.TryGetValue(property, out var expression))
{
expression = new ColumnExpression(property, _innerTable, _nullable);
_propertyExpressionsCache[property] = expression;
}

return expression;
return _propertyExpressions[property];
}

/// <summary>
Expand All @@ -198,7 +189,7 @@ public virtual void AddNavigationBinding([NotNull] INavigation navigation, [NotN
navigation.Name));
}

_navigationExpressionsCache[navigation] = entityShaper;
_navigationExpressions[navigation] = entityShaper;
}

/// <summary>
Expand All @@ -222,7 +213,7 @@ public virtual EntityShaperExpression BindNavigation([NotNull] INavigation navig
navigation.Name));
}

return _navigationExpressionsCache.TryGetValue(navigation, out var expression)
return _navigationExpressions.TryGetValue(navigation, out var expression)
? expression
: null;
}
Expand Down
37 changes: 35 additions & 2 deletions src/EFCore.Relational/Query/RelationalEntityShaperExpression.cs
Original file line number Diff line number Diff line change
Expand Up @@ -4,10 +4,12 @@
using System;
using System.Linq;
using System.Linq.Expressions;
using System.Reflection;
using JetBrains.Annotations;
using Microsoft.EntityFrameworkCore.Infrastructure;
using Microsoft.EntityFrameworkCore.Metadata;
using Microsoft.EntityFrameworkCore.Metadata.Internal;
using Microsoft.EntityFrameworkCore.Storage;
using Microsoft.EntityFrameworkCore.Utilities;

namespace Microsoft.EntityFrameworkCore.Query
Expand All @@ -23,6 +25,14 @@ namespace Microsoft.EntityFrameworkCore.Query
/// </summary>
public class RelationalEntityShaperExpression : EntityShaperExpression
{
private static readonly MethodInfo _createUnableToIdentifyConcreteTypeException
= typeof(RelationalEntityShaperExpression).GetTypeInfo()
.GetDeclaredMethod(nameof(CreateUnableToIdentifyConcreteTypeException));

[UsedImplicitly]
private static Exception CreateUnableToIdentifyConcreteTypeException()
=> new InvalidOperationException("");

/// <summary>
/// Creates a new instance of the <see cref="RelationalEntityShaperExpression" /> class.
/// </summary>
Expand Down Expand Up @@ -55,11 +65,34 @@ protected override LambdaExpression GenerateMaterializationCondition(IEntityType
{
Check.NotNull(entityType, nameof(EntityType));

var baseCondition = base.GenerateMaterializationCondition(entityType, nullable);
LambdaExpression baseCondition;
if (entityType.GetDiscriminatorProperty() == null
&& entityType.GetDirectlyDerivedTypes().Any())
{
// TPT
var valueBufferParameter = Parameter(typeof(ValueBuffer));
var body = entityType.IsAbstract()
? Block(Throw(Call(_createUnableToIdentifyConcreteTypeException)), Constant(null, typeof(IEntityType)))
: (Expression)Constant(entityType, typeof(IEntityType)); // Default type
var concreteEntityTypes = entityType.GetDerivedTypes().Where(dt => !dt.IsAbstract()).ToArray();
for (var i = 0; i < concreteEntityTypes.Length; i++)
{
body = Condition(
valueBufferParameter.CreateValueBufferReadValueExpression(typeof(bool), i, property: null),
Constant(concreteEntityTypes[i], typeof(IEntityType)),
body);
}

baseCondition = Lambda(body, valueBufferParameter);
}
else
{
baseCondition = base.GenerateMaterializationCondition(entityType, nullable);
}

if (entityType.FindPrimaryKey() != null)
{
var linkingFks = entityType.GetViewOrTableMappings().SingleOrDefault()?.Table.GetRowInternalForeignKeys(entityType);
var linkingFks = entityType.GetViewOrTableMappings().FirstOrDefault()?.Table.GetRowInternalForeignKeys(entityType);
if (linkingFks != null
&& linkingFks.Any())
{
Expand Down
Loading

0 comments on commit 03cf99d

Please sign in to comment.