From 0f71d5f19c1a7fd255c7a3b3678ab1291c650029 Mon Sep 17 00:00:00 2001 From: Smit Patel Date: Fri, 31 Jul 2020 12:16:55 -0700 Subject: [PATCH] Query: Correctly specify types to match constraints in materialization functions Resolves #21803 --- ....CustomShaperCompilingExpressionVisitor.cs | 8 +- ...sitor.ShaperProcessingExpressionVisitor.cs | 48 +++++---- .../Query/QueryBugsInMemoryTest.cs | 55 +++++++++++ .../Query/QueryBugsTest.cs | 98 ++++++++++++++++++- 4 files changed, 185 insertions(+), 24 deletions(-) diff --git a/src/EFCore.InMemory/Query/Internal/InMemoryShapedQueryCompilingExpressionVisitor.CustomShaperCompilingExpressionVisitor.cs b/src/EFCore.InMemory/Query/Internal/InMemoryShapedQueryCompilingExpressionVisitor.CustomShaperCompilingExpressionVisitor.cs index 66f82e2ec02..1bab63556f8 100644 --- a/src/EFCore.InMemory/Query/Internal/InMemoryShapedQueryCompilingExpressionVisitor.CustomShaperCompilingExpressionVisitor.cs +++ b/src/EFCore.InMemory/Query/Internal/InMemoryShapedQueryCompilingExpressionVisitor.CustomShaperCompilingExpressionVisitor.cs @@ -196,17 +196,17 @@ protected override Expression VisitExtension(Expression extensionExpression) if (extensionExpression is CollectionShaperExpression collectionShaperExpression) { + var navigation = collectionShaperExpression.Navigation; + var collectionAccessor = navigation?.GetCollectionAccessor(); + var collectionType = collectionAccessor?.CollectionType ?? collectionShaperExpression.Type; var elementType = collectionShaperExpression.ElementType; - var collectionType = collectionShaperExpression.Type; return Expression.Call( _materializeCollectionMethodInfo.MakeGenericMethod(elementType, collectionType), QueryCompilationContext.QueryContextParameter, collectionShaperExpression.Projection, Expression.Constant(((LambdaExpression)Visit(collectionShaperExpression.InnerShaper)).Compile()), - Expression.Constant( - collectionShaperExpression.Navigation?.GetCollectionAccessor(), - typeof(IClrCollectionAccessor))); + Expression.Constant(collectionAccessor, typeof(IClrCollectionAccessor))); } if (extensionExpression is SingleResultShaperExpression singleResultShaperExpression) diff --git a/src/EFCore.Relational/Query/RelationalShapedQueryCompilingExpressionVisitor.ShaperProcessingExpressionVisitor.cs b/src/EFCore.Relational/Query/RelationalShapedQueryCompilingExpressionVisitor.ShaperProcessingExpressionVisitor.cs index 022ce202439..4815bc168de 100644 --- a/src/EFCore.Relational/Query/RelationalShapedQueryCompilingExpressionVisitor.ShaperProcessingExpressionVisitor.cs +++ b/src/EFCore.Relational/Query/RelationalShapedQueryCompilingExpressionVisitor.ShaperProcessingExpressionVisitor.cs @@ -661,10 +661,11 @@ protected override Expression VisitExtension(Expression extensionExpression) _readerColumns) .ProcessShaper(relationalCollectionShaperExpression.InnerShaper, out _, out _); - var collectionType = relationalCollectionShaperExpression.Type; - var elementType = collectionType.TryGetSequenceType(); - var relatedElementType = innerShaper.ReturnType; var navigation = relationalCollectionShaperExpression.Navigation; + var collectionAccessor = navigation?.GetCollectionAccessor(); + var collectionType = collectionAccessor?.CollectionType ?? relationalCollectionShaperExpression.Type; + var elementType = relationalCollectionShaperExpression.ElementType; + var relatedElementType = innerShaper.ReturnType; _inline = true; @@ -700,7 +701,7 @@ protected override Expression VisitExtension(Expression extensionExpression) _resultCoordinatorParameter, Expression.Constant(parentIdentifierLambda.Compile()), Expression.Constant(outerIdentifierLambda.Compile()), - Expression.Constant(navigation?.GetCollectionAccessor(), typeof(IClrCollectionAccessor))))); + Expression.Constant(collectionAccessor, typeof(IClrCollectionAccessor))))); _valuesArrayInitializers.Add(collectionParameter); accessor = Expression.Convert( @@ -735,15 +736,16 @@ protected override Expression VisitExtension(Expression extensionExpression) if (!_variableShaperMapping.TryGetValue(key, out var accessor)) { var innerProcessor = new ShaperProcessingExpressionVisitor(_parentVisitor, _resultCoordinatorParameter, - _executionStrategyParameter, relationalSplitCollectionShaperExpression.SelectExpression, _tags); + _executionStrategyParameter, relationalSplitCollectionShaperExpression.SelectExpression, _tags); var innerShaper = innerProcessor.ProcessShaper(relationalSplitCollectionShaperExpression.InnerShaper, out var relationalCommandCache, out var relatedDataLoaders); - var collectionType = relationalSplitCollectionShaperExpression.Type; - var elementType = collectionType.TryGetSequenceType(); - var relatedElementType = innerShaper.ReturnType; var navigation = relationalSplitCollectionShaperExpression.Navigation; + var collectionAccessor = navigation?.GetCollectionAccessor(); + var collectionType = collectionAccessor?.CollectionType ?? relationalSplitCollectionShaperExpression.Type; + var elementType = relationalSplitCollectionShaperExpression.ElementType; + var relatedElementType = innerShaper.ReturnType; _inline = true; @@ -777,7 +779,7 @@ protected override Expression VisitExtension(Expression extensionExpression) _dataReaderParameter, _resultCoordinatorParameter, Expression.Constant(parentIdentifierLambda.Compile()), - Expression.Constant(navigation?.GetCollectionAccessor(), typeof(IClrCollectionAccessor))))); + Expression.Constant(collectionAccessor, typeof(IClrCollectionAccessor))))); _valuesArrayInitializers.Add(collectionParameter); accessor = Expression.Convert( @@ -1053,7 +1055,9 @@ private static void IncludeReference INavigationBase inverseNavigation, Action fixup, bool trackingQuery) - where TIncludingEntity : TEntity + where TEntity : class + where TIncludingEntity : class, TEntity + where TIncludedEntity : class { if (entity is TIncludingEntity includingEntity) { @@ -1093,7 +1097,8 @@ private static void InitializeIncludeCollection( INavigationBase navigation, IClrCollectionAccessor clrCollectionAccessor, bool trackingQuery) - where TNavigationEntity : TParent + where TParent : class + where TNavigationEntity : class, TParent { object collection = null; if (entity is TNavigationEntity) @@ -1133,6 +1138,8 @@ private static void PopulateIncludeCollection INavigationBase inverseNavigation, Action fixup, bool trackingQuery) + where TIncludingEntity : class + where TIncludedEntity : class { var collectionMaterializationContext = resultCoordinator.Collections[collectionId]; if (collectionMaterializationContext.Parent is TIncludingEntity entity) @@ -1242,7 +1249,8 @@ private static void InitializeSplitIncludeCollection INavigationBase navigation, IClrCollectionAccessor clrCollectionAccessor, bool trackingQuery) - where TNavigationEntity : TParent + where TParent : class + where TNavigationEntity : class, TParent { object collection = null; if (entity is TNavigationEntity) @@ -1279,6 +1287,8 @@ private static void PopulateSplitIncludeCollection fixup, bool trackingQuery) + where TIncludingEntity : class + where TIncludedEntity : class { if (resultCoordinator.DataReaders.Count <= collectionId || resultCoordinator.DataReaders[collectionId] == null) @@ -1357,6 +1367,8 @@ private static async Task PopulateSplitIncludeCollectionAsync fixup, bool trackingQuery) + where TIncludingEntity : class + where TIncludedEntity : class { if (resultCoordinator.DataReaders.Count <= collectionId || resultCoordinator.DataReaders[collectionId] == null) @@ -1435,7 +1447,7 @@ private static TCollection InitializeCollection( Func parentIdentifier, Func outerIdentifier, IClrCollectionAccessor clrCollectionAccessor) - where TCollection : class, IEnumerable + where TCollection : class, ICollection { var collection = clrCollectionAccessor?.Create() ?? new List(); @@ -1566,7 +1578,7 @@ private static TCollection InitializeSplitCollection( SplitQueryResultCoordinator resultCoordinator, Func parentIdentifier, IClrCollectionAccessor clrCollectionAccessor) - where TCollection : class, IEnumerable + where TCollection : class, ICollection { var collection = clrCollectionAccessor?.Create() ?? new List(); var parentKey = parentIdentifier(queryContext, parentDataReader); @@ -1587,8 +1599,8 @@ private static void PopulateSplitCollection identifierValueComparers, Func innerShaper, Action relatedDataLoaders) - where TRelatedEntity : TElement - where TCollection : class, ICollection + where TRelatedEntity : TElement + where TCollection : class, ICollection { if (resultCoordinator.DataReaders.Count <= collectionId || resultCoordinator.DataReaders[collectionId] == null) @@ -1659,8 +1671,8 @@ private static async Task PopulateSplitCollectionAsync identifierValueComparers, Func innerShaper, Func relatedDataLoaders) - where TRelatedEntity : TElement - where TCollection : class, ICollection + where TRelatedEntity : TElement + where TCollection : class, ICollection { if (resultCoordinator.DataReaders.Count <= collectionId || resultCoordinator.DataReaders[collectionId] == null) diff --git a/test/EFCore.InMemory.FunctionalTests/Query/QueryBugsInMemoryTest.cs b/test/EFCore.InMemory.FunctionalTests/Query/QueryBugsInMemoryTest.cs index dc56b21ad49..174ce6d25a5 100644 --- a/test/EFCore.InMemory.FunctionalTests/Query/QueryBugsInMemoryTest.cs +++ b/test/EFCore.InMemory.FunctionalTests/Query/QueryBugsInMemoryTest.cs @@ -883,6 +883,61 @@ protected override void OnConfiguring(DbContextOptionsBuilder optionsBuilder) #endregion + #region Issue21803 + + [ConditionalFact] + public virtual void Select_enumerable_navigation_backed_by_collection() + { + using (CreateScratch(Seed21803, "21803")) + { + using var context = new MyContext21803(); + + var query = context.Set().Select(appEntity => appEntity.OtherEntities); + + query.ToList(); + } + } + + private static void Seed21803(MyContext21803 context) + { + var appEntity = new AppEntity21803(); + context.AddRange( + new OtherEntity21803 { AppEntity = appEntity }, + new OtherEntity21803 { AppEntity = appEntity }, + new OtherEntity21803 { AppEntity = appEntity }, + new OtherEntity21803 { AppEntity = appEntity }); + + context.SaveChanges(); + } + + public class AppEntity21803 + { + private readonly List _otherEntities = new List(); + + public int Id { get; private set; } + public IEnumerable OtherEntities => _otherEntities; + } + + public class OtherEntity21803 + { + public int Id { get; private set; } + public AppEntity21803 AppEntity { get; set; } + } + + private class MyContext21803 : DbContext + { + public DbSet Entities { get; set; } + + protected override void OnConfiguring(DbContextOptionsBuilder optionsBuilder) + { + optionsBuilder + .UseInternalServiceProvider(InMemoryFixture.DefaultServiceProvider) + .UseInMemoryDatabase("21803"); + } + } + + #endregion + #region SharedHelper private static InMemoryTestStore CreateScratch(Action seed, string databaseName) diff --git a/test/EFCore.SqlServer.FunctionalTests/Query/QueryBugsTest.cs b/test/EFCore.SqlServer.FunctionalTests/Query/QueryBugsTest.cs index a53abb3d39a..0a816edd399 100644 --- a/test/EFCore.SqlServer.FunctionalTests/Query/QueryBugsTest.cs +++ b/test/EFCore.SqlServer.FunctionalTests/Query/QueryBugsTest.cs @@ -8103,8 +8103,7 @@ private SqlServerTestStore CreateDatabase19206() #endregion - - #region Issue 18510 + #region Issue18510 [ConditionalFact] public virtual void Invoke_inside_query_filter_gets_correctly_evaluated_during_translation() @@ -8198,6 +8197,101 @@ private SqlServerTestStore CreateDatabase18510() #endregion + #region Issue21803 + + [ConditionalTheory] + [InlineData(true, true)] + [InlineData(true, false)] + [InlineData(false, true)] + [InlineData(false, false)] + public virtual async Task Select_enumerable_navigation_backed_by_collection(bool async, bool split) + { + using (CreateDatabase21803()) + { + using var context = new MyContext21803(_options); + + var query = context.Set().Select(appEntity => appEntity.OtherEntities); + + if (split) + { + query = query.AsSplitQuery(); + } + + if (async) + { + await query.ToListAsync(); + } + else + { + query.ToList(); + } + + if (split) + { + AssertSql( + @"SELECT [e].[Id] +FROM [Entities] AS [e] +ORDER BY [e].[Id]", + // + @"SELECT [o].[Id], [o].[AppEntityId], [e].[Id] +FROM [Entities] AS [e] +INNER JOIN [OtherEntity21803] AS [o] ON [e].[Id] = [o].[AppEntityId] +ORDER BY [e].[Id]"); + } + else + { + AssertSql( + @"SELECT [e].[Id], [o].[Id], [o].[AppEntityId] +FROM [Entities] AS [e] +LEFT JOIN [OtherEntity21803] AS [o] ON [e].[Id] = [o].[AppEntityId] +ORDER BY [e].[Id], [o].[Id]"); + } + } + } + + public class AppEntity21803 + { + private readonly List _otherEntities = new List(); + + public int Id { get; private set; } + public IEnumerable OtherEntities => _otherEntities; + } + + public class OtherEntity21803 + { + public int Id { get; private set; } + public AppEntity21803 AppEntity { get; set; } + } + + private class MyContext21803 : DbContext + { + public DbSet Entities { get; set; } + + public MyContext21803(DbContextOptions options) + : base(options) + { + } + } + + private SqlServerTestStore CreateDatabase21803() + => CreateTestStore( + () => new MyContext21803(_options), + context => + { + var appEntity = new AppEntity21803(); + context.AddRange( + new OtherEntity21803 { AppEntity = appEntity }, + new OtherEntity21803 { AppEntity = appEntity }, + new OtherEntity21803 { AppEntity = appEntity }, + new OtherEntity21803 { AppEntity = appEntity }); + + context.SaveChanges(); + + ClearLog(); + }); + + #endregion + private DbContextOptions _options; private SqlServerTestStore CreateTestStore(