diff --git a/src/EFCore.Cosmos/Query/Internal/SqlExpressionFactory.cs b/src/EFCore.Cosmos/Query/Internal/SqlExpressionFactory.cs index e565a2f3855..2070f1b9fbc 100644 --- a/src/EFCore.Cosmos/Query/Internal/SqlExpressionFactory.cs +++ b/src/EFCore.Cosmos/Query/Internal/SqlExpressionFactory.cs @@ -533,10 +533,11 @@ private void AddDiscriminator(SelectExpression selectExpression, IEntityType ent if (concreteEntityTypes.Count == 1) { var concreteEntityType = concreteEntityTypes[0]; - if (concreteEntityType.GetDiscriminatorProperty() != null) + var discriminatorProperty = concreteEntityType.GetDiscriminatorProperty(); + if (discriminatorProperty != null) { var discriminatorColumn = ((EntityProjectionExpression)selectExpression.GetMappedProjection(new ProjectionMember())) - .BindProperty(concreteEntityType.GetDiscriminatorProperty(), clientEval: false); + .BindProperty(discriminatorProperty, clientEval: false); selectExpression.ApplyPredicate( Equal((SqlExpression)discriminatorColumn, Constant(concreteEntityType.GetDiscriminatorValue()))); diff --git a/src/EFCore.InMemory/Query/Internal/InMemoryExpressionTranslatingExpressionVisitor.cs b/src/EFCore.InMemory/Query/Internal/InMemoryExpressionTranslatingExpressionVisitor.cs index 31a74ec69eb..16af3723528 100644 --- a/src/EFCore.InMemory/Query/Internal/InMemoryExpressionTranslatingExpressionVisitor.cs +++ b/src/EFCore.InMemory/Query/Internal/InMemoryExpressionTranslatingExpressionVisitor.cs @@ -965,8 +965,9 @@ protected override Expression VisitTypeBinary(TypeBinaryExpression typeBinaryExp { var discriminatorProperty = entityType.GetDiscriminatorProperty(); var boundProperty = BindProperty(entityReferenceExpression, discriminatorProperty, discriminatorProperty.ClrType); + var valueComparer = discriminatorProperty.GetKeyValueComparer(); - var equals = Expression.Equal( + var equals = valueComparer.ExtractEqualsBody( boundProperty, Expression.Constant(derivedType.GetDiscriminatorValue(), discriminatorProperty.ClrType)); @@ -974,7 +975,7 @@ protected override Expression VisitTypeBinary(TypeBinaryExpression typeBinaryExp { equals = Expression.OrElse( equals, - Expression.Equal( + valueComparer.ExtractEqualsBody( boundProperty, Expression.Constant(derivedDerivedType.GetDiscriminatorValue(), discriminatorProperty.ClrType))); } diff --git a/src/EFCore.InMemory/Query/Internal/InMemoryQueryExpression.cs b/src/EFCore.InMemory/Query/Internal/InMemoryQueryExpression.cs index d4b1bfe5c5d..caa220282ec 100644 --- a/src/EFCore.InMemory/Query/Internal/InMemoryQueryExpression.cs +++ b/src/EFCore.InMemory/Query/Internal/InMemoryQueryExpression.cs @@ -105,7 +105,10 @@ public InMemoryQueryExpression([NotNull] IEntityType entityType) foreach (var derivedEntityType in entityType.GetDerivedTypes()) { var entityCheck = derivedEntityType.GetConcreteDerivedTypesInclusive() - .Select(e => Equal(readExpressionMap[discriminatorProperty], Constant(e.GetDiscriminatorValue()))) + .Select( + e => discriminatorProperty.GetKeyValueComparer().ExtractEqualsBody( + readExpressionMap[discriminatorProperty], + Constant(e.GetDiscriminatorValue(), discriminatorProperty.ClrType))) .Aggregate((l, r) => OrElse(l, r)); foreach (var property in derivedEntityType.GetDeclaredProperties()) diff --git a/src/EFCore/ChangeTracking/ValueComparer.cs b/src/EFCore/ChangeTracking/ValueComparer.cs index c522ba4a061..8f16c996f56 100644 --- a/src/EFCore/ChangeTracking/ValueComparer.cs +++ b/src/EFCore/ChangeTracking/ValueComparer.cs @@ -3,6 +3,7 @@ using System; using System.Collections; +using System.Collections.Generic; using System.Linq; using System.Linq.Expressions; using System.Reflection; @@ -27,7 +28,7 @@ namespace Microsoft.EntityFrameworkCore.ChangeTracking /// reference. /// /// - public abstract class ValueComparer : IEqualityComparer + public abstract class ValueComparer : IEqualityComparer, IEqualityComparer { private protected static readonly MethodInfo _doubleEqualsMethodInfo = typeof(double).GetRuntimeMethod(nameof(double.Equals), new[] { typeof(double) }); diff --git a/src/EFCore/Infrastructure/ModelValidator.cs b/src/EFCore/Infrastructure/ModelValidator.cs index 0b6f589827f..3c898a19d94 100644 --- a/src/EFCore/Infrastructure/ModelValidator.cs +++ b/src/EFCore/Infrastructure/ModelValidator.cs @@ -638,19 +638,21 @@ protected virtual void ValidateInheritanceMapping( /// The entity type to validate. protected virtual void ValidateDiscriminatorValues([NotNull] IEntityType rootEntityType) { - var discriminatorValues = new Dictionary(); var derivedTypes = rootEntityType.GetDerivedTypesInclusive().ToList(); if (derivedTypes.Count == 1) { return; } - if (rootEntityType.GetDiscriminatorProperty() == null) + var discriminatorProperty = rootEntityType.GetDiscriminatorProperty(); + if (discriminatorProperty == null) { throw new InvalidOperationException( CoreStrings.NoDiscriminatorProperty(rootEntityType.DisplayName())); } + var discriminatorValues = new Dictionary(discriminatorProperty.GetKeyValueComparer()); + foreach (var derivedType in derivedTypes) { if (derivedType.ClrType?.IsInstantiable() != true) diff --git a/src/EFCore/Query/EntityShaperExpression.cs b/src/EFCore/Query/EntityShaperExpression.cs index b2158ebd4bf..280db9e2661 100644 --- a/src/EFCore/Query/EntityShaperExpression.cs +++ b/src/EFCore/Query/EntityShaperExpression.cs @@ -7,9 +7,11 @@ using System.Linq.Expressions; using System.Reflection; using JetBrains.Annotations; +using Microsoft.EntityFrameworkCore.ChangeTracking; using Microsoft.EntityFrameworkCore.Diagnostics; using Microsoft.EntityFrameworkCore.Infrastructure; using Microsoft.EntityFrameworkCore.Metadata; +using Microsoft.EntityFrameworkCore.Query.Internal; using Microsoft.EntityFrameworkCore.Storage; using Microsoft.EntityFrameworkCore.Utilities; @@ -123,7 +125,26 @@ protected virtual LambdaExpression GenerateMaterializationCondition([NotNull] IE Convert(discriminatorValueVariable, typeof(object)))), Constant(null, typeof(IEntityType))); - expressions.Add(Switch(discriminatorValueVariable, exception, switchCases)); + + var discriminatorComparer = discriminatorProperty.GetKeyValueComparer(); + if (discriminatorComparer.IsDefault()) + { + expressions.Add(Switch(discriminatorValueVariable, exception, switchCases)); + } + else + { + var staticComparer = typeof(StaticDiscriminatorComparer<,,>).MakeGenericType( + discriminatorProperty.DeclaringEntityType.ClrType, + discriminatorProperty.ClrType, + discriminatorProperty.GetTypeMapping().Converter.ProviderClrType); + + var comparerField = staticComparer.GetField(nameof(StaticDiscriminatorComparer.Comparer)); + comparerField.SetValue(null, discriminatorComparer); + + var equalsMethod = staticComparer.GetMethod(nameof(StaticDiscriminatorComparer.DiscriminatorEquals)); + expressions.Add(Switch(discriminatorValueVariable, exception, equalsMethod, switchCases)); + } + body = Block(new[] { discriminatorValueVariable }, expressions); } else diff --git a/src/EFCore/Query/Internal/StaticDiscriminatorComparer.cs b/src/EFCore/Query/Internal/StaticDiscriminatorComparer.cs new file mode 100644 index 00000000000..21c87e3b17c --- /dev/null +++ b/src/EFCore/Query/Internal/StaticDiscriminatorComparer.cs @@ -0,0 +1,35 @@ +// Copyright (c) .NET Foundation. All rights reserved. +// Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information. + +using JetBrains.Annotations; +using Microsoft.EntityFrameworkCore.ChangeTracking; + +namespace Microsoft.EntityFrameworkCore.Query.Internal +{ + /// + /// This is an internal API that supports the Entity Framework Core infrastructure and not subject to + /// the same compatibility standards as public APIs. It may be changed or removed without notice in + /// any release. You should only use it directly in your code with extreme caution and knowing that + /// doing so can result in application failures when updating to a new Entity Framework Core release. + /// + // ReSharper disable twice UnusedTypeParameter + public static class StaticDiscriminatorComparer + { + /// + /// This is an internal API that supports the Entity Framework Core infrastructure and not subject to + /// the same compatibility standards as public APIs. It may be changed or removed without notice in + /// any release. You should only use it directly in your code with extreme caution and knowing that + /// doing so can result in application failures when updating to a new Entity Framework Core release. + /// + public static ValueComparer Comparer; + + /// + /// This is an internal API that supports the Entity Framework Core infrastructure and not subject to + /// the same compatibility standards as public APIs. It may be changed or removed without notice in + /// any release. You should only use it directly in your code with extreme caution and knowing that + /// doing so can result in application failures when updating to a new Entity Framework Core release. + /// + public static bool DiscriminatorEquals([CanBeNull] TModel x, [CanBeNull] TModel y) + => Comparer.Equals(x, y); + } +} diff --git a/test/EFCore.Specification.Tests/GraphUpdates/GraphUpdatesTestBase.cs b/test/EFCore.Specification.Tests/GraphUpdates/GraphUpdatesTestBase.cs index 69c9399f743..abbffc620e6 100644 --- a/test/EFCore.Specification.Tests/GraphUpdates/GraphUpdatesTestBase.cs +++ b/test/EFCore.Specification.Tests/GraphUpdates/GraphUpdatesTestBase.cs @@ -185,8 +185,25 @@ protected override void OnModelCreating(ModelBuilder modelBuilder, DbContext con .HasForeignKey(e => e.BackId) .OnDelete(DeleteBehavior.SetNull); - modelBuilder.Entity(); - modelBuilder.Entity(); + modelBuilder.Entity( + b => + { + b.HasDiscriminator(e => e.Disc) + .HasValue(new MyDiscriminator(1)) + .HasValue(new MyDiscriminator(2)) + .HasValue(new MyDiscriminator(3)); + + b.Property(e => e.Disc) + .HasConversion( + v => v.Value, + v => new MyDiscriminator(v), + new ValueComparer( + (l, r) => l.Value == r.Value, + v => v.Value.GetHashCode(), + v => new MyDiscriminator(v.Value))) + .Metadata + .SetAfterSaveBehavior(PropertySaveBehavior.Save); + }); modelBuilder.Entity() .HasOne(e => e.Single) @@ -379,10 +396,6 @@ protected override void OnModelCreating(ModelBuilder modelBuilder, DbContext con modelBuilder.Entity() .HasIndex(e => e.BarCode) .IsUnique(); - - modelBuilder.Entity() - .Property("Discriminator") - .Metadata.SetAfterSaveBehavior(PropertySaveBehavior.Save); } protected virtual object CreateFullGraph() @@ -1692,6 +1705,7 @@ protected class OptionalSingle2 : NotifyingEntity { private int _id; private int? _backId; + private MyDiscriminator _disc; private OptionalSingle1 _back; public int Id @@ -1706,6 +1720,12 @@ public int? BackId set => SetWithNotify(value, ref _backId); } + public MyDiscriminator Disc + { + get => _disc; + set => SetWithNotify(value, ref _disc); + } + public OptionalSingle1 Back { get => _back; @@ -1722,6 +1742,20 @@ public override int GetHashCode() => _id; } + protected class MyDiscriminator + { + public MyDiscriminator(int value) + => Value = value; + + public int Value { get; } + + public override bool Equals(object obj) + => throw new InvalidOperationException(); + + public override int GetHashCode() + => throw new InvalidOperationException(); + } + protected class OptionalSingle2Derived : OptionalSingle2 { public override bool Equals(object obj) diff --git a/test/EFCore.Specification.Tests/GraphUpdates/GraphUpdatesTestBaseMiscellaneous.cs b/test/EFCore.Specification.Tests/GraphUpdates/GraphUpdatesTestBaseMiscellaneous.cs index 1cd603e26b6..2af5199e30e 100644 --- a/test/EFCore.Specification.Tests/GraphUpdates/GraphUpdatesTestBaseMiscellaneous.cs +++ b/test/EFCore.Specification.Tests/GraphUpdates/GraphUpdatesTestBaseMiscellaneous.cs @@ -45,23 +45,23 @@ public virtual void Mutating_discriminator_value_can_be_configured_to_allow_muta context => { var instance = context.Set().First(); - var propertyEntry = context.Entry(instance).Property("Discriminator"); + var propertyEntry = context.Entry(instance).Property(e => e.Disc); id = instance.Id; Assert.IsType(instance); - Assert.Equal(nameof(OptionalSingle2Derived), propertyEntry.CurrentValue); + Assert.Equal(2, propertyEntry.CurrentValue.Value); - propertyEntry.CurrentValue = nameof(OptionalSingle2); + propertyEntry.CurrentValue = new MyDiscriminator(1); context.SaveChanges(); }, context => { var instance = context.Set().First(e => e.Id == id); - var propertyEntry = context.Entry(instance).Property("Discriminator"); + var propertyEntry = context.Entry(instance).Property(e => e.Disc); Assert.IsType(instance); - Assert.Equal(nameof(OptionalSingle2), propertyEntry.CurrentValue); + Assert.Equal(1, propertyEntry.CurrentValue.Value); }); }