Skip to content

Commit

Permalink
Reimplement in VisitBinary
Browse files Browse the repository at this point in the history
  • Loading branch information
roji committed Jun 12, 2020
1 parent f1357da commit ca2e347
Show file tree
Hide file tree
Showing 4 changed files with 133 additions and 78 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
using JetBrains.Annotations;
using Microsoft.EntityFrameworkCore.Diagnostics;
using Microsoft.EntityFrameworkCore.Infrastructure;
using Microsoft.EntityFrameworkCore.Internal;
using Microsoft.EntityFrameworkCore.Metadata;
using Microsoft.EntityFrameworkCore.Query;
using Microsoft.EntityFrameworkCore.Storage;
Expand Down Expand Up @@ -161,6 +162,25 @@ protected override Expression VisitBinary(BinaryExpression binaryExpression)
var left = TryRemoveImplicitConvert(binaryExpression.Left);
var right = TryRemoveImplicitConvert(binaryExpression.Right);

// Remove convert-to-object nodes if both sides have them, or if the other side is null constant
if (TryUnwrapConvertToObject(left, out var leftOperand))
{
if (TryUnwrapConvertToObject(right, out var rightOperand))
{
left = leftOperand;
right = rightOperand;
}
else if (right.IsNullConstantExpression())
{
left = leftOperand;
}
}
else if (TryUnwrapConvertToObject(right, out var rightOperand)
&& left.IsNullConstantExpression())
{
right = rightOperand;
}

var visitedLeft = Visit(left);
var visitedRight = Visit(right);

Expand Down Expand Up @@ -193,6 +213,50 @@ protected override Expression VisitBinary(BinaryExpression binaryExpression)
sqlLeft,
sqlRight,
null);

static Expression TryRemoveImplicitConvert(Expression expression)
{
if (expression is UnaryExpression unaryExpression
&& (unaryExpression.NodeType == ExpressionType.Convert
|| unaryExpression.NodeType == ExpressionType.ConvertChecked))
{
var innerType = unaryExpression.Operand.Type.UnwrapNullableType();
if (innerType.IsEnum)
{
innerType = Enum.GetUnderlyingType(innerType);
}

var convertedType = unaryExpression.Type.UnwrapNullableType();

if (innerType == convertedType
|| (convertedType == typeof(int)
&& (innerType == typeof(byte)
|| innerType == typeof(sbyte)
|| innerType == typeof(char)
|| innerType == typeof(short)
|| innerType == typeof(ushort))))
{
return TryRemoveImplicitConvert(unaryExpression.Operand);
}
}

return expression;
}

static bool TryUnwrapConvertToObject(Expression expression, out Expression operand)
{
if (expression is UnaryExpression unaryExpression
&& (unaryExpression.NodeType == ExpressionType.Convert
|| unaryExpression.NodeType == ExpressionType.ConvertChecked)
&& expression.Type == typeof(object))
{
operand = unaryExpression.Operand;
return true;
}

operand = null;
return false;
}
}

/// <summary>
Expand Down Expand Up @@ -593,35 +657,6 @@ ObjectArrayProjectionExpression objectArrayProjectionExpression
};
}

private static Expression TryRemoveImplicitConvert(Expression expression)
{
if (expression is UnaryExpression unaryExpression
&& (unaryExpression.NodeType == ExpressionType.Convert
|| unaryExpression.NodeType == ExpressionType.ConvertChecked))
{
var innerType = unaryExpression.Operand.Type.UnwrapNullableType();
if (innerType.IsEnum)
{
innerType = Enum.GetUnderlyingType(innerType);
}

var convertedType = unaryExpression.Type.UnwrapNullableType();

if (innerType == convertedType
|| (convertedType == typeof(int)
&& (innerType == typeof(byte)
|| innerType == typeof(sbyte)
|| innerType == typeof(char)
|| innerType == typeof(short)
|| innerType == typeof(ushort))))
{
return TryRemoveImplicitConvert(unaryExpression.Operand);
}
}

return expression;
}

private bool TryRewriteContainsEntity(Expression source, Expression item, out Expression result)
{
result = null;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
using JetBrains.Annotations;
using Microsoft.EntityFrameworkCore.Diagnostics;
using Microsoft.EntityFrameworkCore.Infrastructure;
using Microsoft.EntityFrameworkCore.Internal;
using Microsoft.EntityFrameworkCore.Metadata;
using Microsoft.EntityFrameworkCore.Query.SqlExpressions;
using Microsoft.EntityFrameworkCore.Utilities;
Expand Down Expand Up @@ -341,6 +342,25 @@ protected override Expression VisitBinary(BinaryExpression binaryExpression)
var left = TryRemoveImplicitConvert(binaryExpression.Left);
var right = TryRemoveImplicitConvert(binaryExpression.Right);

// Remove convert-to-object nodes if both sides have them, or if the other side is null constant
if (TryUnwrapConvertToObject(left, out var leftOperand))
{
if (TryUnwrapConvertToObject(right, out var rightOperand))
{
left = leftOperand;
right = rightOperand;
}
else if (right.IsNullConstantExpression())
{
left = leftOperand;
}
}
else if (TryUnwrapConvertToObject(right, out var rightOperand)
&& left.IsNullConstantExpression())
{
right = rightOperand;
}

var visitedLeft = Visit(left);
var visitedRight = Visit(right);

Expand Down Expand Up @@ -370,6 +390,50 @@ protected override Expression VisitBinary(BinaryExpression binaryExpression)
sqlLeft,
sqlRight,
null);

static Expression TryRemoveImplicitConvert(Expression expression)
{
if (expression is UnaryExpression unaryExpression
&& (unaryExpression.NodeType == ExpressionType.Convert
|| unaryExpression.NodeType == ExpressionType.ConvertChecked))
{
var innerType = unaryExpression.Operand.Type.UnwrapNullableType();
if (innerType.IsEnum)
{
innerType = Enum.GetUnderlyingType(innerType);
}

var convertedType = expression.Type.UnwrapNullableType();

if (innerType == convertedType
|| (convertedType == typeof(int)
&& (innerType == typeof(byte)
|| innerType == typeof(sbyte)
|| innerType == typeof(char)
|| innerType == typeof(short)
|| innerType == typeof(ushort))))
{
return TryRemoveImplicitConvert(unaryExpression.Operand);
}
}

return expression;
}

static bool TryUnwrapConvertToObject(Expression expression, out Expression operand)
{
if (expression is UnaryExpression unaryExpression
&& (unaryExpression.NodeType == ExpressionType.Convert
|| unaryExpression.NodeType == ExpressionType.ConvertChecked)
&& expression.Type == typeof(object))
{
operand = unaryExpression.Operand;
return true;
}

operand = null;
return false;
}
}

/// <inheritdoc />
Expand Down Expand Up @@ -884,37 +948,6 @@ private static Expression GetPredicateOnGrouping(
throw new InvalidOperationException(CoreStrings.TranslationFailed(methodCallExpression.Print()));
}

private static Expression TryRemoveImplicitConvert(Expression expression)
{
if (expression is UnaryExpression unaryExpression)
{
if (unaryExpression.NodeType == ExpressionType.Convert
|| unaryExpression.NodeType == ExpressionType.ConvertChecked)
{
var innerType = unaryExpression.Operand.Type.UnwrapNullableType();
if (innerType.IsEnum)
{
innerType = Enum.GetUnderlyingType(innerType);
}

var convertedType = unaryExpression.Type.UnwrapNullableType();

if (innerType == convertedType
|| (convertedType == typeof(int)
&& (innerType == typeof(byte)
|| innerType == typeof(sbyte)
|| innerType == typeof(char)
|| innerType == typeof(short)
|| innerType == typeof(ushort))))
{
return TryRemoveImplicitConvert(unaryExpression.Operand);
}
}
}

return expression;
}

private static Expression ConvertObjectArrayEqualityComparison(BinaryExpression binaryExpression)
{
var leftExpressions = ((NewArrayExpression)binaryExpression.Left).Expressions;
Expand Down
19 changes: 3 additions & 16 deletions src/EFCore.Relational/Query/SqlExpressionFactory.cs
Original file line number Diff line number Diff line change
Expand Up @@ -173,15 +173,9 @@ private SqlExpression ApplyTypeMappingOnSqlBinary(
{
inferredTypeMapping = ExpressionExtensions.InferTypeMapping(left, right)
// We avoid object here since the result does not get typeMapping from outside.
?? (left.Type != typeof(object) ? _typeMappingSource.FindMapping(left.Type) : null)
?? (right.Type != typeof(object) ? _typeMappingSource.FindMapping(right.Type) : null)
// If we still haven't found anything, unwrap convert to object as a last resort
?? (UnwrapConvertToObject(left) is SqlExpression leftOperand
? _typeMappingSource.FindMapping(leftOperand.Type)
: UnwrapConvertToObject(right) is SqlExpression rightOperand
? _typeMappingSource.FindMapping(rightOperand.Type)
: null);

?? (left.Type != typeof(object)
? _typeMappingSource.FindMapping(left.Type)
: _typeMappingSource.FindMapping(right.Type));
resultType = typeof(bool);
resultTypeMapping = _boolTypeMapping;
break;
Expand Down Expand Up @@ -220,13 +214,6 @@ private SqlExpression ApplyTypeMappingOnSqlBinary(
ApplyTypeMapping(right, inferredTypeMapping),
resultType,
resultTypeMapping);

static SqlExpression UnwrapConvertToObject(SqlExpression expression)
=> expression.Type == typeof(object)
&& expression is SqlUnaryExpression unaryExpression
&& unaryExpression.OperatorType == ExpressionType.Convert
? unaryExpression.Operand
: null;
}

/// <inheritdoc />
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@ public NorthwindWhereQuerySqlServerTest(NorthwindQuerySqlServerFixture<NoopModel
: base(fixture)
{
ClearLog();
//Fixture.TestSqlLoggerFactory.SetTestOutputHelper(testOutputHelper);
Fixture.TestSqlLoggerFactory.SetTestOutputHelper(testOutputHelper);
}

protected override bool CanExecuteQueryString => true;
Expand Down Expand Up @@ -1530,7 +1530,7 @@ public override async Task Where_compare_with_both_cast_to_object(bool async)
AssertSql(
@"SELECT [c].[CustomerID], [c].[Address], [c].[City], [c].[CompanyName], [c].[ContactName], [c].[ContactTitle], [c].[Country], [c].[Fax], [c].[Phone], [c].[PostalCode], [c].[Region]
FROM [Customers] AS [c]
WHERE CAST([c].[City] AS nvarchar(max)) = CAST(N'London' AS nvarchar(max))");
WHERE [c].[City] = N'London'");
}

public override async Task Where_Is_on_same_type(bool async)
Expand Down

0 comments on commit ca2e347

Please sign in to comment.