Skip to content

Commit

Permalink
Query: Correctly specify types to match constraints in materializatio…
Browse files Browse the repository at this point in the history
…n functions

Resolves #21803
  • Loading branch information
smitpatel committed Jul 31, 2020
1 parent b3f9380 commit cbb4d35
Show file tree
Hide file tree
Showing 4 changed files with 192 additions and 25 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -127,6 +127,7 @@ private static TCollection MaterializeCollection<TElement, TCollection>(
IEnumerable<ValueBuffer> innerValueBuffers,
Func<QueryContext, ValueBuffer, TElement> innerShaper,
IClrCollectionAccessor clrCollectionAccessor)
where TElement : class
where TCollection : class, ICollection<TElement>
{
var collection = (TCollection)(clrCollectionAccessor?.Create() ?? new List<TElement>());
Expand Down Expand Up @@ -196,17 +197,17 @@ protected override Expression VisitExtension(Expression extensionExpression)

if (extensionExpression is CollectionShaperExpression collectionShaperExpression)
{
var navigation = collectionShaperExpression.Navigation;
var collectionAccessor = navigation?.GetCollectionAccessor();
var collectionType = collectionAccessor?.CollectionType ?? collectionShaperExpression.Type;
var elementType = collectionShaperExpression.ElementType;
var collectionType = collectionShaperExpression.Type;

return Expression.Call(
_materializeCollectionMethodInfo.MakeGenericMethod(elementType, collectionType),
QueryCompilationContext.QueryContextParameter,
collectionShaperExpression.Projection,
Expression.Constant(((LambdaExpression)Visit(collectionShaperExpression.InnerShaper)).Compile()),
Expression.Constant(
collectionShaperExpression.Navigation?.GetCollectionAccessor(),
typeof(IClrCollectionAccessor)));
Expression.Constant(collectionAccessor, typeof(IClrCollectionAccessor)));
}

if (extensionExpression is SingleResultShaperExpression singleResultShaperExpression)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -661,10 +661,11 @@ protected override Expression VisitExtension(Expression extensionExpression)
_readerColumns)
.ProcessShaper(relationalCollectionShaperExpression.InnerShaper, out _, out _);

var collectionType = relationalCollectionShaperExpression.Type;
var elementType = collectionType.TryGetSequenceType();
var relatedElementType = innerShaper.ReturnType;
var navigation = relationalCollectionShaperExpression.Navigation;
var collectionAccessor = navigation?.GetCollectionAccessor();
var collectionType = collectionAccessor?.CollectionType ?? relationalCollectionShaperExpression.Type;
var elementType = relationalCollectionShaperExpression.ElementType;
var relatedElementType = innerShaper.ReturnType;

_inline = true;

Expand Down Expand Up @@ -700,7 +701,7 @@ protected override Expression VisitExtension(Expression extensionExpression)
_resultCoordinatorParameter,
Expression.Constant(parentIdentifierLambda.Compile()),
Expression.Constant(outerIdentifierLambda.Compile()),
Expression.Constant(navigation?.GetCollectionAccessor(), typeof(IClrCollectionAccessor)))));
Expression.Constant(collectionAccessor, typeof(IClrCollectionAccessor)))));

_valuesArrayInitializers.Add(collectionParameter);
accessor = Expression.Convert(
Expand Down Expand Up @@ -735,15 +736,16 @@ protected override Expression VisitExtension(Expression extensionExpression)
if (!_variableShaperMapping.TryGetValue(key, out var accessor))
{
var innerProcessor = new ShaperProcessingExpressionVisitor(_parentVisitor, _resultCoordinatorParameter,
_executionStrategyParameter, relationalSplitCollectionShaperExpression.SelectExpression, _tags);
_executionStrategyParameter, relationalSplitCollectionShaperExpression.SelectExpression, _tags);
var innerShaper = innerProcessor.ProcessShaper(relationalSplitCollectionShaperExpression.InnerShaper,
out var relationalCommandCache,
out var relatedDataLoaders);

var collectionType = relationalSplitCollectionShaperExpression.Type;
var elementType = collectionType.TryGetSequenceType();
var relatedElementType = innerShaper.ReturnType;
var navigation = relationalSplitCollectionShaperExpression.Navigation;
var collectionAccessor = navigation?.GetCollectionAccessor();
var collectionType = collectionAccessor?.CollectionType ?? relationalSplitCollectionShaperExpression.Type;
var elementType = relationalSplitCollectionShaperExpression.ElementType;
var relatedElementType = innerShaper.ReturnType;

_inline = true;

Expand Down Expand Up @@ -777,7 +779,7 @@ protected override Expression VisitExtension(Expression extensionExpression)
_dataReaderParameter,
_resultCoordinatorParameter,
Expression.Constant(parentIdentifierLambda.Compile()),
Expression.Constant(navigation?.GetCollectionAccessor(), typeof(IClrCollectionAccessor)))));
Expression.Constant(collectionAccessor, typeof(IClrCollectionAccessor)))));

_valuesArrayInitializers.Add(collectionParameter);
accessor = Expression.Convert(
Expand Down Expand Up @@ -1053,7 +1055,9 @@ private static void IncludeReference<TEntity, TIncludingEntity, TIncludedEntity>
INavigationBase inverseNavigation,
Action<TIncludingEntity, TIncludedEntity> fixup,
bool trackingQuery)
where TIncludingEntity : TEntity
where TEntity : class
where TIncludingEntity : class, TEntity
where TIncludedEntity : class
{
if (entity is TIncludingEntity includingEntity)
{
Expand Down Expand Up @@ -1093,7 +1097,8 @@ private static void InitializeIncludeCollection<TParent, TNavigationEntity>(
INavigationBase navigation,
IClrCollectionAccessor clrCollectionAccessor,
bool trackingQuery)
where TNavigationEntity : TParent
where TParent : class
where TNavigationEntity : class, TParent
{
object collection = null;
if (entity is TNavigationEntity)
Expand Down Expand Up @@ -1133,6 +1138,8 @@ private static void PopulateIncludeCollection<TIncludingEntity, TIncludedEntity>
INavigationBase inverseNavigation,
Action<TIncludingEntity, TIncludedEntity> fixup,
bool trackingQuery)
where TIncludingEntity : class
where TIncludedEntity : class
{
var collectionMaterializationContext = resultCoordinator.Collections[collectionId];
if (collectionMaterializationContext.Parent is TIncludingEntity entity)
Expand Down Expand Up @@ -1242,7 +1249,8 @@ private static void InitializeSplitIncludeCollection<TParent, TNavigationEntity>
INavigationBase navigation,
IClrCollectionAccessor clrCollectionAccessor,
bool trackingQuery)
where TNavigationEntity : TParent
where TParent : class
where TNavigationEntity : class, TParent
{
object collection = null;
if (entity is TNavigationEntity)
Expand Down Expand Up @@ -1279,6 +1287,8 @@ private static void PopulateSplitIncludeCollection<TIncludingEntity, TIncludedEn
INavigationBase inverseNavigation,
Action<TIncludingEntity, TIncludedEntity> fixup,
bool trackingQuery)
where TIncludingEntity : class
where TIncludedEntity : class
{
if (resultCoordinator.DataReaders.Count <= collectionId
|| resultCoordinator.DataReaders[collectionId] == null)
Expand Down Expand Up @@ -1357,6 +1367,8 @@ private static async Task PopulateSplitIncludeCollectionAsync<TIncludingEntity,
INavigationBase inverseNavigation,
Action<TIncludingEntity, TIncludedEntity> fixup,
bool trackingQuery)
where TIncludingEntity : class
where TIncludedEntity : class
{
if (resultCoordinator.DataReaders.Count <= collectionId
|| resultCoordinator.DataReaders[collectionId] == null)
Expand Down Expand Up @@ -1435,7 +1447,8 @@ private static TCollection InitializeCollection<TElement, TCollection>(
Func<QueryContext, DbDataReader, object[]> parentIdentifier,
Func<QueryContext, DbDataReader, object[]> outerIdentifier,
IClrCollectionAccessor clrCollectionAccessor)
where TCollection : class, IEnumerable<TElement>
where TElement : class
where TCollection : class, ICollection<TElement>
{
var collection = clrCollectionAccessor?.Create() ?? new List<TElement>();

Expand All @@ -1461,7 +1474,8 @@ private static void PopulateCollection<TCollection, TElement, TRelatedEntity>(
IReadOnlyList<ValueComparer> outerIdentifierValueComparers,
IReadOnlyList<ValueComparer> selfIdentifierValueComparers,
Func<QueryContext, DbDataReader, ResultContext, SingleQueryResultCoordinator, TRelatedEntity> innerShaper)
where TRelatedEntity : TElement
where TElement : class
where TRelatedEntity : class, TElement
where TCollection : class, ICollection<TElement>
{
var collectionMaterializationContext = resultCoordinator.Collections[collectionId];
Expand Down Expand Up @@ -1566,7 +1580,8 @@ private static TCollection InitializeSplitCollection<TElement, TCollection>(
SplitQueryResultCoordinator resultCoordinator,
Func<QueryContext, DbDataReader, object[]> parentIdentifier,
IClrCollectionAccessor clrCollectionAccessor)
where TCollection : class, IEnumerable<TElement>
where TElement : class
where TCollection : class, ICollection<TElement>
{
var collection = clrCollectionAccessor?.Create() ?? new List<TElement>();
var parentKey = parentIdentifier(queryContext, parentDataReader);
Expand All @@ -1587,8 +1602,9 @@ private static void PopulateSplitCollection<TCollection, TElement, TRelatedEntit
IReadOnlyList<ValueComparer> identifierValueComparers,
Func<QueryContext, DbDataReader, ResultContext, SplitQueryResultCoordinator, TRelatedEntity> innerShaper,
Action<QueryContext, IExecutionStrategy, SplitQueryResultCoordinator> relatedDataLoaders)
where TRelatedEntity : TElement
where TCollection : class, ICollection<TElement>
where TElement : class
where TRelatedEntity : class, TElement
where TCollection : class, ICollection<TElement>
{
if (resultCoordinator.DataReaders.Count <= collectionId
|| resultCoordinator.DataReaders[collectionId] == null)
Expand Down Expand Up @@ -1659,8 +1675,9 @@ private static async Task PopulateSplitCollectionAsync<TCollection, TElement, TR
IReadOnlyList<ValueComparer> identifierValueComparers,
Func<QueryContext, DbDataReader, ResultContext, SplitQueryResultCoordinator, TRelatedEntity> innerShaper,
Func<QueryContext, IExecutionStrategy, SplitQueryResultCoordinator, Task> relatedDataLoaders)
where TRelatedEntity : TElement
where TCollection : class, ICollection<TElement>
where TElement : class
where TRelatedEntity : class, TElement
where TCollection : class, ICollection<TElement>
{
if (resultCoordinator.DataReaders.Count <= collectionId
|| resultCoordinator.DataReaders[collectionId] == null)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -883,6 +883,61 @@ protected override void OnConfiguring(DbContextOptionsBuilder optionsBuilder)

#endregion

#region Issue21803

[ConditionalFact]
public virtual void Select_enumerable_navigation_backed_by_collection()
{
using (CreateScratch<MyContext21803>(Seed21803, "21803"))
{
using var context = new MyContext21803();

var query = context.Set<AppEntity21803>().Select(appEntity => appEntity.OtherEntities);

query.ToList();
}
}

private static void Seed21803(MyContext21803 context)
{
var appEntity = new AppEntity21803();
context.AddRange(
new OtherEntity21803 { AppEntity = appEntity },
new OtherEntity21803 { AppEntity = appEntity },
new OtherEntity21803 { AppEntity = appEntity },
new OtherEntity21803 { AppEntity = appEntity });

context.SaveChanges();
}

public class AppEntity21803
{
private readonly List<OtherEntity21803> _otherEntities = new List<OtherEntity21803>();

public int Id { get; private set; }
public IEnumerable<OtherEntity21803> OtherEntities => _otherEntities;
}

public class OtherEntity21803
{
public int Id { get; private set; }
public AppEntity21803 AppEntity { get; set; }
}

private class MyContext21803 : DbContext
{
public DbSet<AppEntity21803> Entities { get; set; }

protected override void OnConfiguring(DbContextOptionsBuilder optionsBuilder)
{
optionsBuilder
.UseInternalServiceProvider(InMemoryFixture.DefaultServiceProvider)
.UseInMemoryDatabase("21803");
}
}

#endregion

#region SharedHelper

private static InMemoryTestStore CreateScratch<TContext>(Action<TContext> seed, string databaseName)
Expand Down
98 changes: 96 additions & 2 deletions test/EFCore.SqlServer.FunctionalTests/Query/QueryBugsTest.cs
Original file line number Diff line number Diff line change
Expand Up @@ -8103,8 +8103,7 @@ private SqlServerTestStore CreateDatabase19206()

#endregion


#region Issue 18510
#region Issue18510

[ConditionalFact]
public virtual void Invoke_inside_query_filter_gets_correctly_evaluated_during_translation()
Expand Down Expand Up @@ -8198,6 +8197,101 @@ private SqlServerTestStore CreateDatabase18510()

#endregion

#region Issue21803

[ConditionalTheory]
[InlineData(true, true)]
[InlineData(true, false)]
[InlineData(false, true)]
[InlineData(false, false)]
public virtual async Task Select_enumerable_navigation_backed_by_collection(bool async, bool split)
{
using (CreateDatabase21803())
{
using var context = new MyContext21803(_options);

var query = context.Set<AppEntity21803>().Select(appEntity => appEntity.OtherEntities);

if (split)
{
query = query.AsSplitQuery();
}

if (async)
{
await query.ToListAsync();
}
else
{
query.ToList();
}

if (split)
{
AssertSql(
@"SELECT [e].[Id]
FROM [Entities] AS [e]
ORDER BY [e].[Id]",
//
@"SELECT [o].[Id], [o].[AppEntityId], [e].[Id]
FROM [Entities] AS [e]
INNER JOIN [OtherEntity21803] AS [o] ON [e].[Id] = [o].[AppEntityId]
ORDER BY [e].[Id]");
}
else
{
AssertSql(
@"SELECT [e].[Id], [o].[Id], [o].[AppEntityId]
FROM [Entities] AS [e]
LEFT JOIN [OtherEntity21803] AS [o] ON [e].[Id] = [o].[AppEntityId]
ORDER BY [e].[Id], [o].[Id]");
}
}
}

public class AppEntity21803
{
private readonly List<OtherEntity21803> _otherEntities = new List<OtherEntity21803>();

public int Id { get; private set; }
public IEnumerable<OtherEntity21803> OtherEntities => _otherEntities;
}

public class OtherEntity21803
{
public int Id { get; private set; }
public AppEntity21803 AppEntity { get; set; }
}

private class MyContext21803 : DbContext
{
public DbSet<AppEntity21803> Entities { get; set; }

public MyContext21803(DbContextOptions options)
: base(options)
{
}
}

private SqlServerTestStore CreateDatabase21803()
=> CreateTestStore(
() => new MyContext21803(_options),
context =>
{
var appEntity = new AppEntity21803();
context.AddRange(
new OtherEntity21803 { AppEntity = appEntity },
new OtherEntity21803 { AppEntity = appEntity },
new OtherEntity21803 { AppEntity = appEntity },
new OtherEntity21803 { AppEntity = appEntity });
context.SaveChanges();
ClearLog();
});

#endregion

private DbContextOptions _options;

private SqlServerTestStore CreateTestStore<TContext>(
Expand Down

0 comments on commit cbb4d35

Please sign in to comment.