Skip to content

Commit

Permalink
Fix to #19499 - Query: null semantics not applied on subquery.Contain…
Browse files Browse the repository at this point in the history
…s(null)

We used to translate this into `NULL IN subquery` pattern, but this doesn't work because it doesn't match if the subquery also contains null.
Fix is to convert this into subquery.Any(e => e == NULL), which translates to EXISTS with predicate and we can correctly apply null semantics.
Also it allows us to translate contains on entities with composite keys.

Also made several small fixes:
- marked EXISTS as never nullable for purpose of null semantics,
- marked IN as never nullable for purpose of null semantics when both the subquery projection element and the item expression are not nullable,
- optimized EXISITS (subquery) and IN (subquery) with predicate that resolves to false, directly into false, since empty subquery never exisits and doesn't contain any results,
- improves expression printer output for IN expression in the subquery scenario.
  • Loading branch information
maumar committed Mar 4, 2020
1 parent 350e26c commit b84eec1
Show file tree
Hide file tree
Showing 28 changed files with 987 additions and 205 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -208,10 +208,20 @@ protected override Expression VisitExists(ExistsExpression existsExpression)
{
Check.NotNull(existsExpression, nameof(existsExpression));

return existsExpression.Update(
VisitInternal<SelectExpression>(existsExpression.Subquery).ResultExpression);
var subquery = VisitInternal<SelectExpression>(existsExpression.Subquery).ResultExpression;
_nullable = false;

// if subquery has predicate which evaluates to false, we can simply return false
return IsConstantFalse(subquery.Predicate)
? subquery.Predicate
: existsExpression.Update(subquery);
}

private static bool IsConstantFalse(SqlExpression expression)
=> expression is SqlConstantExpression constantExpression
&& constantExpression.Value is bool boolValue
&& !boolValue;

protected override Expression VisitFromSql(FromSqlExpression fromSqlExpression)
=> Check.NotNull(fromSqlExpression, nameof(fromSqlExpression));

Expand All @@ -223,8 +233,23 @@ protected override Expression VisitIn(InExpression inExpression)

if (inExpression.Subquery != null)
{
var (subquery, subqueryNullable) = VisitInternal<SelectExpression>(inExpression.Subquery);
_nullable = itemNullable || subqueryNullable;
var subquery = VisitInternal<SelectExpression>(inExpression.Subquery).ResultExpression;

// a IN (SELECT * FROM table WHERE false) => false
if (IsConstantFalse(subquery.Predicate))
{
_nullable = false;

return subquery.Predicate;
}

// if item is not nullable, and subquery contains a non-nullable column we know the result can never be null
// note: in this case we could broaden the optimization if we knew the nullability of the projection
// but we don't keep that information and we want to avoid double visitation
_nullable = !(!itemNullable
&& subquery.Projection.Count == 1
&& subquery.Projection[0].Expression is ColumnExpression columnProjection
&& !columnProjection.IsNullable);

return inExpression.Update(item, values: null, subquery);
}
Expand All @@ -234,8 +259,8 @@ protected override Expression VisitIn(InExpression inExpression)
if (UseRelationalNulls
|| !(inExpression.Values is SqlConstantExpression || inExpression.Values is SqlParameterExpression))
{
var (values, valuesNullable) = VisitInternal<SqlExpression>(inExpression.Values);
_nullable = itemNullable || valuesNullable;
var values = VisitInternal<SqlExpression>(inExpression.Values).ResultExpression;
_nullable = false;

return inExpression.Update(item, values, subquery: null);
}
Expand Down Expand Up @@ -266,7 +291,7 @@ protected override Expression VisitIn(InExpression inExpression)
if (!itemNullable
|| (_allowOptimizedExpansion && !inExpression.IsNegated && !hasNullValue))
{
_nullable = itemNullable;
_nullable = false;

// non_nullable IN (1, 2) -> non_nullable IN (1, 2)
// non_nullable IN (1, 2, NULL) -> non_nullable IN (1, 2)
Expand All @@ -276,7 +301,6 @@ protected override Expression VisitIn(InExpression inExpression)
return inExpression.Update(item, inValuesExpression, subquery: null);
}

// adding null comparison term to remove nulls completely from the resulting expression
_nullable = false;

// nullable IN (1, 2) -> nullable IN (1, 2) AND nullable IS NOT NULL (full)
Expand Down
9 changes: 8 additions & 1 deletion src/EFCore.Relational/Query/SqlExpressions/InExpression.cs
Original file line number Diff line number Diff line change
Expand Up @@ -88,7 +88,14 @@ public override void Print(ExpressionPrinter expressionPrinter)
expressionPrinter.Append(IsNegated ? " NOT IN " : " IN ");
expressionPrinter.Append("(");

if (Values is SqlConstantExpression constantValuesExpression
if (Subquery != null)
{
using (expressionPrinter.Indent())
{
expressionPrinter.Visit(Subquery);
}
}
else if (Values is SqlConstantExpression constantValuesExpression
&& constantValuesExpression.Value is IEnumerable constantValues)
{
var first = true;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@

namespace Microsoft.EntityFrameworkCore.Query.Internal
{
public class AllAnyToContainsRewritingExpressionVisitor : ExpressionVisitor
public class AllAnyContainsRewritingExpressionVisitor : ExpressionVisitor
{
private static bool IsExpressionOfFunc(Type type, int funcGenericArgs = 2)
=> type.IsGenericType
Expand Down Expand Up @@ -44,6 +44,27 @@ protected override Expression VisitMethodCall(MethodCallExpression methodCallExp
}
}

if (methodCallExpression.Method.IsGenericMethod
&& methodCallExpression.Method.GetGenericMethodDefinition() is MethodInfo containsMethodInfo
&& containsMethodInfo.Equals(QueryableMethods.Contains)
&& !(methodCallExpression.Arguments[0] is ParameterExpression)
&& (!(methodCallExpression.Arguments[0] is ConstantExpression) || ((ConstantExpression)methodCallExpression.Arguments[0]).IsEntityQueryable())
// special case Queryable.Contains(byte_array, byte) - we don't want those to be rewritten
&& methodCallExpression.Arguments[1].Type != typeof(byte))
{
var typeArgument = methodCallExpression.Method.GetGenericArguments()[0];
var anyMethod = QueryableMethods.AnyWithPredicate.MakeGenericMethod(typeArgument);

var anyLambdaParameter = Expression.Parameter(typeArgument, "p");
var anyLambda = Expression.Lambda(
Expression.Equal(
anyLambdaParameter,
methodCallExpression.Arguments[1]),
anyLambdaParameter);

return Expression.Call(null, anyMethod, new[] { methodCallExpression.Arguments[0], anyLambda });
}

return base.VisitMethodCall(methodCallExpression);
}

Expand Down
2 changes: 1 addition & 1 deletion src/EFCore/Query/QueryTranslationPreprocessor.cs
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@ public virtual Expression Process([NotNull] Expression query)
query = NormalizeQueryableMethodCall(query);

query = new VBToCSharpConvertingExpressionVisitor().Visit(query);
query = new AllAnyToContainsRewritingExpressionVisitor().Visit(query);
query = new AllAnyContainsRewritingExpressionVisitor().Visit(query);
query = new NullCheckRemovingExpressionVisitor().Visit(query);
query = new EntityEqualityRewritingExpressionVisitor(_queryCompilationContext).Rewrite(query);
query = new SubqueryMemberPushdownExpressionVisitor(_queryCompilationContext.Model).Visit(query);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1366,15 +1366,10 @@ FROM root c
WHERE ((c[""Discriminator""] = ""Customer"") AND c[""CustomerID""] IN (""ALFKI""))");
}

[ConditionalFact(Skip = "Issue#17246 (Contains over subquery is not supported)")]
public override void Contains_over_entityType_with_null_should_rewrite_to_identity_equality()
[ConditionalTheory(Skip = "Issue#17246 (Contains over subquery is not supported)")]
public override Task Contains_over_entityType_with_null_should_rewrite_to_false(bool async)
{
base.Contains_over_entityType_with_null_should_rewrite_to_identity_equality();

AssertSql(
@"SELECT c
FROM root c
WHERE ((c[""Discriminator""] = ""Order"") AND (c[""OrderID""] = 10248))");
return base.Contains_over_entityType_with_null_should_rewrite_to_false(async);
}

public override async Task String_FirstOrDefault_in_projection_does_client_eval(bool async)
Expand Down Expand Up @@ -1539,6 +1534,54 @@ public override Task Sum_over_explicit_cast_over_column(bool async)
return base.Sum_over_explicit_cast_over_column(async);
}

[ConditionalTheory(Skip = "Issue#17246 (Contains over subquery is not supported)")]
public override Task Contains_over_scalar_with_null_should_rewrite_to_identity_equality_subquery(bool async)
{
return base.Contains_over_scalar_with_null_should_rewrite_to_identity_equality_subquery(async);
}

[ConditionalTheory(Skip = "Issue#17246 (Contains over subquery is not supported)")]
public override Task Contains_over_nullable_scalar_with_null_in_subquery_translated_correctly(bool async)
{
return base.Contains_over_nullable_scalar_with_null_in_subquery_translated_correctly(async);
}

[ConditionalTheory(Skip = "Issue#17246 (Contains over subquery is not supported)")]
public override Task Contains_over_non_nullable_scalar_with_null_in_subquery_simplifies_to_false(bool async)
{
return base.Contains_over_non_nullable_scalar_with_null_in_subquery_simplifies_to_false(async);
}

[ConditionalTheory(Skip = "Issue#17246 (Contains over subquery is not supported)")]
public override Task Contains_over_entityType_with_null_should_rewrite_to_identity_equality_subquery(bool async)
{
return base.Contains_over_entityType_with_null_should_rewrite_to_identity_equality_subquery(async);
}

[ConditionalTheory(Skip = "Issue#17246 (Contains over subquery is not supported)")]
public override Task Contains_over_entityType_with_null_should_rewrite_to_identity_equality_subquery_complex(bool async)
{
return base.Contains_over_entityType_with_null_should_rewrite_to_identity_equality_subquery_complex(async);
}

[ConditionalTheory(Skip = "Issue#17246 (Contains over subquery is not supported)")]
public override Task Contains_over_entityType_with_null_should_rewrite_to_identity_equality_subquery_negated(bool async)
{
return base.Contains_over_entityType_with_null_should_rewrite_to_identity_equality_subquery_negated(async);
}

[ConditionalTheory(Skip = "Issue#17246 (Contains over subquery is not supported)")]
public override Task Contains_over_entityType_should_materialize_when_composite(bool async)
{
return base.Contains_over_entityType_should_materialize_when_composite(async);
}

[ConditionalTheory(Skip = "Issue#17246 (Contains over subquery is not supported)")]
public override Task Contains_over_entityType_should_materialize_when_composite2(bool async)
{
return base.Contains_over_entityType_should_materialize_when_composite2(async);
}

private void AssertSql(params string[] expected)
=> Fixture.TestSqlLoggerFactory.AssertBaseline(expected);

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2053,6 +2053,12 @@ public override Task Where_collection_navigation_ToArray_Length_member(bool asyn
return base.Where_collection_navigation_ToArray_Length_member(async);
}

[ConditionalTheory(Skip = "Issue#17246 (Contains over subquery is not supported")]
public override Task Where_Queryable_AsEnumerable_Contains_negated(bool async)
{
return base.Where_Queryable_AsEnumerable_Contains_negated(async);
}

private void AssertSql(params string[] expected)
=> Fixture.TestSqlLoggerFactory.AssertBaseline(expected);

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -83,5 +83,17 @@ public override Task Nested_SelectMany_correlated_with_join_table_correctly_tran
{
return base.Nested_SelectMany_correlated_with_join_table_correctly_translated_to_apply(async);
}

[ConditionalTheory(Skip = "issue #19742")]
public override Task Contains_over_optional_navigation_with_null_column(bool async)
{
return base.Contains_over_optional_navigation_with_null_column(async);
}

[ConditionalTheory(Skip = "issue #19742")]
public override Task Contains_over_optional_navigation_with_null_entity_reference(bool async)
{
return base.Contains_over_optional_navigation_with_null_entity_reference(async);
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -1229,6 +1229,24 @@ public virtual async Task String_concat_with_both_arguments_being_null(bool asyn
await AssertQuery(async, ss => ss.Set<NullSemanticsEntity1>().Select(x => x.NullableStringB + x.NullableStringA));
}

[ConditionalTheory]
[MemberData(nameof(IsAsyncData))]
public virtual Task Empty_subquery_with_contains_returns_false(bool async)
{
return AssertQuery(
async,
ss => ss.Set<NullSemanticsEntity1>().Where(e => ss.Set<NullSemanticsEntity2>().Where(x => false).Select(x => x.NullableIntA).Contains(e.NullableIntA)));
}

[ConditionalTheory]
[MemberData(nameof(IsAsyncData))]
public virtual Task Empty_subquery_with_contains_negated_returns_true(bool async)
{
return AssertQuery(
async,
ss => ss.Set<NullSemanticsEntity1>().Where(e => !ss.Set<NullSemanticsEntity2>().Where(x => false).Select(x => x.NullableIntA).Contains(e.NullableIntA)));
}

private string NormalizeDelimitersInRawString(string sql)
=> Fixture.TestStore.NormalizeDelimitersInRawString(sql);

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2452,6 +2452,16 @@ public virtual Task Contains_with_subquery_optional_navigation_and_constant_item
l1 => l1.OneToOne_Optional_FK1.OneToMany_Optional2.MaybeScalar(x => x.Distinct().Select(l3 => l3.Id).Contains(1)) == true));
}

[ConditionalTheory]
[MemberData(nameof(IsAsyncData))]
public virtual Task Contains_with_subquery_optional_navigation_scalar_distinct_and_constant_item(bool async)
{
return AssertQuery(
async,
ss => ss.Set<Level1>().Where(l1 => l1.OneToOne_Optional_FK1.OneToMany_Optional2.Select(l3 => l3.Name.Length).Distinct().Contains(1)),
ss => ss.Set<Level1>().Where(l1 => l1.OneToOne_Optional_FK1.OneToMany_Optional2.MaybeScalar(x => x.Select(l3 => l3.Name.Length).Distinct().Contains(1)) == true));
}

[ConditionalTheory]
[MemberData(nameof(IsAsyncData))]
public virtual async Task Complex_query_with_optional_navigations_and_client_side_evaluation(bool async)
Expand Down Expand Up @@ -4920,5 +4930,55 @@ public virtual Task Nested_SelectMany_correlated_with_join_table_correctly_trans
l2 => l2.OneToOne_Required_PK2.OneToMany_Optional3.DefaultIfEmpty()
.Select(l4 => new { l1Name = l1.Name, l2Name = l2.OneToOne_Required_PK2.Name, l3Name = l4.OneToOne_Optional_PK_Inverse4.Name }))));
}

[ConditionalFact]
public virtual void Contains_over_optional_navigation_with_null_constant()
{
using var ctx = CreateContext();
var result = ctx.Set<Level1>().Select(l1 => l1.OneToOne_Optional_FK1).Contains(null);
var expected = Fixture.QueryAsserter.ExpectedData.Set<Level1>().Select(l1 => l1.OneToOne_Optional_FK1).Contains(null);

Assert.Equal(expected, result);
}

[ConditionalTheory]
[MemberData(nameof(IsAsyncData))]
public virtual Task Contains_over_optional_navigation_with_null_parameter(bool async)
{
return AssertContains(
async,
ss => ss.Set<Level1>().Select(l1 => l1.OneToOne_Optional_FK1),
null);
}

[ConditionalTheory]
[MemberData(nameof(IsAsyncData))]
public virtual Task Contains_over_optional_navigation_with_null_column(bool async)
{
return AssertQuery(
async,
ss => ss.Set<Level1>().Select(l1 => new
{
l1.Name,
OptionalName = l1.OneToOne_Optional_FK1.Name,
Contains = ss.Set<Level1>().Select(x => x.OneToOne_Optional_FK1.Name).Contains(l1.OneToOne_Optional_FK1.Name)
}),
elementSorter: e => (e.Name, e.OptionalName, e.Contains));
}

[ConditionalTheory]
[MemberData(nameof(IsAsyncData))]
public virtual Task Contains_over_optional_navigation_with_null_entity_reference(bool async)
{
return AssertQuery(
async,
ss => ss.Set<Level1>().Select(l1 => new
{
l1.Name,
OptionalName = l1.OneToOne_Optional_FK1.Name,
Contains = ss.Set<Level1>().Select(x => x.OneToOne_Optional_FK1).Contains(l1.OneToOne_Optional_PK1)
}),
elementSorter: e => (e.Name, e.OptionalName, e.Contains));
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -162,6 +162,9 @@ protected GearsOfWarQueryFixtureBase()
Assert.Equal(ee.Name, aa.Name);
Assert.Equal(ee.ThreatLevel, aa.ThreatLevel);
Assert.Equal(ee.ThreatLevelByte, aa.ThreatLevelByte);
Assert.Equal(ee.ThreatLevelNullableByte, aa.ThreatLevelNullableByte);
if (e is LocustCommander locustCommander)
{
var actualLocustCommander = (LocustCommander)aa;
Expand All @@ -183,6 +186,8 @@ protected GearsOfWarQueryFixtureBase()
Assert.Equal(ee.Name, aa.Name);
Assert.Equal(ee.ThreatLevel, aa.ThreatLevel);
Assert.Equal(ee.ThreatLevelByte, aa.ThreatLevelByte);
Assert.Equal(ee.ThreatLevelNullableByte, aa.ThreatLevelNullableByte);
Assert.Equal(ee.DefeatedByNickname, aa.DefeatedByNickname);
Assert.Equal(ee.DefeatedBySquadId, aa.DefeatedBySquadId);
}
Expand Down
Loading

0 comments on commit b84eec1

Please sign in to comment.