Skip to content
Original file line number Diff line number Diff line change
Expand Up @@ -409,37 +409,51 @@ static bool IsRootedCallback(IntPtr pObj)
#pragma warning restore CS8500
}

private ManagedObjectWrapper* _wrapper;
private object _wrappedObject;
private readonly ManagedObjectWrapper* _wrapper;
private readonly ManagedObjectWrapperReleaser _releaser;
private readonly object _wrappedObject;

public ManagedObjectWrapperHolder(ManagedObjectWrapper* wrapper, object wrappedObject)
{
_wrapper = wrapper;
_wrappedObject = wrappedObject;
}

public void InitializeHandle()
{
IntPtr handle = RuntimeImports.RhHandleAllocRefCounted(this);
IntPtr prev = Interlocked.CompareExchange(ref _wrapper->HolderHandle, handle, IntPtr.Zero);
if (prev != IntPtr.Zero)
{
RuntimeImports.RhHandleFree(handle);
}
_releaser = new ManagedObjectWrapperReleaser(wrapper);
_wrapper->HolderHandle = RuntimeImports.RhHandleAllocRefCounted(this);
}

public unsafe IntPtr ComIp => _wrapper->As(in ComWrappers.IID_IUnknown);

public object WrappedObject => _wrappedObject;

public uint AddRef() => _wrapper->AddRef();
}

internal unsafe class ManagedObjectWrapperReleaser
{
private ManagedObjectWrapper* _wrapper;

public ManagedObjectWrapperReleaser(ManagedObjectWrapper* wrapper)
{
_wrapper = wrapper;
}

~ManagedObjectWrapperHolder()
~ManagedObjectWrapperReleaser()
{
IntPtr refCountedHandle = _wrapper->HolderHandle;
if (refCountedHandle != IntPtr.Zero && RuntimeImports.RhHandleGet(refCountedHandle) != null)
{
// The ManagedObjectWrapperHolder has not been fully collected, so it is still
// potentially reachable via the Conditional Weak Table.
// Keep ourselves alive in case the wrapped object is resurrected.
GC.ReRegisterForFinalize(this);
return;
}

// Release GC handle created when MOW was built.
if (_wrapper->Destroy())
{
NativeMemory.Free(_wrapper);
_wrapper = null;
}
else
{
Expand Down Expand Up @@ -531,7 +545,7 @@ public unsafe IntPtr GetOrCreateComInterfaceForObject(object instance, CreateCom
ManagedObjectWrapper* value = CreateCCW(c, flags);
return new ManagedObjectWrapperHolder(value, c);
});
ccwValue.InitializeHandle();
ccwValue.AddRef();
return ccwValue.ComIp;
}

Expand Down Expand Up @@ -581,7 +595,7 @@ public unsafe IntPtr GetOrCreateComInterfaceForObject(object instance, CreateCom
}

mow->HolderHandle = IntPtr.Zero;
mow->RefCount = 1;
mow->RefCount = 0;
mow->UserDefinedCount = userDefinedCount;
mow->UserDefined = userDefined;
mow->Flags = (CreateComInterfaceFlagsEx)flags;
Expand Down
66 changes: 61 additions & 5 deletions src/tests/Interop/COM/ComWrappers/API/Program.cs
Original file line number Diff line number Diff line change
Expand Up @@ -319,20 +319,76 @@ unsafe static void CallSetValue(TestComWrappers wrappers, Test testInstance, int
Assert.NotEqual(IntPtr.Zero, nativeInstance);

var iid = typeof(ITest).GUID;
IntPtr itestPtr;
nint itestPtr;
Assert.Equal(0, Marshal.QueryInterface(nativeInstance, ref iid, out itestPtr));

var inst = Marshal.PtrToStructure<VtblPtr>(itestPtr);
var vtbl = Marshal.PtrToStructure<ITestVtbl>(inst.Vtbl);
var setValue = (delegate* unmanaged<IntPtr, int, int>)vtbl.SetValue;
var inst = (ComWrappers.ComInterfaceDispatch*)itestPtr;
var vtbl = (ITestVtbl*)(inst->Vtable);
var setValue = (delegate* unmanaged<nint, int, int>)vtbl->SetValue;

Assert.Equal(0, setValue(itestPtr, value));
Assert.Equal(value, testInstance.GetValue());

// release for QueryInterface
Assert.Equal(1, Marshal.Release(itestPtr));
// release for GetOrCreateComInterfaceForObject
Assert.Equal(0, Marshal.Release(itestPtr));
Assert.Equal(0, Marshal.Release(nativeInstance));
}
}

[MethodImpl(MethodImplOptions.NoInlining)]
[Fact]
public void ValidateResurrection()
{
Console.WriteLine($"Running {nameof(ValidateResurrection)}...");

var wrappers = new TestComWrappers();

try
{
CreateResurrectingTestInstance(wrappers);

ForceGC();

CallSetValue(wrappers);
}
finally
{
Test.Resurrected = null;
}

[MethodImpl(MethodImplOptions.NoInlining)]
static void CreateResurrectingTestInstance(ComWrappers wrapper)
{
Test testInstance = new Test()
{
EnableResurrection = true,
};
IntPtr nativeInstance = wrapper.GetOrCreateComInterfaceForObject(testInstance, CreateComInterfaceFlags.None);
Assert.Equal(0, Marshal.Release(nativeInstance));
}

unsafe static void CallSetValue(ComWrappers wrappers)
{
Assert.NotEqual(null, Test.Resurrected);
IntPtr nativeInstance = wrappers.GetOrCreateComInterfaceForObject(Test.Resurrected, CreateComInterfaceFlags.None);
Assert.NotEqual(IntPtr.Zero, nativeInstance);

var iid = typeof(ITest).GUID;
nint itestPtr;
Assert.Equal(0, Marshal.QueryInterface(nativeInstance, ref iid, out itestPtr));

var inst = (ComWrappers.ComInterfaceDispatch*)itestPtr;
var vtbl = (ITestVtbl*)(inst->Vtable);
var setValue = (delegate* unmanaged<nint, int, int>)vtbl->SetValue;

Assert.Equal(0, setValue(itestPtr, 42));
Assert.Equal(42, Test.Resurrected.GetValue());

// release for QueryInterface
Assert.Equal(1, Marshal.Release(itestPtr));
// release for GetOrCreateComInterfaceForObject
Assert.Equal(0, Marshal.Release(nativeInstance));
}
}

Expand Down
10 changes: 9 additions & 1 deletion src/tests/Interop/COM/ComWrappers/Common.cs
Original file line number Diff line number Diff line change
Expand Up @@ -19,17 +19,25 @@ interface ITest

class Test : ITest, ICustomQueryInterface
{
public static Test Resurrected;
public static int InstanceCount = 0;

private int id;
private int value = -1;
public Test() { id = Interlocked.Increment(ref InstanceCount); }
~Test() { Interlocked.Decrement(ref InstanceCount); id = -1; }
~Test()
{
Interlocked.Decrement(ref InstanceCount);
id = -1;
if (EnableResurrection)
Resurrected = this;
}

public void SetValue(int i) => this.value = i;
public int GetValue() => this.value;

public bool EnableICustomQueryInterface { get; set; } = false;
public bool EnableResurrection { get; set; } = false;
public Guid ICustomQueryInterface_GetInterfaceIID { get; set; }
public IntPtr ICustomQueryInterface_GetInterfaceResult { get; set; }

Expand Down