Skip to content

Commit

Permalink
Fix to #19499 - Query: null semantics not applied on subquery.Contain…
Browse files Browse the repository at this point in the history
…s(null)

We used to translate this into `NULL IN subquery` pattern, but this doesn't work because it doesn't match if the subquery also contains null.
Fix is to convert this into subquery.Any(e => e == NULL), which translates to EXISTS with predicate and we can correctly apply null semantics.
Also it allows us to translate contains on entities with composite keys.

Also made several small fixes:
- marked EXISTS and IN expressions as never nullable for purpose of null semantics,
- optimized EXISITS (subquery) with predicate that resolves to false, directly into false, since empty subquery never exisits,
- improves expression printer output for IN expression in the subquery scenario.
  • Loading branch information
maumar committed Jan 30, 2020
1 parent ac7a352 commit b603ce4
Show file tree
Hide file tree
Showing 25 changed files with 757 additions and 166 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -207,8 +207,15 @@ protected override Expression VisitExists(ExistsExpression existsExpression)
{
Check.NotNull(existsExpression, nameof(existsExpression));

return existsExpression.Update(
VisitInternal<SelectExpression>(existsExpression.Subquery).ResultExpression);
var subquery = VisitInternal<SelectExpression>(existsExpression.Subquery).ResultExpression;
_nullable = false;

// if subquery has predicate which evaluates to false, we can simply return false
return subquery.Predicate is SqlConstantExpression predicateConstant
&& predicateConstant.Value is bool boolValue
&& !boolValue
? (SqlExpression)predicateConstant
: existsExpression.Update(subquery);
}

protected override Expression VisitFromSql(FromSqlExpression fromSqlExpression)
Expand All @@ -222,8 +229,8 @@ protected override Expression VisitIn(InExpression inExpression)

if (inExpression.Subquery != null)
{
var (subquery, subqueryNullable) = VisitInternal<SelectExpression>(inExpression.Subquery);
_nullable = itemNullable || subqueryNullable;
var subquery = VisitInternal<SelectExpression>(inExpression.Subquery).ResultExpression;
_nullable = false;

return inExpression.Update(item, values: null, subquery);
}
Expand All @@ -233,8 +240,8 @@ protected override Expression VisitIn(InExpression inExpression)
if (UseRelationalNulls
|| !(inExpression.Values is SqlConstantExpression || inExpression.Values is SqlParameterExpression))
{
var (values, valuesNullable) = VisitInternal<SqlExpression>(inExpression.Values);
_nullable = itemNullable || valuesNullable;
var values = VisitInternal<SqlExpression>(inExpression.Values).ResultExpression;
_nullable = false;

return inExpression.Update(item, values, subquery: null);
}
Expand Down Expand Up @@ -265,7 +272,7 @@ protected override Expression VisitIn(InExpression inExpression)
if (!itemNullable
|| (_allowOptimizedExpansion && !inExpression.IsNegated && !hasNullValue))
{
_nullable = itemNullable;
_nullable = false;

// non_nullable IN (1, 2) -> non_nullable IN (1, 2)
// non_nullable IN (1, 2, NULL) -> non_nullable IN (1, 2)
Expand All @@ -275,7 +282,6 @@ protected override Expression VisitIn(InExpression inExpression)
return inExpression.Update(item, inValuesExpression, subquery: null);
}

// adding null comparison term to remove nulls completely from the resulting expression
_nullable = false;

// nullable IN (1, 2) -> nullable IN (1, 2) AND nullable IS NOT NULL (full)
Expand Down
9 changes: 8 additions & 1 deletion src/EFCore.Relational/Query/SqlExpressions/InExpression.cs
Original file line number Diff line number Diff line change
Expand Up @@ -87,7 +87,14 @@ public override void Print(ExpressionPrinter expressionPrinter)
expressionPrinter.Append(IsNegated ? " NOT IN " : " IN ");
expressionPrinter.Append("(");

if (Values is SqlConstantExpression constantValuesExpression
if (Subquery != null)
{
using (expressionPrinter.Indent())
{
expressionPrinter.Visit(Subquery);
}
}
else if (Values is SqlConstantExpression constantValuesExpression
&& constantValuesExpression.Value is IEnumerable constantValues)
{
var first = true;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@

namespace Microsoft.EntityFrameworkCore.Query.Internal
{
public class AllAnyToContainsRewritingExpressionVisitor : ExpressionVisitor
public class AllAnyContainsRewritingExpressionVisitor : ExpressionVisitor
{
private static bool IsExpressionOfFunc(Type type, int funcGenericArgs = 2)
=> type.IsGenericType
Expand Down Expand Up @@ -44,6 +44,28 @@ protected override Expression VisitMethodCall(MethodCallExpression methodCallExp
}
}

if (methodCallExpression.Method.IsGenericMethod
&& methodCallExpression.Method.GetGenericMethodDefinition() is MethodInfo containsMethodInfo
&& containsMethodInfo.Equals(QueryableMethods.Contains)
&& methodCallExpression.Arguments[0].NodeType is ExpressionType containsFirstArgumentNodeType
&& (containsFirstArgumentNodeType != ExpressionType.Constant || ((ConstantExpression)methodCallExpression.Arguments[0]).IsEntityQueryable())
&& containsFirstArgumentNodeType != ExpressionType.Parameter
// special case byte_array.Contains() - we don't want those to be rewritten
&& methodCallExpression.Arguments[1].Type != typeof(byte))
{
var typeArgument = methodCallExpression.Method.GetGenericArguments()[0];
var anyMethod = QueryableMethods.AnyWithPredicate.MakeGenericMethod(typeArgument);

var anyLambdaParameter = Expression.Parameter(typeArgument, "p");
var anyLambda = Expression.Lambda(
Expression.Equal(
anyLambdaParameter,
methodCallExpression.Arguments[1]),
anyLambdaParameter);

return Expression.Call(null, anyMethod, new[] { methodCallExpression.Arguments[0], anyLambda });
}

return base.VisitMethodCall(methodCallExpression);
}

Expand Down
2 changes: 1 addition & 1 deletion src/EFCore/Query/QueryTranslationPreprocessor.cs
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@ public virtual Expression Process([NotNull] Expression query)
query = new QueryMetadataExtractingExpressionVisitor(_queryCompilationContext).Visit(query);
query = new InvocationExpressionRemovingExpressionVisitor().Visit(query);
query = new LanguageNormalizingExpressionVisitor().Visit(query);
query = new AllAnyToContainsRewritingExpressionVisitor().Visit(query);
query = new AllAnyContainsRewritingExpressionVisitor().Visit(query);
query = new GroupJoinFlatteningExpressionVisitor().Visit(query);
query = new NullCheckRemovingExpressionVisitor().Visit(query);
query = new EntityEqualityRewritingExpressionVisitor(_queryCompilationContext).Rewrite(query);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1355,15 +1355,10 @@ FROM root c
WHERE ((c[""Discriminator""] = ""Customer"") AND c[""CustomerID""] IN (""ALFKI""))");
}

[ConditionalFact(Skip = "Issue#17246 (Contains over subquery is not supported)")]
public override void Contains_over_entityType_with_null_should_rewrite_to_identity_equality()
[ConditionalTheory(Skip = "Issue#17246 (Contains over subquery is not supported)")]
public override Task Contains_over_entityType_with_null_should_rewrite_to_false(bool async)
{
base.Contains_over_entityType_with_null_should_rewrite_to_identity_equality();

AssertSql(
@"SELECT c
FROM root c
WHERE ((c[""Discriminator""] = ""Order"") AND (c[""OrderID""] = 10248))");
return base.Contains_over_entityType_with_null_should_rewrite_to_false(async);
}

public override async Task String_FirstOrDefault_in_projection_does_client_eval(bool async)
Expand Down Expand Up @@ -1528,6 +1523,54 @@ public override Task Sum_over_explicit_cast_over_column(bool async)
return base.Sum_over_explicit_cast_over_column(async);
}

[ConditionalTheory(Skip = "Issue#17246 (Contains over subquery is not supported)")]
public override Task Contains_over_scalar_with_null_should_rewrite_to_identity_equality_subquery(bool async)
{
return base.Contains_over_scalar_with_null_should_rewrite_to_identity_equality_subquery(async);
}

[ConditionalTheory(Skip = "Issue#17246 (Contains over subquery is not supported)")]
public override Task Contains_over_nullable_scalar_with_null_in_subquery_translated_correctly(bool async)
{
return base.Contains_over_nullable_scalar_with_null_in_subquery_translated_correctly(async);
}

[ConditionalTheory(Skip = "Issue#17246 (Contains over subquery is not supported)")]
public override Task Contains_over_non_nullable_scalar_with_null_in_subquery_simplifies_to_false(bool async)
{
return base.Contains_over_non_nullable_scalar_with_null_in_subquery_simplifies_to_false(async);
}

[ConditionalTheory(Skip = "Issue#17246 (Contains over subquery is not supported)")]
public override Task Contains_over_entityType_with_null_should_rewrite_to_identity_equality_subquery(bool async)
{
return base.Contains_over_entityType_with_null_should_rewrite_to_identity_equality_subquery(async);
}

[ConditionalTheory(Skip = "Issue#17246 (Contains over subquery is not supported)")]
public override Task Contains_over_entityType_with_null_should_rewrite_to_identity_equality_subquery_complex(bool async)
{
return base.Contains_over_entityType_with_null_should_rewrite_to_identity_equality_subquery_complex(async);
}

[ConditionalTheory(Skip = "Issue#17246 (Contains over subquery is not supported)")]
public override Task Contains_over_entityType_with_null_should_rewrite_to_identity_equality_subquery_negated(bool async)
{
return base.Contains_over_entityType_with_null_should_rewrite_to_identity_equality_subquery_negated(async);
}

[ConditionalTheory(Skip = "Issue#17246 (Contains over subquery is not supported)")]
public override Task Contains_over_entityType_should_materialize_when_composite(bool async)
{
return base.Contains_over_entityType_should_materialize_when_composite(async);
}

[ConditionalTheory(Skip = "Issue#17246 (Contains over subquery is not supported)")]
public override Task Contains_over_entityType_should_materialize_when_composite2(bool async)
{
return base.Contains_over_entityType_should_materialize_when_composite2(async);
}

private void AssertSql(params string[] expected)
=> Fixture.TestSqlLoggerFactory.AssertBaseline(expected);

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2053,6 +2053,12 @@ public override Task Where_collection_navigation_ToArray_Length_member(bool asyn
return base.Where_collection_navigation_ToArray_Length_member(async);
}

[ConditionalTheory(Skip = "Issue#17246 (Contains over subquery is not supported")]
public override Task Where_Queryable_AsEnumerable_Contains_negated(bool async)
{
return base.Where_Queryable_AsEnumerable_Contains_negated(async);
}

private void AssertSql(params string[] expected)
=> Fixture.TestSqlLoggerFactory.AssertBaseline(expected);

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -71,5 +71,17 @@ public override Task Select_subquery_single_nested_subquery2(bool async)
{
return base.Select_subquery_single_nested_subquery2(async);
}

[ConditionalTheory(Skip = "issue #19742")]
public override Task Contains_over_optional_navigation_with_null_column(bool async)
{
return base.Contains_over_optional_navigation_with_null_column(async);
}

[ConditionalTheory(Skip = "issue #19742")]
public override Task Contains_over_optional_navigation_with_null_entity_reference(bool async)
{
return base.Contains_over_optional_navigation_with_null_entity_reference(async);
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -5677,5 +5677,71 @@ public virtual Task Select_subquery_single_nested_subquery2(bool async)
});
});
}

[ConditionalFact]
public virtual void Contains_over_optional_navigation_with_null_constant()
{
using var ctx = CreateContext();
var result = ctx.Set<Level1>().Select(l1 => l1.OneToOne_Optional_FK1).Contains(null);
var expected = Fixture.QueryAsserter.ExpectedData.Set<Level1>().Select(l1 => l1.OneToOne_Optional_FK1).Contains(null);

Assert.Equal(expected, result);
}

[ConditionalTheory]
[MemberData(nameof(IsAsyncData))]
public virtual Task Contains_over_optional_navigation_with_null_parameter(bool async)
{
return AssertContains(
async,
ss => ss.Set<Level1>().Select(l1 => l1.OneToOne_Optional_FK1),
null);
}

[ConditionalTheory]
[MemberData(nameof(IsAsyncData))]
public virtual Task Contains_over_optional_navigation_with_null_column(bool async)
{
return AssertQuery(
async,
ss => ss.Set<Level1>().Select(l1 => new
{
l1.Name,
OptionalName = l1.OneToOne_Optional_FK1.Name,
Contains = ss.Set<Level1>().Select(x => x.OneToOne_Optional_FK1.Name).Contains(l1.OneToOne_Optional_FK1.Name)

}),
ss => ss.Set<Level1>().Select(l1 => new
{
l1.Name,
OptionalName = Maybe(l1.OneToOne_Optional_FK1, () => l1.OneToOne_Optional_FK1.Name),
Contains = ss.Set<Level1>().Select(x => Maybe(
x.OneToOne_Optional_FK1,
() => x.OneToOne_Optional_FK1.Name)).Contains(Maybe(l1.OneToOne_Optional_FK1, () => l1.OneToOne_Optional_FK1.Name))
}),

elementSorter: e => (e.Name, e.OptionalName, e.Contains));
}

[ConditionalTheory]
[MemberData(nameof(IsAsyncData))]
public virtual Task Contains_over_optional_navigation_with_null_entity_reference(bool async)
{
return AssertQuery(
async,
ss => ss.Set<Level1>().Select(l1 => new
{
l1.Name,
OptionalName = l1.OneToOne_Optional_FK1.Name,
Contains = ss.Set<Level1>().Select(x => x.OneToOne_Optional_FK1).Contains(l1.OneToOne_Optional_PK1)
}),
ss => ss.Set<Level1>().Select(l1 => new
{
l1.Name,
OptionalName = Maybe(l1.OneToOne_Optional_FK1, () => l1.OneToOne_Optional_FK1.Name),
Contains = ss.Set<Level1>().Select(x => x.OneToOne_Optional_FK1).Contains(l1.OneToOne_Optional_PK1)
}),
elementSorter: e => (e.Name, e.OptionalName, e.Contains));
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -162,6 +162,9 @@ protected GearsOfWarQueryFixtureBase()
Assert.Equal(ee.Name, aa.Name);
Assert.Equal(ee.ThreatLevel, aa.ThreatLevel);
Assert.Equal(ee.ThreatLevelByte, aa.ThreatLevelByte);
Assert.Equal(ee.ThreatLevelNullableByte, aa.ThreatLevelNullableByte);
if (e is LocustCommander locustCommander)
{
var actualLocustCommander = (LocustCommander)aa;
Expand All @@ -183,6 +186,8 @@ protected GearsOfWarQueryFixtureBase()
Assert.Equal(ee.Name, aa.Name);
Assert.Equal(ee.ThreatLevel, aa.ThreatLevel);
Assert.Equal(ee.ThreatLevelByte, aa.ThreatLevelByte);
Assert.Equal(ee.ThreatLevelNullableByte, aa.ThreatLevelNullableByte);
Assert.Equal(ee.DefeatedByNickname, aa.DefeatedByNickname);
Assert.Equal(ee.DefeatedBySquadId, aa.DefeatedBySquadId);
}
Expand Down
56 changes: 56 additions & 0 deletions test/EFCore.Specification.Tests/Query/GearsOfWarQueryTestBase.cs
Original file line number Diff line number Diff line change
Expand Up @@ -7581,6 +7581,62 @@ public virtual Task Group_by_nullable_property_and_project_the_grouping_key_HasV
.Select(g => g.Key.HasValue));
}

[ConditionalTheory]
[MemberData(nameof(IsAsyncData))]
public virtual Task Contains_on_collection_of_byte_subquery(bool async)
{
return AssertQuery(
async,
ss => ss.Set<LocustLeader>().Where(l => ss.Set<LocustLeader>().Select(ll => ll.ThreatLevelByte).Contains(l.ThreatLevelByte)));
}

[ConditionalTheory]
[MemberData(nameof(IsAsyncData))]
public virtual Task Contains_on_collection_of_nullable_byte_subquery(bool async)
{
return AssertQuery(
async,
ss => ss.Set<LocustLeader>().Where(l => ss.Set<LocustLeader>().Select(ll => ll.ThreatLevelNullableByte).Contains(l.ThreatLevelNullableByte)));
}

[ConditionalTheory]
[MemberData(nameof(IsAsyncData))]
public virtual Task Contains_on_collection_of_nullable_byte_subquery_null_constant(bool async)
{
return AssertQuery(
async,
ss => ss.Set<LocustLeader>().Where(l => ss.Set<LocustLeader>().Select(ll => ll.ThreatLevelNullableByte).Contains(null)));
}

[ConditionalTheory]
[MemberData(nameof(IsAsyncData))]
public virtual Task Contains_on_collection_of_nullable_byte_subquery_null_parameter(bool async)
{
var prm = default(byte?);

return AssertQuery(
async,
ss => ss.Set<LocustLeader>().Where(l => ss.Set<LocustLeader>().Select(ll => ll.ThreatLevelNullableByte).Contains(prm)));
}

[ConditionalTheory]
[MemberData(nameof(IsAsyncData))]
public virtual Task Contains_on_byte_array_property_using_byte_column(bool async)
{
return AssertQuery(
async,
ss => from s in ss.Set<Squad>()
from l in ss.Set<LocustLeader>()
where s.Banner.Contains(l.ThreatLevelByte)
select new { s, l },
elementSorter: e => (e.s.Id, e.l.Name),
elementAsserter: (e, a) =>
{
AssertEqual(e.s, a.s);
AssertEqual(e.l, a.l);
});
}

protected GearsOfWarContext CreateContext() => Fixture.CreateContext();

protected virtual void ClearLog()
Expand Down
Loading

0 comments on commit b603ce4

Please sign in to comment.