Skip to content

Commit

Permalink
Allow discriminator properties to value converter and comparer
Browse files Browse the repository at this point in the history
Fixes #19650
  • Loading branch information
ajcvickers committed Sep 7, 2020
1 parent 26c8e1c commit 4195b9d
Show file tree
Hide file tree
Showing 9 changed files with 118 additions and 20 deletions.
5 changes: 3 additions & 2 deletions src/EFCore.Cosmos/Query/Internal/SqlExpressionFactory.cs
Original file line number Diff line number Diff line change
Expand Up @@ -533,10 +533,11 @@ private void AddDiscriminator(SelectExpression selectExpression, IEntityType ent
if (concreteEntityTypes.Count == 1)
{
var concreteEntityType = concreteEntityTypes[0];
if (concreteEntityType.GetDiscriminatorProperty() != null)
var discriminatorProperty = concreteEntityType.GetDiscriminatorProperty();
if (discriminatorProperty != null)
{
var discriminatorColumn = ((EntityProjectionExpression)selectExpression.GetMappedProjection(new ProjectionMember()))
.BindProperty(concreteEntityType.GetDiscriminatorProperty(), clientEval: false);
.BindProperty(discriminatorProperty, clientEval: false);

selectExpression.ApplyPredicate(
Equal((SqlExpression)discriminatorColumn, Constant(concreteEntityType.GetDiscriminatorValue())));
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -965,16 +965,17 @@ protected override Expression VisitTypeBinary(TypeBinaryExpression typeBinaryExp
{
var discriminatorProperty = entityType.GetDiscriminatorProperty();
var boundProperty = BindProperty(entityReferenceExpression, discriminatorProperty, discriminatorProperty.ClrType);
var valueComparer = discriminatorProperty.GetKeyValueComparer();

var equals = Expression.Equal(
var equals = valueComparer.ExtractEqualsBody(
boundProperty,
Expression.Constant(derivedType.GetDiscriminatorValue(), discriminatorProperty.ClrType));

foreach (var derivedDerivedType in derivedType.GetDerivedTypes())
{
equals = Expression.OrElse(
equals,
Expression.Equal(
valueComparer.ExtractEqualsBody(
boundProperty,
Expression.Constant(derivedDerivedType.GetDiscriminatorValue(), discriminatorProperty.ClrType)));
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -105,7 +105,10 @@ public InMemoryQueryExpression([NotNull] IEntityType entityType)
foreach (var derivedEntityType in entityType.GetDerivedTypes())
{
var entityCheck = derivedEntityType.GetConcreteDerivedTypesInclusive()
.Select(e => Equal(readExpressionMap[discriminatorProperty], Constant(e.GetDiscriminatorValue())))
.Select(
e => discriminatorProperty.GetKeyValueComparer().ExtractEqualsBody(
readExpressionMap[discriminatorProperty],
Constant(e.GetDiscriminatorValue(), discriminatorProperty.ClrType)))
.Aggregate((l, r) => OrElse(l, r));

foreach (var property in derivedEntityType.GetDeclaredProperties())
Expand Down
3 changes: 2 additions & 1 deletion src/EFCore/ChangeTracking/ValueComparer.cs
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@

using System;
using System.Collections;
using System.Collections.Generic;
using System.Linq;
using System.Linq.Expressions;
using System.Reflection;
Expand All @@ -27,7 +28,7 @@ namespace Microsoft.EntityFrameworkCore.ChangeTracking
/// reference.
/// </para>
/// </summary>
public abstract class ValueComparer : IEqualityComparer
public abstract class ValueComparer : IEqualityComparer, IEqualityComparer<object>
{
private protected static readonly MethodInfo _doubleEqualsMethodInfo
= typeof(double).GetRuntimeMethod(nameof(double.Equals), new[] { typeof(double) });
Expand Down
6 changes: 4 additions & 2 deletions src/EFCore/Infrastructure/ModelValidator.cs
Original file line number Diff line number Diff line change
Expand Up @@ -638,19 +638,21 @@ protected virtual void ValidateInheritanceMapping(
/// <param name="rootEntityType"> The entity type to validate. </param>
protected virtual void ValidateDiscriminatorValues([NotNull] IEntityType rootEntityType)
{
var discriminatorValues = new Dictionary<object, IEntityType>();
var derivedTypes = rootEntityType.GetDerivedTypesInclusive().ToList();
if (derivedTypes.Count == 1)
{
return;
}

if (rootEntityType.GetDiscriminatorProperty() == null)
var discriminatorProperty = rootEntityType.GetDiscriminatorProperty();
if (discriminatorProperty == null)
{
throw new InvalidOperationException(
CoreStrings.NoDiscriminatorProperty(rootEntityType.DisplayName()));
}

var discriminatorValues = new Dictionary<object, IEntityType>(discriminatorProperty.GetKeyValueComparer());

foreach (var derivedType in derivedTypes)
{
if (derivedType.ClrType?.IsInstantiable() != true)
Expand Down
23 changes: 22 additions & 1 deletion src/EFCore/Query/EntityShaperExpression.cs
Original file line number Diff line number Diff line change
Expand Up @@ -7,9 +7,11 @@
using System.Linq.Expressions;
using System.Reflection;
using JetBrains.Annotations;
using Microsoft.EntityFrameworkCore.ChangeTracking;
using Microsoft.EntityFrameworkCore.Diagnostics;
using Microsoft.EntityFrameworkCore.Infrastructure;
using Microsoft.EntityFrameworkCore.Metadata;
using Microsoft.EntityFrameworkCore.Query.Internal;
using Microsoft.EntityFrameworkCore.Storage;
using Microsoft.EntityFrameworkCore.Utilities;

Expand Down Expand Up @@ -123,7 +125,26 @@ protected virtual LambdaExpression GenerateMaterializationCondition([NotNull] IE
Convert(discriminatorValueVariable, typeof(object)))),
Constant(null, typeof(IEntityType)));

expressions.Add(Switch(discriminatorValueVariable, exception, switchCases));

var discriminatorComparer = discriminatorProperty.GetKeyValueComparer();
if (discriminatorComparer.IsDefault())
{
expressions.Add(Switch(discriminatorValueVariable, exception, switchCases));
}
else
{
var staticComparer = typeof(StaticDiscriminatorComparer<,,>).MakeGenericType(
discriminatorProperty.DeclaringEntityType.ClrType,
discriminatorProperty.ClrType,
discriminatorProperty.GetTypeMapping().Converter.ProviderClrType);

var comparerField = staticComparer.GetField(nameof(StaticDiscriminatorComparer<int, int, int>.Comparer));
comparerField.SetValue(null, discriminatorComparer);

This comment has been minimized.

Copy link
@AndriySvyryd

AndriySvyryd Sep 8, 2020

Member

This won't work. You are setting it during compilation, so the next query using a different comparer will override it.

You need to create an expression that does this before the switch.

This comment has been minimized.

Copy link
@ajcvickers

ajcvickers Sep 8, 2020

Author Member

It's a generic type, so the only time this would clash if if you have the same entity type with two different comparers for the same model and provider type. We could try to error out in the case, although it's difficult to detect. Otherwise we will need to not generate a switch statement for the case where a customer comparer is being used because switch requires a static method for this.

This comment has been minimized.

Copy link
@AndriySvyryd

AndriySvyryd Sep 8, 2020

Member

I would much rather use if statements and not have mutable global state

This comment has been minimized.

Copy link
@ajcvickers

ajcvickers Sep 8, 2020

Author Member

Andriy keeping me honest since...forever! ;-)


var equalsMethod = staticComparer.GetMethod(nameof(StaticDiscriminatorComparer<int, int, int>.DiscriminatorEquals));
expressions.Add(Switch(discriminatorValueVariable, exception, equalsMethod, switchCases));
}

body = Block(new[] { discriminatorValueVariable }, expressions);
}
else
Expand Down
35 changes: 35 additions & 0 deletions src/EFCore/Query/Internal/StaticDiscriminatorComparer.cs
Original file line number Diff line number Diff line change
@@ -0,0 +1,35 @@
// Copyright (c) .NET Foundation. All rights reserved.
// Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information.

using JetBrains.Annotations;
using Microsoft.EntityFrameworkCore.ChangeTracking;

namespace Microsoft.EntityFrameworkCore.Query.Internal
{
/// <summary>
/// This is an internal API that supports the Entity Framework Core infrastructure and not subject to
/// the same compatibility standards as public APIs. It may be changed or removed without notice in
/// any release. You should only use it directly in your code with extreme caution and knowing that
/// doing so can result in application failures when updating to a new Entity Framework Core release.
/// </summary>
// ReSharper disable twice UnusedTypeParameter
public static class StaticDiscriminatorComparer<TEntity, TModel, TProvider>
{
/// <summary>
/// This is an internal API that supports the Entity Framework Core infrastructure and not subject to
/// the same compatibility standards as public APIs. It may be changed or removed without notice in
/// any release. You should only use it directly in your code with extreme caution and knowing that
/// doing so can result in application failures when updating to a new Entity Framework Core release.
/// </summary>
public static ValueComparer<TModel> Comparer;

/// <summary>
/// This is an internal API that supports the Entity Framework Core infrastructure and not subject to
/// the same compatibility standards as public APIs. It may be changed or removed without notice in
/// any release. You should only use it directly in your code with extreme caution and knowing that
/// doing so can result in application failures when updating to a new Entity Framework Core release.
/// </summary>
public static bool DiscriminatorEquals([CanBeNull] TModel x, [CanBeNull] TModel y)
=> Comparer.Equals(x, y);
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -185,8 +185,25 @@ protected override void OnModelCreating(ModelBuilder modelBuilder, DbContext con
.HasForeignKey<OptionalSingle2>(e => e.BackId)
.OnDelete(DeleteBehavior.SetNull);

modelBuilder.Entity<OptionalSingle2Derived>();
modelBuilder.Entity<OptionalSingle2MoreDerived>();
modelBuilder.Entity<OptionalSingle2>(
b =>
{
b.HasDiscriminator(e => e.Disc)
.HasValue<OptionalSingle2>(new MyDiscriminator(1))
.HasValue<OptionalSingle2Derived>(new MyDiscriminator(2))
.HasValue<OptionalSingle2MoreDerived>(new MyDiscriminator(3));
b.Property(e => e.Disc)
.HasConversion(
v => v.Value,
v => new MyDiscriminator(v),
new ValueComparer<MyDiscriminator>(
(l, r) => l.Value == r.Value,
v => v.Value.GetHashCode(),
v => new MyDiscriminator(v.Value)))
.Metadata
.SetAfterSaveBehavior(PropertySaveBehavior.Save);
});

modelBuilder.Entity<RequiredNonPkSingle1>()
.HasOne(e => e.Single)
Expand Down Expand Up @@ -379,10 +396,6 @@ protected override void OnModelCreating(ModelBuilder modelBuilder, DbContext con
modelBuilder.Entity<Produce>()
.HasIndex(e => e.BarCode)
.IsUnique();

modelBuilder.Entity<OptionalSingle2Derived>()
.Property<string>("Discriminator")
.Metadata.SetAfterSaveBehavior(PropertySaveBehavior.Save);
}

protected virtual object CreateFullGraph()
Expand Down Expand Up @@ -1692,6 +1705,7 @@ protected class OptionalSingle2 : NotifyingEntity
{
private int _id;
private int? _backId;
private MyDiscriminator _disc;
private OptionalSingle1 _back;

public int Id
Expand All @@ -1706,6 +1720,12 @@ public int? BackId
set => SetWithNotify(value, ref _backId);
}

public MyDiscriminator Disc
{
get => _disc;
set => SetWithNotify(value, ref _disc);
}

public OptionalSingle1 Back
{
get => _back;
Expand All @@ -1722,6 +1742,20 @@ public override int GetHashCode()
=> _id;
}

protected class MyDiscriminator
{
public MyDiscriminator(int value)
=> Value = value;

public int Value { get; }

public override bool Equals(object obj)
=> throw new InvalidOperationException();

public override int GetHashCode()
=> throw new InvalidOperationException();
}

protected class OptionalSingle2Derived : OptionalSingle2
{
public override bool Equals(object obj)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -45,23 +45,23 @@ public virtual void Mutating_discriminator_value_can_be_configured_to_allow_muta
context =>
{
var instance = context.Set<OptionalSingle2Derived>().First();
var propertyEntry = context.Entry(instance).Property("Discriminator");
var propertyEntry = context.Entry(instance).Property(e => e.Disc);
id = instance.Id;
Assert.IsType<OptionalSingle2Derived>(instance);
Assert.Equal(nameof(OptionalSingle2Derived), propertyEntry.CurrentValue);
Assert.Equal(2, propertyEntry.CurrentValue.Value);
propertyEntry.CurrentValue = nameof(OptionalSingle2);
propertyEntry.CurrentValue = new MyDiscriminator(1);
context.SaveChanges();
},
context =>
{
var instance = context.Set<OptionalSingle2>().First(e => e.Id == id);
var propertyEntry = context.Entry(instance).Property("Discriminator");
var propertyEntry = context.Entry(instance).Property(e => e.Disc);
Assert.IsType<OptionalSingle2>(instance);
Assert.Equal(nameof(OptionalSingle2), propertyEntry.CurrentValue);
Assert.Equal(1, propertyEntry.CurrentValue.Value);
});
}

Expand Down

0 comments on commit 4195b9d

Please sign in to comment.