Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix to #19731 - Discuss default null propagation strategy for functions #20025

Merged
merged 1 commit into from
Feb 21, 2020
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 6 additions & 6 deletions src/EFCore.Relational/Query/ISqlExpressionFactory.cs
Original file line number Diff line number Diff line change
Expand Up @@ -134,7 +134,7 @@ SqlFunctionExpression Function(
SqlFunctionExpression Function(
[NotNull] string name,
[NotNull] IEnumerable<SqlExpression> arguments,
bool nullResultAllowed,
bool nullable,
[NotNull] IEnumerable<bool> argumentsPropagateNullability,
[NotNull] Type returnType,
[CanBeNull] RelationalTypeMapping typeMapping = null);
Expand All @@ -143,7 +143,7 @@ SqlFunctionExpression Function(
[CanBeNull] string schema,
[NotNull] string name,
[NotNull] IEnumerable<SqlExpression> arguments,
bool nullResultAllowed,
bool nullable,
[NotNull] IEnumerable<bool> argumentsPropagateNullability,
[NotNull] Type returnType,
[CanBeNull] RelationalTypeMapping typeMapping = null);
Expand All @@ -152,29 +152,29 @@ SqlFunctionExpression Function(
[CanBeNull] SqlExpression instance,
[NotNull] string name,
[NotNull] IEnumerable<SqlExpression> arguments,
bool nullResultAllowed,
bool nullable,
bool instancePropagatesNullability,
[NotNull] IEnumerable<bool> argumentsPropagateNullability,
[NotNull] Type returnType,
[CanBeNull] RelationalTypeMapping typeMapping = null);

SqlFunctionExpression Function(
[NotNull] string name,
bool nullResultAllowed,
bool nullable,
[NotNull] Type returnType,
[CanBeNull] RelationalTypeMapping typeMapping = null);

SqlFunctionExpression Function(
[NotNull] string schema,
[NotNull] string name,
bool nullResultAllowed,
bool nullable,
[NotNull] Type returnType,
[CanBeNull] RelationalTypeMapping typeMapping = null);

SqlFunctionExpression Function(
[CanBeNull] SqlExpression instance,
[NotNull] string name,
bool nullResultAllowed,
bool nullable,
bool instancePropagatesNullability,
[NotNull] Type returnType,
[CanBeNull] RelationalTypeMapping typeMapping = null);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1287,7 +1287,7 @@ protected virtual SqlExpression ProcessNullNotNull(
sqlUnaryExpression.TypeMapping));
}

if (!sqlFunctionExpression.NullResultAllowed)
if (!sqlFunctionExpression.IsNullable)
{
// when we know that function can't be nullable:
// non_nullable_function() is null-> false
Expand All @@ -1306,39 +1306,33 @@ protected virtual SqlExpression ProcessNullNotNull(
nullabilityPropagationElements.Add(sqlFunctionExpression.Instance);
}

for (var i = 0; i < sqlFunctionExpression.Arguments.Count; i++)
if (!sqlFunctionExpression.IsNiladic)
{
if (sqlFunctionExpression.ArgumentsPropagateNullability[i])
for (var i = 0; i < sqlFunctionExpression.Arguments.Count; i++)
{
nullabilityPropagationElements.Add(sqlFunctionExpression.Arguments[i]);
if (sqlFunctionExpression.ArgumentsPropagateNullability[i])
{
nullabilityPropagationElements.Add(sqlFunctionExpression.Arguments[i]);
}
}
}

// function(a, b) IS NULL -> a IS NULL || b IS NULL
// function(a, b) IS NOT NULL -> a IS NOT NULL && b IS NOT NULL
if (nullabilityPropagationElements.Count > 0)
{
var result = ProcessNullNotNull(
SqlExpressionFactory.MakeUnary(
sqlUnaryExpression.OperatorType,
nullabilityPropagationElements[0],
sqlUnaryExpression.Type,
sqlUnaryExpression.TypeMapping),
operandNullable: null);

foreach (var element in nullabilityPropagationElements.Skip(1))
{
result = SimplifyLogicalSqlBinaryExpression(
var result = nullabilityPropagationElements
.Select(e => ProcessNullNotNull(
SqlExpressionFactory.MakeUnary(
sqlUnaryExpression.OperatorType,
e,
sqlUnaryExpression.Type,
sqlUnaryExpression.TypeMapping),
operandNullable: null))
.Aggregate((r, e) => SimplifyLogicalSqlBinaryExpression(
sqlUnaryExpression.OperatorType == ExpressionType.Equal
? SqlExpressionFactory.OrElse(
result,
ProcessNullNotNull(
SqlExpressionFactory.IsNull(element),
operandNullable: null))
: SqlExpressionFactory.AndAlso(
result,
ProcessNullNotNull(
SqlExpressionFactory.IsNotNull(element),
operandNullable: null)));
}
? SqlExpressionFactory.OrElse(r, e)
: SqlExpressionFactory.AndAlso(r, e)));

return result;
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -57,7 +57,7 @@ public virtual SqlExpression Translate(
dbFunction.Schema,
dbFunction.Name,
arguments,
nullResultAllowed: true,
nullable: true,
argumentsPropagateNullability: arguments.Select(a => false).ToList(),
method.ReturnType);
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -101,15 +101,15 @@ public virtual SqlExpression TranslateAverage([NotNull] Expression expression)
SqlExpressionFactory.Function(
"AVG",
new[] { sqlExpression },
nullResultAllowed: true,
nullable: true,
argumentsPropagateNullability: new[] { false },
typeof(double)),
sqlExpression.Type,
sqlExpression.TypeMapping)
: (SqlExpression)SqlExpressionFactory.Function(
"AVG",
new[] { sqlExpression },
nullResultAllowed: true,
nullable: true,
argumentsPropagateNullability: new[] { false },
sqlExpression.Type,
sqlExpression.TypeMapping);
Expand All @@ -127,7 +127,7 @@ public virtual SqlExpression TranslateCount([CanBeNull] Expression expression =
SqlExpressionFactory.Function(
"COUNT",
new[] { SqlExpressionFactory.Fragment("*") },
nullResultAllowed: false,
nullable: false,
argumentsPropagateNullability: new[] { false },
typeof(int)));
}
Expand All @@ -144,7 +144,7 @@ public virtual SqlExpression TranslateLongCount([CanBeNull] Expression expressio
SqlExpressionFactory.Function(
"COUNT",
new[] { SqlExpressionFactory.Fragment("*") },
nullResultAllowed: false,
nullable: false,
argumentsPropagateNullability: new[] { false },
typeof(long)));
}
Expand All @@ -162,7 +162,7 @@ public virtual SqlExpression TranslateMax([NotNull] Expression expression)
? SqlExpressionFactory.Function(
"MAX",
new[] { sqlExpression },
nullResultAllowed: true,
nullable: true,
argumentsPropagateNullability: new[] { false },
sqlExpression.Type,
sqlExpression.TypeMapping)
Expand All @@ -182,7 +182,7 @@ public virtual SqlExpression TranslateMin([NotNull] Expression expression)
? SqlExpressionFactory.Function(
"MIN",
new[] { sqlExpression },
nullResultAllowed: true,
nullable: true,
argumentsPropagateNullability: new[] { false },
sqlExpression.Type,
sqlExpression.TypeMapping)
Expand Down Expand Up @@ -210,15 +210,15 @@ public virtual SqlExpression TranslateSum([NotNull] Expression expression)
SqlExpressionFactory.Function(
"SUM",
new[] { sqlExpression },
nullResultAllowed: true,
nullable: true,
argumentsPropagateNullability: new[] { false },
typeof(double)),
inputType,
sqlExpression.TypeMapping)
: (SqlExpression)SqlExpressionFactory.Function(
"SUM",
new[] { sqlExpression },
nullResultAllowed: true,
nullable: true,
argumentsPropagateNullability: new[] { false },
inputType,
sqlExpression.TypeMapping);
Expand Down
42 changes: 23 additions & 19 deletions src/EFCore.Relational/Query/SqlExpressionFactory.cs
Original file line number Diff line number Diff line change
Expand Up @@ -376,7 +376,7 @@ public virtual SqlFunctionExpression Coalesce(SqlExpression left, SqlExpression
return SqlFunctionExpression.Create(
"COALESCE",
typeMappedArguments,
nullResultAllowed: true,
nullable: true,
// COALESCE is handled separately since it's only nullable if *both* arguments are null
argumentsPropagateNullability: new[] { false, false },
resultType,
Expand Down Expand Up @@ -487,21 +487,24 @@ public virtual CaseExpression Case(IReadOnlyList<CaseWhenClause> whenClauses, Sq
return new CaseExpression(typeMappedWhenClauses, elseResult);
}

[Obsolete("Use overload that explicitly specifies value for 'argumentsPropagateNullability' argument.")]
public virtual SqlFunctionExpression Function(
string name,
IEnumerable<SqlExpression> arguments,
Type returnType,
RelationalTypeMapping typeMapping = null)
=> Function(name, arguments, nullResultAllowed: true, argumentsPropagateNullability: arguments.Select(a => false), returnType, typeMapping);
=> Function(name, arguments, nullable: true, argumentsPropagateNullability: arguments.Select(a => false), returnType, typeMapping);

[Obsolete("Use overload that explicitly specifies value for 'argumentsPropagateNullability' argument.")]
public virtual SqlFunctionExpression Function(
string schema,
string name,
IEnumerable<SqlExpression> arguments,
Type returnType,
RelationalTypeMapping typeMapping = null)
=> Function(schema, name, arguments, nullResultAllowed: true, argumentsPropagateNullability: arguments.Select(a => false), returnType, typeMapping);
=> Function(schema, name, arguments, nullable: true, argumentsPropagateNullability: arguments.Select(a => false), returnType, typeMapping);

[Obsolete("Use overload that explicitly specifies values for 'instancePropagatesNullability' and 'argumentsPropagateNullability' arguments.")]
public virtual SqlFunctionExpression Function(
SqlExpression instance,
string name,
Expand All @@ -512,25 +515,26 @@ public virtual SqlFunctionExpression Function(
instance,
name,
arguments,
nullResultAllowed: true,
nullable: true,
instancePropagatesNullability: false,
argumentsPropagateNullability: arguments.Select(a => false),
returnType,
typeMapping);

public virtual SqlFunctionExpression Function(string name, Type returnType, RelationalTypeMapping typeMapping = null)
=> Function(name, nullResultAllowed: true, returnType, typeMapping);
=> Function(name, nullable: true, returnType, typeMapping);

public virtual SqlFunctionExpression Function(string schema, string name, Type returnType, RelationalTypeMapping typeMapping = null)
=> Function(schema, name, nullResultAllowed: true, returnType, typeMapping);
=> Function(schema, name, nullable: true, returnType, typeMapping);

[Obsolete("Use overload that explicitly specifies value for 'instancePropagatesNullability' argument.")]
public virtual SqlFunctionExpression Function(SqlExpression instance, string name, Type returnType, RelationalTypeMapping typeMapping = null)
=> Function(instance, name, nullResultAllowed: true, instancePropagatesNullability: false, returnType, typeMapping);
=> Function(instance, name, nullable: true, instancePropagatesNullability: false, returnType, typeMapping);

public virtual SqlFunctionExpression Function(
string name,
IEnumerable<SqlExpression> arguments,
bool nullResultAllowed,
bool nullable,
IEnumerable<bool> argumentsPropagateNullability,
Type returnType,
RelationalTypeMapping typeMapping = null)
Expand All @@ -549,7 +553,7 @@ public virtual SqlFunctionExpression Function(
return SqlFunctionExpression.Create(
name,
typeMappedArguments,
nullResultAllowed,
nullable,
argumentsPropagateNullability,
returnType,
typeMapping);
Expand All @@ -559,7 +563,7 @@ public virtual SqlFunctionExpression Function(
string schema,
string name,
IEnumerable<SqlExpression> arguments,
bool nullResultAllowed,
bool nullable,
IEnumerable<bool> argumentsPropagateNullability,
Type returnType,
RelationalTypeMapping typeMapping = null)
Expand All @@ -578,7 +582,7 @@ public virtual SqlFunctionExpression Function(
schema,
name,
typeMappedArguments,
nullResultAllowed,
nullable,
argumentsPropagateNullability,
returnType,
typeMapping);
Expand All @@ -588,7 +592,7 @@ public virtual SqlFunctionExpression Function(
SqlExpression instance,
string name,
IEnumerable<SqlExpression> arguments,
bool nullResultAllowed,
bool nullable,
bool instancePropagatesNullability,
IEnumerable<bool> argumentsPropagateNullability,
Type returnType,
Expand All @@ -609,34 +613,34 @@ public virtual SqlFunctionExpression Function(
instance,
name,
typeMappedArguments,
nullResultAllowed,
nullable,
instancePropagatesNullability,
argumentsPropagateNullability,
returnType,
typeMapping);
}

public virtual SqlFunctionExpression Function(string name, bool nullResultAllowed, Type returnType, RelationalTypeMapping typeMapping = null)
public virtual SqlFunctionExpression Function(string name, bool nullable, Type returnType, RelationalTypeMapping typeMapping = null)
{
Check.NotEmpty(name, nameof(name));
Check.NotNull(returnType, nameof(returnType));

return SqlFunctionExpression.CreateNiladic(name, nullResultAllowed, returnType, typeMapping);
return SqlFunctionExpression.CreateNiladic(name, nullable, returnType, typeMapping);
}

public virtual SqlFunctionExpression Function(string schema, string name, bool nullResultAllowed, Type returnType, RelationalTypeMapping typeMapping = null)
public virtual SqlFunctionExpression Function(string schema, string name, bool nullable, Type returnType, RelationalTypeMapping typeMapping = null)
{
Check.NotEmpty(schema, nameof(schema));
Check.NotEmpty(name, nameof(name));
Check.NotNull(returnType, nameof(returnType));

return SqlFunctionExpression.CreateNiladic(schema, name, nullResultAllowed, returnType, typeMapping);
return SqlFunctionExpression.CreateNiladic(schema, name, nullable, returnType, typeMapping);
}

public virtual SqlFunctionExpression Function(
SqlExpression instance,
string name,
bool nullResultAllowed,
bool nullable,
bool instancePropagatesNullability,
Type returnType,
RelationalTypeMapping typeMapping = null)
Expand All @@ -647,7 +651,7 @@ public virtual SqlFunctionExpression Function(
return SqlFunctionExpression.CreateNiladic(
ApplyDefaultTypeMapping(instance),
name,
nullResultAllowed,
nullable,
instancePropagatesNullability,
returnType,
typeMapping);
Expand Down
Loading