From baf46cac80219c90be994c7553334e47baf3c137 Mon Sep 17 00:00:00 2001 From: Philipp Zech Date: Sat, 22 Feb 2020 23:07:11 +0100 Subject: [PATCH 1/5] WIP: SQLite: Enable ef_compare for decimals --- .../SqliteSqlTranslatingExpressionVisitor.cs | 72 +++++++++++++-- .../Internal/SqliteRelationalConnection.cs | 88 +++++++++++++++++-- 2 files changed, 149 insertions(+), 11 deletions(-) diff --git a/src/EFCore.Sqlite.Core/Query/Internal/SqliteSqlTranslatingExpressionVisitor.cs b/src/EFCore.Sqlite.Core/Query/Internal/SqliteSqlTranslatingExpressionVisitor.cs index 71134edf2e7..b3b72d751c2 100644 --- a/src/EFCore.Sqlite.Core/Query/Internal/SqliteSqlTranslatingExpressionVisitor.cs +++ b/src/EFCore.Sqlite.Core/Query/Internal/SqliteSqlTranslatingExpressionVisitor.cs @@ -59,10 +59,7 @@ private static readonly IReadOnlyDictionary - { - typeof(ulong) - }, + [ExpressionType.Modulo] = new HashSet { typeof(ulong) }, [ExpressionType.Multiply] = new HashSet { typeof(decimal), @@ -98,7 +95,7 @@ protected override Expression VisitUnary(UnaryExpression unaryExpression) Check.NotNull(unaryExpression, nameof(unaryExpression)); if (unaryExpression.NodeType == ExpressionType.ArrayLength - && unaryExpression.Operand.Type == typeof(byte[])) + && unaryExpression.Operand.Type == typeof(byte[])) { return base.Visit(unaryExpression.Operand) is SqlExpression sqlExpression ? SqlExpressionFactory.Function( @@ -156,6 +153,71 @@ protected override Expression VisitBinary(BinaryExpression binaryExpression) visitedExpression.TypeMapping); } + if (sqlBinary.OperatorType == ExpressionType.GreaterThan + && (_functionModuloTypes.Contains(GetProviderType(sqlBinary.Left)) + || _functionModuloTypes.Contains(GetProviderType(sqlBinary.Right)))) + { + return SqlExpressionFactory.Function( + "ef_compare_gt", + new[] { sqlBinary.Left, sqlBinary.Right }, + nullable: true, + argumentsPropagateNullability: new[] { true, true }, + visitedExpression.Type, + visitedExpression.TypeMapping); + } + + if (sqlBinary.OperatorType == ExpressionType.GreaterThanOrEqual + && (_functionModuloTypes.Contains(GetProviderType(sqlBinary.Left)) + || _functionModuloTypes.Contains(GetProviderType(sqlBinary.Right)))) + { + return SqlExpressionFactory.Function( + "ef_compare_geq", + new[] { sqlBinary.Left, sqlBinary.Right }, + nullable: true, + argumentsPropagateNullability: new[] { true, true }, + visitedExpression.Type, + visitedExpression.TypeMapping); + } + + if (sqlBinary.OperatorType == ExpressionType.LessThan + && (_functionModuloTypes.Contains(GetProviderType(sqlBinary.Left)) + || _functionModuloTypes.Contains(GetProviderType(sqlBinary.Right)))) + { + return SqlExpressionFactory.Function( + "ef_compare_lt", + new[] { sqlBinary.Left, sqlBinary.Right }, + nullable: true, + argumentsPropagateNullability: new[] { true, true }, + visitedExpression.Type, + visitedExpression.TypeMapping); + } + + if (sqlBinary.OperatorType == ExpressionType.LessThanOrEqual + && (_functionModuloTypes.Contains(GetProviderType(sqlBinary.Left)) + || _functionModuloTypes.Contains(GetProviderType(sqlBinary.Right)))) + { + return SqlExpressionFactory.Function( + "ef_compare_leq", + new[] { sqlBinary.Left, sqlBinary.Right }, + nullable: true, + argumentsPropagateNullability: new[] { true, true }, + visitedExpression.Type, + visitedExpression.TypeMapping); + } + + if (sqlBinary.OperatorType == ExpressionType.Equal + && (_functionModuloTypes.Contains(GetProviderType(sqlBinary.Left)) + || _functionModuloTypes.Contains(GetProviderType(sqlBinary.Right)))) + { + return SqlExpressionFactory.Function( + "ef_compare_eq", + new[] { sqlBinary.Left, sqlBinary.Right }, + nullable: true, + argumentsPropagateNullability: new[] { true, true }, + visitedExpression.Type, + visitedExpression.TypeMapping); + } + if (_restrictedBinaryExpressions.TryGetValue(sqlBinary.OperatorType, out var restrictedTypes) && (restrictedTypes.Contains(GetProviderType(sqlBinary.Left)) || restrictedTypes.Contains(GetProviderType(sqlBinary.Right)))) diff --git a/src/EFCore.Sqlite.Core/Storage/Internal/SqliteRelationalConnection.cs b/src/EFCore.Sqlite.Core/Storage/Internal/SqliteRelationalConnection.cs index b8869226c0a..4b117ad63ff 100644 --- a/src/EFCore.Sqlite.Core/Storage/Internal/SqliteRelationalConnection.cs +++ b/src/EFCore.Sqlite.Core/Storage/Internal/SqliteRelationalConnection.cs @@ -9,6 +9,7 @@ using Microsoft.Data.Sqlite; using Microsoft.EntityFrameworkCore.Diagnostics; using Microsoft.EntityFrameworkCore.Infrastructure; +using Microsoft.EntityFrameworkCore.Query.Internal; using Microsoft.EntityFrameworkCore.Sqlite.Infrastructure.Internal; using Microsoft.EntityFrameworkCore.Sqlite.Internal; using Microsoft.EntityFrameworkCore.Storage; @@ -92,7 +93,8 @@ protected override DbConnection CreateDbConnection() /// public virtual ISqliteRelationalConnection CreateReadOnlyConnection() { - var connectionStringBuilder = new SqliteConnectionStringBuilder(GetCheckedConnectionString()) { Mode = SqliteOpenMode.ReadOnly }; + var connectionStringBuilder = + new SqliteConnectionStringBuilder(GetCheckedConnectionString()) { Mode = SqliteOpenMode.ReadOnly }; var contextOptions = new DbContextOptionsBuilder().UseSqlite(connectionStringBuilder.ToString()).Options; @@ -117,24 +119,98 @@ private void InitializeDbConnection(DbConnection connection) "ef_mod", (dividend, divisor) => { - if (dividend == null || divisor == null) + if (dividend == null + || divisor == null) { return null; } + if (dividend is string s) { - return decimal.Parse(s, CultureInfo.InvariantCulture) % - Convert.ToDecimal(divisor, CultureInfo.InvariantCulture); + return decimal.Parse(s, CultureInfo.InvariantCulture) + % Convert.ToDecimal(divisor, CultureInfo.InvariantCulture); } - return Convert.ToDouble(dividend, CultureInfo.InvariantCulture) % - Convert.ToDouble(divisor, CultureInfo.InvariantCulture); + return Convert.ToDouble(dividend, CultureInfo.InvariantCulture) + % Convert.ToDouble(divisor, CultureInfo.InvariantCulture); }); + + CreateEfCompareFunctions(sqliteConnection); } else { _logger.UnexpectedConnectionTypeWarning(connection.GetType()); } } + + private void CreateEfCompareFunctions(SqliteConnection sqliteConnection) + { + var functions = new[] + { + ("ef_compare_gt", Comparer.Operator.GreaterThan), + ("ef_compare_geq", Comparer.Operator.GreaterThanOrEqual), + ("ef_compare_lt", Comparer.Operator.LessThan), + ("ef_compare_leq", Comparer.Operator.LessThanOrEqual), + ("ef_compare_eq", Comparer.Operator.Equal) + }; + + foreach (var function in functions) + { + sqliteConnection.CreateFunction( + function.Item1, + (left, right) => + { + if (left == null + || right == null) + { + return null; + } + + var leftSide = left is string leftAsString + ? decimal.Parse(leftAsString, CultureInfo.CurrentCulture) + : Convert.ToDecimal(left, CultureInfo.CurrentCulture); + var rightSide = right is string rightAsString + ? decimal.Parse(rightAsString, CultureInfo.CurrentCulture) + : Convert.ToDecimal(right, CultureInfo.CurrentCulture); + + return Comparer.IsTrue( + Convert.ToDecimal(leftSide, CultureInfo.CurrentCulture), function.Item2, + Convert.ToDecimal(rightSide, CultureInfo.CurrentCulture)); + }); + } + } + + internal static class Comparer + { + public static bool IsTrue(T value1, Operator comparisonOperator, U value2) + where T : U + where U : IComparable + { + switch (comparisonOperator) + { + case Operator.GreaterThan: + return value1.CompareTo(value2) > 0; + case Operator.GreaterThanOrEqual: + return value1.CompareTo(value2) >= 0; + case Operator.LessThan: + return value1.CompareTo(value2) < 0; + case Operator.LessThanOrEqual: + return value1.CompareTo(value2) <= 0; + case Operator.Equal: + return value1.CompareTo(value2) == 0; + default: + return false; + } + } + + internal enum Operator + { + GreaterThan = 1, + GreaterThanOrEqual = 2, + LessThan = 3, + LessThanOrEqual = 4, + Equal = 5 + } + } } } From c6a2a3d3ab580cd19183454923c502d195b250e3 Mon Sep 17 00:00:00 2001 From: Philipp Zech Date: Sat, 29 Feb 2020 15:59:28 +0100 Subject: [PATCH 2/5] WIP: Refactoring Code compiles but tests do not yet pass. --- .../SqliteSqlTranslatingExpressionVisitor.cs | 95 ++++++------------- .../Internal/SqliteRelationalConnection.cs | 77 ++------------- 2 files changed, 35 insertions(+), 137 deletions(-) diff --git a/src/EFCore.Sqlite.Core/Query/Internal/SqliteSqlTranslatingExpressionVisitor.cs b/src/EFCore.Sqlite.Core/Query/Internal/SqliteSqlTranslatingExpressionVisitor.cs index b3b72d751c2..6c5ed8f7f1c 100644 --- a/src/EFCore.Sqlite.Core/Query/Internal/SqliteSqlTranslatingExpressionVisitor.cs +++ b/src/EFCore.Sqlite.Core/Query/Internal/SqliteSqlTranslatingExpressionVisitor.cs @@ -34,28 +34,24 @@ private static readonly IReadOnlyDictionary { typeof(DateTimeOffset), - typeof(decimal), typeof(TimeSpan), typeof(ulong) }, [ExpressionType.GreaterThanOrEqual] = new HashSet { typeof(DateTimeOffset), - typeof(decimal), typeof(TimeSpan), typeof(ulong) }, [ExpressionType.LessThan] = new HashSet { typeof(DateTimeOffset), - typeof(decimal), typeof(TimeSpan), typeof(ulong) }, [ExpressionType.LessThanOrEqual] = new HashSet { typeof(DateTimeOffset), - typeof(decimal), typeof(TimeSpan), typeof(ulong) }, @@ -153,69 +149,9 @@ protected override Expression VisitBinary(BinaryExpression binaryExpression) visitedExpression.TypeMapping); } - if (sqlBinary.OperatorType == ExpressionType.GreaterThan - && (_functionModuloTypes.Contains(GetProviderType(sqlBinary.Left)) - || _functionModuloTypes.Contains(GetProviderType(sqlBinary.Right)))) - { - return SqlExpressionFactory.Function( - "ef_compare_gt", - new[] { sqlBinary.Left, sqlBinary.Right }, - nullable: true, - argumentsPropagateNullability: new[] { true, true }, - visitedExpression.Type, - visitedExpression.TypeMapping); - } - - if (sqlBinary.OperatorType == ExpressionType.GreaterThanOrEqual - && (_functionModuloTypes.Contains(GetProviderType(sqlBinary.Left)) - || _functionModuloTypes.Contains(GetProviderType(sqlBinary.Right)))) - { - return SqlExpressionFactory.Function( - "ef_compare_geq", - new[] { sqlBinary.Left, sqlBinary.Right }, - nullable: true, - argumentsPropagateNullability: new[] { true, true }, - visitedExpression.Type, - visitedExpression.TypeMapping); - } - - if (sqlBinary.OperatorType == ExpressionType.LessThan - && (_functionModuloTypes.Contains(GetProviderType(sqlBinary.Left)) - || _functionModuloTypes.Contains(GetProviderType(sqlBinary.Right)))) - { - return SqlExpressionFactory.Function( - "ef_compare_lt", - new[] { sqlBinary.Left, sqlBinary.Right }, - nullable: true, - argumentsPropagateNullability: new[] { true, true }, - visitedExpression.Type, - visitedExpression.TypeMapping); - } - - if (sqlBinary.OperatorType == ExpressionType.LessThanOrEqual - && (_functionModuloTypes.Contains(GetProviderType(sqlBinary.Left)) - || _functionModuloTypes.Contains(GetProviderType(sqlBinary.Right)))) - { - return SqlExpressionFactory.Function( - "ef_compare_leq", - new[] { sqlBinary.Left, sqlBinary.Right }, - nullable: true, - argumentsPropagateNullability: new[] { true, true }, - visitedExpression.Type, - visitedExpression.TypeMapping); - } - - if (sqlBinary.OperatorType == ExpressionType.Equal - && (_functionModuloTypes.Contains(GetProviderType(sqlBinary.Left)) - || _functionModuloTypes.Contains(GetProviderType(sqlBinary.Right)))) + if (AttemptDecimalCompare(sqlBinary)) { - return SqlExpressionFactory.Function( - "ef_compare_eq", - new[] { sqlBinary.Left, sqlBinary.Right }, - nullable: true, - argumentsPropagateNullability: new[] { true, true }, - visitedExpression.Type, - visitedExpression.TypeMapping); + return DoDecimalCompare(visitedExpression, sqlBinary.OperatorType, sqlBinary.Left, sqlBinary.Right); } if (_restrictedBinaryExpressions.TryGetValue(sqlBinary.OperatorType, out var restrictedTypes) @@ -295,5 +231,32 @@ private static Type GetProviderType(SqlExpression expression) : (expression.TypeMapping?.Converter?.ProviderClrType ?? expression.TypeMapping?.ClrType ?? expression.Type).UnwrapNullableType(); + + private static bool AttemptDecimalCompare(SqlBinaryExpression sqlBinary) + { + return GetProviderType(sqlBinary.Left) == typeof(decimal) + && GetProviderType(sqlBinary.Right) == typeof(decimal); + } + + private Expression DoDecimalCompare(SqlExpression visitedExpression, ExpressionType op, SqlExpression left, SqlExpression right) + { + var actual = SqlExpressionFactory.Function( + name: "ef_compare", + arguments: new[] { left, right }, + nullable: true, + argumentsPropagateNullability: new[] { true, true }, + visitedExpression.Type, + visitedExpression.TypeMapping); + var oracle = SqlExpressionFactory.Constant(0); + + return op switch + { + ExpressionType.GreaterThan => SqlExpressionFactory.GreaterThan(left: actual, right: oracle), + ExpressionType.GreaterThanOrEqual => SqlExpressionFactory.GreaterThanOrEqual(left: actual, right: oracle), + ExpressionType.LessThan => SqlExpressionFactory.LessThan(left: actual, right: oracle), + ExpressionType.LessThanOrEqual => SqlExpressionFactory.LessThanOrEqual(left: actual, right: oracle), + _ => visitedExpression + }; + } } } diff --git a/src/EFCore.Sqlite.Core/Storage/Internal/SqliteRelationalConnection.cs b/src/EFCore.Sqlite.Core/Storage/Internal/SqliteRelationalConnection.cs index 4b117ad63ff..2a29378fc44 100644 --- a/src/EFCore.Sqlite.Core/Storage/Internal/SqliteRelationalConnection.cs +++ b/src/EFCore.Sqlite.Core/Storage/Internal/SqliteRelationalConnection.cs @@ -135,82 +135,17 @@ private void InitializeDbConnection(DbConnection connection) % Convert.ToDouble(divisor, CultureInfo.InvariantCulture); }); - CreateEfCompareFunctions(sqliteConnection); + sqliteConnection.CreateFunction( + "ef_compare", + (decimal? left, decimal? right) => left.HasValue && right.HasValue + ? decimal.Compare(left.Value, right.Value) + : default(int?), + isDeterministic: true); } else { _logger.UnexpectedConnectionTypeWarning(connection.GetType()); } } - - private void CreateEfCompareFunctions(SqliteConnection sqliteConnection) - { - var functions = new[] - { - ("ef_compare_gt", Comparer.Operator.GreaterThan), - ("ef_compare_geq", Comparer.Operator.GreaterThanOrEqual), - ("ef_compare_lt", Comparer.Operator.LessThan), - ("ef_compare_leq", Comparer.Operator.LessThanOrEqual), - ("ef_compare_eq", Comparer.Operator.Equal) - }; - - foreach (var function in functions) - { - sqliteConnection.CreateFunction( - function.Item1, - (left, right) => - { - if (left == null - || right == null) - { - return null; - } - - var leftSide = left is string leftAsString - ? decimal.Parse(leftAsString, CultureInfo.CurrentCulture) - : Convert.ToDecimal(left, CultureInfo.CurrentCulture); - var rightSide = right is string rightAsString - ? decimal.Parse(rightAsString, CultureInfo.CurrentCulture) - : Convert.ToDecimal(right, CultureInfo.CurrentCulture); - - return Comparer.IsTrue( - Convert.ToDecimal(leftSide, CultureInfo.CurrentCulture), function.Item2, - Convert.ToDecimal(rightSide, CultureInfo.CurrentCulture)); - }); - } - } - - internal static class Comparer - { - public static bool IsTrue(T value1, Operator comparisonOperator, U value2) - where T : U - where U : IComparable - { - switch (comparisonOperator) - { - case Operator.GreaterThan: - return value1.CompareTo(value2) > 0; - case Operator.GreaterThanOrEqual: - return value1.CompareTo(value2) >= 0; - case Operator.LessThan: - return value1.CompareTo(value2) < 0; - case Operator.LessThanOrEqual: - return value1.CompareTo(value2) <= 0; - case Operator.Equal: - return value1.CompareTo(value2) == 0; - default: - return false; - } - } - - internal enum Operator - { - GreaterThan = 1, - GreaterThanOrEqual = 2, - LessThan = 3, - LessThanOrEqual = 4, - Equal = 5 - } - } } } From b34afaffffcf1eedfc2c5246f851135ba132a833 Mon Sep 17 00:00:00 2001 From: Philipp Zech Date: Mon, 2 Mar 2020 19:53:55 +0100 Subject: [PATCH 3/5] fix: return types --- .../SqliteSqlTranslatingExpressionVisitor.cs | 17 +++++++---------- .../Internal/SqliteRelationalConnection.cs | 2 +- 2 files changed, 8 insertions(+), 11 deletions(-) diff --git a/src/EFCore.Sqlite.Core/Query/Internal/SqliteSqlTranslatingExpressionVisitor.cs b/src/EFCore.Sqlite.Core/Query/Internal/SqliteSqlTranslatingExpressionVisitor.cs index 6c5ed8f7f1c..bfca60be5c9 100644 --- a/src/EFCore.Sqlite.Core/Query/Internal/SqliteSqlTranslatingExpressionVisitor.cs +++ b/src/EFCore.Sqlite.Core/Query/Internal/SqliteSqlTranslatingExpressionVisitor.cs @@ -232,22 +232,19 @@ private static Type GetProviderType(SqlExpression expression) ?? expression.TypeMapping?.ClrType ?? expression.Type).UnwrapNullableType(); - private static bool AttemptDecimalCompare(SqlBinaryExpression sqlBinary) - { - return GetProviderType(sqlBinary.Left) == typeof(decimal) - && GetProviderType(sqlBinary.Right) == typeof(decimal); - } + private static bool AttemptDecimalCompare(SqlBinaryExpression sqlBinary) => + GetProviderType(sqlBinary.Left) == typeof(decimal) + && GetProviderType(sqlBinary.Right) == typeof(decimal); private Expression DoDecimalCompare(SqlExpression visitedExpression, ExpressionType op, SqlExpression left, SqlExpression right) { var actual = SqlExpressionFactory.Function( name: "ef_compare", - arguments: new[] { left, right }, + new[] { left, right }, nullable: true, - argumentsPropagateNullability: new[] { true, true }, - visitedExpression.Type, - visitedExpression.TypeMapping); - var oracle = SqlExpressionFactory.Constant(0); + new[] { true, true }, + typeof(int)); + var oracle = SqlExpressionFactory.Constant(value: 0); return op switch { diff --git a/src/EFCore.Sqlite.Core/Storage/Internal/SqliteRelationalConnection.cs b/src/EFCore.Sqlite.Core/Storage/Internal/SqliteRelationalConnection.cs index 2a29378fc44..d54abd97374 100644 --- a/src/EFCore.Sqlite.Core/Storage/Internal/SqliteRelationalConnection.cs +++ b/src/EFCore.Sqlite.Core/Storage/Internal/SqliteRelationalConnection.cs @@ -136,7 +136,7 @@ private void InitializeDbConnection(DbConnection connection) }); sqliteConnection.CreateFunction( - "ef_compare", + name: "ef_compare", (decimal? left, decimal? right) => left.HasValue && right.HasValue ? decimal.Compare(left.Value, right.Value) : default(int?), From 849a225853c73da7b8d2d15a3e6b6bdc36d7d706 Mon Sep 17 00:00:00 2001 From: Philipp Zech Date: Tue, 3 Mar 2020 16:53:48 +0100 Subject: [PATCH 4/5] fix: add ExpressionType to check for decimal compare --- .../Internal/SqliteSqlTranslatingExpressionVisitor.cs | 7 ++++++- 1 file changed, 6 insertions(+), 1 deletion(-) diff --git a/src/EFCore.Sqlite.Core/Query/Internal/SqliteSqlTranslatingExpressionVisitor.cs b/src/EFCore.Sqlite.Core/Query/Internal/SqliteSqlTranslatingExpressionVisitor.cs index bfca60be5c9..d550fe9e010 100644 --- a/src/EFCore.Sqlite.Core/Query/Internal/SqliteSqlTranslatingExpressionVisitor.cs +++ b/src/EFCore.Sqlite.Core/Query/Internal/SqliteSqlTranslatingExpressionVisitor.cs @@ -5,6 +5,7 @@ using System.Collections.Generic; using System.Linq; using System.Linq.Expressions; +using System.Reflection.Emit; using JetBrains.Annotations; using Microsoft.EntityFrameworkCore.Metadata; using Microsoft.EntityFrameworkCore.Query; @@ -234,7 +235,11 @@ private static Type GetProviderType(SqlExpression expression) private static bool AttemptDecimalCompare(SqlBinaryExpression sqlBinary) => GetProviderType(sqlBinary.Left) == typeof(decimal) - && GetProviderType(sqlBinary.Right) == typeof(decimal); + && GetProviderType(sqlBinary.Right) == typeof(decimal) + && new[] + { + ExpressionType.GreaterThan, ExpressionType.GreaterThanOrEqual, ExpressionType.LessThan, ExpressionType.LessThanOrEqual + }.Contains(sqlBinary.OperatorType); private Expression DoDecimalCompare(SqlExpression visitedExpression, ExpressionType op, SqlExpression left, SqlExpression right) { From 4331176416ea7cf955733f040da2795fcc74b05d Mon Sep 17 00:00:00 2001 From: Philipp Zech Date: Tue, 3 Mar 2020 16:54:56 +0100 Subject: [PATCH 5/5] fix: remove unused import --- .../Query/Internal/SqliteSqlTranslatingExpressionVisitor.cs | 1 - 1 file changed, 1 deletion(-) diff --git a/src/EFCore.Sqlite.Core/Query/Internal/SqliteSqlTranslatingExpressionVisitor.cs b/src/EFCore.Sqlite.Core/Query/Internal/SqliteSqlTranslatingExpressionVisitor.cs index d550fe9e010..4a1a253afc8 100644 --- a/src/EFCore.Sqlite.Core/Query/Internal/SqliteSqlTranslatingExpressionVisitor.cs +++ b/src/EFCore.Sqlite.Core/Query/Internal/SqliteSqlTranslatingExpressionVisitor.cs @@ -5,7 +5,6 @@ using System.Collections.Generic; using System.Linq; using System.Linq.Expressions; -using System.Reflection.Emit; using JetBrains.Annotations; using Microsoft.EntityFrameworkCore.Metadata; using Microsoft.EntityFrameworkCore.Query;