Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Allow discriminator properties to have value converter and comparer #22425

Merged
merged 1 commit into from
Sep 8, 2020
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 3 additions & 2 deletions src/EFCore.Cosmos/Query/Internal/SqlExpressionFactory.cs
Original file line number Diff line number Diff line change
Expand Up @@ -533,10 +533,11 @@ private void AddDiscriminator(SelectExpression selectExpression, IEntityType ent
if (concreteEntityTypes.Count == 1)
{
var concreteEntityType = concreteEntityTypes[0];
if (concreteEntityType.GetDiscriminatorProperty() != null)
var discriminatorProperty = concreteEntityType.GetDiscriminatorProperty();
if (discriminatorProperty != null)
{
var discriminatorColumn = ((EntityProjectionExpression)selectExpression.GetMappedProjection(new ProjectionMember()))
.BindProperty(concreteEntityType.GetDiscriminatorProperty(), clientEval: false);
.BindProperty(discriminatorProperty, clientEval: false);

selectExpression.ApplyPredicate(
Equal((SqlExpression)discriminatorColumn, Constant(concreteEntityType.GetDiscriminatorValue())));
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -965,16 +965,17 @@ protected override Expression VisitTypeBinary(TypeBinaryExpression typeBinaryExp
{
var discriminatorProperty = entityType.GetDiscriminatorProperty();
var boundProperty = BindProperty(entityReferenceExpression, discriminatorProperty, discriminatorProperty.ClrType);
var valueComparer = discriminatorProperty.GetKeyValueComparer();

var equals = Expression.Equal(
var equals = valueComparer.ExtractEqualsBody(
boundProperty,
Expression.Constant(derivedType.GetDiscriminatorValue(), discriminatorProperty.ClrType));

foreach (var derivedDerivedType in derivedType.GetDerivedTypes())
{
equals = Expression.OrElse(
equals,
Expression.Equal(
valueComparer.ExtractEqualsBody(
boundProperty,
Expression.Constant(derivedDerivedType.GetDiscriminatorValue(), discriminatorProperty.ClrType)));
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -105,7 +105,10 @@ public InMemoryQueryExpression([NotNull] IEntityType entityType)
foreach (var derivedEntityType in entityType.GetDerivedTypes())
{
var entityCheck = derivedEntityType.GetConcreteDerivedTypesInclusive()
.Select(e => Equal(readExpressionMap[discriminatorProperty], Constant(e.GetDiscriminatorValue())))
.Select(
e => discriminatorProperty.GetKeyValueComparer().ExtractEqualsBody(
readExpressionMap[discriminatorProperty],
Constant(e.GetDiscriminatorValue(), discriminatorProperty.ClrType)))
.Aggregate((l, r) => OrElse(l, r));

foreach (var property in derivedEntityType.GetDeclaredProperties())
Expand Down
3 changes: 2 additions & 1 deletion src/EFCore/ChangeTracking/ValueComparer.cs
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@

using System;
using System.Collections;
using System.Collections.Generic;
using System.Linq;
using System.Linq.Expressions;
using System.Reflection;
Expand All @@ -27,7 +28,7 @@ namespace Microsoft.EntityFrameworkCore.ChangeTracking
/// reference.
/// </para>
/// </summary>
public abstract class ValueComparer : IEqualityComparer
public abstract class ValueComparer : IEqualityComparer, IEqualityComparer<object>
{
private protected static readonly MethodInfo _doubleEqualsMethodInfo
= typeof(double).GetRuntimeMethod(nameof(double.Equals), new[] { typeof(double) });
Expand Down
6 changes: 4 additions & 2 deletions src/EFCore/Infrastructure/ModelValidator.cs
Original file line number Diff line number Diff line change
Expand Up @@ -638,19 +638,21 @@ protected virtual void ValidateInheritanceMapping(
/// <param name="rootEntityType"> The entity type to validate. </param>
protected virtual void ValidateDiscriminatorValues([NotNull] IEntityType rootEntityType)
{
var discriminatorValues = new Dictionary<object, IEntityType>();
var derivedTypes = rootEntityType.GetDerivedTypesInclusive().ToList();
if (derivedTypes.Count == 1)
{
return;
}

if (rootEntityType.GetDiscriminatorProperty() == null)
var discriminatorProperty = rootEntityType.GetDiscriminatorProperty();
if (discriminatorProperty == null)
{
throw new InvalidOperationException(
CoreStrings.NoDiscriminatorProperty(rootEntityType.DisplayName()));
}

var discriminatorValues = new Dictionary<object, IEntityType>(discriminatorProperty.GetKeyValueComparer());

foreach (var derivedType in derivedTypes)
{
if (derivedType.ClrType?.IsInstantiable() != true)
Expand Down
23 changes: 22 additions & 1 deletion src/EFCore/Query/EntityShaperExpression.cs
Original file line number Diff line number Diff line change
Expand Up @@ -7,9 +7,11 @@
using System.Linq.Expressions;
using System.Reflection;
using JetBrains.Annotations;
using Microsoft.EntityFrameworkCore.ChangeTracking;
using Microsoft.EntityFrameworkCore.Diagnostics;
using Microsoft.EntityFrameworkCore.Infrastructure;
using Microsoft.EntityFrameworkCore.Metadata;
using Microsoft.EntityFrameworkCore.Query.Internal;
using Microsoft.EntityFrameworkCore.Storage;
using Microsoft.EntityFrameworkCore.Utilities;

Expand Down Expand Up @@ -123,7 +125,26 @@ protected virtual LambdaExpression GenerateMaterializationCondition([NotNull] IE
Convert(discriminatorValueVariable, typeof(object)))),
Constant(null, typeof(IEntityType)));

expressions.Add(Switch(discriminatorValueVariable, exception, switchCases));

var discriminatorComparer = discriminatorProperty.GetKeyValueComparer();
if (discriminatorComparer.IsDefault())
{
expressions.Add(Switch(discriminatorValueVariable, exception, switchCases));
}
else
{
var staticComparer = typeof(StaticDiscriminatorComparer<,,>).MakeGenericType(
discriminatorProperty.DeclaringEntityType.ClrType,
discriminatorProperty.ClrType,
discriminatorProperty.GetTypeMapping().Converter.ProviderClrType);

var comparerField = staticComparer.GetField(nameof(StaticDiscriminatorComparer<int, int, int>.Comparer));
comparerField.SetValue(null, discriminatorComparer);

var equalsMethod = staticComparer.GetMethod(nameof(StaticDiscriminatorComparer<int, int, int>.DiscriminatorEquals));
expressions.Add(Switch(discriminatorValueVariable, exception, equalsMethod, switchCases));
}

body = Block(new[] { discriminatorValueVariable }, expressions);
}
else
Expand Down
35 changes: 35 additions & 0 deletions src/EFCore/Query/Internal/StaticDiscriminatorComparer.cs
Original file line number Diff line number Diff line change
@@ -0,0 +1,35 @@
// Copyright (c) .NET Foundation. All rights reserved.
// Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information.

using JetBrains.Annotations;
using Microsoft.EntityFrameworkCore.ChangeTracking;

namespace Microsoft.EntityFrameworkCore.Query.Internal
{
/// <summary>
/// This is an internal API that supports the Entity Framework Core infrastructure and not subject to
/// the same compatibility standards as public APIs. It may be changed or removed without notice in
/// any release. You should only use it directly in your code with extreme caution and knowing that
/// doing so can result in application failures when updating to a new Entity Framework Core release.
/// </summary>
// ReSharper disable twice UnusedTypeParameter
public static class StaticDiscriminatorComparer<TEntity, TModel, TProvider>
{
/// <summary>
/// This is an internal API that supports the Entity Framework Core infrastructure and not subject to
/// the same compatibility standards as public APIs. It may be changed or removed without notice in
/// any release. You should only use it directly in your code with extreme caution and knowing that
/// doing so can result in application failures when updating to a new Entity Framework Core release.
/// </summary>
public static ValueComparer<TModel> Comparer;

/// <summary>
/// This is an internal API that supports the Entity Framework Core infrastructure and not subject to
/// the same compatibility standards as public APIs. It may be changed or removed without notice in
/// any release. You should only use it directly in your code with extreme caution and knowing that
/// doing so can result in application failures when updating to a new Entity Framework Core release.
/// </summary>
public static bool DiscriminatorEquals([CanBeNull] TModel x, [CanBeNull] TModel y)
=> Comparer.Equals(x, y);
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -185,8 +185,25 @@ protected override void OnModelCreating(ModelBuilder modelBuilder, DbContext con
.HasForeignKey<OptionalSingle2>(e => e.BackId)
.OnDelete(DeleteBehavior.SetNull);

modelBuilder.Entity<OptionalSingle2Derived>();
modelBuilder.Entity<OptionalSingle2MoreDerived>();
modelBuilder.Entity<OptionalSingle2>(
b =>
{
b.HasDiscriminator(e => e.Disc)
.HasValue<OptionalSingle2>(new MyDiscriminator(1))
.HasValue<OptionalSingle2Derived>(new MyDiscriminator(2))
.HasValue<OptionalSingle2MoreDerived>(new MyDiscriminator(3));

b.Property(e => e.Disc)
.HasConversion(
v => v.Value,
v => new MyDiscriminator(v),
new ValueComparer<MyDiscriminator>(
(l, r) => l.Value == r.Value,
v => v.Value.GetHashCode(),
v => new MyDiscriminator(v.Value)))
.Metadata
.SetAfterSaveBehavior(PropertySaveBehavior.Save);
});

modelBuilder.Entity<RequiredNonPkSingle1>()
.HasOne(e => e.Single)
Expand Down Expand Up @@ -379,10 +396,6 @@ protected override void OnModelCreating(ModelBuilder modelBuilder, DbContext con
modelBuilder.Entity<Produce>()
.HasIndex(e => e.BarCode)
.IsUnique();

modelBuilder.Entity<OptionalSingle2Derived>()
.Property<string>("Discriminator")
.Metadata.SetAfterSaveBehavior(PropertySaveBehavior.Save);
}

protected virtual object CreateFullGraph()
Expand Down Expand Up @@ -1692,6 +1705,7 @@ protected class OptionalSingle2 : NotifyingEntity
{
private int _id;
private int? _backId;
private MyDiscriminator _disc;
private OptionalSingle1 _back;

public int Id
Expand All @@ -1706,6 +1720,12 @@ public int? BackId
set => SetWithNotify(value, ref _backId);
}

public MyDiscriminator Disc
{
get => _disc;
set => SetWithNotify(value, ref _disc);
}

public OptionalSingle1 Back
{
get => _back;
Expand All @@ -1722,6 +1742,20 @@ public override int GetHashCode()
=> _id;
}

protected class MyDiscriminator
{
public MyDiscriminator(int value)
=> Value = value;

public int Value { get; }

public override bool Equals(object obj)
=> throw new InvalidOperationException();

public override int GetHashCode()
=> throw new InvalidOperationException();
}

protected class OptionalSingle2Derived : OptionalSingle2
{
public override bool Equals(object obj)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -45,23 +45,23 @@ public virtual void Mutating_discriminator_value_can_be_configured_to_allow_muta
context =>
{
var instance = context.Set<OptionalSingle2Derived>().First();
var propertyEntry = context.Entry(instance).Property("Discriminator");
var propertyEntry = context.Entry(instance).Property(e => e.Disc);
id = instance.Id;

Assert.IsType<OptionalSingle2Derived>(instance);
Assert.Equal(nameof(OptionalSingle2Derived), propertyEntry.CurrentValue);
Assert.Equal(2, propertyEntry.CurrentValue.Value);

propertyEntry.CurrentValue = nameof(OptionalSingle2);
propertyEntry.CurrentValue = new MyDiscriminator(1);

context.SaveChanges();
},
context =>
{
var instance = context.Set<OptionalSingle2>().First(e => e.Id == id);
var propertyEntry = context.Entry(instance).Property("Discriminator");
var propertyEntry = context.Entry(instance).Property(e => e.Disc);

Assert.IsType<OptionalSingle2>(instance);
Assert.Equal(nameof(OptionalSingle2), propertyEntry.CurrentValue);
Assert.Equal(1, propertyEntry.CurrentValue.Value);
});
}

Expand Down