Skip to content

Commit

Permalink
Allow discriminator properties to value converter and comparer (#22425)
Browse files Browse the repository at this point in the history
Fixes #19650
  • Loading branch information
ajcvickers committed Sep 8, 2020
1 parent d55c772 commit 3f54048
Show file tree
Hide file tree
Showing 9 changed files with 118 additions and 20 deletions.
5 changes: 3 additions & 2 deletions src/EFCore.Cosmos/Query/Internal/SqlExpressionFactory.cs
Original file line number Diff line number Diff line change
Expand Up @@ -533,10 +533,11 @@ private void AddDiscriminator(SelectExpression selectExpression, IEntityType ent
if (concreteEntityTypes.Count == 1)
{
var concreteEntityType = concreteEntityTypes[0];
if (concreteEntityType.GetDiscriminatorProperty() != null)
var discriminatorProperty = concreteEntityType.GetDiscriminatorProperty();
if (discriminatorProperty != null)
{
var discriminatorColumn = ((EntityProjectionExpression)selectExpression.GetMappedProjection(new ProjectionMember()))
.BindProperty(concreteEntityType.GetDiscriminatorProperty(), clientEval: false);
.BindProperty(discriminatorProperty, clientEval: false);

selectExpression.ApplyPredicate(
Equal((SqlExpression)discriminatorColumn, Constant(concreteEntityType.GetDiscriminatorValue())));
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -965,16 +965,17 @@ protected override Expression VisitTypeBinary(TypeBinaryExpression typeBinaryExp
{
var discriminatorProperty = entityType.GetDiscriminatorProperty();
var boundProperty = BindProperty(entityReferenceExpression, discriminatorProperty, discriminatorProperty.ClrType);
var valueComparer = discriminatorProperty.GetKeyValueComparer();

var equals = Expression.Equal(
var equals = valueComparer.ExtractEqualsBody(
boundProperty,
Expression.Constant(derivedType.GetDiscriminatorValue(), discriminatorProperty.ClrType));

foreach (var derivedDerivedType in derivedType.GetDerivedTypes())
{
equals = Expression.OrElse(
equals,
Expression.Equal(
valueComparer.ExtractEqualsBody(
boundProperty,
Expression.Constant(derivedDerivedType.GetDiscriminatorValue(), discriminatorProperty.ClrType)));
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -105,7 +105,10 @@ public InMemoryQueryExpression([NotNull] IEntityType entityType)
foreach (var derivedEntityType in entityType.GetDerivedTypes())
{
var entityCheck = derivedEntityType.GetConcreteDerivedTypesInclusive()
.Select(e => Equal(readExpressionMap[discriminatorProperty], Constant(e.GetDiscriminatorValue())))
.Select(
e => discriminatorProperty.GetKeyValueComparer().ExtractEqualsBody(
readExpressionMap[discriminatorProperty],
Constant(e.GetDiscriminatorValue(), discriminatorProperty.ClrType)))
.Aggregate((l, r) => OrElse(l, r));

foreach (var property in derivedEntityType.GetDeclaredProperties())
Expand Down
3 changes: 2 additions & 1 deletion src/EFCore/ChangeTracking/ValueComparer.cs
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@

using System;
using System.Collections;
using System.Collections.Generic;
using System.Linq;
using System.Linq.Expressions;
using System.Reflection;
Expand All @@ -27,7 +28,7 @@ namespace Microsoft.EntityFrameworkCore.ChangeTracking
/// reference.
/// </para>
/// </summary>
public abstract class ValueComparer : IEqualityComparer
public abstract class ValueComparer : IEqualityComparer, IEqualityComparer<object>
{
private protected static readonly MethodInfo _doubleEqualsMethodInfo
= typeof(double).GetRuntimeMethod(nameof(double.Equals), new[] { typeof(double) });
Expand Down
6 changes: 4 additions & 2 deletions src/EFCore/Infrastructure/ModelValidator.cs
Original file line number Diff line number Diff line change
Expand Up @@ -638,19 +638,21 @@ protected virtual void ValidateInheritanceMapping(
/// <param name="rootEntityType"> The entity type to validate. </param>
protected virtual void ValidateDiscriminatorValues([NotNull] IEntityType rootEntityType)
{
var discriminatorValues = new Dictionary<object, IEntityType>();
var derivedTypes = rootEntityType.GetDerivedTypesInclusive().ToList();
if (derivedTypes.Count == 1)
{
return;
}

if (rootEntityType.GetDiscriminatorProperty() == null)
var discriminatorProperty = rootEntityType.GetDiscriminatorProperty();
if (discriminatorProperty == null)
{
throw new InvalidOperationException(
CoreStrings.NoDiscriminatorProperty(rootEntityType.DisplayName()));
}

var discriminatorValues = new Dictionary<object, IEntityType>(discriminatorProperty.GetKeyValueComparer());

foreach (var derivedType in derivedTypes)
{
if (derivedType.ClrType?.IsInstantiable() != true)
Expand Down
23 changes: 22 additions & 1 deletion src/EFCore/Query/EntityShaperExpression.cs
Original file line number Diff line number Diff line change
Expand Up @@ -7,9 +7,11 @@
using System.Linq.Expressions;
using System.Reflection;
using JetBrains.Annotations;
using Microsoft.EntityFrameworkCore.ChangeTracking;
using Microsoft.EntityFrameworkCore.Diagnostics;
using Microsoft.EntityFrameworkCore.Infrastructure;
using Microsoft.EntityFrameworkCore.Metadata;
using Microsoft.EntityFrameworkCore.Query.Internal;
using Microsoft.EntityFrameworkCore.Storage;
using Microsoft.EntityFrameworkCore.Utilities;

Expand Down Expand Up @@ -123,7 +125,26 @@ protected virtual LambdaExpression GenerateMaterializationCondition([NotNull] IE
Convert(discriminatorValueVariable, typeof(object)))),
Constant(null, typeof(IEntityType)));

expressions.Add(Switch(discriminatorValueVariable, exception, switchCases));

var discriminatorComparer = discriminatorProperty.GetKeyValueComparer();
if (discriminatorComparer.IsDefault())
{
expressions.Add(Switch(discriminatorValueVariable, exception, switchCases));
}
else
{
var staticComparer = typeof(StaticDiscriminatorComparer<,,>).MakeGenericType(
discriminatorProperty.DeclaringEntityType.ClrType,
discriminatorProperty.ClrType,
discriminatorProperty.GetTypeMapping().Converter.ProviderClrType);

var comparerField = staticComparer.GetField(nameof(StaticDiscriminatorComparer<int, int, int>.Comparer));
comparerField.SetValue(null, discriminatorComparer);

var equalsMethod = staticComparer.GetMethod(nameof(StaticDiscriminatorComparer<int, int, int>.DiscriminatorEquals));
expressions.Add(Switch(discriminatorValueVariable, exception, equalsMethod, switchCases));
}

body = Block(new[] { discriminatorValueVariable }, expressions);
}
else
Expand Down
35 changes: 35 additions & 0 deletions src/EFCore/Query/Internal/StaticDiscriminatorComparer.cs
Original file line number Diff line number Diff line change
@@ -0,0 +1,35 @@
// Copyright (c) .NET Foundation. All rights reserved.
// Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information.

using JetBrains.Annotations;
using Microsoft.EntityFrameworkCore.ChangeTracking;

namespace Microsoft.EntityFrameworkCore.Query.Internal
{
/// <summary>
/// This is an internal API that supports the Entity Framework Core infrastructure and not subject to
/// the same compatibility standards as public APIs. It may be changed or removed without notice in
/// any release. You should only use it directly in your code with extreme caution and knowing that
/// doing so can result in application failures when updating to a new Entity Framework Core release.
/// </summary>
// ReSharper disable twice UnusedTypeParameter
public static class StaticDiscriminatorComparer<TEntity, TModel, TProvider>
{
/// <summary>
/// This is an internal API that supports the Entity Framework Core infrastructure and not subject to
/// the same compatibility standards as public APIs. It may be changed or removed without notice in
/// any release. You should only use it directly in your code with extreme caution and knowing that
/// doing so can result in application failures when updating to a new Entity Framework Core release.
/// </summary>
public static ValueComparer<TModel> Comparer;

/// <summary>
/// This is an internal API that supports the Entity Framework Core infrastructure and not subject to
/// the same compatibility standards as public APIs. It may be changed or removed without notice in
/// any release. You should only use it directly in your code with extreme caution and knowing that
/// doing so can result in application failures when updating to a new Entity Framework Core release.
/// </summary>
public static bool DiscriminatorEquals([CanBeNull] TModel x, [CanBeNull] TModel y)
=> Comparer.Equals(x, y);
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -185,8 +185,25 @@ protected override void OnModelCreating(ModelBuilder modelBuilder, DbContext con
.HasForeignKey<OptionalSingle2>(e => e.BackId)
.OnDelete(DeleteBehavior.SetNull);

modelBuilder.Entity<OptionalSingle2Derived>();
modelBuilder.Entity<OptionalSingle2MoreDerived>();
modelBuilder.Entity<OptionalSingle2>(
b =>
{
b.HasDiscriminator(e => e.Disc)
.HasValue<OptionalSingle2>(new MyDiscriminator(1))
.HasValue<OptionalSingle2Derived>(new MyDiscriminator(2))
.HasValue<OptionalSingle2MoreDerived>(new MyDiscriminator(3));
b.Property(e => e.Disc)
.HasConversion(
v => v.Value,
v => new MyDiscriminator(v),
new ValueComparer<MyDiscriminator>(
(l, r) => l.Value == r.Value,
v => v.Value.GetHashCode(),
v => new MyDiscriminator(v.Value)))
.Metadata
.SetAfterSaveBehavior(PropertySaveBehavior.Save);
});

modelBuilder.Entity<RequiredNonPkSingle1>()
.HasOne(e => e.Single)
Expand Down Expand Up @@ -379,10 +396,6 @@ protected override void OnModelCreating(ModelBuilder modelBuilder, DbContext con
modelBuilder.Entity<Produce>()
.HasIndex(e => e.BarCode)
.IsUnique();

modelBuilder.Entity<OptionalSingle2Derived>()
.Property<string>("Discriminator")
.Metadata.SetAfterSaveBehavior(PropertySaveBehavior.Save);
}

protected virtual object CreateFullGraph()
Expand Down Expand Up @@ -1692,6 +1705,7 @@ protected class OptionalSingle2 : NotifyingEntity
{
private int _id;
private int? _backId;
private MyDiscriminator _disc;
private OptionalSingle1 _back;

public int Id
Expand All @@ -1706,6 +1720,12 @@ public int? BackId
set => SetWithNotify(value, ref _backId);
}

public MyDiscriminator Disc
{
get => _disc;
set => SetWithNotify(value, ref _disc);
}

public OptionalSingle1 Back
{
get => _back;
Expand All @@ -1722,6 +1742,20 @@ public override int GetHashCode()
=> _id;
}

protected class MyDiscriminator
{
public MyDiscriminator(int value)
=> Value = value;

public int Value { get; }

public override bool Equals(object obj)
=> throw new InvalidOperationException();

public override int GetHashCode()
=> throw new InvalidOperationException();
}

protected class OptionalSingle2Derived : OptionalSingle2
{
public override bool Equals(object obj)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -45,23 +45,23 @@ public virtual void Mutating_discriminator_value_can_be_configured_to_allow_muta
context =>
{
var instance = context.Set<OptionalSingle2Derived>().First();
var propertyEntry = context.Entry(instance).Property("Discriminator");
var propertyEntry = context.Entry(instance).Property(e => e.Disc);
id = instance.Id;
Assert.IsType<OptionalSingle2Derived>(instance);
Assert.Equal(nameof(OptionalSingle2Derived), propertyEntry.CurrentValue);
Assert.Equal(2, propertyEntry.CurrentValue.Value);
propertyEntry.CurrentValue = nameof(OptionalSingle2);
propertyEntry.CurrentValue = new MyDiscriminator(1);
context.SaveChanges();
},
context =>
{
var instance = context.Set<OptionalSingle2>().First(e => e.Id == id);
var propertyEntry = context.Entry(instance).Property("Discriminator");
var propertyEntry = context.Entry(instance).Property(e => e.Disc);
Assert.IsType<OptionalSingle2>(instance);
Assert.Equal(nameof(OptionalSingle2), propertyEntry.CurrentValue);
Assert.Equal(1, propertyEntry.CurrentValue.Value);
});
}

Expand Down

0 comments on commit 3f54048

Please sign in to comment.