diff --git a/src/coreclr/tools/aot/ILCompiler.Compiler/Compiler/TypePreinit.cs b/src/coreclr/tools/aot/ILCompiler.Compiler/Compiler/TypePreinit.cs index 7e9f457d763b0a..7676ec1ae2c7a0 100644 --- a/src/coreclr/tools/aot/ILCompiler.Compiler/Compiler/TypePreinit.cs +++ b/src/coreclr/tools/aot/ILCompiler.Compiler/Compiler/TypePreinit.cs @@ -41,6 +41,7 @@ public class TypePreinit private readonly Dictionary _fieldValues = new Dictionary(); private readonly Dictionary _internedStrings = new Dictionary(); private readonly Dictionary _internedTypes = new Dictionary(); + private readonly Dictionary _nestedPreinitResults = new Dictionary(); private TypePreinit(MetadataType owningType, CompilationModuleGroup compilationGroup, ILProvider ilProvider, TypePreinitializationPolicy policy, ReadOnlyFieldPolicy readOnlyPolicy, FlowAnnotations flowAnnotations) { @@ -109,6 +110,40 @@ public static PreinitializationInfo ScanType(CompilationModuleGroup compilationG return new PreinitializationInfo(type, status.FailureReason); } + private bool TryGetNestedPreinitResult(MethodDesc callingMethod, MetadataType type, Stack recursionProtect, ref int instructionCounter, out NestedPreinitResult result) + { + if (!_nestedPreinitResults.TryGetValue(type, out result)) + { + TypePreinit nestedPreinit = new TypePreinit(type, _compilationGroup, _ilProvider, _policy, _readOnlyPolicy, _flowAnnotations); + recursionProtect ??= new Stack(); + recursionProtect.Push(callingMethod); + + // Since we don't reset the instruction counter as we interpret the nested cctor, + // remember the instruction counter before we start interpreting so that we can subtract + // the instructions later when we convert object instances allocated in the nested + // cctor to foreign instances in the currently analyzed cctor. + // E.g. if the nested cctor allocates a new object at the beginning of the cctor, + // we should treat it as a ForeignTypeInstance with allocation site ID 0, not allocation + // site ID of `instructionCounter + 0`. + // We could also reset the counter, but we use the instruction counter as a complexity cutoff + // and resetting it would lead to unpredictable analysis durations. + int baseInstructionCounter = instructionCounter; + Status status = nestedPreinit.TryScanMethod(type.GetStaticConstructor(), null, recursionProtect, ref instructionCounter, out Value _); + if (!status.IsSuccessful) + { + result = default; + return false; + } + recursionProtect.Pop(); + + result = new NestedPreinitResult(nestedPreinit._fieldValues, baseInstructionCounter); + + _nestedPreinitResults.Add(type, result); + } + + return true; + } + private Status TryScanMethod(MethodDesc method, Value[] parameters, Stack recursionProtect, ref int instructionCounter, out Value returnValue) { MethodIL methodIL = _ilProvider.GetMethodIL(method); @@ -344,6 +379,7 @@ private Status TryScanMethod(MethodIL methodIL, Value[] parameters, Stack(); - recursionProtect.Push(methodIL.OwningMethod); - - // Since we don't reset the instruction counter as we interpret the nested cctor, - // remember the instruction counter before we start interpreting so that we can subtract - // the instructions later when we convert object instances allocated in the nested - // cctor to foreign instances in the currently analyzed cctor. - // E.g. if the nested cctor allocates a new object at the beginning of the cctor, - // we should treat it as a ForeignTypeInstance with allocation site ID 0, not allocation - // site ID of `instructionCounter + 0`. - // We could also reset the counter, but we use the instruction counter as a complexity cutoff - // and resetting it would lead to unpredictable analysis durations. - int baseInstructionCounter = instructionCounter; - Status status = nestedPreinit.TryScanMethod(field.OwningType.GetStaticConstructor(), null, recursionProtect, ref instructionCounter, out Value _); - if (!status.IsSuccessful) + if (!TryGetNestedPreinitResult(methodIL.OwningMethod, (MetadataType)field.OwningType, recursionProtect, ref instructionCounter, out NestedPreinitResult nestedPreinitResult)) { return Status.Fail(methodIL.OwningMethod, opcode, "Nested cctor failed to preinit"); } - recursionProtect.Pop(); - Value value = nestedPreinit._fieldValues[field]; - if (value is BaseValueTypeValue) - stack.PushFromLocation(field.FieldType, value); - else if (value is ReferenceTypeValue referenceType) - stack.PushFromLocation(field.FieldType, referenceType.ToForeignInstance(baseInstructionCounter, this)); - else + + if (!nestedPreinitResult.TryGetFieldValue(this, field, out fieldValue)) return Status.Fail(methodIL.OwningMethod, opcode); } else if (_readOnlyPolicy.IsReadOnly(field) + && opcode != ILOpcode.ldsflda // We need to intern these for correctness in ldsfda scenarios && !field.OwningType.HasStaticConstructor) { // (Effectively) read only field but no static constructor to set it: the value is default-initialized. - stack.PushFromLocation(field.FieldType, NewUninitializedLocationValue(field.FieldType)); + fieldValue = NewUninitializedLocationValue(field.FieldType); } else { return Status.Fail(methodIL.OwningMethod, opcode, "Load from other non-initonly static"); } - } - break; - - case ILOpcode.ldsflda: - { - FieldDesc field = (FieldDesc)methodIL.GetObject(reader.ReadILToken()); - if (!field.IsStatic || field.IsLiteral) - { - ThrowHelper.ThrowInvalidProgramException(); - } - if (field.OwningType != _type) + if (opcode == ILOpcode.ldsfld) { - return Status.Fail(methodIL.OwningMethod, opcode, "Address of other static"); + stack.PushFromLocation(field.FieldType, fieldValue); } - - if (field.IsThreadStatic || field.HasRva) + else { - return Status.Fail(methodIL.OwningMethod, opcode, "Unsupported static"); - } + Debug.Assert(opcode == ILOpcode.ldsflda); + if (fieldValue == null || !fieldValue.TryCreateByRef(out Value byRefValue)) + { + return Status.Fail(methodIL.OwningMethod, opcode, "Unsupported byref"); + } - if (_flowAnnotations.RequiresDataflowAnalysisDueToSignature(field)) - { - return Status.Fail(methodIL.OwningMethod, opcode, "Needs dataflow analysis"); + stack.Push(StackValueKind.ByRef, byRefValue); } - - Value fieldValue = _fieldValues[field]; - if (fieldValue == null || !fieldValue.TryCreateByRef(out Value byRefValue)) - { - return Status.Fail(methodIL.OwningMethod, opcode, "Unsupported byref"); - } - - stack.Push(StackValueKind.ByRef, byRefValue); } break; @@ -465,9 +471,13 @@ private Status TryScanMethod(MethodIL methodIL, Value[] parameters, Stack ILOpcode.ldind_i1, - TypeFlags.Boolean or TypeFlags.Byte => ILOpcode.ldind_u1, - TypeFlags.Int16 => ILOpcode.ldind_i2, - TypeFlags.Char or TypeFlags.UInt16 => ILOpcode.ldind_u2, - TypeFlags.Int32 => ILOpcode.ldind_i4, - TypeFlags.UInt32 => ILOpcode.ldind_u4, - TypeFlags.Int64 or TypeFlags.UInt64 => ILOpcode.ldind_i8, - TypeFlags.Single => ILOpcode.ldind_r4, - TypeFlags.Double => ILOpcode.ldind_r8, - _ => ILOpcode.ldobj, - }; - - if (opcode == ILOpcode.ldobj) - { - return Status.Fail(methodIL.OwningMethod, opcode); - } - } + TypeDesc type = opcode switch + { + ILOpcode.ldind_i1 => context.GetWellKnownType(WellKnownType.SByte), + ILOpcode.ldind_u1 => context.GetWellKnownType(WellKnownType.Byte), + ILOpcode.ldind_i2 => context.GetWellKnownType(WellKnownType.Int16), + ILOpcode.ldind_u2 => context.GetWellKnownType(WellKnownType.UInt16), + ILOpcode.ldind_i4 => context.GetWellKnownType(WellKnownType.Int32), + ILOpcode.ldind_u4 => context.GetWellKnownType(WellKnownType.UInt32), + ILOpcode.ldind_i8 => context.GetWellKnownType(WellKnownType.Int64), + _ /* ldobj */ => (TypeDesc)methodIL.GetObject(reader.ReadILToken()), + }; StackEntry entry = stack.Pop(); - if (entry.Value is ByRefValue byRefVal) + if (entry.ValueKind != StackValueKind.ByRef && entry.ValueKind != StackValueKind.NativeInt) + ThrowHelper.ThrowInvalidProgramException(); + + if (entry.Value is ByRefValueBase byRefVal + && byRefVal.TryLoad(type, out Value dereferenced)) { - switch (opcode) - { - case ILOpcode.ldind_i1: - stack.Push(StackValueKind.Int32, ValueTypeValue.FromInt32(byRefVal.DereferenceAsSByte())); - break; - case ILOpcode.ldind_u1: - stack.Push(StackValueKind.Int32, ValueTypeValue.FromInt32((byte)byRefVal.DereferenceAsSByte())); - break; - case ILOpcode.ldind_i2: - stack.Push(StackValueKind.Int32, ValueTypeValue.FromInt32(byRefVal.DereferenceAsInt16())); - break; - case ILOpcode.ldind_u2: - stack.Push(StackValueKind.Int32, ValueTypeValue.FromInt32((ushort)byRefVal.DereferenceAsInt16())); - break; - case ILOpcode.ldind_i4: - case ILOpcode.ldind_u4: - stack.Push(StackValueKind.Int32, ValueTypeValue.FromInt32(byRefVal.DereferenceAsInt32())); - break; - case ILOpcode.ldind_i8: - stack.Push(StackValueKind.Int64, ValueTypeValue.FromInt64(byRefVal.DereferenceAsInt64())); - break; - case ILOpcode.ldind_r4: - stack.Push(StackValueKind.Float, ValueTypeValue.FromDouble(byRefVal.DereferenceAsSingle())); - break; - case ILOpcode.ldind_r8: - stack.Push(StackValueKind.Float, ValueTypeValue.FromDouble(byRefVal.DereferenceAsDouble())); - break; - } + stack.PushFromLocation(type, dereferenced); } else { - ThrowHelper.ThrowInvalidProgramException(); + return Status.Fail(methodIL.OwningMethod, "Ldind from unsupported byref"); } } break; @@ -2600,6 +2579,22 @@ public override bool TryStore(Value value) return false; } + public override bool TryLoad(TypeDesc type, out Value value) + { + if (!VTableLikeStructValue.IsCompatible(type) + || type is not MetadataType mdType + || mdType.InstanceFieldSize.AsInt > (_methods.Length - _index) * _pointerSize) + { + value = null; + return false; + } + + MethodDesc[] slots = new MethodDesc[GetFieldCount(mdType)]; + Array.Copy(_methods, _index, slots, 0, slots.Length); + value = new VTableLikeStructValue(mdType, slots); + return true; + } + public override Value Clone() => this; // The reference is immutable private int GetFieldIndex(FieldDesc field) @@ -2918,6 +2913,11 @@ public override bool GetRawData(NodeFactory factory, out object data) private abstract class ByRefValueBase : Value, INativeIntConvertibleValue { public virtual bool TryStore(Value value) => false; + public virtual bool TryLoad(TypeDesc type, out Value value) + { + value = null; + return false; + } public virtual bool TryInitialize(int size) => false; } @@ -2978,6 +2978,21 @@ public override bool TryStore(Value value) return true; } + public override bool TryLoad(TypeDesc type, out Value value) + { + if (!type.IsPrimitive + || ((MetadataType)type).InstanceFieldSize.AsInt > PointedToBytes.Length - PointedToOffset) + { + value = null; + return false; + } + + var result = new ValueTypeValue(type); + Array.Copy(PointedToBytes, PointedToOffset, result.InstanceBytes, 0, result.InstanceBytes.Length); + value = result; + return true; + } + public override Value Clone() => this; // Immutable public override void WriteFieldData(ref ObjectDataBuilder builder, NodeFactory factory) @@ -2991,20 +3006,6 @@ public override bool GetRawData(NodeFactory factory, out object data) data = null; return false; } - - private ReadOnlySpan AsExactByteCount(int count) - { - if (PointedToOffset + count > PointedToBytes.Length) - ThrowHelper.ThrowInvalidProgramException(); - return new ReadOnlySpan(PointedToBytes, PointedToOffset, count); - } - - public sbyte DereferenceAsSByte() => (sbyte)AsExactByteCount(1)[0]; - public short DereferenceAsInt16() => BitConverter.ToInt16(AsExactByteCount(2)); - public int DereferenceAsInt32() => BitConverter.ToInt32(AsExactByteCount(4)); - public long DereferenceAsInt64() => BitConverter.ToInt64(AsExactByteCount(8)); - public float DereferenceAsSingle() => BitConverter.ToSingle(AsExactByteCount(4)); - public double DereferenceAsDouble() => BitConverter.ToDouble(AsExactByteCount(8)); } private abstract class ReferenceTypeValue : Value @@ -3542,6 +3543,34 @@ public static Status Fail(MethodDesc method, string detail) } } + private readonly struct NestedPreinitResult + { + private readonly Dictionary _fieldValues; + private readonly int _baseInstructionCounter; + + public NestedPreinitResult(Dictionary fieldValues, int baseInstructionCounter) + => (_fieldValues, _baseInstructionCounter) = (fieldValues, baseInstructionCounter); + + public bool TryGetFieldValue(TypePreinit context, FieldDesc field, out Value value) + { + Value fieldValue = _fieldValues[field]; + + if (fieldValue is ReferenceTypeValue referenceType) + { + value = referenceType.ToForeignInstance(_baseInstructionCounter, context); + return true; + } + else if (fieldValue is BaseValueTypeValue) + { + value = fieldValue; + return true; + } + + value = null; + return false; + } + } + public class PreinitializationInfo { private readonly Dictionary _fieldValues; diff --git a/src/tests/nativeaot/SmokeTests/Preinitialization/Preinitialization.cs b/src/tests/nativeaot/SmokeTests/Preinitialization/Preinitialization.cs index bec5a112b52db1..abb073117d53b6 100644 --- a/src/tests/nativeaot/SmokeTests/Preinitialization/Preinitialization.cs +++ b/src/tests/nativeaot/SmokeTests/Preinitialization/Preinitialization.cs @@ -63,6 +63,9 @@ private static int Main() TestDataflow.Run(); TestConversions.Run(); TestVTables.Run(); + TestVTableManipulation.Run(); + TestVTableNegativeScenarios.Run(); + TestByRefFieldAddressEquality.Run(); #else Console.WriteLine("Preinitialization is disabled in multimodule builds for now. Skipping test."); #endif @@ -1682,6 +1685,12 @@ public static unsafe class IUnknownImpl [FixedAddressValueType] public static readonly IUnknownVftbl Vtbl; + public static nint AbiToProjectionVftablePtr + { + [MethodImpl(MethodImplOptions.AggressiveInlining)] + get => (nint)Unsafe.AsPointer(ref Unsafe.AsRef(in Vtbl)); + } + static IUnknownImpl() { ComWrappers.GetIUnknownImpl( @@ -1696,9 +1705,15 @@ public static unsafe class IInspectableImpl [FixedAddressValueType] public static readonly IInspectableVftbl Vtbl; + public static nint AbiToProjectionVftablePtr + { + [MethodImpl(MethodImplOptions.AggressiveInlining)] + get => (nint)Unsafe.AsPointer(ref Unsafe.AsRef(in Vtbl)); + } + static IInspectableImpl() { - *(IUnknownVftbl*)Unsafe.AsPointer(ref Vtbl) = IUnknownImpl.Vtbl; + *(IUnknownVftbl*)Unsafe.AsPointer(ref Vtbl) = *(IUnknownVftbl*)IUnknownImpl.AbiToProjectionVftablePtr; Vtbl.GetIids = &GetIids; Vtbl.GetRuntimeClassName = &GetRuntimeClassName; @@ -1715,6 +1730,21 @@ static IInspectableImpl() public static int GetTrustLevel(void* thisPtr, int* trustLevel) => 0; } + internal static unsafe class IStringableImpl + { + public static readonly IStringableVftbl Vtbl; + + static IStringableImpl() + { + *(IInspectableVftbl*)Unsafe.AsPointer(ref Vtbl) = *(IInspectableVftbl*)IInspectableImpl.AbiToProjectionVftablePtr; + + Vtbl.ToString = &ToString; + } + + [UnmanagedCallersOnly(CallConvs = [typeof(CallConvMemberFunction)])] + public static int ToString(void* thisPtr, nint* value) => 0; + } + public unsafe struct IUnknownVftbl { public delegate* unmanaged[MemberFunction] QueryInterface; @@ -1732,6 +1762,17 @@ public unsafe struct IInspectableVftbl public delegate* unmanaged[MemberFunction] GetTrustLevel; } + internal unsafe struct IStringableVftbl + { + public delegate* unmanaged[MemberFunction] QueryInterface; + public delegate* unmanaged[MemberFunction] AddRef; + public delegate* unmanaged[MemberFunction] Release; + public delegate* unmanaged[MemberFunction] GetIids; + public delegate* unmanaged[MemberFunction] GetRuntimeClassName; + public delegate* unmanaged[MemberFunction] GetTrustLevel; + public new delegate* unmanaged[MemberFunction] ToString; + } + public static unsafe void Run() { Assert.IsPreinitialized(typeof(IUnknownImpl)); @@ -1749,6 +1790,200 @@ public static unsafe void Run() Assert.AreEqual((nuint)release, (nuint)IInspectableImpl.Vtbl.Release); Assert.AreEqual((nuint)(delegate* unmanaged[MemberFunction])&IInspectableImpl.GetIids, (nuint)IInspectableImpl.Vtbl.GetIids); Assert.AreEqual((nuint)(delegate* unmanaged[MemberFunction])&IInspectableImpl.GetTrustLevel, (nuint)IInspectableImpl.Vtbl.GetTrustLevel); + + Assert.IsPreinitialized(typeof(IStringableImpl)); + Assert.AreEqual((nuint)qi, (nuint)IStringableImpl.Vtbl.QueryInterface); + Assert.AreEqual((nuint)addref, (nuint)IStringableImpl.Vtbl.AddRef); + Assert.AreEqual((nuint)release, (nuint)IStringableImpl.Vtbl.Release); + Assert.AreEqual((nuint)(delegate* unmanaged[MemberFunction])&IInspectableImpl.GetIids, (nuint)IStringableImpl.Vtbl.GetIids); + Assert.AreEqual((nuint)(delegate* unmanaged[MemberFunction])&IInspectableImpl.GetTrustLevel, (nuint)IStringableImpl.Vtbl.GetTrustLevel); + Assert.AreEqual((nuint)(delegate* unmanaged[MemberFunction])&IStringableImpl.ToString, (nuint)IStringableImpl.Vtbl.ToString); + } +} + +class TestVTableManipulation +{ + public unsafe class TinyVtableAImpl + { + [FixedAddressValueType] + public static readonly ITinyVtableA Vtbl = Initialize(); + + private static ITinyVtableA Initialize() + { + ITinyVtableA result = default; + result.First = &First; + result.Second = &Second; + return result; + } + } + + public unsafe class TinyVtableBImpl + { + [FixedAddressValueType] + public static readonly ITinyVtableB Vtbl; + + public static nint AbiToProjectionVftablePtr => (nint)Unsafe.AsPointer(ref Unsafe.AsRef(in Vtbl)); + + static TinyVtableBImpl() + { + *(ITinyVtableA*)Unsafe.AsPointer(ref Vtbl) = TinyVtableAImpl.Vtbl; + Vtbl.Third = &Third; + } + } + + public unsafe class TinyVtableCImpl + { + [FixedAddressValueType] + public static readonly ITinyVtableC Vtbl; + + static TinyVtableCImpl() + { + *(ITinyVtableB*)Unsafe.AsPointer(ref Vtbl) = *(ITinyVtableB*)TinyVtableBImpl.AbiToProjectionVftablePtr; + Vtbl.Fourth = &Fourth; + } + } + + public unsafe struct ITinyVtableA + { + public delegate* First; + public delegate* Second; + } + + public unsafe struct ITinyVtableB + { + public delegate* First; + public delegate* Second; + public delegate* Third; + } + + public unsafe struct ITinyVtableC + { + public delegate* First; + public delegate* Second; + public delegate* Third; + public delegate* Fourth; + } + + static void First() { } + static void Second() { } + static void Third() { } + static void Fourth() { } + + public static unsafe void Run() + { + Assert.IsPreinitialized(typeof(TinyVtableAImpl)); + Assert.AreEqual((nuint)(delegate*)&First, (nuint)TinyVtableAImpl.Vtbl.First); + Assert.AreEqual((nuint)(delegate*)&Second, (nuint)TinyVtableAImpl.Vtbl.Second); + + Assert.IsPreinitialized(typeof(TinyVtableBImpl)); + Assert.AreEqual((nuint)(delegate*)&First, (nuint)TinyVtableBImpl.Vtbl.First); + Assert.AreEqual((nuint)(delegate*)&Second, (nuint)TinyVtableBImpl.Vtbl.Second); + Assert.AreEqual((nuint)(delegate*)&Third, (nuint)TinyVtableBImpl.Vtbl.Third); + + Assert.IsPreinitialized(typeof(TinyVtableCImpl)); + Assert.AreEqual((nuint)(delegate*)&First, (nuint)TinyVtableCImpl.Vtbl.First); + Assert.AreEqual((nuint)(delegate*)&Second, (nuint)TinyVtableCImpl.Vtbl.Second); + Assert.AreEqual((nuint)(delegate*)&Third, (nuint)TinyVtableCImpl.Vtbl.Third); + Assert.AreEqual((nuint)(delegate*)&Fourth, (nuint)TinyVtableCImpl.Vtbl.Fourth); + } +} + +class TestVTableNegativeScenarios +{ + class StoreIntoNint + { + public static readonly nint Field; + + unsafe static StoreIntoNint() + { + ITinyVtable result = default; + Field = (nint)Unsafe.AsPointer(ref Unsafe.AsRef(in result)); + } + } + + class WriteNonMethodPointer + { + public static readonly ITinyVtable Vtbl; + + static unsafe WriteNonMethodPointer() + { + Vtbl.First = (delegate*)123; + Vtbl.Second = (delegate*)456; + } + } + + unsafe class WriteNonMethodIndirect + { + public static readonly ITinyVtable Vtbl; + + static void Write(ref delegate* f, int val) => f = (delegate*)val; + + static unsafe WriteNonMethodIndirect() + { + Write(ref Vtbl.First, 123); + Write(ref Vtbl.Second, 456); + } + } + + static void First() { } + static void Second() { } + + public unsafe struct ITinyVtable + { + public delegate* First; + public delegate* Second; + } + + public static unsafe void Run() + { + Assert.IsLazyInitialized(typeof(StoreIntoNint)); + if (StoreIntoNint.Field == 0) + throw new Exception(); + + Assert.IsLazyInitialized(typeof(WriteNonMethodPointer)); + Assert.AreEqual(WriteNonMethodPointer.Vtbl.First, (void*)123); + Assert.AreEqual(WriteNonMethodPointer.Vtbl.Second, (void*)456); + + Assert.IsLazyInitialized(typeof(WriteNonMethodIndirect)); + Assert.AreEqual(WriteNonMethodIndirect.Vtbl.First, (void*)123); + Assert.AreEqual(WriteNonMethodIndirect.Vtbl.Second, (void*)456); + } +} + +unsafe class TestByRefFieldAddressEquality +{ + class ClassWithInitializedByRefs + { + [FixedAddressValueType] + public static readonly int MyByRef = 1234; + + public static nint HiddenGetAddress() => (nint)Unsafe.AsPointer(ref Unsafe.AsRef(in MyByRef)); + } + + class ClassWithUninitializedByRefs + { + [FixedAddressValueType] + public static readonly int MyByRef; + + public static nint HiddenGetAddress() => (nint)Unsafe.AsPointer(ref Unsafe.AsRef(in MyByRef)); + } + + class ClassTakingAddressOfInitialized + { + public static bool AreEqual = ClassWithInitializedByRefs.HiddenGetAddress() == ClassWithInitializedByRefs.HiddenGetAddress(); + } + + class ClassTakingAddressOfUninitialized + { + public static bool AreEqual = ClassWithUninitializedByRefs.HiddenGetAddress() == ClassWithUninitializedByRefs.HiddenGetAddress(); + } + + public static void Run() + { + Assert.IsPreinitialized(typeof(ClassTakingAddressOfInitialized)); + Assert.AreEqual(true, ClassTakingAddressOfInitialized.AreEqual); + + Assert.AreEqual(true, ClassTakingAddressOfUninitialized.AreEqual); } }