Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -37,14 +37,11 @@ internal static object CreateDefaultComparer(Type type)
{
result = CreateInstanceForAnotherGenericParameter((RuntimeType)typeof(GenericComparer<int>), runtimeType);
}
// Nullable does not implement IComparable<T?> directly because that would add an extra interface call per comparison.
// Instead, it relies on Comparer<T?>.Default to specialize for nullables and do the lifted comparisons if T implements IComparable.
else if (type.IsGenericType)
else if (type.IsGenericType && type.GetGenericTypeDefinition() == typeof(Nullable<>))
{
if (type.GetGenericTypeDefinition() == typeof(Nullable<>))
{
result = TryCreateNullableComparer(runtimeType);
}
// Nullable does not implement IComparable<T?> directly because that would add an extra interface call per comparison.
var embeddedType = (RuntimeType)type.GetGenericArguments()[0];
result = RuntimeTypeHandle.CreateInstanceForAnotherGenericParameter((RuntimeType)typeof(NullableComparer<int>), embeddedType);
}
// The comparer for enums is specialized to avoid boxing.
else if (type.IsEnum)
Expand All @@ -55,25 +52,6 @@ internal static object CreateDefaultComparer(Type type)
return result ?? CreateInstanceForAnotherGenericParameter((RuntimeType)typeof(ObjectComparer<object>), runtimeType);
}

/// <summary>
/// Creates the default <see cref="Comparer{T}"/> for a nullable type.
/// </summary>
/// <param name="nullableType">The nullable type to create the default comparer for.</param>
private static object? TryCreateNullableComparer(RuntimeType nullableType)
{
Debug.Assert(nullableType != null);
Debug.Assert(nullableType.IsGenericType && nullableType.GetGenericTypeDefinition() == typeof(Nullable<>));

var embeddedType = (RuntimeType)nullableType.GetGenericArguments()[0];

if (typeof(IComparable<>).MakeGenericType(embeddedType).IsAssignableFrom(embeddedType))
{
return RuntimeTypeHandle.CreateInstanceForAnotherGenericParameter((RuntimeType)typeof(NullableComparer<int>), embeddedType);
}

return null;
}

/// <summary>
/// Creates the default <see cref="Comparer{T}"/> for an enum type.
/// </summary>
Expand Down Expand Up @@ -135,14 +113,11 @@ internal static object CreateDefaultEqualityComparer(Type type)
// If T implements IEquatable<T> return a GenericEqualityComparer<T>
result = CreateInstanceForAnotherGenericParameter((RuntimeType)typeof(GenericEqualityComparer<string>), runtimeType);
}
else if (type.IsGenericType)
else if (type.IsGenericType && type.GetGenericTypeDefinition() == typeof(Nullable<>))
{
// Nullable does not implement IEquatable<T?> directly because that would add an extra interface call per comparison.
// Instead, it relies on EqualityComparer<T?>.Default to specialize for nullables and do the lifted comparisons if T implements IEquatable.
if (type.GetGenericTypeDefinition() == typeof(Nullable<>))
{
result = TryCreateNullableEqualityComparer(runtimeType);
}
var embeddedType = (RuntimeType)type.GetGenericArguments()[0];
result = CreateInstanceForAnotherGenericParameter((RuntimeType)typeof(NullableEqualityComparer<int>), embeddedType);
}
else if (type.IsEnum)
{
Expand All @@ -153,25 +128,6 @@ internal static object CreateDefaultEqualityComparer(Type type)
return result ?? CreateInstanceForAnotherGenericParameter((RuntimeType)typeof(ObjectEqualityComparer<object>), runtimeType);
}

/// <summary>
/// Creates the default <see cref="EqualityComparer{T}"/> for a nullable type.
/// </summary>
/// <param name="nullableType">The nullable type to create the default equality comparer for.</param>
private static object? TryCreateNullableEqualityComparer(RuntimeType nullableType)
{
Debug.Assert(nullableType != null);
Debug.Assert(nullableType.IsGenericType && nullableType.GetGenericTypeDefinition() == typeof(Nullable<>));

var embeddedType = (RuntimeType)nullableType.GetGenericArguments()[0];

if (typeof(IEquatable<>).MakeGenericType(embeddedType).IsAssignableFrom(embeddedType))
{
return RuntimeTypeHandle.CreateInstanceForAnotherGenericParameter((RuntimeType)typeof(NullableEqualityComparer<int>), embeddedType);
}

return null;
}

/// <summary>
/// Creates the default <see cref="EqualityComparer{T}"/> for an enum type.
/// </summary>
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -57,7 +57,7 @@ internal override int LastIndexOf(T[] array, T value, int startIndex, int count)
}
}

public sealed partial class NullableEqualityComparer<T> : EqualityComparer<T?> where T : struct, IEquatable<T>
public sealed partial class NullableEqualityComparer<T> : EqualityComparer<T?> where T : struct
{
internal override int IndexOf(T?[] array, T? value, int startIndex, int count)
{
Expand All @@ -73,7 +73,7 @@ internal override int IndexOf(T?[] array, T? value, int startIndex, int count)
{
for (int i = startIndex; i < endIndex; i++)
{
if (array[i].HasValue && array[i].value.Equals(value.value)) return i;
if (array[i].HasValue && EqualityComparer<T>.Default.Equals(array[i].value, value.value)) return i;
}
}
return -1;
Expand All @@ -93,7 +93,7 @@ internal override int LastIndexOf(T?[] array, T? value, int startIndex, int coun
{
for (int i = startIndex; i >= endIndex; i--)
{
if (array[i].HasValue && array[i].value.Equals(value.value)) return i;
if (array[i].HasValue && EqualityComparer<T>.Default.Equals(array[i].value, value.value)) return i;
}
}
return -1;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -59,11 +59,8 @@ internal static object GetComparer(RuntimeTypeHandle t)
if (RuntimeAugments.IsNullable(t))
{
RuntimeTypeHandle nullableType = RuntimeAugments.GetNullableType(t);
if (ImplementsIComparable(nullableType))
{
openComparerType = typeof(NullableComparer<>).TypeHandle;
comparerTypeArgument = nullableType;
}
openComparerType = typeof(NullableComparer<>).TypeHandle;
comparerTypeArgument = nullableType;
}
if (EqualityComparerHelpers.IsEnum(t))
{
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -62,11 +62,8 @@ internal static object GetComparer(RuntimeTypeHandle t)
if (RuntimeAugments.IsNullable(t))
{
RuntimeTypeHandle nullableType = RuntimeAugments.GetNullableType(t);
if (ImplementsIEquatable(nullableType))
{
openComparerType = typeof(NullableEqualityComparer<>).TypeHandle;
comparerTypeArgument = nullableType;
}
openComparerType = typeof(NullableEqualityComparer<>).TypeHandle;
comparerTypeArgument = nullableType;
}
if (IsEnum(t))
{
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -113,11 +113,9 @@ private static TypeDesc GetComparerForType(TypeDesc type, string flavor, string
// We can't tell at compile time either.
return null;
}
else if (ImplementsInterfaceOfSelf(nullableType, interfaceName))
{
return context.SystemModule.GetKnownType("System.Collections.Generic", $"Nullable{flavor}`1")
.MakeInstantiatedType(nullableType);
}

return context.SystemModule.GetKnownType("System.Collections.Generic", $"Nullable{flavor}`1")
.MakeInstantiatedType(nullableType);
}
else if (type.IsEnum)
{
Expand Down
8 changes: 2 additions & 6 deletions src/coreclr/vm/jitinterface.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -8931,12 +8931,8 @@ CORINFO_CLASS_HANDLE CEEInfo::getDefaultEqualityComparerClassHelper(CORINFO_CLAS
if (Nullable::IsNullableType(elemTypeHnd))
{
Instantiation nullableInst = elemTypeHnd.AsMethodTable()->GetInstantiation();
TypeHandle iequatable = TypeHandle(CoreLibBinder::GetClass(CLASS__IEQUATABLEGENERIC)).Instantiate(nullableInst);
if (nullableInst[0].CanCastTo(iequatable))
{
TypeHandle resultTh = ((TypeHandle)CoreLibBinder::GetClass(CLASS__NULLABLE_EQUALITYCOMPARER)).Instantiate(nullableInst);
return CORINFO_CLASS_HANDLE(resultTh.GetMethodTable());
}
TypeHandle resultTh = ((TypeHandle)CoreLibBinder::GetClass(CLASS__NULLABLE_EQUALITYCOMPARER)).Instantiate(nullableInst);
return CORINFO_CLASS_HANDLE(resultTh.GetMethodTable());
}

// Enum
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -33,8 +33,10 @@ public class HashData<T> : TheoryData<T, int> { }
[MemberData(nameof(Int16EnumData))]
[MemberData(nameof(SByteEnumData))]
[MemberData(nameof(Int32EnumData))]
[MemberData(nameof(NullableInt32EnumData))]
[MemberData(nameof(Int64EnumData))]
[MemberData(nameof(NonEquatableValueTypeData))]
[MemberData(nameof(NullableNonEquatableValueTypeData))]
[MemberData(nameof(ObjectData))]
public void EqualsTest<T>(T left, T right, bool expected)
{
Expand Down Expand Up @@ -253,6 +255,22 @@ public static EqualsData<Int32Enum> Int32EnumData()
};
}

public static EqualsData<Int32Enum?> NullableInt32EnumData()
{
return new EqualsData<Int32Enum?>
{
{ (Int32Enum)(-2), (Int32Enum)(-4), false },
{ Int32Enum.Two, Int32Enum.Two, true },
{ Int32Enum.Min, Int32Enum.Max, false },
{ Int32Enum.Min, Int32Enum.Min, true },
{ Int32Enum.One, Int32Enum.Min + 1, false },
{ (Int32Enum)(-2), null, false },
{ Int32Enum.Two, null, false },
{ null, Int32Enum.Max, false },
{ null, Int32Enum.Min + 1, false }
};
}

public static EqualsData<Int64Enum> Int64EnumData()
{
return new EqualsData<Int64Enum>
Expand Down Expand Up @@ -281,6 +299,24 @@ public static EqualsData<NonEquatableValueType> NonEquatableValueTypeData()
};
}

public static EqualsData<NonEquatableValueType?> NullableNonEquatableValueTypeData()
{
// Comparisons for structs that do not override ValueType.Equals or
// ValueType.GetHashCode should still work as expected.

var one = new NonEquatableValueType { Value = 1 };

return new EqualsData<NonEquatableValueType?>
{
{ new NonEquatableValueType(), new NonEquatableValueType(), true },
{ one, one, true },
{ new NonEquatableValueType(-1), new NonEquatableValueType(), false },
{ new NonEquatableValueType(2), new NonEquatableValueType(2), true },
{ new NonEquatableValueType(-1), null, false },
{ null, new NonEquatableValueType(2), false }
};
}

public static EqualsData<object> ObjectData()
{
var obj = new object();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -77,13 +77,24 @@ public override int GetHashCode() =>
[Serializable]
[TypeForwardedFrom("mscorlib, Version=4.0.0.0, Culture=neutral, PublicKeyToken=b77a5c561934e089")]
// Needs to be public to support binary serialization compatibility
public sealed partial class NullableComparer<T> : Comparer<T?> where T : struct, IComparable<T>
public sealed class NullableComparer<T> : Comparer<T?>, ISerializable where T : struct
{
public NullableComparer() { }
private NullableComparer(SerializationInfo info, StreamingContext context) { }
public void GetObjectData(SerializationInfo info, StreamingContext context)
{
if (!typeof(T).IsAssignableTo(typeof(IComparable<T>)))
{
// We used to use NullableComparer only for types implementing IComparable<T>
info.SetType(typeof(ObjectComparer<T?>));
}
}

public override int Compare(T? x, T? y)
{
if (x.HasValue)
{
if (y.HasValue) return x.value.CompareTo(y.value);
if (y.HasValue) return Comparer<T>.Default.Compare(x.value, y.value);
return 1;
}
if (y.HasValue) return -1;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -97,14 +97,25 @@ public override int GetHashCode() =>
[Serializable]
[TypeForwardedFrom("mscorlib, Version=4.0.0.0, Culture=neutral, PublicKeyToken=b77a5c561934e089")]
// Needs to be public to support binary serialization compatibility
public sealed partial class NullableEqualityComparer<T> : EqualityComparer<T?> where T : struct, IEquatable<T>
public sealed partial class NullableEqualityComparer<T> : EqualityComparer<T?>, ISerializable where T : struct
{
public NullableEqualityComparer() { }
private NullableEqualityComparer(SerializationInfo info, StreamingContext context) { }
public void GetObjectData(SerializationInfo info, StreamingContext context)
{
if (!typeof(T).IsAssignableTo(typeof(IEquatable<T>)))
{
// We used to use NullableComparer only for types implementing IEquatable<T>
info.SetType(typeof(ObjectEqualityComparer<T?>));
}
}

[MethodImpl(MethodImplOptions.AggressiveInlining)]
public override bool Equals(T? x, T? y)
{
if (x.HasValue)
{
if (y.HasValue) return x.value.Equals(y.value);
if (y.HasValue) return EqualityComparer<T>.Default.Equals(x.value, y.value);
return false;
}
if (y.HasValue) return false;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -727,5 +727,37 @@ private class DelegateBinder : SerializationBinder
public Func<string, string, Type> BindToTypeDelegate = null;
public override Type BindToType(string assemblyName, string typeName) => BindToTypeDelegate?.Invoke(assemblyName, typeName);
}

public struct MyStruct
{
public int A;
}

public static IEnumerable<object[]> NullableComparersTestData()
{
yield return new object[] { "NullableEqualityComparer`1", EqualityComparer<byte?>.Default };
yield return new object[] { "NullableEqualityComparer`1", EqualityComparer<int?>.Default };
yield return new object[] { "NullableEqualityComparer`1", EqualityComparer<float?>.Default };
yield return new object[] { "NullableEqualityComparer`1", EqualityComparer<Guid?>.Default }; // implements IEquatable<>

yield return new object[] { "ObjectEqualityComparer`1", EqualityComparer<MyStruct?>.Default }; // doesn't implement IEquatable<>
yield return new object[] { "ObjectEqualityComparer`1", EqualityComparer<DayOfWeek?>.Default };

yield return new object[] { "NullableComparer`1", Comparer<byte?>.Default };
yield return new object[] { "NullableComparer`1", Comparer<int?>.Default };
yield return new object[] { "NullableComparer`1", Comparer<float?>.Default };
yield return new object[] { "NullableComparer`1", Comparer<Guid?>.Default };

yield return new object[] { "ObjectComparer`1", Comparer<MyStruct?>.Default };
yield return new object[] { "ObjectComparer`1", Comparer<DayOfWeek?>.Default };
}

[Theory]
[MemberData(nameof(NullableComparersTestData))]
public void NullableComparersRoundtrip(string expectedType, object obj)
{
string serialized = BinaryFormatterHelpers.ToBase64String(obj);
Assert.Equal(expectedType, BinaryFormatterHelpers.FromBase64String(serialized).GetType().Name);
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -32,12 +32,11 @@ private static Comparer<T> CreateComparer()
if (typeof(IComparable<T>).IsAssignableFrom(t))
return (Comparer<T>)RuntimeType.CreateInstanceForAnotherGenericParameter(typeof(GenericComparer<>), t);

// If T is a Nullable<U> where U implements IComparable<U> return a NullableComparer<U>
// If T is a Nullable<U> return a NullableComparer<U>
if (t.IsGenericType && t.GetGenericTypeDefinition() == typeof(Nullable<>))
{
RuntimeType u = (RuntimeType)t.GetGenericArguments()[0];
if (typeof(IComparable<>).MakeGenericType(u).IsAssignableFrom(u))
return (Comparer<T>)RuntimeType.CreateInstanceForAnotherGenericParameter(typeof(NullableComparer<>), u);
return (Comparer<T>)RuntimeType.CreateInstanceForAnotherGenericParameter(typeof(NullableComparer<>), u);
}

if (t.IsEnum)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -48,10 +48,7 @@ private static EqualityComparer<T> CreateComparer()
if (t.IsGenericType && t.GetGenericTypeDefinition() == typeof(Nullable<>))
{
RuntimeType u = (RuntimeType)t.GetGenericArguments()[0];
if (typeof(IEquatable<>).MakeGenericType(u).IsAssignableFrom(u))
{
return (EqualityComparer<T>)RuntimeType.CreateInstanceForAnotherGenericParameter(typeof(NullableEqualityComparer<>), u);
}
return (EqualityComparer<T>)RuntimeType.CreateInstanceForAnotherGenericParameter(typeof(NullableEqualityComparer<>), u);
}

if (t.IsEnum)
Expand Down
3 changes: 2 additions & 1 deletion src/mono/mono/mini/method-to-ir.c
Original file line number Diff line number Diff line change
Expand Up @@ -5422,7 +5422,8 @@ handle_call_res_devirt (MonoCompile *cfg, MonoMethod *cmethod, MonoInst *call_re

/* EqualityComparer<T>.Default returns specific types depending on T */
// FIXME: Add more
/* 1. Implements IEquatable<T> */
// 1. Implements IEquatable<T>
// 2. Nullable<T>
/*
* Can't use this for string/byte as it might use a different comparer:
*
Expand Down
5 changes: 5 additions & 0 deletions src/tests/JIT/opt/Devirtualization/Comparer_get_Default.cs
Original file line number Diff line number Diff line change
Expand Up @@ -221,6 +221,11 @@ private static void GetTypeTests()
AssertEquals("System.Collections.Generic.EnumComparer`1[System.Runtime.CompilerServices.MethodImplOptions]", Comparer<MethodImplOptions>.Default.GetType().ToString());
AssertEquals("System.Collections.Generic.NullableComparer`1[System.Byte]", Comparer<byte?>.Default.GetType().ToString());
AssertEquals("System.Collections.Generic.ObjectComparer`1[Struct1]", Comparer<Struct1>.Default.GetType().ToString());

AssertEquals("System.Collections.Generic.NullableComparer`1[System.Runtime.CompilerServices.MethodImplOptions]", Comparer<MethodImplOptions?>.Default.GetType().ToString());
AssertEquals("System.Collections.Generic.NullableEqualityComparer`1[System.Runtime.CompilerServices.MethodImplOptions]", EqualityComparer<MethodImplOptions?>.Default.GetType().ToString());
AssertEquals("System.Collections.Generic.NullableComparer`1[Struct1]", Comparer<Struct1?>.Default.GetType().ToString());
AssertEquals("System.Collections.Generic.NullableEqualityComparer`1[Struct1]", EqualityComparer<Struct1?>.Default.GetType().ToString());
}
private static int GetHashCodeTests()
{
Expand Down