diff --git a/src/EFCore.SqlServer/Extensions/SqlServerDbFunctionsExtensions.cs b/src/EFCore.SqlServer/Extensions/SqlServerDbFunctionsExtensions.cs index 794ee92f4c9..30ab9fb03ee 100644 --- a/src/EFCore.SqlServer/Extensions/SqlServerDbFunctionsExtensions.cs +++ b/src/EFCore.SqlServer/Extensions/SqlServerDbFunctionsExtensions.cs @@ -966,5 +966,104 @@ public static TimeSpan TimeFromParts( int fractions, int precision) => throw new InvalidOperationException(CoreStrings.FunctionOnClient(nameof(TimeFromParts))); + + /// + /// Returns the number of bytes used to represent any expression. + /// + /// The DbFunctions instance. + /// The value to be examined for data length. + /// The number of bytes in the input value. + public static int? DataLength( + [CanBeNull] this DbFunctions _, + [CanBeNull] string arg) + => throw new InvalidOperationException(CoreStrings.FunctionOnClient(nameof(DataLength))); + + /// + /// Returns the number of bytes used to represent any expression. + /// + /// The DbFunctions instance. + /// The value to be examined for data length. + /// The number of bytes in the input value. + public static int? DataLength( + [CanBeNull] this DbFunctions _, + bool? arg) + => throw new InvalidOperationException(CoreStrings.FunctionOnClient(nameof(DataLength))); + + /// + /// Returns the number of bytes used to represent any expression. + /// + /// The DbFunctions instance. + /// The value to be examined for data length. + /// The number of bytes in the input value. + public static int? DataLength( + [CanBeNull] this DbFunctions _, + double? arg) + => throw new InvalidOperationException(CoreStrings.FunctionOnClient(nameof(DataLength))); + + /// + /// Returns the number of bytes used to represent any expression. + /// + /// The DbFunctions instance. + /// The value to be examined for data length. + /// The number of bytes in the input value. + public static int? DataLength( + [CanBeNull] this DbFunctions _, + decimal? arg) + => throw new InvalidOperationException(CoreStrings.FunctionOnClient(nameof(DataLength))); + + /// + /// Returns the number of bytes used to represent any expression. + /// + /// The DbFunctions instance. + /// The value to be examined for data length. + /// The number of bytes in the input value. + public static int? DataLength( + [CanBeNull] this DbFunctions _, + DateTime? arg) + => throw new InvalidOperationException(CoreStrings.FunctionOnClient(nameof(DataLength))); + + /// + /// Returns the number of bytes used to represent any expression. + /// + /// The DbFunctions instance. + /// The value to be examined for data length. + /// The number of bytes in the input value. + public static int? DataLength( + [CanBeNull] this DbFunctions _, + TimeSpan? arg) + => throw new InvalidOperationException(CoreStrings.FunctionOnClient(nameof(DataLength))); + + /// + /// Returns the number of bytes used to represent any expression. + /// + /// The DbFunctions instance. + /// The value to be examined for data length. + /// The number of bytes in the input value. + public static int? DataLength( + [CanBeNull] this DbFunctions _, + DateTimeOffset? arg) + => throw new InvalidOperationException(CoreStrings.FunctionOnClient(nameof(DataLength))); + + /// + /// Returns the number of bytes used to represent any expression. + /// + /// The DbFunctions instance. + /// The value to be examined for data length. + /// The number of bytes in the input value. + public static int? DataLength( + [CanBeNull] this DbFunctions _, + [CanBeNull] byte[] arg) + => throw new InvalidOperationException(CoreStrings.FunctionOnClient(nameof(DataLength))); + + /// + /// Returns the number of bytes used to represent any expression. + /// + /// The DbFunctions instance. + /// The value to be examined for data length. + /// The number of bytes in the input value. + public static int? DataLength( + [CanBeNull] this DbFunctions _, + Guid? arg) + => throw new InvalidOperationException(CoreStrings.FunctionOnClient(nameof(DataLength))); } } diff --git a/src/EFCore.SqlServer/Query/Internal/SqlServerDataLengthFunctionTranslator.cs b/src/EFCore.SqlServer/Query/Internal/SqlServerDataLengthFunctionTranslator.cs new file mode 100644 index 00000000000..7915666a7e5 --- /dev/null +++ b/src/EFCore.SqlServer/Query/Internal/SqlServerDataLengthFunctionTranslator.cs @@ -0,0 +1,102 @@ +// Copyright (c) .NET Foundation. All rights reserved. +// Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information. + +using System; +using System.Collections.Generic; +using System.Linq; +using System.Reflection; +using JetBrains.Annotations; +using Microsoft.EntityFrameworkCore.Query; +using Microsoft.EntityFrameworkCore.Query.SqlExpressions; +using Microsoft.EntityFrameworkCore.Utilities; + +namespace Microsoft.EntityFrameworkCore.SqlServer.Query.Internal +{ + public class SqlServerDataLengthFunctionTranslator : IMethodCallTranslator + { + private static readonly List _longReturningTypes = new List { "nvarchar(max)", "varchar(max)", "varbinary(max)" }; + + private static readonly HashSet _methodInfoDataLengthMapping + = new HashSet + { + typeof(SqlServerDbFunctionsExtensions).GetRuntimeMethod( + nameof(SqlServerDbFunctionsExtensions.DataLength), + new[] { typeof(DbFunctions), typeof(string) }), + + typeof(SqlServerDbFunctionsExtensions).GetRuntimeMethod( + nameof(SqlServerDbFunctionsExtensions.DataLength), + new[] { typeof(DbFunctions), typeof(bool?) }), + + typeof(SqlServerDbFunctionsExtensions).GetRuntimeMethod( + nameof(SqlServerDbFunctionsExtensions.DataLength), + new[] { typeof(DbFunctions), typeof(double?) }), + + typeof(SqlServerDbFunctionsExtensions).GetRuntimeMethod( + nameof(SqlServerDbFunctionsExtensions.DataLength), + new[] { typeof(DbFunctions), typeof(decimal?) }), + + typeof(SqlServerDbFunctionsExtensions).GetRuntimeMethod( + nameof(SqlServerDbFunctionsExtensions.DataLength), + new[] { typeof(DbFunctions), typeof(DateTime?) }), + + typeof(SqlServerDbFunctionsExtensions).GetRuntimeMethod( + nameof(SqlServerDbFunctionsExtensions.DataLength), + new[] { typeof(DbFunctions), typeof(TimeSpan?) }), + + typeof(SqlServerDbFunctionsExtensions).GetRuntimeMethod( + nameof(SqlServerDbFunctionsExtensions.DataLength), + new[] { typeof(DbFunctions), typeof(DateTimeOffset?) }), + + typeof(SqlServerDbFunctionsExtensions).GetRuntimeMethod( + nameof(SqlServerDbFunctionsExtensions.DataLength), + new[] { typeof(DbFunctions), typeof(byte[]) }), + + typeof(SqlServerDbFunctionsExtensions).GetRuntimeMethod( + nameof(SqlServerDbFunctionsExtensions.DataLength), + new[] { typeof(DbFunctions), typeof(Guid?) }) + }; + + private readonly ISqlExpressionFactory _sqlExpressionFactory; + + public SqlServerDataLengthFunctionTranslator([NotNull] ISqlExpressionFactory sqlExpressionFactory) + { + _sqlExpressionFactory = sqlExpressionFactory; + } + + public virtual SqlExpression Translate(SqlExpression instance, MethodInfo method, IReadOnlyList arguments) + { + Check.NotNull(method, nameof(method)); + Check.NotNull(arguments, nameof(arguments)); + + if (_methodInfoDataLengthMapping.Contains(method)) + { + var argument = arguments[1]; + if (argument.TypeMapping == null) + { + argument = _sqlExpressionFactory.ApplyDefaultTypeMapping(argument); + } + + if (_longReturningTypes.Contains(argument.TypeMapping.StoreType)) + { + var result = _sqlExpressionFactory.Function( + "DATALENGTH", + arguments.Skip(1), + nullable: true, + argumentsPropagateNullability: new[] { true }, + typeof(long)); + + return _sqlExpressionFactory.Convert(result, method.ReturnType); + } + + return _sqlExpressionFactory.Function( + "DATALENGTH", + arguments.Skip(1), + nullable: true, + argumentsPropagateNullability: new[] { true }, + method.ReturnType); + } + + return null; + } + } +} diff --git a/src/EFCore.SqlServer/Query/Internal/SqlServerMethodCallTranslatorProvider.cs b/src/EFCore.SqlServer/Query/Internal/SqlServerMethodCallTranslatorProvider.cs index 502db911571..332d03b6125 100644 --- a/src/EFCore.SqlServer/Query/Internal/SqlServerMethodCallTranslatorProvider.cs +++ b/src/EFCore.SqlServer/Query/Internal/SqlServerMethodCallTranslatorProvider.cs @@ -18,6 +18,7 @@ public SqlServerMethodCallTranslatorProvider([NotNull] RelationalMethodCallTrans { new SqlServerByteArrayMethodTranslator(sqlExpressionFactory), new SqlServerConvertTranslator(sqlExpressionFactory), + new SqlServerDataLengthFunctionTranslator(sqlExpressionFactory), new SqlServerDateDiffFunctionsTranslator(sqlExpressionFactory), new SqlServerDateTimeMethodTranslator(sqlExpressionFactory), new SqlServerFromPartsFunctionTranslator(sqlExpressionFactory, typeMappingSource), @@ -26,7 +27,7 @@ public SqlServerMethodCallTranslatorProvider([NotNull] RelationalMethodCallTrans new SqlServerMathTranslator(sqlExpressionFactory), new SqlServerNewGuidTranslator(sqlExpressionFactory), new SqlServerObjectToStringTranslator(sqlExpressionFactory), - new SqlServerStringMethodTranslator(sqlExpressionFactory), + new SqlServerStringMethodTranslator(sqlExpressionFactory) }); } } diff --git a/test/EFCore.SqlServer.FunctionalTests/Query/GearsOfWarQuerySqlServerTest.cs b/test/EFCore.SqlServer.FunctionalTests/Query/GearsOfWarQuerySqlServerTest.cs index d30fafa4a1a..a7cb439b0d3 100644 --- a/test/EFCore.SqlServer.FunctionalTests/Query/GearsOfWarQuerySqlServerTest.cs +++ b/test/EFCore.SqlServer.FunctionalTests/Query/GearsOfWarQuerySqlServerTest.cs @@ -2,7 +2,10 @@ // Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information. using System; +using System.Linq; using System.Threading.Tasks; +using Microsoft.EntityFrameworkCore.TestModels.GearsOfWarModel; +using Xunit; using Xunit.Abstractions; namespace Microsoft.EntityFrameworkCore.Query @@ -6976,6 +6979,20 @@ FROM [Gears] AS [g] ORDER BY [g].[Nickname]"); } + [ConditionalTheory] + [MemberData(nameof(IsAsyncData))] + public async Task DataLength_function_for_string_parameter(bool async) + { + await AssertQueryScalar( + async, + ss => ss.Set().Select(m => EF.Functions.DataLength(m.CodeName)), + ss => ss.Set().Select(m => (int?)(m.CodeName.Length * 2))); + + AssertSql( + @"SELECT CAST(DATALENGTH([m].[CodeName]) AS int) +FROM [Missions] AS [m]"); + } + private void AssertSql(params string[] expected) => Fixture.TestSqlLoggerFactory.AssertBaseline(expected); } diff --git a/test/EFCore.SqlServer.FunctionalTests/Query/NorthwindDbFunctionsQuerySqlServerTest.cs b/test/EFCore.SqlServer.FunctionalTests/Query/NorthwindDbFunctionsQuerySqlServerTest.cs index 0a4fd174a8c..002bdb5c910 100644 --- a/test/EFCore.SqlServer.FunctionalTests/Query/NorthwindDbFunctionsQuerySqlServerTest.cs +++ b/test/EFCore.SqlServer.FunctionalTests/Query/NorthwindDbFunctionsQuerySqlServerTest.cs @@ -967,6 +967,60 @@ FROM [Orders] AS [o] } } + [ConditionalFact] + public virtual void DataLength_column_compare() + { + using (var context = CreateContext()) + { + var count = context.Orders + .Count(c => c.OrderID < EF.Functions.DataLength(c.OrderDate)); + + Assert.Equal(0, count); + + AssertSql( + @"SELECT COUNT(*) +FROM [Orders] AS [o] +WHERE [o].[OrderID] < DATALENGTH([o].[OrderDate])"); + } + } + + [ConditionalFact] + public virtual void DataLength_constant_compare() + { + using (var context = CreateContext()) + { + var count = context.Orders + .Count(c => 100 < EF.Functions.DataLength(c.OrderDate)); + + Assert.Equal(0, count); + + AssertSql( + @"SELECT COUNT(*) +FROM [Orders] AS [o] +WHERE 100 < DATALENGTH([o].[OrderDate])"); + } + } + + [ConditionalFact] + public virtual void DataLength_compare_with_local_variable() + { + int? lenght = 100; + using (var context = CreateContext()) + { + var count = context.Orders + .Count(c => lenght < EF.Functions.DataLength(c.OrderDate)); + + Assert.Equal(0, count); + + AssertSql( + @$"@__lenght_0='100' (Nullable = true) + +SELECT COUNT(*) +FROM [Orders] AS [o] +WHERE @__lenght_0 < DATALENGTH([o].[OrderDate])"); + } + } + private void AssertSql(params string[] expected) => Fixture.TestSqlLoggerFactory.AssertBaseline(expected); }