Skip to content

Commit

Permalink
Use the correct shared foreign key for navigation fixup (#21834)
Browse files Browse the repository at this point in the history
Fixes #20719
  • Loading branch information
AndriySvyryd committed Jul 29, 2020
1 parent e57ec85 commit a0342cc
Show file tree
Hide file tree
Showing 13 changed files with 153 additions and 45 deletions.
15 changes: 14 additions & 1 deletion src/EFCore.Relational/Infrastructure/RelationalModelValidator.cs
Original file line number Diff line number Diff line change
Expand Up @@ -883,13 +883,26 @@ protected virtual void ValidateSharedForeignKeysCompatibility(
StoreObjectIdentifier storeObject,
[NotNull] IDiagnosticsLogger<DbLoggerCategory.Model.Validation> logger)
{
if (storeObject.StoreObjectType != StoreObjectType.Table)
{
return;
}

var foreignKeyMappings = new Dictionary<string, IForeignKey>();

foreach (var foreignKey in mappedTypes.SelectMany(et => et.GetDeclaredForeignKeys()))
{
var principalTable = foreignKey.PrincipalEntityType.GetTableName();
var principalSchema = foreignKey.PrincipalEntityType.GetSchema();

if (principalTable == null)
{
continue;
}

var foreignKeyName = foreignKey.GetConstraintName(
storeObject,
StoreObjectIdentifier.Table(foreignKey.PrincipalEntityType.GetTableName(), foreignKey.PrincipalEntityType.GetSchema()));
StoreObjectIdentifier.Table(principalTable, principalSchema));
if (!foreignKeyMappings.TryGetValue(foreignKeyName, out var duplicateForeignKey))
{
foreignKeyMappings[foreignKeyName] = foreignKey;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -365,15 +365,17 @@ private void TryUniquifyForeignKeyNames(
{
foreach (var foreignKey in entityType.GetDeclaredForeignKeys())
{
if (foreignKey.DeclaringEntityType.GetTableName() == foreignKey.PrincipalEntityType.GetTableName()
&& foreignKey.DeclaringEntityType.GetSchema() == foreignKey.PrincipalEntityType.GetSchema())
var principalTable = foreignKey.PrincipalEntityType.GetTableName();
var principalSchema = foreignKey.PrincipalEntityType.GetSchema();
if (principalTable == null
|| (foreignKey.DeclaringEntityType.GetTableName() == principalTable
&& foreignKey.DeclaringEntityType.GetSchema() == principalSchema))
{
continue;
}

var foreignKeyName = foreignKey.GetConstraintName(storeObject,
StoreObjectIdentifier.Table(foreignKey.PrincipalEntityType.GetTableName(),
foreignKey.PrincipalEntityType.GetSchema()));
StoreObjectIdentifier.Table(principalTable, principalSchema));
if (!foreignKeys.TryGetValue(foreignKeyName, out var otherForeignKey))
{
foreignKeys[foreignKeyName] = foreignKey;
Expand Down
2 changes: 1 addition & 1 deletion src/EFCore/ChangeTracking/Internal/StateManager.cs
Original file line number Diff line number Diff line change
Expand Up @@ -842,7 +842,7 @@ public virtual IEnumerable<IUpdateEntry> GetDependents(
IUpdateEntry principalEntry, IForeignKey foreignKey)
{
var dependentIdentityMap = FindIdentityMap(foreignKey.DeclaringEntityType.FindPrimaryKey());
return dependentIdentityMap != null
return dependentIdentityMap != null && foreignKey.PrincipalEntityType.IsAssignableFrom(principalEntry.EntityType)
? dependentIdentityMap.GetDependentsMap(foreignKey).GetDependents(principalEntry)
: Enumerable.Empty<IUpdateEntry>();
}
Expand Down
1 change: 0 additions & 1 deletion src/EFCore/Metadata/Internal/ClrCollectionAccessor.cs
Original file line number Diff line number Diff line change
Expand Up @@ -109,7 +109,6 @@ public virtual object GetOrCreate(object entity, bool forMaterialization)
private ICollection<TElement> GetOrCreateCollection(object instance, bool forMaterialization)
{
var collection = GetCollection(instance);

if (collection == null)
{
var setCollection = forMaterialization
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@ public TestSqlLoggerFactory(Func<string, bool> shouldLogCategory)
public IReadOnlyList<string> Parameters => ((TestSqlLogger)Logger).Parameters;
public string Sql => string.Join(_eol + _eol, SqlStatements);

public void AssertBaseline(string[] expected)
public void AssertBaseline(string[] expected, bool assertOrder = true)
{
if (_proceduralQueryGeneration)
{
Expand All @@ -44,12 +44,25 @@ public void AssertBaseline(string[] expected)

try
{
for (var i = 0; i < expected.Length; i++)
if (assertOrder)
{
Assert.Equal(expected[i], SqlStatements[i], ignoreLineEndingDifferences: true);
}
for (var i = 0; i < expected.Length; i++)
{
Assert.Equal(expected[i], SqlStatements[i], ignoreLineEndingDifferences: true);
}

Assert.Empty(SqlStatements.Skip(expected.Length));
Assert.Empty(SqlStatements.Skip(expected.Length));
}
else
{
foreach (var expectedFragment in expected)
{
var normalizedExpectedFragment = expectedFragment.Replace("\r", string.Empty).Replace("\n", _eol);
Assert.Contains(
normalizedExpectedFragment,
SqlStatements);
}
}
}
catch
{
Expand Down
Original file line number Diff line number Diff line change
@@ -1,12 +1,15 @@
// Copyright (c) .NET Foundation. All rights reserved.
// Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information.

using System.Collections.Generic;

namespace Microsoft.EntityFrameworkCore.TestModels.UpdatesModel
{
public class Category
{
public int Id { get; set; }
public int? PrincipalId { get; set; }
public string Name { get; set; }
public ICollection<ProductCategory> ProductCategories { get; set; }
}
}
Original file line number Diff line number Diff line change
@@ -1,18 +1,19 @@
// Copyright (c) .NET Foundation. All rights reserved.
// Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information.

using System;
using System.Collections.Generic;
using System.ComponentModel.DataAnnotations;

namespace Microsoft.EntityFrameworkCore.TestModels.UpdatesModel
{
public class Product
public class Product : ProductBase
{
public Guid Id { get; set; }
public int? DependentId { get; set; }
public string Name { get; set; }

[ConcurrencyCheck]
public decimal Price { get; set; }

public ICollection<ProductCategory> ProductCategories { get; set; }
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,12 @@
// Copyright (c) .NET Foundation. All rights reserved.
// Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information.

using System;

namespace Microsoft.EntityFrameworkCore.TestModels.UpdatesModel
{
public abstract class ProductBase
{
public Guid Id { get; set; }
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
// Copyright (c) .NET Foundation. All rights reserved.
// Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information.

using System;

namespace Microsoft.EntityFrameworkCore.TestModels.UpdatesModel
{
public class ProductCategory
{
public int CategoryId { get; set; }
public Guid ProductId { get; set; }
}
}
Original file line number Diff line number Diff line change
@@ -1,17 +1,18 @@
// Copyright (c) .NET Foundation. All rights reserved.
// Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information.

using System;
using System.Collections.Generic;
using System.ComponentModel.DataAnnotations;

namespace Microsoft.EntityFrameworkCore.TestModels.UpdatesModel
{
public class ProductWithBytes
public class ProductWithBytes : ProductBase
{
public Guid Id { get; set; }
public string Name { get; set; }

[ConcurrencyCheck]
public byte[] Bytes { get; set; }

public ICollection<ProductCategory> ProductCategories { get; set; }
}
}
17 changes: 10 additions & 7 deletions test/EFCore.Specification.Tests/UpdatesFixtureBase.cs
Original file line number Diff line number Diff line change
Expand Up @@ -11,21 +11,24 @@ public abstract class UpdatesFixtureBase : SharedStoreFixtureBase<UpdatesContext

protected override void OnModelCreating(ModelBuilder modelBuilder, DbContext context)
{
modelBuilder.Entity<Product>().HasMany(e => e.ProductCategories).WithOne()
.HasForeignKey(e => e.ProductId);
modelBuilder.Entity<ProductWithBytes>().HasMany(e => e.ProductCategories).WithOne()
.HasForeignKey(e => e.ProductId);

modelBuilder.Entity<ProductCategory>()
.HasKey(p => new { p.CategoryId, p.ProductId });

modelBuilder.Entity<Product>().HasOne<Category>().WithMany()
.HasForeignKey(e => e.DependentId)
.HasPrincipalKey(e => e.PrincipalId);

modelBuilder.Entity<Product>()
.Property(e => e.Id)
.ValueGeneratedNever();

modelBuilder.Entity<Category>()
.Property(e => e.Id)
.ValueGeneratedNever();

modelBuilder.Entity<ProductWithBytes>()
.Property(e => e.Id)
.ValueGeneratedNever();
modelBuilder.Entity<Category>().HasMany(e => e.ProductCategories).WithOne()
.HasForeignKey(e => e.CategoryId);

modelBuilder.Entity<AFewBytes>()
.Property(e => e.Id)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,9 @@ protected override void OnModelCreating(ModelBuilder modelBuilder, DbContext con
{
base.OnModelCreating(modelBuilder, context);

modelBuilder.Entity<ProductBase>()
.Property(p => p.Id).HasDefaultValueSql("NEWID()");

modelBuilder.Entity<Product>()
.Property(p => p.Price).HasColumnType("decimal(18,2)");

Expand Down
85 changes: 65 additions & 20 deletions test/EFCore.SqlServer.FunctionalTests/UpdatesSqlServerTest.cs
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
// Copyright (c) .NET Foundation. All rights reserved.
// Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information.

using System.Collections.Generic;
using System.Linq;
using Microsoft.EntityFrameworkCore.TestModels.UpdatesModel;
using Xunit;
Expand All @@ -18,36 +19,77 @@ public UpdatesSqlServerTest(UpdatesSqlServerFixture fixture, ITestOutputHelper t
Fixture.TestSqlLoggerFactory.Clear();
}

public override void Save_replaced_principal()
[ConditionalFact]
public virtual void Save_with_shared_foreign_key()
{
base.Save_replaced_principal();
ExecuteWithStrategyInTransaction(
context =>
{
context.AddRange(
new ProductWithBytes
{
ProductCategories = new List<ProductCategory>
{
new ProductCategory { CategoryId = 77 }
}
},
new Category { Id = 77, PrincipalId = 777 });
AssertSql(
@"SELECT TOP(2) [c].[Id], [c].[Name], [c].[PrincipalId]
FROM [Categories] AS [c]",
//
@"@__category_PrincipalId_0='778' (Nullable = true)
context.SaveChanges();
},
context =>
{
var product = context.Set<ProductBase>()
.Include(p => ((ProductWithBytes)p).ProductCategories)
.Include(p => ((Product)p).ProductCategories)
.OfType<ProductWithBytes>()
.Single();
var productCategory = product.ProductCategories.Single();
Assert.Equal(productCategory.CategoryId, context.Set<ProductCategory>().Single().CategoryId);
Assert.Equal(productCategory.CategoryId, context.Set<Category>().Single(c => c.PrincipalId == 777).Id);
});

SELECT [p].[Id], [p].[DependentId], [p].[Name], [p].[Price]
FROM [Products] AS [p]
WHERE [p].[DependentId] = @__category_PrincipalId_0",
AssertContainsSql(
@"@p0='77'
@p1=NULL (Size = 4000)
@p2='777'
SET NOCOUNT ON;
INSERT INTO [Categories] ([Id], [Name], [PrincipalId])
VALUES (@p0, @p1, @p2);",
//
@"@p0='ProductWithBytes' (Nullable = false) (Size = 4000)
@p1=NULL (Size = 8000) (DbType = Binary)
@p2=NULL (Size = 4000)
SET NOCOUNT ON;
DECLARE @inserted0 TABLE ([Id] uniqueidentifier, [_Position] [int]);
MERGE [ProductBase] USING (
VALUES (@p0, @p1, @p2, 0)) AS i ([Discriminator], [Bytes], [ProductWithBytes_Name], _Position) ON 1=0
WHEN NOT MATCHED THEN
INSERT ([Discriminator], [Bytes], [ProductWithBytes_Name])
VALUES (i.[Discriminator], i.[Bytes], i.[ProductWithBytes_Name])
OUTPUT INSERTED.[Id], i._Position
INTO @inserted0;
SELECT [t].[Id] FROM [ProductBase] t
INNER JOIN @inserted0 i ON ([t].[Id] = [i].[Id])
ORDER BY [i].[_Position];");

}

public override void Save_replaced_principal()
{
base.Save_replaced_principal();

AssertContainsSql(
@"@p1='78'
@p0='New Category' (Size = 4000)
SET NOCOUNT ON;
UPDATE [Categories] SET [Name] = @p0
WHERE [Id] = @p1;
SELECT @@ROWCOUNT;",
//
@"SELECT TOP(2) [c].[Id], [c].[Name], [c].[PrincipalId]
FROM [Categories] AS [c]",
//
@"@__category_PrincipalId_0='778' (Nullable = true)
SELECT [p].[Id], [p].[DependentId], [p].[Name], [p].[Price]
FROM [Products] AS [p]
WHERE [p].[DependentId] = @__category_PrincipalId_0");
SELECT @@ROWCOUNT;");
}

public override void Identifiers_are_generated_correctly()
Expand Down Expand Up @@ -94,5 +136,8 @@ public override void Identifiers_are_generated_correctly()

private void AssertSql(params string[] expected)
=> Fixture.TestSqlLoggerFactory.AssertBaseline(expected);

protected void AssertContainsSql(params string[] expected)
=> Fixture.TestSqlLoggerFactory.AssertBaseline(expected, assertOrder: false);
}
}

0 comments on commit a0342cc

Please sign in to comment.