Skip to content

Commit

Permalink
Check constraints for discriminators
Browse files Browse the repository at this point in the history
Closes efcore#2
  • Loading branch information
roji committed Sep 13, 2020
1 parent 811cf9b commit d331ba6
Show file tree
Hide file tree
Showing 7 changed files with 180 additions and 15 deletions.
54 changes: 54 additions & 0 deletions EFCore.CheckConstraints.Test/DiscriminatorCheckConstraintTest.cs
Original file line number Diff line number Diff line change
@@ -0,0 +1,54 @@
using EFCore.CheckConstraints.Internal;
using Microsoft.EntityFrameworkCore;
using Microsoft.EntityFrameworkCore.Metadata.Conventions.Infrastructure;
using Microsoft.EntityFrameworkCore.SqlServer.Storage.Internal;
using Microsoft.EntityFrameworkCore.Storage;
using Microsoft.EntityFrameworkCore.TestUtilities;
using Microsoft.Extensions.DependencyInjection;
using Xunit;

namespace EFCore.CheckConstraints.Test
{
public class DiscriminatorCheckConstraintTest
{
[Fact]
public void Generate_check_constraint_with_all_enum_names()
{
var builder = CreateBuilder();
builder.Entity<Parent>();
builder.Entity<Child>();

var model = builder.FinalizeModel();

var checkConstraint = Assert.Single(model.FindEntityType(typeof(Parent)).GetCheckConstraints());
Assert.NotNull(checkConstraint);
Assert.Equal("CK_Parent_Discriminator_Constraint", checkConstraint.Name);
Assert.Equal("[Discriminator] IN (N'Child', N'Parent')", checkConstraint.Sql);
}

class Parent
{
public int Id { get; set; }
public string Discriminator { get; set; }
}

class Child : Parent
{
public string ChildProperty { get; set; }
}

private ModelBuilder CreateBuilder()
{
var conventionSet = SqlServerTestHelpers.Instance.CreateContextServices()
.GetRequiredService<IConventionSetBuilder>()
.CreateConventionSet();

conventionSet.ModelFinalizingConventions.Add(
new DiscriminatorCheckConstraintConvention(
new SqlServerSqlGenerationHelper(
new RelationalSqlGenerationHelperDependencies())));

return new ModelBuilder(conventionSet);
}
}
}
7 changes: 6 additions & 1 deletion EFCore.CheckConstraints.Test/EnumCheckConstraintTest.cs
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,8 @@
using EFCore.CheckConstraints.Internal;
using Microsoft.EntityFrameworkCore;
using Microsoft.EntityFrameworkCore.Metadata.Conventions.Infrastructure;
using Microsoft.EntityFrameworkCore.SqlServer.Storage.Internal;
using Microsoft.EntityFrameworkCore.Storage;
using Microsoft.EntityFrameworkCore.TestUtilities;
using Microsoft.Extensions.DependencyInjection;
using Xunit;
Expand Down Expand Up @@ -144,7 +146,10 @@ private ModelBuilder CreateBuilder()
.GetRequiredService<IConventionSetBuilder>()
.CreateConventionSet();

conventionSet.ModelFinalizingConventions.Add(new EnumCheckConstraintConvention());
conventionSet.ModelFinalizingConventions.Add(
new EnumCheckConstraintConvention(
new SqlServerSqlGenerationHelper(
new RelationalSqlGenerationHelperDependencies())));

return new ModelBuilder(conventionSet);
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2,21 +2,35 @@
using Microsoft.EntityFrameworkCore.Metadata.Conventions;
using Microsoft.EntityFrameworkCore.Metadata.Conventions.Infrastructure;
using JetBrains.Annotations;
using Microsoft.EntityFrameworkCore.Storage;

namespace EFCore.CheckConstraints.Internal
{
public class CheckConstraintsConventionSetPlugin : IConventionSetPlugin
{
readonly IDbContextOptions _options;
public CheckConstraintsConventionSetPlugin([NotNull] IDbContextOptions options) => _options = options;
readonly ISqlGenerationHelper _sqlGenerationHelper;

public CheckConstraintsConventionSetPlugin(
[NotNull] IDbContextOptions options,
ISqlGenerationHelper sqlGenerationHelper)
{
_options = options;
_sqlGenerationHelper = sqlGenerationHelper;
}

public ConventionSet ModifyConventions(ConventionSet conventionSet)
{
var extension = _options.FindExtension<CheckConstraintsOptionsExtension>();

if (extension.AreEnumCheckConstraintsEnabled)
{
conventionSet.ModelFinalizingConventions.Add(new EnumCheckConstraintConvention());
conventionSet.ModelFinalizingConventions.Add(new EnumCheckConstraintConvention(_sqlGenerationHelper));
}

if (extension.AreDiscriminatorCheckConstraintsEnabled)
{
conventionSet.ModelFinalizingConventions.Add(new DiscriminatorCheckConstraintConvention(_sqlGenerationHelper));
}

return conventionSet;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -11,17 +11,22 @@ public class CheckConstraintsOptionsExtension : IDbContextOptionsExtension
{
private DbContextOptionsExtensionInfo _info;
private bool _enumCheckConstraintsEnabled;
private bool _discriminatorCheckConstraintsEnabled;

public CheckConstraintsOptionsExtension() {}

protected CheckConstraintsOptionsExtension([NotNull] CheckConstraintsOptionsExtension copyFrom)
=> _enumCheckConstraintsEnabled = copyFrom._enumCheckConstraintsEnabled;
{
_enumCheckConstraintsEnabled = copyFrom._enumCheckConstraintsEnabled;
_discriminatorCheckConstraintsEnabled = copyFrom._discriminatorCheckConstraintsEnabled;
}

public virtual DbContextOptionsExtensionInfo Info => _info ??= new ExtensionInfo(this);

protected virtual CheckConstraintsOptionsExtension Clone() => new CheckConstraintsOptionsExtension(this);

public virtual bool AreEnumCheckConstraintsEnabled => _enumCheckConstraintsEnabled;
public virtual bool AreDiscriminatorCheckConstraintsEnabled => _discriminatorCheckConstraintsEnabled;

public virtual CheckConstraintsOptionsExtension WithEnumCheckConstraintsEnabled(bool enumCheckConstraintsEnabled)
{
Expand All @@ -30,6 +35,14 @@ public virtual CheckConstraintsOptionsExtension WithEnumCheckConstraintsEnabled(
return clone;
}

public virtual CheckConstraintsOptionsExtension WithDiscriminatorCheckConstraintsEnabled(
bool discriminatorCheckConstraintsEnabled)
{
var clone = Clone();
clone._discriminatorCheckConstraintsEnabled = discriminatorCheckConstraintsEnabled;
return clone;
}

public void Validate(IDbContextOptions options) {}

public void ApplyServices(IServiceCollection services)
Expand All @@ -52,16 +65,24 @@ public override string LogFragment
{
if (_logFragment == null)
{
var builder = new StringBuilder("using check constraints");
var builder = new StringBuilder("using check constraints (");
var isFirst = true;

if (Extension._enumCheckConstraintsEnabled)
if (Extension.AreEnumCheckConstraintsEnabled)
{
builder
.Append(" (")
.Append("enums")
.Append(")");
builder.Append("enums");
isFirst = false;
}

if (Extension.AreDiscriminatorCheckConstraintsEnabled)
{
if (!isFirst)
builder.Append(", ");
builder.Append("discriminators");
}

builder.Append(')');

_logFragment = builder.ToString();
}

Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,52 @@
using System.Linq;
using System.Text;
using Microsoft.EntityFrameworkCore.Metadata.Builders;
using Microsoft.EntityFrameworkCore;
using Microsoft.EntityFrameworkCore.Metadata.Conventions;
using Microsoft.EntityFrameworkCore.Storage;

namespace EFCore.CheckConstraints.Internal
{
/// <summary>
/// A convention that creates check constraints ensuring that (complete) discriminator columns only have
/// expected values.
/// </summary>
public class DiscriminatorCheckConstraintConvention : IModelFinalizingConvention
{
readonly ISqlGenerationHelper _sqlGenerationHelper;

public DiscriminatorCheckConstraintConvention(ISqlGenerationHelper sqlGenerationHelper)
=> _sqlGenerationHelper = sqlGenerationHelper;

/// <inheritdoc />
public virtual void ProcessModelFinalizing(IConventionModelBuilder modelBuilder, IConventionContext<IConventionModelBuilder> context)
{
var sql = new StringBuilder();

foreach (var (rootEntityType, discriminatorValues) in modelBuilder.Metadata
.GetEntityTypes()
.GroupBy(e => e.GetRootType())
.Where(g => g.Key.GetDiscriminatorProperty() != null && g.Key.GetIsDiscriminatorMappingComplete())
.Select(g => (g.Key, g.Select(e => e.GetDiscriminatorValue()))))
{
var discriminatorProperty = rootEntityType.GetDiscriminatorProperty();
sql.Clear();

sql.Append(_sqlGenerationHelper.DelimitIdentifier(discriminatorProperty.GetColumnName()));
sql.Append(" IN (");
foreach (var discriminatorValue in discriminatorValues)
{
var value = ((RelationalTypeMapping)discriminatorProperty.FindTypeMapping())
.GenerateSqlLiteral(discriminatorValue);
sql.Append($"{value}, ");
}

sql.Remove(sql.Length - 2, 2);
sql.Append(")");

var constraintName = $"CK_{rootEntityType.GetTableName()}_Discriminator_Constraint";
rootEntityType.AddCheckConstraint(constraintName, sql.ToString());
}
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -9,10 +9,15 @@
namespace EFCore.CheckConstraints.Internal
{
/// <summary>
/// A convention that creates check constraint for Enum column in a model.
/// A convention that creates check constraints for enum columns.
/// </summary>
public class EnumCheckConstraintConvention : IModelFinalizingConvention
{
readonly ISqlGenerationHelper _sqlGenerationHelper;

public EnumCheckConstraintConvention(ISqlGenerationHelper sqlGenerationHelper)
=> _sqlGenerationHelper = sqlGenerationHelper;

/// <inheritdoc />
public virtual void ProcessModelFinalizing(IConventionModelBuilder modelBuilder, IConventionContext<IConventionModelBuilder> context)
{
Expand All @@ -36,9 +41,8 @@ public virtual void ProcessModelFinalizing(IConventionModelBuilder modelBuilder,

sql.Clear();

sql.Append("[");
sql.Append(property.GetColumnName());
sql.Append("] IN ("); ;
sql.Append(_sqlGenerationHelper.DelimitIdentifier(property.GetColumnName()));
sql.Append(" IN (");
foreach (var item in enumValues)
{
var value = ((RelationalTypeMapping)typeMapping).GenerateSqlLiteral(item);
Expand Down
17 changes: 16 additions & 1 deletion EFCore.CheckConstraints/CheckConstraintsExtensions.cs
Original file line number Diff line number Diff line change
Expand Up @@ -21,9 +21,24 @@ public static DbContextOptionsBuilder UseEnumCheckConstraints(
return optionsBuilder;
}

public static DbContextOptionsBuilder UseDiscriminatorCheckConstraints(
[NotNull] this DbContextOptionsBuilder optionsBuilder , CultureInfo culture = null)
{
Check.NotNull(optionsBuilder, nameof(optionsBuilder));

var extension = (optionsBuilder.Options.FindExtension<CheckConstraintsOptionsExtension>() ?? new CheckConstraintsOptionsExtension())
.WithDiscriminatorCheckConstraintsEnabled(true);

((IDbContextOptionsBuilderInfrastructure)optionsBuilder).AddOrUpdateExtension(extension);

return optionsBuilder;
}

public static DbContextOptionsBuilder UseAllCheckConstraints(
[NotNull] this DbContextOptionsBuilder optionsBuilder, CultureInfo culture = null)
=> optionsBuilder.UseEnumCheckConstraints();
=> optionsBuilder
.UseEnumCheckConstraints()
.UseDiscriminatorCheckConstraints();

public static DbContextOptionsBuilder<TContext> UseEnumCheckConstraints<TContext>([NotNull] this DbContextOptionsBuilder<TContext> optionsBuilder , CultureInfo culture = null)
where TContext : DbContext
Expand Down

0 comments on commit d331ba6

Please sign in to comment.