diff --git a/src/EFCore.Cosmos/Query/Internal/CosmosShapedQueryCompilingExpressionVisitor.CosmosProjectionBindingRemovingExpressionVisitor.cs b/src/EFCore.Cosmos/Query/Internal/CosmosShapedQueryCompilingExpressionVisitor.CosmosProjectionBindingRemovingExpressionVisitor.cs index 3866035b7cc..2bf83954bd6 100644 --- a/src/EFCore.Cosmos/Query/Internal/CosmosShapedQueryCompilingExpressionVisitor.CosmosProjectionBindingRemovingExpressionVisitor.cs +++ b/src/EFCore.Cosmos/Query/Internal/CosmosShapedQueryCompilingExpressionVisitor.CosmosProjectionBindingRemovingExpressionVisitor.cs @@ -187,7 +187,7 @@ protected override Expression VisitMethodCall(MethodCallExpression methodCallExp var method = methodCallExpression.Method; var genericMethod = method.IsGenericMethod ? method.GetGenericMethodDefinition() : null; - if (genericMethod == EntityMaterializerSource.TryReadValueMethod) + if (genericMethod == EntityFrameworkCore.Infrastructure.ExpressionExtensions.ValueBufferTryReadValueMethod) { var property = (IProperty)((ConstantExpression)methodCallExpression.Arguments[2]).Value; Expression innerExpression; diff --git a/src/EFCore.InMemory/Query/Internal/InMemoryExpressionTranslatingExpressionVisitor.cs b/src/EFCore.InMemory/Query/Internal/InMemoryExpressionTranslatingExpressionVisitor.cs index edaee096156..b1a31ce5550 100644 --- a/src/EFCore.InMemory/Query/Internal/InMemoryExpressionTranslatingExpressionVisitor.cs +++ b/src/EFCore.InMemory/Query/Internal/InMemoryExpressionTranslatingExpressionVisitor.cs @@ -17,6 +17,7 @@ using Microsoft.EntityFrameworkCore.Query; using Microsoft.EntityFrameworkCore.Storage; using Microsoft.EntityFrameworkCore.Utilities; +using ExpressionExtensions = Microsoft.EntityFrameworkCore.Infrastructure.ExpressionExtensions; namespace Microsoft.EntityFrameworkCore.InMemory.Query.Internal { @@ -399,7 +400,7 @@ protected override Expression VisitMethodCall(MethodCallExpression methodCallExp Check.NotNull(methodCallExpression, nameof(methodCallExpression)); if (methodCallExpression.Method.IsGenericMethod - && methodCallExpression.Method.GetGenericMethodDefinition() == EntityMaterializerSource.TryReadValueMethod) + && methodCallExpression.Method.GetGenericMethodDefinition() == ExpressionExtensions.ValueBufferTryReadValueMethod) { return methodCallExpression; } diff --git a/src/EFCore.InMemory/Query/Internal/InMemoryQueryExpression.cs b/src/EFCore.InMemory/Query/Internal/InMemoryQueryExpression.cs index 6bed3a3287b..f4efd3fedc4 100644 --- a/src/EFCore.InMemory/Query/Internal/InMemoryQueryExpression.cs +++ b/src/EFCore.InMemory/Query/Internal/InMemoryQueryExpression.cs @@ -8,12 +8,14 @@ using System.Reflection; using JetBrains.Annotations; using Microsoft.EntityFrameworkCore.Diagnostics; +using Microsoft.EntityFrameworkCore.Infrastructure; using Microsoft.EntityFrameworkCore.InMemory.Internal; using Microsoft.EntityFrameworkCore.Metadata; using Microsoft.EntityFrameworkCore.Metadata.Internal; using Microsoft.EntityFrameworkCore.Query; using Microsoft.EntityFrameworkCore.Storage; using Microsoft.EntityFrameworkCore.Utilities; +using ExpressionExtensions = Microsoft.EntityFrameworkCore.Infrastructure.ExpressionExtensions; namespace Microsoft.EntityFrameworkCore.InMemory.Query.Internal { @@ -324,7 +326,7 @@ public virtual void ApplyDefaultIfEmpty() private static IPropertyBase InferPropertyFromInner(Expression expression) => expression is MethodCallExpression methodCallExpression && methodCallExpression.Method.IsGenericMethod - && methodCallExpression.Method.GetGenericMethodDefinition() == EntityMaterializerSource.TryReadValueMethod + && methodCallExpression.Method.GetGenericMethodDefinition() == ExpressionExtensions.ValueBufferTryReadValueMethod ? (IPropertyBase)((ConstantExpression)methodCallExpression.Arguments[2]).Value : null; @@ -440,24 +442,15 @@ private Expression GetGroupingKey(Expression key, List groupingExpre default: var index = groupingExpressions.Count; groupingExpressions.Add(key); - return CreateReadValueExpression( - groupingKeyAccessExpression, + return groupingKeyAccessExpression.CreateValueBufferReadValueExpression( key.Type, index, InferPropertyFromInner(key)); } } - private static Expression CreateReadValueExpression( - Expression valueBufferParameter, Type type, int index, IPropertyBase property) - => Call( - EntityMaterializerSource.TryReadValueMethod.MakeGenericMethod(type), - valueBufferParameter, - Constant(index), - Constant(property, typeof(IPropertyBase))); - private Expression CreateReadValueExpression(Type type, int index, IPropertyBase property) - => CreateReadValueExpression(_valueBufferParameter, type, index, property); + => _valueBufferParameter.CreateValueBufferReadValueExpression(type, index, property); public virtual void AddInnerJoin( [NotNull] InMemoryQueryExpression innerQueryExpression, @@ -943,11 +936,11 @@ protected override Expression VisitMethodCall(MethodCallExpression methodCallExp Check.NotNull(methodCallExpression, nameof(methodCallExpression)); if (methodCallExpression.Method.IsGenericMethod - && methodCallExpression.Method.GetGenericMethodDefinition() == EntityMaterializerSource.TryReadValueMethod + && methodCallExpression.Method.GetGenericMethodDefinition() == ExpressionExtensions.ValueBufferTryReadValueMethod && !methodCallExpression.Type.IsNullableType()) { return Call( - EntityMaterializerSource.TryReadValueMethod.MakeGenericMethod(methodCallExpression.Type.MakeNullable()), + ExpressionExtensions.ValueBufferTryReadValueMethod.MakeGenericMethod(methodCallExpression.Type.MakeNullable()), methodCallExpression.Arguments); } diff --git a/src/EFCore.InMemory/Query/Internal/InMemoryShapedQueryCompilingExpressionVisitor.InMemoryProjectionBindingRemovingExpressionVisitor.cs b/src/EFCore.InMemory/Query/Internal/InMemoryShapedQueryCompilingExpressionVisitor.InMemoryProjectionBindingRemovingExpressionVisitor.cs index 6934d61f31d..677f64cbe07 100644 --- a/src/EFCore.InMemory/Query/Internal/InMemoryShapedQueryCompilingExpressionVisitor.InMemoryProjectionBindingRemovingExpressionVisitor.cs +++ b/src/EFCore.InMemory/Query/Internal/InMemoryShapedQueryCompilingExpressionVisitor.InMemoryProjectionBindingRemovingExpressionVisitor.cs @@ -9,6 +9,7 @@ using Microsoft.EntityFrameworkCore.Query; using Microsoft.EntityFrameworkCore.Storage; using Microsoft.EntityFrameworkCore.Utilities; +using ExpressionExtensions = Microsoft.EntityFrameworkCore.Infrastructure.ExpressionExtensions; namespace Microsoft.EntityFrameworkCore.InMemory.Query.Internal { @@ -61,7 +62,7 @@ protected override Expression VisitMethodCall(MethodCallExpression methodCallExp Check.NotNull(methodCallExpression, nameof(methodCallExpression)); if (methodCallExpression.Method.IsGenericMethod - && methodCallExpression.Method.GetGenericMethodDefinition() == EntityMaterializerSource.TryReadValueMethod) + && methodCallExpression.Method.GetGenericMethodDefinition() == ExpressionExtensions.ValueBufferTryReadValueMethod) { var property = (IProperty)((ConstantExpression)methodCallExpression.Arguments[2]).Value; var (indexMap, valueBuffer) = @@ -88,11 +89,10 @@ protected override Expression VisitExtension(Expression extensionExpression) var projectionIndex = (int)GetProjectionIndex(queryExpression, projectionBindingExpression); var valueBuffer = queryExpression.CurrentParameter; - return Expression.Call( - EntityMaterializerSource.TryReadValueMethod.MakeGenericMethod(projectionBindingExpression.Type), - valueBuffer, - Expression.Constant(projectionIndex), - Expression.Constant(InferPropertyFromInner(queryExpression.Projection[projectionIndex]), typeof(IPropertyBase))); + return valueBuffer.CreateValueBufferReadValueExpression( + projectionBindingExpression.Type, + projectionIndex, + InferPropertyFromInner(queryExpression.Projection[projectionIndex])); } return base.VisitExtension(extensionExpression); @@ -102,7 +102,7 @@ private IPropertyBase InferPropertyFromInner(Expression expression) { if (expression is MethodCallExpression methodCallExpression && methodCallExpression.Method.IsGenericMethod - && methodCallExpression.Method.GetGenericMethodDefinition() == EntityMaterializerSource.TryReadValueMethod) + && methodCallExpression.Method.GetGenericMethodDefinition() == ExpressionExtensions.ValueBufferTryReadValueMethod) { return (IPropertyBase)((ConstantExpression)methodCallExpression.Arguments[2]).Value; } diff --git a/src/EFCore.Relational/Query/RelationalShapedQueryCompilingExpressionVisitor.RelationalProjectionBindingRemovingExpressionVisitor.cs b/src/EFCore.Relational/Query/RelationalShapedQueryCompilingExpressionVisitor.RelationalProjectionBindingRemovingExpressionVisitor.cs index f8735eba83f..d159e00482c 100644 --- a/src/EFCore.Relational/Query/RelationalShapedQueryCompilingExpressionVisitor.RelationalProjectionBindingRemovingExpressionVisitor.cs +++ b/src/EFCore.Relational/Query/RelationalShapedQueryCompilingExpressionVisitor.RelationalProjectionBindingRemovingExpressionVisitor.cs @@ -92,7 +92,7 @@ protected override Expression VisitMethodCall(MethodCallExpression methodCallExp Check.NotNull(methodCallExpression, nameof(methodCallExpression)); if (methodCallExpression.Method.IsGenericMethod - && methodCallExpression.Method.GetGenericMethodDefinition() == EntityMaterializerSource.TryReadValueMethod) + && methodCallExpression.Method.GetGenericMethodDefinition() == Infrastructure.ExpressionExtensions.ValueBufferTryReadValueMethod) { var property = (IProperty)((ConstantExpression)methodCallExpression.Arguments[2]).Value; var propertyProjectionMap = methodCallExpression.Arguments[0] is ProjectionBindingExpression projectionBindingExpression diff --git a/src/EFCore/Infrastructure/ExpressionExtensions.cs b/src/EFCore/Infrastructure/ExpressionExtensions.cs index 48d5a8b970e..e7f6c80df1c 100644 --- a/src/EFCore/Infrastructure/ExpressionExtensions.cs +++ b/src/EFCore/Infrastructure/ExpressionExtensions.cs @@ -6,11 +6,13 @@ using System.Linq; using System.Linq.Expressions; using System.Reflection; +using System.Runtime.CompilerServices; using JetBrains.Annotations; using Microsoft.EntityFrameworkCore.Diagnostics; using Microsoft.EntityFrameworkCore.Internal; using Microsoft.EntityFrameworkCore.Metadata; using Microsoft.EntityFrameworkCore.Query; +using Microsoft.EntityFrameworkCore.Storage; using Microsoft.EntityFrameworkCore.Utilities; namespace Microsoft.EntityFrameworkCore.Infrastructure @@ -225,5 +227,48 @@ var propertyPaths return propertyPaths; } + + /// + /// + /// Creates an tree representing reading a value from a + /// + /// + /// This method is typically used by database providers (and other extensions). It is generally + /// not used in application code. + /// + /// + /// The expression that exposes the . + /// The type to read. + /// The index in the buffer to read from. + /// The IPropertyBase being read if any. + /// An expression to read the value. + public static Expression CreateValueBufferReadValueExpression( + [NotNull] this Expression valueBuffer, + [NotNull] Type type, + int index, + [CanBeNull] IPropertyBase property) + => Expression.Call( + ValueBufferTryReadValueMethod.MakeGenericMethod(type), + valueBuffer, + Expression.Constant(index), + Expression.Constant(property, typeof(IPropertyBase))); + + /// + /// + /// MethodInfo which is used to generate an tree representing reading a value from a + /// + /// + /// This method is typically used by database providers (and other extensions). It is generally + /// not used in application code. + /// + /// + public static readonly MethodInfo ValueBufferTryReadValueMethod + = typeof(ExpressionExtensions).GetTypeInfo() + .GetDeclaredMethod(nameof(ValueBufferTryReadValue)); + + [MethodImpl(MethodImplOptions.AggressiveInlining)] + private static TValue ValueBufferTryReadValue( + in ValueBuffer valueBuffer, int index, IPropertyBase property) + => valueBuffer[index] is TValue value ? value : default; } } diff --git a/src/EFCore/Metadata/PropertyParameterBinding.cs b/src/EFCore/Metadata/PropertyParameterBinding.cs index 3d70dde51ad..73d02ca15fb 100644 --- a/src/EFCore/Metadata/PropertyParameterBinding.cs +++ b/src/EFCore/Metadata/PropertyParameterBinding.cs @@ -3,6 +3,7 @@ using System.Linq.Expressions; using JetBrains.Annotations; +using Microsoft.EntityFrameworkCore.Infrastructure; using Microsoft.EntityFrameworkCore.Query; using Microsoft.EntityFrameworkCore.Storage; @@ -33,11 +34,8 @@ public override Expression BindToParameter(ParameterBindingInfo bindingInfo) { var property = ConsumedProperties[0]; - return Expression.Call( - EntityMaterializerSource.TryReadValueMethod.MakeGenericMethod(property.ClrType), - Expression.Call(bindingInfo.MaterializationContextExpression, MaterializationContext.GetValueBufferMethod), - Expression.Constant(bindingInfo.GetValueBufferIndex(property)), - Expression.Constant(property, typeof(IPropertyBase))); + return Expression.Call(bindingInfo.MaterializationContextExpression, MaterializationContext.GetValueBufferMethod) + .CreateValueBufferReadValueExpression(property.ClrType, bindingInfo.GetValueBufferIndex(property), property); } } } diff --git a/src/EFCore/Query/EntityMaterializerSource.cs b/src/EFCore/Query/EntityMaterializerSource.cs index 84b25be969c..fd6b3896a42 100644 --- a/src/EFCore/Query/EntityMaterializerSource.cs +++ b/src/EFCore/Query/EntityMaterializerSource.cs @@ -34,26 +34,6 @@ public EntityMaterializerSource([NotNull] EntityMaterializerSourceDependencies d { } - public virtual Expression CreateReadValueExpression( - Expression valueBufferExpression, - Type type, - int index, - IPropertyBase property) - => Expression.Call( - TryReadValueMethod.MakeGenericMethod(type), - valueBufferExpression, - Expression.Constant(index), - Expression.Constant(property, typeof(IPropertyBase))); - - public static readonly MethodInfo TryReadValueMethod - = typeof(EntityMaterializerSource).GetTypeInfo() - .GetDeclaredMethod(nameof(TryReadValue)); - - [MethodImpl(MethodImplOptions.AggressiveInlining)] - private static TValue TryReadValue( - in ValueBuffer valueBuffer, int index, IPropertyBase property) - => valueBuffer[index] is TValue value ? value : default; - public virtual Expression CreateMaterializeExpression( IEntityType entityType, string entityInstanceName, @@ -127,8 +107,7 @@ var blockExpressions var readValueExpression = property is IServiceProperty serviceProperty ? serviceProperty.GetParameterBinding().BindToParameter(bindingInfo) - : CreateReadValueExpression( - valueBufferExpression, + : valueBufferExpression.CreateValueBufferReadValueExpression( memberInfo.GetMemberType(), property.GetIndex(), property); diff --git a/src/EFCore/Query/EntityShaperExpression.cs b/src/EFCore/Query/EntityShaperExpression.cs index c68aa5b45b6..d0e34bedd61 100644 --- a/src/EFCore/Query/EntityShaperExpression.cs +++ b/src/EFCore/Query/EntityShaperExpression.cs @@ -2,31 +2,122 @@ // Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information. using System; +using System.Collections.Generic; +using System.Linq; using System.Linq.Expressions; +using System.Reflection; using JetBrains.Annotations; +using Microsoft.EntityFrameworkCore.Diagnostics; +using Microsoft.EntityFrameworkCore.Infrastructure; using Microsoft.EntityFrameworkCore.Metadata; +using Microsoft.EntityFrameworkCore.Metadata.Internal; +using Microsoft.EntityFrameworkCore.Storage; using Microsoft.EntityFrameworkCore.Utilities; namespace Microsoft.EntityFrameworkCore.Query { public class EntityShaperExpression : Expression, IPrintableExpression { + private static readonly MethodInfo _createUnableToDiscriminateException + = typeof(EntityShaperExpression).GetTypeInfo() + .GetDeclaredMethod(nameof(CreateUnableToDiscriminateException)); + + [UsedImplicitly] + private static Exception CreateUnableToDiscriminateException(IEntityType entityType, object discriminator) + => new InvalidOperationException(CoreStrings.UnableToDiscriminate(entityType.DisplayName(), discriminator?.ToString())); + public EntityShaperExpression( [NotNull] IEntityType entityType, [NotNull] Expression valueBufferExpression, bool nullable) + : this(entityType, valueBufferExpression, nullable, null) + { + } + + protected EntityShaperExpression( + [NotNull] IEntityType entityType, + [NotNull] Expression valueBufferExpression, + bool nullable, + [CanBeNull] LambdaExpression discriminatorCondition) { Check.NotNull(entityType, nameof(entityType)); Check.NotNull(valueBufferExpression, nameof(valueBufferExpression)); + if (discriminatorCondition == null) + { + discriminatorCondition = GenerateDiscriminatorCondition(entityType, nullable); + } + else if (discriminatorCondition.Parameters.Count != 1 + || discriminatorCondition.Parameters[0].Type != typeof(ValueBuffer) + || discriminatorCondition.ReturnType != typeof(IEntityType)) + { + throw new InvalidOperationException( + "Discriminator condition must be lambda expression of type Func."); + } + EntityType = entityType; ValueBufferExpression = valueBufferExpression; IsNullable = nullable; + DiscriminatorCondition = discriminatorCondition; + } + + private LambdaExpression GenerateDiscriminatorCondition(IEntityType entityType, bool nullable) + { + var valueBufferParameter = Parameter(typeof(ValueBuffer)); + Expression body; + var concreteEntityTypes = entityType.GetConcreteDerivedTypesInclusive().ToArray(); + var discriminatorProperty = entityType.GetDiscriminatorProperty(); + if (discriminatorProperty != null) + { + var discriminatorValueVariable = Variable(discriminatorProperty.ClrType, "discriminator"); + var expressions = new List + { + Assign( + discriminatorValueVariable, + valueBufferParameter.CreateValueBufferReadValueExpression( + discriminatorProperty.ClrType, discriminatorProperty.GetIndex(), discriminatorProperty)) + }; + + var switchCases = new SwitchCase[concreteEntityTypes.Length]; + for (var i = 0; i < concreteEntityTypes.Length; i++) + { + var discriminatorValue = Constant(concreteEntityTypes[i].GetDiscriminatorValue(), discriminatorProperty.ClrType); + switchCases[i] = SwitchCase(Constant(concreteEntityTypes[i], typeof(IEntityType)), discriminatorValue); + } + + var exception = Block( + Throw(Call( + _createUnableToDiscriminateException, Constant(entityType), Convert(discriminatorValueVariable, typeof(object)))), + Constant(null, typeof(IEntityType))); + + expressions.Add(Switch(discriminatorValueVariable, exception, switchCases)); + body = Block(new[] { discriminatorValueVariable }, expressions); + } + else + { + body = Constant(concreteEntityTypes.Length == 1 ? concreteEntityTypes[0] : entityType, typeof(IEntityType)); + } + + if (entityType.FindPrimaryKey() == null + && nullable) + { + body = Condition( + entityType.GetProperties() + .Select(p => NotEqual( + valueBufferParameter.CreateValueBufferReadValueExpression(typeof(object), p.GetIndex(), p), + Constant(null))) + .Aggregate((a, b) => OrElse(a, b)), + body, + Default(typeof(IEntityType))); + } + + return Lambda(body, valueBufferParameter); } public virtual IEntityType EntityType { get; } public virtual Expression ValueBufferExpression { get; } public virtual bool IsNullable { get; } + public virtual LambdaExpression DiscriminatorCondition { get; } protected override Expression VisitChildren(ExpressionVisitor visitor) { @@ -48,6 +139,7 @@ public virtual EntityShaperExpression WithEntityType([NotNull] IEntityType entit public virtual EntityShaperExpression MarkAsNullable() => !IsNullable + // Marking nullable requires recomputation of Discriminator condition ? new EntityShaperExpression(EntityType, ValueBufferExpression, true) : this; @@ -56,7 +148,7 @@ public virtual EntityShaperExpression Update([NotNull] Expression valueBufferExp Check.NotNull(valueBufferExpression, nameof(valueBufferExpression)); return valueBufferExpression != ValueBufferExpression - ? new EntityShaperExpression(EntityType, valueBufferExpression, IsNullable) + ? new EntityShaperExpression(EntityType, valueBufferExpression, IsNullable, DiscriminatorCondition) : this; } diff --git a/src/EFCore/Query/IEntityMaterializerSource.cs b/src/EFCore/Query/IEntityMaterializerSource.cs index 3a2a52eb4cd..1dd57bd9fb0 100644 --- a/src/EFCore/Query/IEntityMaterializerSource.cs +++ b/src/EFCore/Query/IEntityMaterializerSource.cs @@ -27,26 +27,6 @@ namespace Microsoft.EntityFrameworkCore.Query /// public interface IEntityMaterializerSource { - /// - /// - /// Creates an tree representing reading a value from a - /// - /// - /// This method is typically used by database providers (and other extensions). It is generally - /// not used in application code. - /// - /// - /// The expression that exposes the . - /// The type to read. - /// The index in the buffer to read from. - /// The IPropertyBase being read if any. - /// An expression to read the value. - Expression CreateReadValueExpression( - [NotNull] Expression valueBuffer, - [NotNull] Type type, - int index, - [CanBeNull] IPropertyBase property); - /// /// /// Creates an tree representing creating an entity instance. diff --git a/src/EFCore/Query/ShapedQueryCompilingExpressionVisitor.cs b/src/EFCore/Query/ShapedQueryCompilingExpressionVisitor.cs index d55ffd4df5d..2a6e633f5f7 100644 --- a/src/EFCore/Query/ShapedQueryCompilingExpressionVisitor.cs +++ b/src/EFCore/Query/ShapedQueryCompilingExpressionVisitor.cs @@ -11,6 +11,7 @@ using JetBrains.Annotations; using Microsoft.EntityFrameworkCore.ChangeTracking.Internal; using Microsoft.EntityFrameworkCore.Diagnostics; +using Microsoft.EntityFrameworkCore.Infrastructure; using Microsoft.EntityFrameworkCore.Internal; using Microsoft.EntityFrameworkCore.Metadata; using Microsoft.EntityFrameworkCore.Metadata.Internal; @@ -372,8 +373,7 @@ private Expression ProcessEntityShaper(EntityShaperExpression entityShaperExpres typeof(object), primaryKey.Properties .Select( - p => _entityMaterializerSource.CreateReadValueExpression( - valueBufferExpression, + p => valueBufferExpression.CreateValueBufferReadValueExpression( typeof(object), p.GetIndex(), p))), @@ -396,50 +396,27 @@ private Expression ProcessEntityShaper(EntityShaperExpression entityShaperExpres Expression.MakeMemberAccess(entryVariable, _entityMemberInfo), entityType.ClrType))), MaterializeEntity( - entityType, materializationContextVariable, concreteEntityTypeVariable, instanceVariable, + entityShaperExpression, materializationContextVariable, concreteEntityTypeVariable, instanceVariable, entryVariable)))); } else { if (primaryKey != null) { - expressions.Add(Expression.IfThen( primaryKey.Properties.Select( p => Expression.NotEqual( - _entityMaterializerSource.CreateReadValueExpression( - valueBufferExpression, typeof(object), p.GetIndex(), p), + valueBufferExpression.CreateValueBufferReadValueExpression(typeof(object), p.GetIndex(), p), Expression.Constant(null))) .Aggregate((a, b) => Expression.AndAlso(a, b)), MaterializeEntity( - entityType, materializationContextVariable, concreteEntityTypeVariable, instanceVariable, null))); + entityShaperExpression, materializationContextVariable, concreteEntityTypeVariable, instanceVariable, null))); } else { - if (entityShaperExpression.IsNullable) - { - expressions.Add( - Expression.IfThen( - entityType.GetProperties() - .Select( - p => - Expression.NotEqual( - _entityMaterializerSource.CreateReadValueExpression( - valueBufferExpression, - typeof(object), - p.GetIndex(), - p), - Expression.Constant(null))) - .Aggregate((a, b) => Expression.OrElse(a, b)), - MaterializeEntity( - entityType, materializationContextVariable, concreteEntityTypeVariable, instanceVariable, null))); - } - else - { - expressions.Add( - MaterializeEntity( - entityType, materializationContextVariable, concreteEntityTypeVariable, instanceVariable, null)); - } + expressions.Add( + MaterializeEntity( + entityShaperExpression, materializationContextVariable, concreteEntityTypeVariable, instanceVariable, null)); } } @@ -448,12 +425,14 @@ private Expression ProcessEntityShaper(EntityShaperExpression entityShaperExpres } private Expression MaterializeEntity( - IEntityType entityType, + EntityShaperExpression entityShaperExpression, ParameterExpression materializationContextVariable, ParameterExpression concreteEntityTypeVariable, ParameterExpression instanceVariable, ParameterExpression entryVariable) { + var entityType = entityShaperExpression.EntityType; + var expressions = new List(); var variables = new List(); @@ -470,50 +449,34 @@ private Expression MaterializeEntity( Expression materializationExpression; var valueBufferExpression = Expression.Call(materializationContextVariable, MaterializationContext.GetValueBufferMethod); var expressionContext = (returnType, materializationContextVariable, concreteEntityTypeVariable, shadowValuesVariable); - var concreteEntityTypes = entityType.GetConcreteDerivedTypesInclusive().ToList(); - var firstEntityType = concreteEntityTypes[0]; - if (concreteEntityTypes.Count == 1) + expressions.Add( + Expression.Assign(concreteEntityTypeVariable, + ReplacingExpressionVisitor.Replace( + entityShaperExpression.DiscriminatorCondition.Parameters[0], + valueBufferExpression, + entityShaperExpression.DiscriminatorCondition.Body))); + + var concreteEntityTypes = entityType.GetConcreteDerivedTypesInclusive().ToArray(); + var discriminatorProperty = entityType.GetDiscriminatorProperty(); + if (discriminatorProperty == null + && concreteEntityTypes.Length > 1) { - materializationExpression = CreateFullMaterializeExpression(firstEntityType, expressionContext); + concreteEntityTypes = new [] { entityType }; } - else - { - var discriminatorProperty = firstEntityType.GetDiscriminatorProperty(); - var discriminatorValueVariable = Expression.Variable( - discriminatorProperty.ClrType, "discriminator" + _currentEntityIndex); - variables.Add(discriminatorValueVariable); - - expressions.Add( - Expression.Assign( - discriminatorValueVariable, - _entityMaterializerSource.CreateReadValueExpression( - valueBufferExpression, - discriminatorProperty.ClrType, - discriminatorProperty.GetIndex(), - discriminatorProperty))); - - materializationExpression = Expression.Block( - Expression.Throw( - Expression.Call( - _createUnableToDiscriminateException, - Expression.Constant(entityType), - Expression.Convert(discriminatorValueVariable, typeof(object)))), - Expression.Constant(null, returnType)); - foreach (var concreteEntityType in concreteEntityTypes) - { - var discriminatorValue - = Expression.Constant( - concreteEntityType.GetDiscriminatorValue(), - discriminatorProperty.ClrType); - - materializationExpression = Expression.Condition( - Expression.Equal(discriminatorValueVariable, discriminatorValue), - CreateFullMaterializeExpression(concreteEntityType, expressionContext), - materializationExpression); - } + var switchCases = new SwitchCase[concreteEntityTypes.Length]; + for (var i = 0; i < concreteEntityTypes.Length; i++) + { + switchCases[i] = Expression.SwitchCase( + CreateFullMaterializeExpression(concreteEntityTypes[i], expressionContext), + Expression.Constant(concreteEntityTypes[i], typeof(IEntityType))); } + materializationExpression = Expression.Switch( + concreteEntityTypeVariable, + Expression.Constant(null, returnType), + switchCases); + expressions.Add(Expression.Assign(instanceVariable, materializationExpression)); if (_trackQueryResults @@ -554,12 +517,7 @@ private BlockExpression CreateFullMaterializeExpression( concreteEntityTypeVariable, shadowValuesVariable) = materializeExpressionContext; - var blockExpressions = new List(3) - { - Expression.Assign( - concreteEntityTypeVariable, - Expression.Constant(concreteEntityType)) - }; + var blockExpressions = new List(2); var materializer = _entityMaterializerSource .CreateMaterializeExpression(concreteEntityType, "instance", materializationContextVariable); @@ -578,11 +536,7 @@ private BlockExpression CreateFullMaterializeExpression( Expression.NewArrayInit( typeof(object), shadowProperties.Select( - p => _entityMaterializerSource.CreateReadValueExpression( - valueBufferExpression, - typeof(object), - p.GetIndex(), - p)))))); + p => valueBufferExpression.CreateValueBufferReadValueExpression(typeof(object), p.GetIndex(), p)))))); } materializer = materializer.Type == returnType