diff --git a/src/coreclr/System.Private.CoreLib/src/System/Collections/Generic/ComparerHelpers.cs b/src/coreclr/System.Private.CoreLib/src/System/Collections/Generic/ComparerHelpers.cs index 6812a6c92faf26..b020c127629413 100644 --- a/src/coreclr/System.Private.CoreLib/src/System/Collections/Generic/ComparerHelpers.cs +++ b/src/coreclr/System.Private.CoreLib/src/System/Collections/Generic/ComparerHelpers.cs @@ -37,14 +37,11 @@ internal static object CreateDefaultComparer(Type type) { result = CreateInstanceForAnotherGenericParameter((RuntimeType)typeof(GenericComparer), runtimeType); } - // Nullable does not implement IComparable directly because that would add an extra interface call per comparison. - // Instead, it relies on Comparer.Default to specialize for nullables and do the lifted comparisons if T implements IComparable. - else if (type.IsGenericType) + else if (type.IsGenericType && type.GetGenericTypeDefinition() == typeof(Nullable<>)) { - if (type.GetGenericTypeDefinition() == typeof(Nullable<>)) - { - result = TryCreateNullableComparer(runtimeType); - } + // Nullable does not implement IComparable directly because that would add an extra interface call per comparison. + var embeddedType = (RuntimeType)type.GetGenericArguments()[0]; + result = RuntimeTypeHandle.CreateInstanceForAnotherGenericParameter((RuntimeType)typeof(NullableComparer), embeddedType); } // The comparer for enums is specialized to avoid boxing. else if (type.IsEnum) @@ -55,25 +52,6 @@ internal static object CreateDefaultComparer(Type type) return result ?? CreateInstanceForAnotherGenericParameter((RuntimeType)typeof(ObjectComparer), runtimeType); } - /// - /// Creates the default for a nullable type. - /// - /// The nullable type to create the default comparer for. - private static object? TryCreateNullableComparer(RuntimeType nullableType) - { - Debug.Assert(nullableType != null); - Debug.Assert(nullableType.IsGenericType && nullableType.GetGenericTypeDefinition() == typeof(Nullable<>)); - - var embeddedType = (RuntimeType)nullableType.GetGenericArguments()[0]; - - if (typeof(IComparable<>).MakeGenericType(embeddedType).IsAssignableFrom(embeddedType)) - { - return RuntimeTypeHandle.CreateInstanceForAnotherGenericParameter((RuntimeType)typeof(NullableComparer), embeddedType); - } - - return null; - } - /// /// Creates the default for an enum type. /// @@ -135,14 +113,11 @@ internal static object CreateDefaultEqualityComparer(Type type) // If T implements IEquatable return a GenericEqualityComparer result = CreateInstanceForAnotherGenericParameter((RuntimeType)typeof(GenericEqualityComparer), runtimeType); } - else if (type.IsGenericType) + else if (type.IsGenericType && type.GetGenericTypeDefinition() == typeof(Nullable<>)) { // Nullable does not implement IEquatable directly because that would add an extra interface call per comparison. - // Instead, it relies on EqualityComparer.Default to specialize for nullables and do the lifted comparisons if T implements IEquatable. - if (type.GetGenericTypeDefinition() == typeof(Nullable<>)) - { - result = TryCreateNullableEqualityComparer(runtimeType); - } + var embeddedType = (RuntimeType)type.GetGenericArguments()[0]; + result = CreateInstanceForAnotherGenericParameter((RuntimeType)typeof(NullableEqualityComparer), embeddedType); } else if (type.IsEnum) { @@ -153,25 +128,6 @@ internal static object CreateDefaultEqualityComparer(Type type) return result ?? CreateInstanceForAnotherGenericParameter((RuntimeType)typeof(ObjectEqualityComparer), runtimeType); } - /// - /// Creates the default for a nullable type. - /// - /// The nullable type to create the default equality comparer for. - private static object? TryCreateNullableEqualityComparer(RuntimeType nullableType) - { - Debug.Assert(nullableType != null); - Debug.Assert(nullableType.IsGenericType && nullableType.GetGenericTypeDefinition() == typeof(Nullable<>)); - - var embeddedType = (RuntimeType)nullableType.GetGenericArguments()[0]; - - if (typeof(IEquatable<>).MakeGenericType(embeddedType).IsAssignableFrom(embeddedType)) - { - return RuntimeTypeHandle.CreateInstanceForAnotherGenericParameter((RuntimeType)typeof(NullableEqualityComparer), embeddedType); - } - - return null; - } - /// /// Creates the default for an enum type. /// diff --git a/src/coreclr/System.Private.CoreLib/src/System/Collections/Generic/EqualityComparer.CoreCLR.cs b/src/coreclr/System.Private.CoreLib/src/System/Collections/Generic/EqualityComparer.CoreCLR.cs index ef5ec59482110a..9d35ae81313b6e 100644 --- a/src/coreclr/System.Private.CoreLib/src/System/Collections/Generic/EqualityComparer.CoreCLR.cs +++ b/src/coreclr/System.Private.CoreLib/src/System/Collections/Generic/EqualityComparer.CoreCLR.cs @@ -57,7 +57,7 @@ internal override int LastIndexOf(T[] array, T value, int startIndex, int count) } } - public sealed partial class NullableEqualityComparer : EqualityComparer where T : struct, IEquatable + public sealed partial class NullableEqualityComparer : EqualityComparer where T : struct { internal override int IndexOf(T?[] array, T? value, int startIndex, int count) { @@ -73,7 +73,7 @@ internal override int IndexOf(T?[] array, T? value, int startIndex, int count) { for (int i = startIndex; i < endIndex; i++) { - if (array[i].HasValue && array[i].value.Equals(value.value)) return i; + if (array[i].HasValue && EqualityComparer.Default.Equals(array[i].value, value.value)) return i; } } return -1; @@ -93,7 +93,7 @@ internal override int LastIndexOf(T?[] array, T? value, int startIndex, int coun { for (int i = startIndex; i >= endIndex; i--) { - if (array[i].HasValue && array[i].value.Equals(value.value)) return i; + if (array[i].HasValue && EqualityComparer.Default.Equals(array[i].value, value.value)) return i; } } return -1; diff --git a/src/coreclr/nativeaot/System.Private.CoreLib/src/Internal/IntrinsicSupport/ComparerHelpers.cs b/src/coreclr/nativeaot/System.Private.CoreLib/src/Internal/IntrinsicSupport/ComparerHelpers.cs index 2a74b9560d4f48..2800fc3735a495 100644 --- a/src/coreclr/nativeaot/System.Private.CoreLib/src/Internal/IntrinsicSupport/ComparerHelpers.cs +++ b/src/coreclr/nativeaot/System.Private.CoreLib/src/Internal/IntrinsicSupport/ComparerHelpers.cs @@ -59,11 +59,8 @@ internal static object GetComparer(RuntimeTypeHandle t) if (RuntimeAugments.IsNullable(t)) { RuntimeTypeHandle nullableType = RuntimeAugments.GetNullableType(t); - if (ImplementsIComparable(nullableType)) - { - openComparerType = typeof(NullableComparer<>).TypeHandle; - comparerTypeArgument = nullableType; - } + openComparerType = typeof(NullableComparer<>).TypeHandle; + comparerTypeArgument = nullableType; } if (EqualityComparerHelpers.IsEnum(t)) { diff --git a/src/coreclr/nativeaot/System.Private.CoreLib/src/Internal/IntrinsicSupport/EqualityComparerHelpers.cs b/src/coreclr/nativeaot/System.Private.CoreLib/src/Internal/IntrinsicSupport/EqualityComparerHelpers.cs index 2e12a3cab28057..783a70395ef10e 100644 --- a/src/coreclr/nativeaot/System.Private.CoreLib/src/Internal/IntrinsicSupport/EqualityComparerHelpers.cs +++ b/src/coreclr/nativeaot/System.Private.CoreLib/src/Internal/IntrinsicSupport/EqualityComparerHelpers.cs @@ -62,11 +62,8 @@ internal static object GetComparer(RuntimeTypeHandle t) if (RuntimeAugments.IsNullable(t)) { RuntimeTypeHandle nullableType = RuntimeAugments.GetNullableType(t); - if (ImplementsIEquatable(nullableType)) - { - openComparerType = typeof(NullableEqualityComparer<>).TypeHandle; - comparerTypeArgument = nullableType; - } + openComparerType = typeof(NullableEqualityComparer<>).TypeHandle; + comparerTypeArgument = nullableType; } if (IsEnum(t)) { diff --git a/src/coreclr/tools/Common/TypeSystem/IL/Stubs/ComparerIntrinsics.cs b/src/coreclr/tools/Common/TypeSystem/IL/Stubs/ComparerIntrinsics.cs index 289556d9e58195..91856de2d880d5 100644 --- a/src/coreclr/tools/Common/TypeSystem/IL/Stubs/ComparerIntrinsics.cs +++ b/src/coreclr/tools/Common/TypeSystem/IL/Stubs/ComparerIntrinsics.cs @@ -113,11 +113,9 @@ private static TypeDesc GetComparerForType(TypeDesc type, string flavor, string // We can't tell at compile time either. return null; } - else if (ImplementsInterfaceOfSelf(nullableType, interfaceName)) - { - return context.SystemModule.GetKnownType("System.Collections.Generic", $"Nullable{flavor}`1") - .MakeInstantiatedType(nullableType); - } + + return context.SystemModule.GetKnownType("System.Collections.Generic", $"Nullable{flavor}`1") + .MakeInstantiatedType(nullableType); } else if (type.IsEnum) { diff --git a/src/coreclr/vm/jitinterface.cpp b/src/coreclr/vm/jitinterface.cpp index 4bfbb7b71e55ff..2e1e04607c05ad 100644 --- a/src/coreclr/vm/jitinterface.cpp +++ b/src/coreclr/vm/jitinterface.cpp @@ -8931,12 +8931,8 @@ CORINFO_CLASS_HANDLE CEEInfo::getDefaultEqualityComparerClassHelper(CORINFO_CLAS if (Nullable::IsNullableType(elemTypeHnd)) { Instantiation nullableInst = elemTypeHnd.AsMethodTable()->GetInstantiation(); - TypeHandle iequatable = TypeHandle(CoreLibBinder::GetClass(CLASS__IEQUATABLEGENERIC)).Instantiate(nullableInst); - if (nullableInst[0].CanCastTo(iequatable)) - { - TypeHandle resultTh = ((TypeHandle)CoreLibBinder::GetClass(CLASS__NULLABLE_EQUALITYCOMPARER)).Instantiate(nullableInst); - return CORINFO_CLASS_HANDLE(resultTh.GetMethodTable()); - } + TypeHandle resultTh = ((TypeHandle)CoreLibBinder::GetClass(CLASS__NULLABLE_EQUALITYCOMPARER)).Instantiate(nullableInst); + return CORINFO_CLASS_HANDLE(resultTh.GetMethodTable()); } // Enum diff --git a/src/libraries/System.Collections/tests/Generic/Comparers/EqualityComparer.Tests.cs b/src/libraries/System.Collections/tests/Generic/Comparers/EqualityComparer.Tests.cs index 4d4ce87f02b28d..c8b2581da92571 100644 --- a/src/libraries/System.Collections/tests/Generic/Comparers/EqualityComparer.Tests.cs +++ b/src/libraries/System.Collections/tests/Generic/Comparers/EqualityComparer.Tests.cs @@ -33,8 +33,10 @@ public class HashData : TheoryData { } [MemberData(nameof(Int16EnumData))] [MemberData(nameof(SByteEnumData))] [MemberData(nameof(Int32EnumData))] + [MemberData(nameof(NullableInt32EnumData))] [MemberData(nameof(Int64EnumData))] [MemberData(nameof(NonEquatableValueTypeData))] + [MemberData(nameof(NullableNonEquatableValueTypeData))] [MemberData(nameof(ObjectData))] public void EqualsTest(T left, T right, bool expected) { @@ -253,6 +255,22 @@ public static EqualsData Int32EnumData() }; } + public static EqualsData NullableInt32EnumData() + { + return new EqualsData + { + { (Int32Enum)(-2), (Int32Enum)(-4), false }, + { Int32Enum.Two, Int32Enum.Two, true }, + { Int32Enum.Min, Int32Enum.Max, false }, + { Int32Enum.Min, Int32Enum.Min, true }, + { Int32Enum.One, Int32Enum.Min + 1, false }, + { (Int32Enum)(-2), null, false }, + { Int32Enum.Two, null, false }, + { null, Int32Enum.Max, false }, + { null, Int32Enum.Min + 1, false } + }; + } + public static EqualsData Int64EnumData() { return new EqualsData @@ -281,6 +299,24 @@ public static EqualsData NonEquatableValueTypeData() }; } + public static EqualsData NullableNonEquatableValueTypeData() + { + // Comparisons for structs that do not override ValueType.Equals or + // ValueType.GetHashCode should still work as expected. + + var one = new NonEquatableValueType { Value = 1 }; + + return new EqualsData + { + { new NonEquatableValueType(), new NonEquatableValueType(), true }, + { one, one, true }, + { new NonEquatableValueType(-1), new NonEquatableValueType(), false }, + { new NonEquatableValueType(2), new NonEquatableValueType(2), true }, + { new NonEquatableValueType(-1), null, false }, + { null, new NonEquatableValueType(2), false } + }; + } + public static EqualsData ObjectData() { var obj = new object(); diff --git a/src/libraries/System.Private.CoreLib/src/System/Collections/Generic/Comparer.cs b/src/libraries/System.Private.CoreLib/src/System/Collections/Generic/Comparer.cs index 1b5d6116587109..22cc2f3aac02cd 100644 --- a/src/libraries/System.Private.CoreLib/src/System/Collections/Generic/Comparer.cs +++ b/src/libraries/System.Private.CoreLib/src/System/Collections/Generic/Comparer.cs @@ -77,13 +77,24 @@ public override int GetHashCode() => [Serializable] [TypeForwardedFrom("mscorlib, Version=4.0.0.0, Culture=neutral, PublicKeyToken=b77a5c561934e089")] // Needs to be public to support binary serialization compatibility - public sealed partial class NullableComparer : Comparer where T : struct, IComparable + public sealed class NullableComparer : Comparer, ISerializable where T : struct { + public NullableComparer() { } + private NullableComparer(SerializationInfo info, StreamingContext context) { } + public void GetObjectData(SerializationInfo info, StreamingContext context) + { + if (!typeof(T).IsAssignableTo(typeof(IComparable))) + { + // We used to use NullableComparer only for types implementing IComparable + info.SetType(typeof(ObjectComparer)); + } + } + public override int Compare(T? x, T? y) { if (x.HasValue) { - if (y.HasValue) return x.value.CompareTo(y.value); + if (y.HasValue) return Comparer.Default.Compare(x.value, y.value); return 1; } if (y.HasValue) return -1; diff --git a/src/libraries/System.Private.CoreLib/src/System/Collections/Generic/EqualityComparer.cs b/src/libraries/System.Private.CoreLib/src/System/Collections/Generic/EqualityComparer.cs index 6f93085fffa270..60a65072a41f90 100644 --- a/src/libraries/System.Private.CoreLib/src/System/Collections/Generic/EqualityComparer.cs +++ b/src/libraries/System.Private.CoreLib/src/System/Collections/Generic/EqualityComparer.cs @@ -97,14 +97,25 @@ public override int GetHashCode() => [Serializable] [TypeForwardedFrom("mscorlib, Version=4.0.0.0, Culture=neutral, PublicKeyToken=b77a5c561934e089")] // Needs to be public to support binary serialization compatibility - public sealed partial class NullableEqualityComparer : EqualityComparer where T : struct, IEquatable + public sealed partial class NullableEqualityComparer : EqualityComparer, ISerializable where T : struct { + public NullableEqualityComparer() { } + private NullableEqualityComparer(SerializationInfo info, StreamingContext context) { } + public void GetObjectData(SerializationInfo info, StreamingContext context) + { + if (!typeof(T).IsAssignableTo(typeof(IEquatable))) + { + // We used to use NullableComparer only for types implementing IEquatable + info.SetType(typeof(ObjectEqualityComparer)); + } + } + [MethodImpl(MethodImplOptions.AggressiveInlining)] public override bool Equals(T? x, T? y) { if (x.HasValue) { - if (y.HasValue) return x.value.Equals(y.value); + if (y.HasValue) return EqualityComparer.Default.Equals(x.value, y.value); return false; } if (y.HasValue) return false; diff --git a/src/libraries/System.Runtime.Serialization.Formatters/tests/BinaryFormatterTests.cs b/src/libraries/System.Runtime.Serialization.Formatters/tests/BinaryFormatterTests.cs index 250a84819b82e0..2a04351cfa9eed 100644 --- a/src/libraries/System.Runtime.Serialization.Formatters/tests/BinaryFormatterTests.cs +++ b/src/libraries/System.Runtime.Serialization.Formatters/tests/BinaryFormatterTests.cs @@ -727,5 +727,37 @@ private class DelegateBinder : SerializationBinder public Func BindToTypeDelegate = null; public override Type BindToType(string assemblyName, string typeName) => BindToTypeDelegate?.Invoke(assemblyName, typeName); } + + public struct MyStruct + { + public int A; + } + + public static IEnumerable NullableComparersTestData() + { + yield return new object[] { "NullableEqualityComparer`1", EqualityComparer.Default }; + yield return new object[] { "NullableEqualityComparer`1", EqualityComparer.Default }; + yield return new object[] { "NullableEqualityComparer`1", EqualityComparer.Default }; + yield return new object[] { "NullableEqualityComparer`1", EqualityComparer.Default }; // implements IEquatable<> + + yield return new object[] { "ObjectEqualityComparer`1", EqualityComparer.Default }; // doesn't implement IEquatable<> + yield return new object[] { "ObjectEqualityComparer`1", EqualityComparer.Default }; + + yield return new object[] { "NullableComparer`1", Comparer.Default }; + yield return new object[] { "NullableComparer`1", Comparer.Default }; + yield return new object[] { "NullableComparer`1", Comparer.Default }; + yield return new object[] { "NullableComparer`1", Comparer.Default }; + + yield return new object[] { "ObjectComparer`1", Comparer.Default }; + yield return new object[] { "ObjectComparer`1", Comparer.Default }; + } + + [Theory] + [MemberData(nameof(NullableComparersTestData))] + public void NullableComparersRoundtrip(string expectedType, object obj) + { + string serialized = BinaryFormatterHelpers.ToBase64String(obj); + Assert.Equal(expectedType, BinaryFormatterHelpers.FromBase64String(serialized).GetType().Name); + } } } diff --git a/src/mono/System.Private.CoreLib/src/System/Collections/Generic/Comparer.Mono.cs b/src/mono/System.Private.CoreLib/src/System/Collections/Generic/Comparer.Mono.cs index 94066bccfc5436..06044427a30e61 100644 --- a/src/mono/System.Private.CoreLib/src/System/Collections/Generic/Comparer.Mono.cs +++ b/src/mono/System.Private.CoreLib/src/System/Collections/Generic/Comparer.Mono.cs @@ -32,12 +32,11 @@ private static Comparer CreateComparer() if (typeof(IComparable).IsAssignableFrom(t)) return (Comparer)RuntimeType.CreateInstanceForAnotherGenericParameter(typeof(GenericComparer<>), t); - // If T is a Nullable where U implements IComparable return a NullableComparer + // If T is a Nullable return a NullableComparer if (t.IsGenericType && t.GetGenericTypeDefinition() == typeof(Nullable<>)) { RuntimeType u = (RuntimeType)t.GetGenericArguments()[0]; - if (typeof(IComparable<>).MakeGenericType(u).IsAssignableFrom(u)) - return (Comparer)RuntimeType.CreateInstanceForAnotherGenericParameter(typeof(NullableComparer<>), u); + return (Comparer)RuntimeType.CreateInstanceForAnotherGenericParameter(typeof(NullableComparer<>), u); } if (t.IsEnum) diff --git a/src/mono/System.Private.CoreLib/src/System/Collections/Generic/EqualityComparer.Mono.cs b/src/mono/System.Private.CoreLib/src/System/Collections/Generic/EqualityComparer.Mono.cs index 73a5edbf74539b..796c946eebb8d8 100644 --- a/src/mono/System.Private.CoreLib/src/System/Collections/Generic/EqualityComparer.Mono.cs +++ b/src/mono/System.Private.CoreLib/src/System/Collections/Generic/EqualityComparer.Mono.cs @@ -48,10 +48,7 @@ private static EqualityComparer CreateComparer() if (t.IsGenericType && t.GetGenericTypeDefinition() == typeof(Nullable<>)) { RuntimeType u = (RuntimeType)t.GetGenericArguments()[0]; - if (typeof(IEquatable<>).MakeGenericType(u).IsAssignableFrom(u)) - { - return (EqualityComparer)RuntimeType.CreateInstanceForAnotherGenericParameter(typeof(NullableEqualityComparer<>), u); - } + return (EqualityComparer)RuntimeType.CreateInstanceForAnotherGenericParameter(typeof(NullableEqualityComparer<>), u); } if (t.IsEnum) diff --git a/src/mono/mono/mini/method-to-ir.c b/src/mono/mono/mini/method-to-ir.c index e5588087f8b18d..10432e49149929 100644 --- a/src/mono/mono/mini/method-to-ir.c +++ b/src/mono/mono/mini/method-to-ir.c @@ -5422,7 +5422,8 @@ handle_call_res_devirt (MonoCompile *cfg, MonoMethod *cmethod, MonoInst *call_re /* EqualityComparer.Default returns specific types depending on T */ // FIXME: Add more - /* 1. Implements IEquatable */ + // 1. Implements IEquatable + // 2. Nullable /* * Can't use this for string/byte as it might use a different comparer: * diff --git a/src/tests/JIT/opt/Devirtualization/Comparer_get_Default.cs b/src/tests/JIT/opt/Devirtualization/Comparer_get_Default.cs index 888c303a59b822..e754e79a072e20 100644 --- a/src/tests/JIT/opt/Devirtualization/Comparer_get_Default.cs +++ b/src/tests/JIT/opt/Devirtualization/Comparer_get_Default.cs @@ -221,6 +221,11 @@ private static void GetTypeTests() AssertEquals("System.Collections.Generic.EnumComparer`1[System.Runtime.CompilerServices.MethodImplOptions]", Comparer.Default.GetType().ToString()); AssertEquals("System.Collections.Generic.NullableComparer`1[System.Byte]", Comparer.Default.GetType().ToString()); AssertEquals("System.Collections.Generic.ObjectComparer`1[Struct1]", Comparer.Default.GetType().ToString()); + + AssertEquals("System.Collections.Generic.NullableComparer`1[System.Runtime.CompilerServices.MethodImplOptions]", Comparer.Default.GetType().ToString()); + AssertEquals("System.Collections.Generic.NullableEqualityComparer`1[System.Runtime.CompilerServices.MethodImplOptions]", EqualityComparer.Default.GetType().ToString()); + AssertEquals("System.Collections.Generic.NullableComparer`1[Struct1]", Comparer.Default.GetType().ToString()); + AssertEquals("System.Collections.Generic.NullableEqualityComparer`1[Struct1]", EqualityComparer.Default.GetType().ToString()); } private static int GetHashCodeTests() {