diff --git a/src/EFCore.Cosmos/Query/Internal/CosmosMethodCallTranslatorProvider.cs b/src/EFCore.Cosmos/Query/Internal/CosmosMethodCallTranslatorProvider.cs index e6007b2344f..37e983cbd81 100644 --- a/src/EFCore.Cosmos/Query/Internal/CosmosMethodCallTranslatorProvider.cs +++ b/src/EFCore.Cosmos/Query/Internal/CosmosMethodCallTranslatorProvider.cs @@ -37,7 +37,7 @@ public CosmosMethodCallTranslatorProvider( new IMethodCallTranslator[] { new EqualsTranslator(sqlExpressionFactory), - //new StringMethodTranslator(sqlExpressionFactory), + new StringMethodTranslator(sqlExpressionFactory), new ContainsTranslator(sqlExpressionFactory) //new LikeTranslator(sqlExpressionFactory), //new EnumHasFlagTranslator(sqlExpressionFactory), diff --git a/src/EFCore.Cosmos/Query/Internal/StringMethodTranslator.cs b/src/EFCore.Cosmos/Query/Internal/StringMethodTranslator.cs new file mode 100644 index 00000000000..03cd6bf809e --- /dev/null +++ b/src/EFCore.Cosmos/Query/Internal/StringMethodTranslator.cs @@ -0,0 +1,81 @@ +// Copyright (c) .NET Foundation. All rights reserved. +// Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information. + +using System; +using System.Collections.Generic; +using System.Reflection; +using JetBrains.Annotations; +using Microsoft.EntityFrameworkCore.Utilities; + +namespace Microsoft.EntityFrameworkCore.Cosmos.Query.Internal +{ + /// + /// This is an internal API that supports the Entity Framework Core infrastructure and not subject to + /// the same compatibility standards as public APIs. It may be changed or removed without notice in + /// any release. You should only use it directly in your code with extreme caution and knowing that + /// doing so can result in application failures when updating to a new Entity Framework Core release. + /// + public class StringMethodTranslator : IMethodCallTranslator + { + private static readonly MethodInfo _containsMethodInfo + = typeof(string).GetRuntimeMethod(nameof(string.Contains), new[] { typeof(string) }); + + private static readonly MethodInfo _startsWithMethodInfo + = typeof(string).GetRuntimeMethod(nameof(string.StartsWith), new[] { typeof(string) }); + + private static readonly MethodInfo _endsWithMethodInfo + = typeof(string).GetRuntimeMethod(nameof(string.EndsWith), new[] { typeof(string) }); + + private readonly ISqlExpressionFactory _sqlExpressionFactory; + + /// + /// This is an internal API that supports the Entity Framework Core infrastructure and not subject to + /// the same compatibility standards as public APIs. It may be changed or removed without notice in + /// any release. You should only use it directly in your code with extreme caution and knowing that + /// doing so can result in application failures when updating to a new Entity Framework Core release. + /// + public StringMethodTranslator([NotNull] ISqlExpressionFactory sqlExpressionFactory) + { + _sqlExpressionFactory = sqlExpressionFactory; + } + + /// + /// This is an internal API that supports the Entity Framework Core infrastructure and not subject to + /// the same compatibility standards as public APIs. It may be changed or removed without notice in + /// any release. You should only use it directly in your code with extreme caution and knowing that + /// doing so can result in application failures when updating to a new Entity Framework Core release. + /// + public virtual SqlExpression Translate(SqlExpression instance, MethodInfo method, IReadOnlyList arguments) + { + Check.NotNull(method, nameof(method)); + Check.NotNull(arguments, nameof(arguments)); + + if (_containsMethodInfo.Equals(method)) + { + return TranslateSystemFunction("CONTAINS", instance, arguments[0], typeof(bool)); + } + + if (_startsWithMethodInfo.Equals(method)) + { + return TranslateSystemFunction("STARTSWITH", instance, arguments[0], typeof(bool)); + } + + if (_endsWithMethodInfo.Equals(method)) + { + return TranslateSystemFunction("ENDSWITH", instance, arguments[0], typeof(bool)); + } + + return null; + } + + private SqlExpression TranslateSystemFunction(string function, SqlExpression instance, SqlExpression pattern, Type returnType) + { + Check.NotNull(instance, nameof(instance)); + return _sqlExpressionFactory.Function( + function, + new[] { instance, pattern }, + returnType, + ExpressionExtensions.InferTypeMapping(instance, pattern)); + } + } +}