Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

#10582 - Translate SequenceEqual for Sqlite and SqlServer byte arrays #19594

Merged
merged 11 commits into from
Jan 28, 2020
Original file line number Diff line number Diff line change
@@ -0,0 +1,32 @@
// Copyright (c) .NET Foundation. All rights reserved.
// Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information.

using System.Collections.Generic;
using System.Reflection;
using JetBrains.Annotations;
using Microsoft.EntityFrameworkCore.Query.SqlExpressions;

namespace Microsoft.EntityFrameworkCore.Query.Internal
{
public class ByteArraySequenceEqualTranslator: IMethodCallTranslator
{
private readonly ISqlExpressionFactory _sqlExpressionFactory;

public ByteArraySequenceEqualTranslator([NotNull] ISqlExpressionFactory sqlExpressionFactory)
{
_sqlExpressionFactory = sqlExpressionFactory;
}

public virtual SqlExpression Translate(SqlExpression instance, MethodInfo method, IReadOnlyList<SqlExpression> arguments)
{
if (method.IsGenericMethod
&& method.GetGenericMethodDefinition().Equals(EnumerableMethods.SequenceEqual)
&& arguments[0].Type == typeof(byte[]))
{
return _sqlExpressionFactory.Equal(arguments[0], arguments[1]);
}

return null;
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,8 @@ public RelationalMethodCallTranslatorProvider([NotNull] RelationalMethodCallTran
new LikeTranslator(sqlExpressionFactory),
new EnumHasFlagTranslator(sqlExpressionFactory),
new GetValueOrDefaultTranslator(sqlExpressionFactory),
new ComparisonTranslator(sqlExpressionFactory)
new ComparisonTranslator(sqlExpressionFactory),
new ByteArraySequenceEqualTranslator(sqlExpressionFactory)
});
_sqlExpressionFactory = sqlExpressionFactory;
}
Expand Down
3 changes: 3 additions & 0 deletions src/Shared/EnumerableMethods.cs
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@ internal static class EnumerableMethods
public static MethodInfo AnyWithoutPredicate { get; }
public static MethodInfo AnyWithPredicate { get; }
public static MethodInfo Contains { get; }
public static MethodInfo SequenceEqual { get; }

public static MethodInfo ToList { get; }
public static MethodInfo ToArray { get; }
Expand Down Expand Up @@ -151,6 +152,8 @@ static EnumerableMethods()
&& IsFunc(mi.GetParameters()[1].ParameterType));
Contains = enumerableMethods.Single(
mi => mi.Name == nameof(Enumerable.Contains) && mi.GetParameters().Length == 2);
SequenceEqual = enumerableMethods.Single(
mi => mi.Name == nameof(Enumerable.SequenceEqual) && mi.GetParameters().Length == 2);

ToList = enumerableMethods.Single(
mi => mi.Name == nameof(Enumerable.ToList) && mi.GetParameters().Length == 1);
Expand Down
11 changes: 11 additions & 0 deletions test/EFCore.Specification.Tests/Query/GearsOfWarQueryTestBase.cs
Original file line number Diff line number Diff line change
Expand Up @@ -7548,6 +7548,17 @@ public virtual Task Projecting_required_string_column_compared_to_null_parameter
ss => ss.Set<Gear>().Select(g => g.Nickname == nullParameter));
}

[ConditionalTheory]
[MemberData(nameof(IsAsyncData))]
public virtual Task Byte_array_filter_by_SequenceEqual(bool async)
{
var byteArrayParam = new byte[] { 0x04, 0x05, 0x06, 0x07, 0x08 };

return AssertQuery(
async,
ss => ss.Set<Squad>().Where(s => s.Banner5.SequenceEqual(byteArrayParam)));
}

[ConditionalTheory]
[MemberData(nameof(IsAsyncData))]
public virtual Task Group_by_nullable_property_HasValue_and_project_the_grouping_key(bool async)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -7527,6 +7527,17 @@ FROM [Gears] AS [g]
WHERE [g].[Discriminator] IN (N'Gear', N'Officer')");
}

public override async Task Byte_array_filter_by_SequenceEqual(bool isAsync)
{
await base.Byte_array_filter_by_SequenceEqual(isAsync);

AssertSql(@"@__byteArrayParam_0='0x0405060708' (Size = 5)

SELECT [s].[Id], [s].[Banner], [s].[Banner5], [s].[InternalNumber], [s].[Name]
FROM [Squads] AS [s]
WHERE [s].[Banner5] = @__byteArrayParam_0");
}

public override async Task Group_by_nullable_property_HasValue_and_project_the_grouping_key(bool async)
{
await base.Group_by_nullable_property_HasValue_and_project_the_grouping_key(async);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -137,6 +137,17 @@ SELECT COUNT(*)
WHERE length(""s"".""Banner"") = length(@__byteArrayParam)");
}

public override async Task Byte_array_filter_by_SequenceEqual(bool async)
{
await base.Byte_array_filter_by_SequenceEqual(async);

AssertSql(@"@__byteArrayParam_0='0x0405060708' (Size = 5) (DbType = String)

SELECT ""s"".""Id"", ""s"".""Banner"", ""s"".""Banner5"", ""s"".""InternalNumber"", ""s"".""Name""
FROM ""Squads"" AS ""s""
WHERE ""s"".""Banner5"" = @__byteArrayParam_0");
}

private void AssertSql(params string[] expected)
=> Fixture.TestSqlLoggerFactory.AssertBaseline(expected);
}
Expand Down