Skip to content

Commit

Permalink
Annotate IsNiladic for nullability on Arguments (#23409)
Browse files Browse the repository at this point in the history
  • Loading branch information
smitpatel committed Nov 20, 2020
1 parent a076e3f commit e260690
Show file tree
Hide file tree
Showing 4 changed files with 9 additions and 8 deletions.
2 changes: 1 addition & 1 deletion src/EFCore.Relational/Query/QuerySqlGenerator.cs
Original file line number Diff line number Diff line change
Expand Up @@ -292,7 +292,7 @@ protected override Expression VisitSqlFunction(SqlFunctionExpression sqlFunction
if (!sqlFunctionExpression.IsNiladic)
{
_relationalCommandBuilder.Append("(");
GenerateList(sqlFunctionExpression.Arguments!, e => Visit(e));
GenerateList(sqlFunctionExpression.Arguments, e => Visit(e));
_relationalCommandBuilder.Append(")");
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
using JetBrains.Annotations;
using Microsoft.EntityFrameworkCore.Storage;
using Microsoft.EntityFrameworkCore.Utilities;
using CA = System.Diagnostics.CodeAnalysis;

#nullable enable
namespace Microsoft.EntityFrameworkCore.Query.SqlExpressions
Expand All @@ -22,7 +23,6 @@ namespace Microsoft.EntityFrameworkCore.Query.SqlExpressions
/// not used in application code.
/// </para>
/// </summary>
// TODO - NULLABLE for C# 9 (MemberNotNullIf on IsNiladic)
public class SqlFunctionExpression : SqlExpression
{
/// <summary>
Expand Down Expand Up @@ -231,6 +231,7 @@ private SqlFunctionExpression(
/// <summary>
/// A bool value indicating if the function is niladic.
/// </summary>
[CA.MemberNotNullWhen(false, nameof(Arguments), nameof(ArgumentsPropagateNullability))]
public virtual bool IsNiladic { get; }

/// <summary>
Expand Down Expand Up @@ -276,7 +277,7 @@ protected override Expression VisitChildren(ExpressionVisitor visitor)
SqlExpression[]? arguments = default;
if (!IsNiladic)
{
arguments = new SqlExpression[Arguments!.Count];
arguments = new SqlExpression[Arguments.Count];
for (var i = 0; i < arguments.Length; i++)
{
arguments[i] = (SqlExpression)visitor.Visit(Arguments[i]);
Expand Down Expand Up @@ -367,7 +368,7 @@ protected override void Print(ExpressionPrinter expressionPrinter)
if (!IsNiladic)
{
expressionPrinter.Append("(");
expressionPrinter.VisitCollection(Arguments!);
expressionPrinter.VisitCollection(Arguments);
expressionPrinter.Append(")");
}
}
Expand All @@ -381,6 +382,7 @@ public override bool Equals(object? obj)

private bool Equals(SqlFunctionExpression sqlFunctionExpression)
=> base.Equals(sqlFunctionExpression)
&& IsNiladic == sqlFunctionExpression.IsNiladic
&& Name == sqlFunctionExpression.Name
&& Schema == sqlFunctionExpression.Schema
&& ((Instance == null && sqlFunctionExpression.Instance == null)
Expand Down
5 changes: 2 additions & 3 deletions src/EFCore.Relational/Query/SqlNullabilityProcessor.cs
Original file line number Diff line number Diff line change
Expand Up @@ -1002,7 +1002,7 @@ protected virtual SqlExpression VisitSqlFunction(
return sqlFunctionExpression.Update(instance, sqlFunctionExpression.Arguments);
}

var arguments = new SqlExpression[sqlFunctionExpression.Arguments!.Count];
var arguments = new SqlExpression[sqlFunctionExpression.Arguments.Count];
for (var i = 0; i < arguments.Length; i++)
{
arguments[i] = Visit(sqlFunctionExpression.Arguments[i], out _);
Expand Down Expand Up @@ -1687,8 +1687,7 @@ private SqlExpression ProcessNullNotNull(SqlUnaryExpression sqlUnaryExpression,
nullabilityPropagationElements.Add(sqlFunctionExpression.Instance);
}

if (sqlFunctionExpression.Arguments != null
&& sqlFunctionExpression.ArgumentsPropagateNullability != null)
if (!sqlFunctionExpression.IsNiladic)
{
for (var i = 0; i < sqlFunctionExpression.Arguments.Count; i++)
{
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -453,7 +453,7 @@ protected override Expression VisitSqlFunction(SqlFunctionExpression sqlFunction
SqlExpression[]? arguments = default;
if (!sqlFunctionExpression.IsNiladic)
{
arguments = new SqlExpression[sqlFunctionExpression.Arguments!.Count];
arguments = new SqlExpression[sqlFunctionExpression.Arguments.Count];
for (var i = 0; i < arguments.Length; i++)
{
arguments[i] = (SqlExpression)Visit(sqlFunctionExpression.Arguments[i]);
Expand Down

0 comments on commit e260690

Please sign in to comment.