Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add structural equality and order comparison for common collections #80025

Closed
wants to merge 4 commits into from
Closed
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
19 changes: 19 additions & 0 deletions src/libraries/System.Collections/ref/System.Collections.cs
Original file line number Diff line number Diff line change
Expand Up @@ -556,6 +556,25 @@ protected EqualityComparer() { }
bool System.Collections.IEqualityComparer.Equals(object? x, object? y) { throw null; }
int System.Collections.IEqualityComparer.GetHashCode(object obj) { throw null; }
}

public static class EqualityComparer
{
public static System.Collections.Generic.IEqualityComparer<TEnumerable> CreateEnumerableComparer<TEnumerable, T>(System.Collections.Generic.IEqualityComparer<T>? elementComparer = null)
where TEnumerable : System.Collections.Generic.IEnumerable<T> { throw null; }

public static System.Collections.Generic.IEqualityComparer<TSet> CreateSetComparer<TSet, T>(System.Collections.Generic.IEqualityComparer<T>? elementComparer = null)
where TSet : System.Collections.Generic.IReadOnlySet<T> { throw null; }

public static System.Collections.Generic.IEqualityComparer<TDictionary> CreateDictionaryComparer<TDictionary, TKey, TValue>(System.Collections.Generic.IEqualityComparer<TKey>? keyComparer = null, System.Collections.Generic.IEqualityComparer<TValue>? valueComparer = null)
where TDictionary : System.Collections.Generic.IReadOnlyDictionary<TKey, TValue> { throw null; }
}

public static class ComparerFactory
Copy link
Contributor Author

@manandre manandre Dec 30, 2022

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The approved API uses the Comparer class in the System.Collections.Generic namespace, but it is already present in the System.Collections namespace
https://github.com/dotnet/runtime/blob/5b8ebeabb32f7f4118d0cc8b8db28705b62469ee/src/libraries/System.Runtime/ref/System.Runtime.cs#L7251-L7258

Issue reported in #77209 (comment)

{
public static System.Collections.Generic.IComparer<TEnumerable> CreateEnumerableComparer<TEnumerable, T>(System.Collections.Generic.IComparer<T>? elementComparer = null)
where TEnumerable : System.Collections.Generic.IEnumerable<T> { throw null; }
}

public partial class HashSet<T> : System.Collections.Generic.ICollection<T>, System.Collections.Generic.IEnumerable<T>, System.Collections.Generic.IReadOnlyCollection<T>, System.Collections.Generic.ISet<T>, System.Collections.Generic.IReadOnlySet<T>, System.Collections.IEnumerable, System.Runtime.Serialization.IDeserializationCallback, System.Runtime.Serialization.ISerializable
{
public HashSet() { }
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -342,5 +342,92 @@ public void NullableOfIComparableComparisonsShouldTryToCallLeftHandCompareToFirs
Assert.Equal(1, comparer.Compare(left, right));
Assert.Equal(0, comparer.Compare(right, left));
}

[Theory]
[MemberData(nameof(StringComparisonsData))]
public void Comparer_CreateEnumerableComparer(string left, string right, int expected)
{
var comparer = ComparerFactory.CreateEnumerableComparer<string, char>();
Assert.Equal(expected, comparer.Compare(left, right));
}

public static IEnumerable<object[]> StringComparisonsData()
{
var str = new string(""); // this needs to be cached into a local so we can pass the same ref in twice

var testCases = new[]
{
Tuple.Create(str, str, 0),
Tuple.Create(default(string), str, -1),
Tuple.Create("a", "a", 0),
Tuple.Create("", "a", -1),
Tuple.Create("a", "b", -1),
Tuple.Create("a", "ab", -1),
Tuple.Create("ab", "b", -1),
Tuple.Create("ab", "ba", -1),
Tuple.Create("ab", "a" + "b", 0)
};

foreach (var testCase in testCases)
{
yield return new object[] { testCase.Item1, testCase.Item2, testCase.Item3 };
yield return new object[] { testCase.Item2, testCase.Item1, -testCase.Item3 };
}
}

[Theory]
[MemberData(nameof(StringComparisonsData_IgnoreCase))]
public void Comparer_CreateEnumerableComparer_IgnoreCase(string left, string right, int expected)
{
var comparer = ComparerFactory.CreateEnumerableComparer<string, char>(IgnoreCaseComparer());
Assert.Equal(expected, comparer.Compare(left, right));

static Comparer<char> IgnoreCaseComparer() => Comparer<char>.Create((x, y) => char.ToUpperInvariant(x).CompareTo(char.ToUpperInvariant(y)));
}

public static IEnumerable<object[]> StringComparisonsData_IgnoreCase()
{
foreach(var data in StringComparisonsData())
{
yield return data;
}

var testCases = new[]
{
Tuple.Create("a", "A", 0),
Tuple.Create("", "A", -1),
Tuple.Create("a", "B", -1),
Tuple.Create("a", "Ab", -1),
Tuple.Create("a", "AB", -1),
Tuple.Create("ab", "b", -1),
Tuple.Create("ab", "B", -1),
Tuple.Create("ab", "BA", -1),
Tuple.Create("ab", "Ba", -1),
Tuple.Create("ab", "bA", -1),
Tuple.Create("ab", "A" + "B", 0)
};

foreach (var testCase in testCases)
{
yield return new object[] { testCase.Item1, testCase.Item2, testCase.Item3 };
yield return new object[] { testCase.Item2, testCase.Item1, -testCase.Item3 };
}
}

[Fact]
public void Comparer_CreateEnumerableComparer_EqualsGetHashCodeOverridden()
{
var comparer = ComparerFactory.CreateEnumerableComparer<string, char>();
Assert.True(comparer.Equals(comparer));
Assert.Equal(comparer.GetHashCode(), comparer.GetHashCode());

var ec1 = Comparer<char>.Create((x, y) => x.CompareTo(y));
var ec2 = Comparer<char>.Create((x, y) => x.CompareTo(y));

Assert.True(ComparerFactory.CreateEnumerableComparer<string, char>(ec1).Equals(ComparerFactory.CreateEnumerableComparer<string, char>(ec1)));
Assert.True(ComparerFactory.CreateEnumerableComparer<string, char>(ec1).GetHashCode().Equals(ComparerFactory.CreateEnumerableComparer<string, char>(ec1).GetHashCode()));

Assert.False(ComparerFactory.CreateEnumerableComparer<string, char>(ec1).Equals(ComparerFactory.CreateEnumerableComparer<string, char>(ec2)));
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -41,22 +41,18 @@ public class HashData<T> : TheoryData<T, int> { }
public void EqualsTest<T>(T left, T right, bool expected)
{
var comparer = EqualityComparer<T>.Default;
IEqualityComparer nonGenericComparer = comparer;
AssertEqualsGeneric(comparer, left, right, expected);
AssertEqualsNonGeneric(comparer, left, right, expected);
}

private void AssertEqualsGeneric<T>(IEqualityComparer<T> comparer, T left, T right, bool expected)
{
Assert.Equal(expected, comparer.Equals(left, right));
Assert.Equal(expected, comparer.Equals(right, left)); // Should be commutative.

Assert.True(comparer.Equals(left, left)); // Should be reflexive.
Assert.True(comparer.Equals(right, right));

// If both sides are Ts then the explicit implementation of
// IEqualityComparer.Equals should also succeed, with the same results
Assert.Equal(expected, nonGenericComparer.Equals(left, right));
Assert.Equal(expected, nonGenericComparer.Equals(right, left));

Assert.True(nonGenericComparer.Equals(left, left));
Assert.True(nonGenericComparer.Equals(right, right));

// All comparers returned by EqualityComparer<T>.Default should be
// able to handle nulls before dispatching to IEquatable<T>.Equals()
if (default(T) == null)
Expand All @@ -70,8 +66,35 @@ public void EqualsTest<T>(T left, T right, bool expected)

Assert.Equal(right == null, comparer.Equals(right, nil));
Assert.Equal(right == null, comparer.Equals(nil, right));
}

// GetHashCode: If 2 objects are equal, then their hash code should be the same.

if (expected)
{
int hash = comparer.GetHashCode(left);

Assert.Equal(hash, comparer.GetHashCode(left)); // Should return the same result across multiple invocations
Assert.Equal(hash, comparer.GetHashCode(right));
}
}

private void AssertEqualsNonGeneric<T>(IEqualityComparer nonGenericComparer, T left, T right, bool expected)
{
// If both sides are Ts then the explicit implementation of
// IEqualityComparer.Equals should also succeed, with the same results
Assert.Equal(expected, nonGenericComparer.Equals(left, right));
Assert.Equal(expected, nonGenericComparer.Equals(right, left));

Assert.True(nonGenericComparer.Equals(left, left));
Assert.True(nonGenericComparer.Equals(right, right));

// All comparers returned by EqualityComparer<T>.Default should be
// able to handle nulls before dispatching to IEquatable<T>.Equals()
if (default(T) == null)
{
T nil = default(T);

// IEqualityComparer.Equals explicit implementation
Assert.True(nonGenericComparer.Equals(nil, nil));

Assert.Equal(left == null, nonGenericComparer.Equals(left, nil));
Expand All @@ -85,12 +108,9 @@ public void EqualsTest<T>(T left, T right, bool expected)

if (expected)
{
int hash = comparer.GetHashCode(left);

Assert.Equal(hash, comparer.GetHashCode(left)); // Should return the same result across multiple invocations
Assert.Equal(hash, comparer.GetHashCode(right));
int hash = nonGenericComparer.GetHashCode(left);

Assert.Equal(hash, nonGenericComparer.GetHashCode(left));
Assert.Equal(hash, nonGenericComparer.GetHashCode(left)); // Should return the same result across multiple invocations
Assert.Equal(hash, nonGenericComparer.GetHashCode(right));
}
}
Expand Down Expand Up @@ -206,6 +226,22 @@ public static EqualsData<string> StringData()
};
}

public static EqualsData<string> StringData_IgnoreCase()
{
return new EqualsData<string>
{
{ "foo", "foo", true },
{ string.Empty, null, false },
{ "bar", new string("bar".ToCharArray()), true },
{ "foo", "bar", false },
{ "foo", "Foo", true },
{ "foo", "FOO", true },
{ "foo", "Bar", false },
{ "foo", "BAR", false },
{ "bar", new string("BAR".ToCharArray()), true }
};
}

public static EqualsData<Equatable> IEquatableData()
{
var one = new Equatable(1);
Expand Down Expand Up @@ -548,5 +584,136 @@ public void EqualityComparerCreate_EqualsGetHashCodeOverridden()
Assert.False(EqualityComparer<int>.Create(equals1, getHashCode1).Equals(EqualityComparer<int>.Create(equals1, getHashCode2)));
Assert.False(EqualityComparer<int>.Create(equals1, getHashCode1).Equals(EqualityComparer<int>.Create(equals2, getHashCode1)));
}

[Theory]
[MemberData(nameof(StringData))]
public void EqualityComparer_CreateEnumerableComparer_EqualsTest(string left, string right, bool expected)
{
var comparer = EqualityComparer.CreateEnumerableComparer<string, char>();
AssertEqualsGeneric(comparer, left, right, expected);
}

[Theory]
[MemberData(nameof(StringData_IgnoreCase))]
public void EqualityComparer_CreateEnumerableComparer_EqualsIgnoreCaseTest(string left, string right, bool expected)
{
var comparer = EqualityComparer.CreateEnumerableComparer<string, char>(IgnoreCaseEqualityComparer());
AssertEqualsGeneric(comparer, left, right, expected);
}

[Fact]
public void EqualityComparer_CreateEnumerableComparer_EqualsGetHashCodeOverridden()
{
var comparer = EqualityComparer.CreateEnumerableComparer<string, char>();
Assert.True(comparer.Equals(comparer));
Assert.Equal(comparer.GetHashCode(), comparer.GetHashCode());

var ec1 = EqualityComparer<char>.Create((x, y) => x == y);
var ec2 = EqualityComparer<char>.Create((x, y) => x == y);

Assert.True(EqualityComparer.CreateEnumerableComparer<string, char>(ec1).Equals(EqualityComparer.CreateEnumerableComparer<string, char>(ec1)));
Assert.True(EqualityComparer.CreateEnumerableComparer<string, char>(ec1).GetHashCode().Equals(EqualityComparer.CreateEnumerableComparer<string, char>(ec1).GetHashCode()));

Assert.False(EqualityComparer.CreateEnumerableComparer<string, char>(ec1).Equals(EqualityComparer.CreateEnumerableComparer<string, char>(ec2)));
}

[Theory]
[MemberData(nameof(StringData))]
public void EqualityComparer_CreateSetComparer_EqualsTest(string left, string right, bool expected)
{
var comparer = EqualityComparer.CreateSetComparer<HashSet<char>, char>();
AssertEqualsGeneric(comparer, left?.ToHashSet(), right?.ToHashSet(), expected);
}

[Theory]
[MemberData(nameof(StringData))]
public void EqualityComparer_CreateSetComparer_EqualsIgnoreCaseTest(string left, string right, bool expected)
{
var comparer = EqualityComparer.CreateSetComparer<HashSet<char>, char>(IgnoreCaseEqualityComparer());
AssertEqualsGeneric(comparer, left?.ToHashSet(), right?.ToHashSet(), expected);
}

[Theory]
[MemberData(nameof(StringData))]
public void EqualityComparer_CreateSetComparer_SortedSet_EqualsTest(string left, string right, bool expected)
{
var comparer = EqualityComparer.CreateSetComparer<SortedSet<char>, char>();
AssertEqualsGeneric(comparer, ToSortedSet(left), ToSortedSet(right), expected);
}

[Theory]
[MemberData(nameof(StringData))]
public void EqualityComparer_CreateSetComparer_SortedSet_EqualsIgnoreCaseTest(string left, string right, bool expected)
{
var comparer = EqualityComparer.CreateSetComparer<SortedSet<char>, char>(IgnoreCaseEqualityComparer());
AssertEqualsGeneric(comparer, ToSortedSet(left), ToSortedSet(right), expected);
}

[Fact]
public void EqualityComparer_CreateSetComparer_EqualsGetHashCodeOverridden()
{
var comparer = EqualityComparer.CreateSetComparer<HashSet<char>, char>();
Assert.True(comparer.Equals(comparer));
Assert.Equal(comparer.GetHashCode(), comparer.GetHashCode());

var ec1 = EqualityComparer<char>.Create((x, y) => x == y);
var ec2 = EqualityComparer<char>.Create((x, y) => x == y);

Assert.True(EqualityComparer.CreateSetComparer<HashSet<char>, char>(ec1).Equals(EqualityComparer.CreateSetComparer<HashSet<char>, char>(ec1)));
Assert.True(EqualityComparer.CreateSetComparer<HashSet<char>, char>(ec1).GetHashCode().Equals(EqualityComparer.CreateSetComparer<HashSet<char>, char>(ec1).GetHashCode()));

Assert.False(EqualityComparer.CreateSetComparer<HashSet<char>, char>(ec1).Equals(EqualityComparer.CreateSetComparer<HashSet<char>, char>(ec2)));
}

[Theory]
[MemberData(nameof(StringData))]
public void EqualityComparer_CreateDictionaryComparer_EqualsTest(string left, string right, bool expected)
{
var comparer = EqualityComparer.CreateDictionaryComparer<Dictionary<int, char>, int, char>();
AssertEqualsGeneric(comparer, ToIndexedDictionary(left), ToIndexedDictionary(right), expected);
}

[Theory]
[MemberData(nameof(StringData))]
public void EqualityComparer_CreateDictionaryComparer_EqualsIgnoreCaseTest(string left, string right, bool expected)
{
var comparer = EqualityComparer.CreateDictionaryComparer<Dictionary<int, char>, int, char>(valueComparer: IgnoreCaseEqualityComparer());
AssertEqualsGeneric(comparer, ToIndexedDictionary(left), ToIndexedDictionary(right), expected);
}

[Fact]
public void EqualityComparer_CreateDictionaryComparer_EqualsGetHashCodeOverridden()
{
var comparer = EqualityComparer.CreateDictionaryComparer<Dictionary<int, char>, int, char>();
Assert.True(comparer.Equals(comparer));
Assert.Equal(comparer.GetHashCode(), comparer.GetHashCode());

var ec11 = EqualityComparer<int>.Create((x, y) => x == y);
var ec12 = EqualityComparer<int>.Create((x, y) => x == y);
var ec21 = EqualityComparer<char>.Create((x, y) => x == y);
var ec22 = EqualityComparer<char>.Create((x, y) => x == y);

Assert.True(EqualityComparer.CreateDictionaryComparer<Dictionary<int, char>, int, char>(ec11).Equals(EqualityComparer.CreateDictionaryComparer<Dictionary<int, char>, int, char>(ec11)));
Assert.True(EqualityComparer.CreateDictionaryComparer<Dictionary<int, char>, int, char>(null, ec21).Equals(EqualityComparer.CreateDictionaryComparer<Dictionary<int, char>, int, char>(null, ec21)));
Assert.True(EqualityComparer.CreateDictionaryComparer<Dictionary<int, char>, int, char>(ec11, ec21).Equals(EqualityComparer.CreateDictionaryComparer<Dictionary<int, char>, int, char>(ec11, ec21)));
Assert.True(EqualityComparer.CreateDictionaryComparer<Dictionary<int, char>, int, char>(ec11).GetHashCode().Equals(EqualityComparer.CreateDictionaryComparer<Dictionary<int, char>, int, char>(ec11).GetHashCode()));
Assert.True(EqualityComparer.CreateDictionaryComparer<Dictionary<int, char>, int, char>(null, ec21).GetHashCode().Equals(EqualityComparer.CreateDictionaryComparer<Dictionary<int, char>, int, char>(null, ec21).GetHashCode()));
Assert.True(EqualityComparer.CreateDictionaryComparer<Dictionary<int, char>, int, char>(ec11, ec21).GetHashCode().Equals(EqualityComparer.CreateDictionaryComparer<Dictionary<int, char>, int, char>(ec11, ec21).GetHashCode()));

Assert.False(EqualityComparer.CreateDictionaryComparer<Dictionary<int, char>, int, char>(ec11).Equals(EqualityComparer.CreateDictionaryComparer<Dictionary<int, char>, int, char>(ec12)));
Assert.False(EqualityComparer.CreateDictionaryComparer<Dictionary<int, char>, int, char>(null, ec21).Equals(EqualityComparer.CreateDictionaryComparer<Dictionary<int, char>, int, char>(null, ec22)));
Assert.False(EqualityComparer.CreateDictionaryComparer<Dictionary<int, char>, int, char>(ec11, ec21).Equals(EqualityComparer.CreateDictionaryComparer<Dictionary<int, char>, int, char>(ec11, ec22)));
Assert.False(EqualityComparer.CreateDictionaryComparer<Dictionary<int, char>, int, char>(ec11, ec21).Equals(EqualityComparer.CreateDictionaryComparer<Dictionary<int, char>, int, char>(ec12, ec21)));
Assert.False(EqualityComparer.CreateDictionaryComparer<Dictionary<int, char>, int, char>(ec11, ec21).Equals(EqualityComparer.CreateDictionaryComparer<Dictionary<int, char>, int, char>(ec12, ec22)));
}

private SortedSet<char> ToSortedSet(string left) => left is not null ? new SortedSet<char>(left) : null;

private static Dictionary<int, char> ToIndexedDictionary(string value) => value?
.Select((value, index) => (value, index))
.ToDictionary(t => t.index, t => t.value);

private static EqualityComparer<char> IgnoreCaseEqualityComparer() =>
EqualityComparer<char>.Create((x, y) => char.ToUpperInvariant(x).Equals(char.ToUpperInvariant(y)));
}
}
Loading