diff --git a/src/coreclr/nativeaot/System.Private.CoreLib/src/System/Runtime/InteropServices/ComWrappers.NativeAot.cs b/src/coreclr/nativeaot/System.Private.CoreLib/src/System/Runtime/InteropServices/ComWrappers.NativeAot.cs index 0bb2fb389dd13d..85fa9a7277f5be 100644 --- a/src/coreclr/nativeaot/System.Private.CoreLib/src/System/Runtime/InteropServices/ComWrappers.NativeAot.cs +++ b/src/coreclr/nativeaot/System.Private.CoreLib/src/System/Runtime/InteropServices/ComWrappers.NativeAot.cs @@ -409,23 +409,16 @@ static bool IsRootedCallback(IntPtr pObj) #pragma warning restore CS8500 } - private ManagedObjectWrapper* _wrapper; - private object _wrappedObject; + private readonly ManagedObjectWrapper* _wrapper; + private readonly ManagedObjectWrapperReleaser _releaser; + private readonly object _wrappedObject; public ManagedObjectWrapperHolder(ManagedObjectWrapper* wrapper, object wrappedObject) { _wrapper = wrapper; _wrappedObject = wrappedObject; - } - - public void InitializeHandle() - { - IntPtr handle = RuntimeImports.RhHandleAllocRefCounted(this); - IntPtr prev = Interlocked.CompareExchange(ref _wrapper->HolderHandle, handle, IntPtr.Zero); - if (prev != IntPtr.Zero) - { - RuntimeImports.RhHandleFree(handle); - } + _releaser = new ManagedObjectWrapperReleaser(wrapper); + _wrapper->HolderHandle = RuntimeImports.RhHandleAllocRefCounted(this); } public unsafe IntPtr ComIp => _wrapper->As(in ComWrappers.IID_IUnknown); @@ -433,13 +426,34 @@ public void InitializeHandle() public object WrappedObject => _wrappedObject; public uint AddRef() => _wrapper->AddRef(); + } + + internal unsafe class ManagedObjectWrapperReleaser + { + private ManagedObjectWrapper* _wrapper; + + public ManagedObjectWrapperReleaser(ManagedObjectWrapper* wrapper) + { + _wrapper = wrapper; + } - ~ManagedObjectWrapperHolder() + ~ManagedObjectWrapperReleaser() { + IntPtr refCountedHandle = _wrapper->HolderHandle; + if (refCountedHandle != IntPtr.Zero && RuntimeImports.RhHandleGet(refCountedHandle) != null) + { + // The ManagedObjectWrapperHolder has not been fully collected, so it is still + // potentially reachable via the Conditional Weak Table. + // Keep ourselves alive in case the wrapped object is resurrected. + GC.ReRegisterForFinalize(this); + return; + } + // Release GC handle created when MOW was built. if (_wrapper->Destroy()) { NativeMemory.Free(_wrapper); + _wrapper = null; } else { @@ -531,7 +545,7 @@ public unsafe IntPtr GetOrCreateComInterfaceForObject(object instance, CreateCom ManagedObjectWrapper* value = CreateCCW(c, flags); return new ManagedObjectWrapperHolder(value, c); }); - ccwValue.InitializeHandle(); + ccwValue.AddRef(); return ccwValue.ComIp; } @@ -581,7 +595,7 @@ public unsafe IntPtr GetOrCreateComInterfaceForObject(object instance, CreateCom } mow->HolderHandle = IntPtr.Zero; - mow->RefCount = 1; + mow->RefCount = 0; mow->UserDefinedCount = userDefinedCount; mow->UserDefined = userDefined; mow->Flags = (CreateComInterfaceFlagsEx)flags; diff --git a/src/tests/Interop/COM/ComWrappers/API/Program.cs b/src/tests/Interop/COM/ComWrappers/API/Program.cs index 56c2748087e37f..fef40e8d5d0d1d 100644 --- a/src/tests/Interop/COM/ComWrappers/API/Program.cs +++ b/src/tests/Interop/COM/ComWrappers/API/Program.cs @@ -319,12 +319,12 @@ unsafe static void CallSetValue(TestComWrappers wrappers, Test testInstance, int Assert.NotEqual(IntPtr.Zero, nativeInstance); var iid = typeof(ITest).GUID; - IntPtr itestPtr; + nint itestPtr; Assert.Equal(0, Marshal.QueryInterface(nativeInstance, ref iid, out itestPtr)); - var inst = Marshal.PtrToStructure(itestPtr); - var vtbl = Marshal.PtrToStructure(inst.Vtbl); - var setValue = (delegate* unmanaged)vtbl.SetValue; + var inst = (ComWrappers.ComInterfaceDispatch*)itestPtr; + var vtbl = (ITestVtbl*)(inst->Vtable); + var setValue = (delegate* unmanaged)vtbl->SetValue; Assert.Equal(0, setValue(itestPtr, value)); Assert.Equal(value, testInstance.GetValue()); @@ -332,7 +332,63 @@ unsafe static void CallSetValue(TestComWrappers wrappers, Test testInstance, int // release for QueryInterface Assert.Equal(1, Marshal.Release(itestPtr)); // release for GetOrCreateComInterfaceForObject - Assert.Equal(0, Marshal.Release(itestPtr)); + Assert.Equal(0, Marshal.Release(nativeInstance)); + } + } + + [MethodImpl(MethodImplOptions.NoInlining)] + [Fact] + public void ValidateResurrection() + { + Console.WriteLine($"Running {nameof(ValidateResurrection)}..."); + + var wrappers = new TestComWrappers(); + + try + { + CreateResurrectingTestInstance(wrappers); + + ForceGC(); + + CallSetValue(wrappers); + } + finally + { + Test.Resurrected = null; + } + + [MethodImpl(MethodImplOptions.NoInlining)] + static void CreateResurrectingTestInstance(ComWrappers wrapper) + { + Test testInstance = new Test() + { + EnableResurrection = true, + }; + IntPtr nativeInstance = wrapper.GetOrCreateComInterfaceForObject(testInstance, CreateComInterfaceFlags.None); + Assert.Equal(0, Marshal.Release(nativeInstance)); + } + + unsafe static void CallSetValue(ComWrappers wrappers) + { + Assert.NotEqual(null, Test.Resurrected); + IntPtr nativeInstance = wrappers.GetOrCreateComInterfaceForObject(Test.Resurrected, CreateComInterfaceFlags.None); + Assert.NotEqual(IntPtr.Zero, nativeInstance); + + var iid = typeof(ITest).GUID; + nint itestPtr; + Assert.Equal(0, Marshal.QueryInterface(nativeInstance, ref iid, out itestPtr)); + + var inst = (ComWrappers.ComInterfaceDispatch*)itestPtr; + var vtbl = (ITestVtbl*)(inst->Vtable); + var setValue = (delegate* unmanaged)vtbl->SetValue; + + Assert.Equal(0, setValue(itestPtr, 42)); + Assert.Equal(42, Test.Resurrected.GetValue()); + + // release for QueryInterface + Assert.Equal(1, Marshal.Release(itestPtr)); + // release for GetOrCreateComInterfaceForObject + Assert.Equal(0, Marshal.Release(nativeInstance)); } } diff --git a/src/tests/Interop/COM/ComWrappers/Common.cs b/src/tests/Interop/COM/ComWrappers/Common.cs index 41fdf348053dba..748c5f9c5c701e 100644 --- a/src/tests/Interop/COM/ComWrappers/Common.cs +++ b/src/tests/Interop/COM/ComWrappers/Common.cs @@ -19,17 +19,25 @@ interface ITest class Test : ITest, ICustomQueryInterface { + public static Test Resurrected; public static int InstanceCount = 0; private int id; private int value = -1; public Test() { id = Interlocked.Increment(ref InstanceCount); } - ~Test() { Interlocked.Decrement(ref InstanceCount); id = -1; } + ~Test() + { + Interlocked.Decrement(ref InstanceCount); + id = -1; + if (EnableResurrection) + Resurrected = this; + } public void SetValue(int i) => this.value = i; public int GetValue() => this.value; public bool EnableICustomQueryInterface { get; set; } = false; + public bool EnableResurrection { get; set; } = false; public Guid ICustomQueryInterface_GetInterfaceIID { get; set; } public IntPtr ICustomQueryInterface_GetInterfaceResult { get; set; }