From 616d53fca1be7ca03c658e5c07ad6201e0537cc9 Mon Sep 17 00:00:00 2001 From: Jeremy Koritzinsky Date: Wed, 17 May 2023 16:39:47 -0700 Subject: [PATCH 1/9] Only call Free in unmanaged->managed stubs when ownership has been transfered to the callee Fixes #85795 --- ...ributedMarshallingModelGeneratorFactory.cs | 7 +- .../MarshalAsMarshallingGeneratorFactory.cs | 2 +- .../StatelessMarshallingStrategy.cs | 87 ++++++++++ .../NativeToManagedStubCodeContext.cs | 2 +- .../IDerivedTests.cs | 9 +- .../ImplicitThisTests.cs | 25 ++- ...nmanagedToManagedCustomMarshallingTests.cs | 163 ++++++++++++++++++ .../NativeExports/VirtualMethodTables.cs | 19 ++ 8 files changed, 298 insertions(+), 16 deletions(-) create mode 100644 src/libraries/System.Runtime.InteropServices/tests/ComInterfaceGenerator.Tests/UnmanagedToManagedCustomMarshallingTests.cs diff --git a/src/libraries/System.Runtime.InteropServices/gen/Microsoft.Interop.SourceGeneration/Marshalling/AttributedMarshallingModelGeneratorFactory.cs b/src/libraries/System.Runtime.InteropServices/gen/Microsoft.Interop.SourceGeneration/Marshalling/AttributedMarshallingModelGeneratorFactory.cs index 03b90462ea3013..2266edee6d10b4 100644 --- a/src/libraries/System.Runtime.InteropServices/gen/Microsoft.Interop.SourceGeneration/Marshalling/AttributedMarshallingModelGeneratorFactory.cs +++ b/src/libraries/System.Runtime.InteropServices/gen/Microsoft.Interop.SourceGeneration/Marshalling/AttributedMarshallingModelGeneratorFactory.cs @@ -245,7 +245,12 @@ private IMarshallingGenerator CreateCustomNativeTypeMarshaller(TypePositionInfo marshallingStrategy = new StatelessCallerAllocatedBufferMarshalling(marshallingStrategy, marshallerData.MarshallerType.Syntax, marshallerData.BufferElementType.Syntax, isLinearCollectionMarshalling: false); if (marshallerData.Shape.HasFlag(MarshallerShape.Free)) - marshallingStrategy = new StatelessFreeMarshalling(marshallingStrategy, marshallerData.MarshallerType.Syntax); + { + if (context.Direction == MarshalDirection.ManagedToUnmanaged) + marshallingStrategy = new StatelessFreeMarshalling(marshallingStrategy, marshallerData.MarshallerType.Syntax); + else if (info.RefKind == RefKind.Ref) + marshallingStrategy = new StatelessByRefFreeMarshalling(marshallingStrategy, marshallerData.MarshallerType.Syntax); + } } IMarshallingGenerator marshallingGenerator = new CustomTypeMarshallingGenerator(marshallingStrategy, enableByValueContentsMarshalling: false); diff --git a/src/libraries/System.Runtime.InteropServices/gen/Microsoft.Interop.SourceGeneration/Marshalling/MarshalAsMarshallingGeneratorFactory.cs b/src/libraries/System.Runtime.InteropServices/gen/Microsoft.Interop.SourceGeneration/Marshalling/MarshalAsMarshallingGeneratorFactory.cs index 6bf87839e5846d..1790e588a8e790 100644 --- a/src/libraries/System.Runtime.InteropServices/gen/Microsoft.Interop.SourceGeneration/Marshalling/MarshalAsMarshallingGeneratorFactory.cs +++ b/src/libraries/System.Runtime.InteropServices/gen/Microsoft.Interop.SourceGeneration/Marshalling/MarshalAsMarshallingGeneratorFactory.cs @@ -88,7 +88,7 @@ public IMarshallingGenerator Create( return s_delegate; case { MarshallingAttributeInfo: SafeHandleMarshallingInfo(_, bool isAbstract) }: - if (!context.AdditionalTemporaryStateLivesAcrossStages) + if (!context.AdditionalTemporaryStateLivesAcrossStages || context.Direction != MarshalDirection.ManagedToUnmanaged) { throw new MarshallingNotSupportedException(info, context); } diff --git a/src/libraries/System.Runtime.InteropServices/gen/Microsoft.Interop.SourceGeneration/Marshalling/StatelessMarshallingStrategy.cs b/src/libraries/System.Runtime.InteropServices/gen/Microsoft.Interop.SourceGeneration/Marshalling/StatelessMarshallingStrategy.cs index 94ea82a31b6d0a..000cba627ef811 100644 --- a/src/libraries/System.Runtime.InteropServices/gen/Microsoft.Interop.SourceGeneration/Marshalling/StatelessMarshallingStrategy.cs +++ b/src/libraries/System.Runtime.InteropServices/gen/Microsoft.Interop.SourceGeneration/Marshalling/StatelessMarshallingStrategy.cs @@ -295,6 +295,93 @@ public IEnumerable GenerateCleanupStatements(TypePositionInfo i public bool UsesNativeIdentifier(TypePositionInfo info, StubCodeContext context) => _innerMarshaller.UsesNativeIdentifier(info, context); } + internal sealed class StatelessByRefFreeMarshalling : ICustomTypeMarshallingStrategy + { + private readonly ICustomTypeMarshallingStrategy _innerMarshaller; + private readonly TypeSyntax _marshallerType; + + public StatelessByRefFreeMarshalling(ICustomTypeMarshallingStrategy innerMarshaller, TypeSyntax marshallerType) + { + _innerMarshaller = innerMarshaller; + _marshallerType = marshallerType; + } + + public ManagedTypeInfo AsNativeType(TypePositionInfo info) => _innerMarshaller.AsNativeType(info); + + public IEnumerable GenerateCleanupStatements(TypePositionInfo info, StubCodeContext context) + { + foreach (StatementSyntax statement in _innerMarshaller.GenerateCleanupStatements(info, context)) + { + yield return statement; + } + // if () + // .Free(); + yield return IfStatement( + IdentifierName(context.GetAdditionalIdentifier(info, "freeUnmanaged")), + ExpressionStatement( + InvocationExpression( + MemberAccessExpression(SyntaxKind.SimpleMemberAccessExpression, + _marshallerType, + IdentifierName(ShapeMemberNames.Free)), + ArgumentList(SingletonSeparatedList( + Argument(IdentifierName(context.GetAdditionalIdentifier(info, "original")))))))); + } + + public IEnumerable GenerateGuaranteedUnmarshalStatements(TypePositionInfo info, StubCodeContext context) => _innerMarshaller.GenerateGuaranteedUnmarshalStatements(info, context); + public IEnumerable GenerateMarshalStatements(TypePositionInfo info, StubCodeContext context) => _innerMarshaller.GenerateMarshalStatements(info, context); + public IEnumerable GenerateNotifyForSuccessfulInvokeStatements(TypePositionInfo info, StubCodeContext context) => _innerMarshaller.GenerateNotifyForSuccessfulInvokeStatements(info, context); + public IEnumerable GeneratePinnedMarshalStatements(TypePositionInfo info, StubCodeContext context) => _innerMarshaller.GeneratePinnedMarshalStatements(info, context); + public IEnumerable GeneratePinStatements(TypePositionInfo info, StubCodeContext context) => _innerMarshaller.GeneratePinStatements(info, context); + public IEnumerable GenerateSetupStatements(TypePositionInfo info, StubCodeContext context) + { + foreach (StatementSyntax statement in _innerMarshaller.GenerateSetupStatements(info, context)) + { + yield return statement; + } + + // bool = false; + yield return LocalDeclarationStatement( + VariableDeclaration( + PredefinedType(Token(SyntaxKind.BoolKeyword)), + SingletonSeparatedList( + VariableDeclarator( + Identifier(context.GetAdditionalIdentifier(info, "freeUnmanaged")), + null, + EqualsValueClause( + LiteralExpression(SyntaxKind.FalseLiteralExpression)))))); + + // = ; + yield return LocalDeclarationStatement( + VariableDeclaration( + AsNativeType(info).Syntax, + SingletonSeparatedList( + VariableDeclarator( + Identifier(context.GetAdditionalIdentifier(info, "original")), + null, + EqualsValueClause( + IdentifierName(context.GetIdentifiers(info).native)))))); + } + + public IEnumerable GenerateUnmarshalCaptureStatements(TypePositionInfo info, StubCodeContext context) + { + foreach (StatementSyntax statement in _innerMarshaller.GenerateUnmarshalCaptureStatements(info, context)) + { + yield return statement; + } + + // Now that we've captured the new value to pass to the caller, we need to make sure that we free the old one. + + // = true; + yield return ExpressionStatement( + AssignmentExpression(SyntaxKind.SimpleAssignmentExpression, + IdentifierName(context.GetAdditionalIdentifier(info, "freeUnmanaged")), + LiteralExpression(SyntaxKind.TrueLiteralExpression))); + } + + public IEnumerable GenerateUnmarshalStatements(TypePositionInfo info, StubCodeContext context) => _innerMarshaller.GenerateUnmarshalStatements(info, context); + public bool UsesNativeIdentifier(TypePositionInfo info, StubCodeContext context) => _innerMarshaller.UsesNativeIdentifier(info, context); + } + /// /// Marshaller type that enables allocating space for marshalling a linear collection using a marshaller that implements the LinearCollection marshalling spec. /// diff --git a/src/libraries/System.Runtime.InteropServices/gen/Microsoft.Interop.SourceGeneration/NativeToManagedStubCodeContext.cs b/src/libraries/System.Runtime.InteropServices/gen/Microsoft.Interop.SourceGeneration/NativeToManagedStubCodeContext.cs index 3dbd182406643d..7940f6a15fd233 100644 --- a/src/libraries/System.Runtime.InteropServices/gen/Microsoft.Interop.SourceGeneration/NativeToManagedStubCodeContext.cs +++ b/src/libraries/System.Runtime.InteropServices/gen/Microsoft.Interop.SourceGeneration/NativeToManagedStubCodeContext.cs @@ -11,7 +11,7 @@ public sealed record NativeToManagedStubCodeContext : StubCodeContext { public override bool SingleFrameSpansNativeContext => false; - public override bool AdditionalTemporaryStateLivesAcrossStages => false; + public override bool AdditionalTemporaryStateLivesAcrossStages => true; private readonly TargetFramework _framework; private readonly Version _frameworkVersion; diff --git a/src/libraries/System.Runtime.InteropServices/tests/ComInterfaceGenerator.Tests/IDerivedTests.cs b/src/libraries/System.Runtime.InteropServices/tests/ComInterfaceGenerator.Tests/IDerivedTests.cs index 937ba3de524ef2..68c7420d0b3cbc 100644 --- a/src/libraries/System.Runtime.InteropServices/tests/ComInterfaceGenerator.Tests/IDerivedTests.cs +++ b/src/libraries/System.Runtime.InteropServices/tests/ComInterfaceGenerator.Tests/IDerivedTests.cs @@ -48,10 +48,9 @@ public unsafe void CallBaseInterfaceMethod_EnsureQiCalledOnce() iface.SetInt(5); Assert.Equal(5, iface.GetInt()); - // https://github.com/dotnet/runtime/issues/85795 - //Assert.Equal("myName", iface.GetName()); - //iface.SetName("updated"); - //Assert.Equal("updated", iface.GetName()); + Assert.Equal("myName", iface.GetName()); + iface.SetName("updated"); + Assert.Equal("updated", iface.GetName()); var iUnknownStrategyProperty = typeof(ComObject).GetProperty("IUnknownStrategy", BindingFlags.NonPublic | BindingFlags.Instance); @@ -67,7 +66,7 @@ partial class DerivedImpl : IDerived { int data = 3; string myName = "myName"; - public void DoThingWithString([MarshalUsing(typeof(Utf16StringMarshaller))] string name) => throw new NotImplementedException(); + public void DoThingWithString(string name) => throw new NotImplementedException(); public int GetInt() => data; diff --git a/src/libraries/System.Runtime.InteropServices/tests/ComInterfaceGenerator.Tests/ImplicitThisTests.cs b/src/libraries/System.Runtime.InteropServices/tests/ComInterfaceGenerator.Tests/ImplicitThisTests.cs index e6214e759450e0..d1995087f9279f 100644 --- a/src/libraries/System.Runtime.InteropServices/tests/ComInterfaceGenerator.Tests/ImplicitThisTests.cs +++ b/src/libraries/System.Runtime.InteropServices/tests/ComInterfaceGenerator.Tests/ImplicitThisTests.cs @@ -5,6 +5,7 @@ using System.Runtime.CompilerServices; using System.Runtime.InteropServices; using System.Runtime.InteropServices.Marshalling; +using System.Threading; using ComInterfaceGenerator.Tests; using Xunit; @@ -35,6 +36,8 @@ static void* IUnmanagedInterfaceType.VirtualMethodTableManagedImplementation int GetData(); [VirtualMethodIndex(1, ImplicitThisParameter = true)] void SetData(int x); + [VirtualMethodIndex(2, ImplicitThisParameter = true)] + void ExchangeData(ref int x); } [NativeMarshalling(typeof(NativeObjectMarshaller))] @@ -105,16 +108,21 @@ public unsafe void ValidateImplicitThisUnmanagedToManagedFunctionCallsSucceed() void* wrapper = VTableGCHandlePair.Allocate(impl); - Assert.Equal(startingValue, NativeExportsNE.ImplicitThis.GetNativeObjectData(wrapper)); - NativeExportsNE.ImplicitThis.SetNativeObjectData(wrapper, newValue); - Assert.Equal(newValue, NativeExportsNE.ImplicitThis.GetNativeObjectData(wrapper)); - // Verify that we actually updated the managed instance. - Assert.Equal(newValue, impl.GetData()); - - VTableGCHandlePair.Free(wrapper); + try + { + Assert.Equal(startingValue, NativeExportsNE.ImplicitThis.GetNativeObjectData(wrapper)); + NativeExportsNE.ImplicitThis.SetNativeObjectData(wrapper, newValue); + Assert.Equal(newValue, NativeExportsNE.ImplicitThis.GetNativeObjectData(wrapper)); + // Verify that we actually updated the managed instance. + Assert.Equal(newValue, impl.GetData()); + } + finally + { + VTableGCHandlePair.Free(wrapper); + } } - class ManagedObjectImplementation : NativeExportsNE.ImplicitThis.INativeObject + sealed class ManagedObjectImplementation : NativeExportsNE.ImplicitThis.INativeObject { private int _data; @@ -123,6 +131,7 @@ public ManagedObjectImplementation(int value) _data = value; } + public void ExchangeData(ref int x) => x = Interlocked.Exchange(ref _data, x); public int GetData() => _data; public void SetData(int x) => _data = x; } diff --git a/src/libraries/System.Runtime.InteropServices/tests/ComInterfaceGenerator.Tests/UnmanagedToManagedCustomMarshallingTests.cs b/src/libraries/System.Runtime.InteropServices/tests/ComInterfaceGenerator.Tests/UnmanagedToManagedCustomMarshallingTests.cs new file mode 100644 index 00000000000000..d13c0544dd5fe9 --- /dev/null +++ b/src/libraries/System.Runtime.InteropServices/tests/ComInterfaceGenerator.Tests/UnmanagedToManagedCustomMarshallingTests.cs @@ -0,0 +1,163 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. + +using System; +using System.Collections.Generic; +using System.Linq; +using System.Runtime.CompilerServices; +using System.Runtime.InteropServices; +using System.Runtime.InteropServices.Marshalling; +using System.Text; +using System.Threading; +using System.Threading.Tasks; +using SharedTypes; +using Xunit; +using static ComInterfaceGenerator.Tests.UnmanagedToManagedCustomMarshallingTests; + +namespace ComInterfaceGenerator.Tests +{ + internal unsafe partial class NativeExportsNE + { + internal partial class UnmanagedToManagedCustomMarshalling + { + [UnmanagedObjectUnwrapper>] + internal partial interface INativeObject : IUnmanagedInterfaceType + { + + private static void** s_vtable = (void**)RuntimeHelpers.AllocateTypeAssociatedMemory(typeof(INativeObject), sizeof(void*) * 2); + static void* IUnmanagedInterfaceType.VirtualMethodTableManagedImplementation + { + get + { + if (s_vtable[0] == null) + { + Native.PopulateUnmanagedVirtualMethodTable(s_vtable); + } + return s_vtable; + } + } + + [VirtualMethodIndex(0, ImplicitThisParameter = true)] + [return: MarshalUsing(typeof(IntWrapperMarshallerToIntWithFreeCounts))] + IntWrapper GetData(); + [VirtualMethodIndex(1, ImplicitThisParameter = true)] + void SetData([MarshalUsing(typeof(IntWrapperMarshallerToIntWithFreeCounts))] IntWrapper x); + [VirtualMethodIndex(2, ImplicitThisParameter = true)] + void ExchangeData([MarshalUsing(typeof(IntWrapperMarshallerToIntWithFreeCounts))] ref IntWrapper data); + } + + [NativeMarshalling(typeof(NativeObjectMarshaller))] + public class NativeObject : INativeObject.Native, IUnmanagedVirtualMethodTableProvider, IDisposable + { + private readonly void* _pointer; + + public NativeObject(void* pointer) + { + _pointer = pointer; + } + + public VirtualMethodTableInfo GetVirtualMethodTableInfoForKey(Type type) + { + Assert.Equal(typeof(INativeObject), type); + return new VirtualMethodTableInfo(_pointer, *(void***)_pointer); + } + + public void Dispose() + { + DeleteNativeObject(_pointer); + } + } + + [CustomMarshaller(typeof(NativeObject), MarshalMode.ManagedToUnmanagedOut, typeof(NativeObjectMarshaller))] + static class NativeObjectMarshaller + { + public static NativeObject ConvertToManaged(void* value) => new NativeObject(value); + } + + [LibraryImport(NativeExportsNE_Binary, EntryPoint = "new_native_object")] + public static partial NativeObject NewNativeObject(); + + [LibraryImport(NativeExportsNE_Binary, EntryPoint = "delete_native_object")] + public static partial void DeleteNativeObject(void* obj); + + [LibraryImport(NativeExportsNE_Binary, EntryPoint = "set_native_object_data")] + public static partial void SetNativeObjectData(void* obj, int data); + + [LibraryImport(NativeExportsNE_Binary, EntryPoint = "get_native_object_data")] + public static partial int GetNativeObjectData(void* obj); + + [LibraryImport(NativeExportsNE_Binary, EntryPoint = "exchange_native_object_data")] + public static partial int ExchangeNativeObjectData(void* obj, ref int x); + } + } + public class UnmanagedToManagedCustomMarshallingTests + { + [Fact] + public unsafe void ValidateImplicitThisUnmanagedToManagedFunctionCallsSucceed() + { + const int startingValue = 13; + const int newValue = 42; + + ManagedObjectImplementation impl = new ManagedObjectImplementation(startingValue); + + void* wrapper = VTableGCHandlePair.Allocate(impl); + + try + { + int freeCalls = IntWrapperMarshallerToIntWithFreeCounts.NumCallsToFree; + NativeExportsNE.UnmanagedToManagedCustomMarshalling.GetNativeObjectData(wrapper); + + Assert.Equal(freeCalls, IntWrapperMarshallerToIntWithFreeCounts.NumCallsToFree); + + NativeExportsNE.UnmanagedToManagedCustomMarshalling.SetNativeObjectData(wrapper, newValue); + Assert.Equal(freeCalls, IntWrapperMarshallerToIntWithFreeCounts.NumCallsToFree); + + int finalValue = 10; + + NativeExportsNE.UnmanagedToManagedCustomMarshalling.ExchangeNativeObjectData(wrapper, ref finalValue); + Assert.Equal(freeCalls + 1, IntWrapperMarshallerToIntWithFreeCounts.NumCallsToFree); + } + finally + { + VTableGCHandlePair.Free(wrapper); + } + } + + sealed class ManagedObjectImplementation : NativeExportsNE.UnmanagedToManagedCustomMarshalling.INativeObject + { + private IntWrapper _data; + + public ManagedObjectImplementation(int value) + { + _data = new() { i = value }; + } + + public void ExchangeData(ref IntWrapper x) => x = Interlocked.Exchange(ref _data, x); + public IntWrapper GetData() => _data; + public void SetData(IntWrapper x) => _data = x; + } + + + [CustomMarshaller(typeof(IntWrapper), MarshalMode.Default, typeof(IntWrapperMarshallerToIntWithFreeCounts))] + public static unsafe class IntWrapperMarshallerToIntWithFreeCounts + { + [ThreadStatic] + public static int NumCallsToFree = 0; + + public static int ConvertToUnmanaged(IntWrapper managed) + { + return managed.i; + } + + public static IntWrapper ConvertToManaged(int unmanaged) + { + return new IntWrapper { i = unmanaged }; + } + + public static void Free(int unmanaged) + { + NumCallsToFree++; + } + } + } +} diff --git a/src/libraries/System.Runtime.InteropServices/tests/TestAssets/NativeExports/VirtualMethodTables.cs b/src/libraries/System.Runtime.InteropServices/tests/TestAssets/NativeExports/VirtualMethodTables.cs index a286dd4de36147..f091146b7ed737 100644 --- a/src/libraries/System.Runtime.InteropServices/tests/TestAssets/NativeExports/VirtualMethodTables.cs +++ b/src/libraries/System.Runtime.InteropServices/tests/TestAssets/NativeExports/VirtualMethodTables.cs @@ -8,6 +8,7 @@ using System.Runtime.CompilerServices; using System.Runtime.InteropServices; using System.Text; +using System.Threading; using System.Threading.Tasks; namespace NativeExports @@ -48,6 +49,7 @@ public struct VirtualFunctionTable { public delegate* unmanaged getData; public delegate* unmanaged setData; + public delegate* unmanaged exchangeData; } public readonly VirtualFunctionTable* VTable; @@ -66,12 +68,14 @@ public struct VirtualFunctionTable // The order of functions here should match NativeObjectInterface.VirtualFunctionTable's members. public delegate* unmanaged getData; public delegate* unmanaged setData; + public delegate* unmanaged exchangeData; } static NativeObject() { VTablePointer = (VirtualFunctionTable*)RuntimeHelpers.AllocateTypeAssociatedMemory(typeof(NativeObject), sizeof(VirtualFunctionTable)); VTablePointer->getData = &GetData; VTablePointer->setData = &SetData; + VTablePointer->exchangeData = &ExchangeData; } private static readonly VirtualFunctionTable* VTablePointer; @@ -95,6 +99,14 @@ private static void SetData(NativeObject* obj, int value) { obj->Data = value; } + + [UnmanagedCallersOnly] + private static void ExchangeData(NativeObject* obj, int* value) + { + var temp = obj->Data; + obj->Data = *value; + *value = temp; + } } [UnmanagedCallersOnly(EntryPoint = "new_native_object")] @@ -127,5 +139,12 @@ public static int GetNativeObjectData([DNNE.C99Type("struct INativeObject*")] Na { return obj->VTable->getData(obj); } + + [UnmanagedCallersOnly(EntryPoint = "exchange_native_object_data")] + [DNNE.C99DeclCode("struct INativeObject;")] + public static void ExchangeNativeObjectData([DNNE.C99Type("struct INativeObject*")] NativeObjectInterface* obj, int* x) + { + obj->VTable->exchangeData(obj, x); + } } } From 53ea459758442c052b9ca8ff5e6eca9af405df8f Mon Sep 17 00:00:00 2001 From: Jeremy Koritzinsky Date: Wed, 17 May 2023 16:42:38 -0700 Subject: [PATCH 2/9] Add missing condition --- .../Marshalling/AttributedMarshallingModelGeneratorFactory.cs | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/libraries/System.Runtime.InteropServices/gen/Microsoft.Interop.SourceGeneration/Marshalling/AttributedMarshallingModelGeneratorFactory.cs b/src/libraries/System.Runtime.InteropServices/gen/Microsoft.Interop.SourceGeneration/Marshalling/AttributedMarshallingModelGeneratorFactory.cs index 2266edee6d10b4..49f96702cd5fc2 100644 --- a/src/libraries/System.Runtime.InteropServices/gen/Microsoft.Interop.SourceGeneration/Marshalling/AttributedMarshallingModelGeneratorFactory.cs +++ b/src/libraries/System.Runtime.InteropServices/gen/Microsoft.Interop.SourceGeneration/Marshalling/AttributedMarshallingModelGeneratorFactory.cs @@ -246,7 +246,7 @@ private IMarshallingGenerator CreateCustomNativeTypeMarshaller(TypePositionInfo if (marshallerData.Shape.HasFlag(MarshallerShape.Free)) { - if (context.Direction == MarshalDirection.ManagedToUnmanaged) + if (context.Direction == MarshalDirection.ManagedToUnmanaged || !context.AdditionalTemporaryStateLivesAcrossStages) marshallingStrategy = new StatelessFreeMarshalling(marshallingStrategy, marshallerData.MarshallerType.Syntax); else if (info.RefKind == RefKind.Ref) marshallingStrategy = new StatelessByRefFreeMarshalling(marshallingStrategy, marshallerData.MarshallerType.Syntax); From 0d8cf4cd86ce7c961c660f1a78f24244454faf49 Mon Sep 17 00:00:00 2001 From: Jeremy Koritzinsky Date: Fri, 19 May 2023 09:28:47 -0700 Subject: [PATCH 3/9] Fix ref parameter marshalling and by-value collection marshalling. Disable by-value [In,Out] on unmanged->managed for now as it's difficult to reason about. --- ...CollectionElementMarshallingCodeContext.cs | 1 + ...ributedMarshallingModelGeneratorFactory.cs | 33 +++++++++++-- .../CustomTypeMarshallingGenerator.cs | 2 +- .../StatelessMarshallingStrategy.cs | 46 +++++++++++-------- 4 files changed, 58 insertions(+), 24 deletions(-) diff --git a/src/libraries/System.Runtime.InteropServices/gen/Microsoft.Interop.SourceGeneration/LinearCollectionElementMarshallingCodeContext.cs b/src/libraries/System.Runtime.InteropServices/gen/Microsoft.Interop.SourceGeneration/LinearCollectionElementMarshallingCodeContext.cs index afa27b08e0e6d0..a11ce0ecefdef9 100644 --- a/src/libraries/System.Runtime.InteropServices/gen/Microsoft.Interop.SourceGeneration/LinearCollectionElementMarshallingCodeContext.cs +++ b/src/libraries/System.Runtime.InteropServices/gen/Microsoft.Interop.SourceGeneration/LinearCollectionElementMarshallingCodeContext.cs @@ -34,6 +34,7 @@ public LinearCollectionElementMarshallingCodeContext( _managedSpanIdentifier = managedSpanIdentifier; _nativeSpanIdentifier = nativeSpanIdentifier; ParentContext = parentContext; + Direction = ParentContext.Direction; } public override (TargetFramework framework, Version version) GetTargetFramework() diff --git a/src/libraries/System.Runtime.InteropServices/gen/Microsoft.Interop.SourceGeneration/Marshalling/AttributedMarshallingModelGeneratorFactory.cs b/src/libraries/System.Runtime.InteropServices/gen/Microsoft.Interop.SourceGeneration/Marshalling/AttributedMarshallingModelGeneratorFactory.cs index 49f96702cd5fc2..79d6d50e5203f7 100644 --- a/src/libraries/System.Runtime.InteropServices/gen/Microsoft.Interop.SourceGeneration/Marshalling/AttributedMarshallingModelGeneratorFactory.cs +++ b/src/libraries/System.Runtime.InteropServices/gen/Microsoft.Interop.SourceGeneration/Marshalling/AttributedMarshallingModelGeneratorFactory.cs @@ -246,10 +246,21 @@ private IMarshallingGenerator CreateCustomNativeTypeMarshaller(TypePositionInfo if (marshallerData.Shape.HasFlag(MarshallerShape.Free)) { - if (context.Direction == MarshalDirection.ManagedToUnmanaged || !context.AdditionalTemporaryStateLivesAcrossStages) + if (context.Direction == MarshalDirection.ManagedToUnmanaged) + { marshallingStrategy = new StatelessFreeMarshalling(marshallingStrategy, marshallerData.MarshallerType.Syntax); + } else if (info.RefKind == RefKind.Ref) - marshallingStrategy = new StatelessByRefFreeMarshalling(marshallingStrategy, marshallerData.MarshallerType.Syntax); + { + if (!context.AdditionalTemporaryStateLivesAcrossStages) + { + marshallingStrategy = new StatelessFreeMarshalling(marshallingStrategy, marshallerData.MarshallerType.Syntax); + } + else + { + marshallingStrategy = new StatelessByRefFreeMarshalling(marshallingStrategy, marshallerData.MarshallerType.Syntax); + } + } } } @@ -340,7 +351,23 @@ private IMarshallingGenerator CreateNativeCollectionMarshaller( } if (marshallerData.Shape.HasFlag(MarshallerShape.Free)) - marshallingStrategy = new StatelessFreeMarshalling(marshallingStrategy, marshallerTypeSyntax); + { + if (context.Direction == MarshalDirection.ManagedToUnmanaged) + { + marshallingStrategy = new StatelessFreeMarshalling(marshallingStrategy, marshallerTypeSyntax); + } + else if (info.RefKind == RefKind.Ref) + { + if (!context.AdditionalTemporaryStateLivesAcrossStages) + { + marshallingStrategy = new StatelessFreeMarshalling(marshallingStrategy, marshallerTypeSyntax); + } + else + { + marshallingStrategy = new StatelessByRefFreeMarshalling(marshallingStrategy, marshallerTypeSyntax); + } + } + } } IMarshallingGenerator marshallingGenerator = new CustomTypeMarshallingGenerator( diff --git a/src/libraries/System.Runtime.InteropServices/gen/Microsoft.Interop.SourceGeneration/Marshalling/CustomTypeMarshallingGenerator.cs b/src/libraries/System.Runtime.InteropServices/gen/Microsoft.Interop.SourceGeneration/Marshalling/CustomTypeMarshallingGenerator.cs index d5b2da82134908..2b54ccc5c72b70 100644 --- a/src/libraries/System.Runtime.InteropServices/gen/Microsoft.Interop.SourceGeneration/Marshalling/CustomTypeMarshallingGenerator.cs +++ b/src/libraries/System.Runtime.InteropServices/gen/Microsoft.Interop.SourceGeneration/Marshalling/CustomTypeMarshallingGenerator.cs @@ -107,7 +107,7 @@ public IEnumerable Generate(TypePositionInfo info, StubCodeCont public bool SupportsByValueMarshalKind(ByValueContentsMarshalKind marshalKind, StubCodeContext context) { - return _enableByValueContentsMarshalling; + return _enableByValueContentsMarshalling && context.Direction == MarshalDirection.ManagedToUnmanaged; } public bool UsesNativeIdentifier(TypePositionInfo info, StubCodeContext context) diff --git a/src/libraries/System.Runtime.InteropServices/gen/Microsoft.Interop.SourceGeneration/Marshalling/StatelessMarshallingStrategy.cs b/src/libraries/System.Runtime.InteropServices/gen/Microsoft.Interop.SourceGeneration/Marshalling/StatelessMarshallingStrategy.cs index 000cba627ef811..3b322d7e311f6c 100644 --- a/src/libraries/System.Runtime.InteropServices/gen/Microsoft.Interop.SourceGeneration/Marshalling/StatelessMarshallingStrategy.cs +++ b/src/libraries/System.Runtime.InteropServices/gen/Microsoft.Interop.SourceGeneration/Marshalling/StatelessMarshallingStrategy.cs @@ -297,6 +297,9 @@ public IEnumerable GenerateCleanupStatements(TypePositionInfo i internal sealed class StatelessByRefFreeMarshalling : ICustomTypeMarshallingStrategy { + private const string FreeUnmanagedIdentifier = "freeUnmanaged"; + private const string OriginalValueIdentifier = "original"; + private readonly ICustomTypeMarshallingStrategy _innerMarshaller; private readonly TypeSyntax _marshallerType; @@ -317,20 +320,37 @@ public IEnumerable GenerateCleanupStatements(TypePositionInfo i // if () // .Free(); yield return IfStatement( - IdentifierName(context.GetAdditionalIdentifier(info, "freeUnmanaged")), + IdentifierName(context.GetAdditionalIdentifier(info, FreeUnmanagedIdentifier)), ExpressionStatement( InvocationExpression( MemberAccessExpression(SyntaxKind.SimpleMemberAccessExpression, _marshallerType, IdentifierName(ShapeMemberNames.Free)), ArgumentList(SingletonSeparatedList( - Argument(IdentifierName(context.GetAdditionalIdentifier(info, "original")))))))); + Argument(IdentifierName(context.GetAdditionalIdentifier(info, OriginalValueIdentifier)))))))); } public IEnumerable GenerateGuaranteedUnmarshalStatements(TypePositionInfo info, StubCodeContext context) => _innerMarshaller.GenerateGuaranteedUnmarshalStatements(info, context); - public IEnumerable GenerateMarshalStatements(TypePositionInfo info, StubCodeContext context) => _innerMarshaller.GenerateMarshalStatements(info, context); + public IEnumerable GenerateMarshalStatements(TypePositionInfo info, StubCodeContext context) + { + foreach (StatementSyntax statement in _innerMarshaller.GenerateMarshalStatements(info, context)) + { + yield return statement; + } + + // Now that we've set the new value to pass to the caller on the identifier, we need to make sure that we free the old one. + // The caller will not see the old one any more, so it won't be able to free it. + + // = true; + yield return ExpressionStatement( + AssignmentExpression(SyntaxKind.SimpleAssignmentExpression, + IdentifierName(context.GetAdditionalIdentifier(info, FreeUnmanagedIdentifier)), + LiteralExpression(SyntaxKind.TrueLiteralExpression))); + } + public IEnumerable GenerateNotifyForSuccessfulInvokeStatements(TypePositionInfo info, StubCodeContext context) => _innerMarshaller.GenerateNotifyForSuccessfulInvokeStatements(info, context); public IEnumerable GeneratePinnedMarshalStatements(TypePositionInfo info, StubCodeContext context) => _innerMarshaller.GeneratePinnedMarshalStatements(info, context); + public IEnumerable GeneratePinStatements(TypePositionInfo info, StubCodeContext context) => _innerMarshaller.GeneratePinStatements(info, context); public IEnumerable GenerateSetupStatements(TypePositionInfo info, StubCodeContext context) { @@ -345,7 +365,7 @@ public IEnumerable GenerateSetupStatements(TypePositionInfo inf PredefinedType(Token(SyntaxKind.BoolKeyword)), SingletonSeparatedList( VariableDeclarator( - Identifier(context.GetAdditionalIdentifier(info, "freeUnmanaged")), + Identifier(context.GetAdditionalIdentifier(info, FreeUnmanagedIdentifier)), null, EqualsValueClause( LiteralExpression(SyntaxKind.FalseLiteralExpression)))))); @@ -356,27 +376,13 @@ public IEnumerable GenerateSetupStatements(TypePositionInfo inf AsNativeType(info).Syntax, SingletonSeparatedList( VariableDeclarator( - Identifier(context.GetAdditionalIdentifier(info, "original")), + Identifier(context.GetAdditionalIdentifier(info, OriginalValueIdentifier)), null, EqualsValueClause( IdentifierName(context.GetIdentifiers(info).native)))))); } - public IEnumerable GenerateUnmarshalCaptureStatements(TypePositionInfo info, StubCodeContext context) - { - foreach (StatementSyntax statement in _innerMarshaller.GenerateUnmarshalCaptureStatements(info, context)) - { - yield return statement; - } - - // Now that we've captured the new value to pass to the caller, we need to make sure that we free the old one. - - // = true; - yield return ExpressionStatement( - AssignmentExpression(SyntaxKind.SimpleAssignmentExpression, - IdentifierName(context.GetAdditionalIdentifier(info, "freeUnmanaged")), - LiteralExpression(SyntaxKind.TrueLiteralExpression))); - } + public IEnumerable GenerateUnmarshalCaptureStatements(TypePositionInfo info, StubCodeContext context) => _innerMarshaller.GenerateUnmarshalCaptureStatements(info, context); public IEnumerable GenerateUnmarshalStatements(TypePositionInfo info, StubCodeContext context) => _innerMarshaller.GenerateUnmarshalStatements(info, context); public bool UsesNativeIdentifier(TypePositionInfo info, StubCodeContext context) => _innerMarshaller.UsesNativeIdentifier(info, context); From c3a6fe4fadb8552c3e8c60d6de7390220a1c4630 Mon Sep 17 00:00:00 2001 From: Jeremy Koritzinsky Date: Fri, 19 May 2023 09:34:45 -0700 Subject: [PATCH 4/9] Add test for by-ref collections --- .../ImplicitThisTests.cs | 20 ++++- ...nmanagedToManagedCustomMarshallingTests.cs | 78 ++++++++++++++++++- .../NativeExports/VirtualMethodTables.cs | 48 ++++++++++++ 3 files changed, 143 insertions(+), 3 deletions(-) diff --git a/src/libraries/System.Runtime.InteropServices/tests/ComInterfaceGenerator.Tests/ImplicitThisTests.cs b/src/libraries/System.Runtime.InteropServices/tests/ComInterfaceGenerator.Tests/ImplicitThisTests.cs index d1995087f9279f..e4b68e0e407e65 100644 --- a/src/libraries/System.Runtime.InteropServices/tests/ComInterfaceGenerator.Tests/ImplicitThisTests.cs +++ b/src/libraries/System.Runtime.InteropServices/tests/ComInterfaceGenerator.Tests/ImplicitThisTests.cs @@ -2,11 +2,11 @@ // The .NET Foundation licenses this file to you under the MIT license. using System; +using System.Linq; using System.Runtime.CompilerServices; using System.Runtime.InteropServices; using System.Runtime.InteropServices.Marshalling; using System.Threading; -using ComInterfaceGenerator.Tests; using Xunit; namespace ComInterfaceGenerator.Tests @@ -38,6 +38,16 @@ static void* IUnmanagedInterfaceType.VirtualMethodTableManagedImplementation void SetData(int x); [VirtualMethodIndex(2, ImplicitThisParameter = true)] void ExchangeData(ref int x); + [VirtualMethodIndex(3, ImplicitThisParameter = true)] + void SumAndSetData( + [MarshalUsing(CountElementName = nameof(numValues))] int[] values, + int numValues, + out int oldValue); + [VirtualMethodIndex(4, ImplicitThisParameter = true)] + void SumAndSetData( + [MarshalUsing(CountElementName = nameof(numValues))] ref int[] values, + int numValues, + out int oldValue); } [NativeMarshalling(typeof(NativeObjectMarshaller))] @@ -134,6 +144,14 @@ public ManagedObjectImplementation(int value) public void ExchangeData(ref int x) => x = Interlocked.Exchange(ref _data, x); public int GetData() => _data; public void SetData(int x) => _data = x; + public void SumAndSetData([MarshalUsing(CountElementName = "numValues")] int[] values, int numValues, out int oldValue) + { + int value = values.Sum(); + oldValue = _data; + _data = value; + } + + public void SumAndSetData([MarshalUsing(CountElementName = "numValues")] ref int[] values, int numValues, out int oldValue) => SumAndSetData(values, numValues, out oldValue); } } } diff --git a/src/libraries/System.Runtime.InteropServices/tests/ComInterfaceGenerator.Tests/UnmanagedToManagedCustomMarshallingTests.cs b/src/libraries/System.Runtime.InteropServices/tests/ComInterfaceGenerator.Tests/UnmanagedToManagedCustomMarshallingTests.cs index d13c0544dd5fe9..6d9f30cdbbd822 100644 --- a/src/libraries/System.Runtime.InteropServices/tests/ComInterfaceGenerator.Tests/UnmanagedToManagedCustomMarshallingTests.cs +++ b/src/libraries/System.Runtime.InteropServices/tests/ComInterfaceGenerator.Tests/UnmanagedToManagedCustomMarshallingTests.cs @@ -44,6 +44,16 @@ static void* IUnmanagedInterfaceType.VirtualMethodTableManagedImplementation void SetData([MarshalUsing(typeof(IntWrapperMarshallerToIntWithFreeCounts))] IntWrapper x); [VirtualMethodIndex(2, ImplicitThisParameter = true)] void ExchangeData([MarshalUsing(typeof(IntWrapperMarshallerToIntWithFreeCounts))] ref IntWrapper data); + [VirtualMethodIndex(3, ImplicitThisParameter = true)] + void SumAndSetData( + [MarshalUsing(CountElementName = nameof(numValues)), MarshalUsing(typeof(IntWrapperMarshallerToIntWithFreeCounts), ElementIndirectionDepth = 1)] IntWrapper[] values123, + int numValues, + [MarshalUsing(typeof(IntWrapperMarshallerToIntWithFreeCounts))] out IntWrapper oldValue); + [VirtualMethodIndex(4, ImplicitThisParameter = true)] + void SumAndSetData( + [MarshalUsing(CountElementName = nameof(numValues)), MarshalUsing(typeof(IntWrapperMarshallerToIntWithFreeCounts), ElementIndirectionDepth = 1)] ref IntWrapper[] values123, + int numValues, + [MarshalUsing(typeof(IntWrapperMarshallerToIntWithFreeCounts))] out IntWrapper oldValue); } [NativeMarshalling(typeof(NativeObjectMarshaller))] @@ -88,12 +98,18 @@ static class NativeObjectMarshaller [LibraryImport(NativeExportsNE_Binary, EntryPoint = "exchange_native_object_data")] public static partial int ExchangeNativeObjectData(void* obj, ref int x); + + [LibraryImport(NativeExportsNE_Binary, EntryPoint = "sum_and_set_native_object_data")] + public static partial int SumAndSetNativeObjectData(void* obj, [MarshalUsing(CountElementName = nameof(numValues))] int[] arr, int numValues, out int oldValue); + + [LibraryImport(NativeExportsNE_Binary, EntryPoint = "sum_and_set_native_object_data_wth_ref")] + public static partial int SumAndSetNativeObjectData(void* obj, [MarshalUsing(CountElementName = nameof(numValues))] ref int[] arr, int numValues, out int oldValue); } } public class UnmanagedToManagedCustomMarshallingTests { [Fact] - public unsafe void ValidateImplicitThisUnmanagedToManagedFunctionCallsSucceed() + public unsafe void ValidateOnlyByRefStatelessFreed() { const int startingValue = 13; const int newValue = 42; @@ -123,6 +139,56 @@ public unsafe void ValidateImplicitThisUnmanagedToManagedFunctionCallsSucceed() } } + [Fact] + public unsafe void ValidateArrayElementsAndOutParameterNotFreed() + { + const int startingValue = 13; + + ManagedObjectImplementation impl = new ManagedObjectImplementation(startingValue); + + void* wrapper = VTableGCHandlePair.Allocate(impl); + + try + { + var values = new int[] { 1, 32, 63, 124, 255 }; + + int freeCalls = IntWrapperMarshallerToIntWithFreeCounts.NumCallsToFree; + + NativeExportsNE.UnmanagedToManagedCustomMarshalling.SumAndSetNativeObjectData(wrapper, values, values.Length, out int _); + + Assert.Equal(freeCalls, IntWrapperMarshallerToIntWithFreeCounts.NumCallsToFree); + } + finally + { + VTableGCHandlePair.Free(wrapper); + } + } + + [Fact] + public unsafe void ValidateArrayElementsByRefFreed() + { + const int startingValue = 13; + + ManagedObjectImplementation impl = new ManagedObjectImplementation(startingValue); + + void* wrapper = VTableGCHandlePair.Allocate(impl); + + try + { + var values = new int[] { 1, 32, 63, 124, 255 }; + + int freeCalls = IntWrapperMarshallerToIntWithFreeCounts.NumCallsToFree; + + NativeExportsNE.UnmanagedToManagedCustomMarshalling.SumAndSetNativeObjectData(wrapper, ref values, values.Length, out int _); + + Assert.Equal(freeCalls + values.Length, IntWrapperMarshallerToIntWithFreeCounts.NumCallsToFree); + } + finally + { + VTableGCHandlePair.Free(wrapper); + } + } + sealed class ManagedObjectImplementation : NativeExportsNE.UnmanagedToManagedCustomMarshalling.INativeObject { private IntWrapper _data; @@ -135,6 +201,14 @@ public ManagedObjectImplementation(int value) public void ExchangeData(ref IntWrapper x) => x = Interlocked.Exchange(ref _data, x); public IntWrapper GetData() => _data; public void SetData(IntWrapper x) => _data = x; + public void SumAndSetData(ref IntWrapper[] values, int numValues, out IntWrapper oldValue) + { + int value = values.Sum(value => value.i); + oldValue = _data; + _data = new() { i = value }; + } + + public void SumAndSetData([MarshalUsing(CountElementName = "numValues"), MarshalUsing(typeof(IntWrapperMarshallerToIntWithFreeCounts), ElementIndirectionDepth = 1)] IntWrapper[] values123, int numValues, [MarshalUsing(typeof(IntWrapperMarshallerToIntWithFreeCounts))] out IntWrapper oldValue) => SumAndSetData(values123, numValues, out oldValue); } @@ -154,7 +228,7 @@ public static IntWrapper ConvertToManaged(int unmanaged) return new IntWrapper { i = unmanaged }; } - public static void Free(int unmanaged) + public static void Free(int _) { NumCallsToFree++; } diff --git a/src/libraries/System.Runtime.InteropServices/tests/TestAssets/NativeExports/VirtualMethodTables.cs b/src/libraries/System.Runtime.InteropServices/tests/TestAssets/NativeExports/VirtualMethodTables.cs index f091146b7ed737..618f0c235f9613 100644 --- a/src/libraries/System.Runtime.InteropServices/tests/TestAssets/NativeExports/VirtualMethodTables.cs +++ b/src/libraries/System.Runtime.InteropServices/tests/TestAssets/NativeExports/VirtualMethodTables.cs @@ -50,6 +50,8 @@ public struct VirtualFunctionTable public delegate* unmanaged getData; public delegate* unmanaged setData; public delegate* unmanaged exchangeData; + public delegate* unmanaged sumAndSetData; + public delegate* unmanaged sumAndSetDataWithRef; } public readonly VirtualFunctionTable* VTable; @@ -69,6 +71,8 @@ public struct VirtualFunctionTable public delegate* unmanaged getData; public delegate* unmanaged setData; public delegate* unmanaged exchangeData; + public delegate* unmanaged sumAndSetData; + public delegate* unmanaged sumAndSetDataWithRef; } static NativeObject() { @@ -76,6 +80,8 @@ static NativeObject() VTablePointer->getData = &GetData; VTablePointer->setData = &SetData; VTablePointer->exchangeData = &ExchangeData; + VTablePointer->sumAndSetData = &SumAndSetData; + VTablePointer->sumAndSetDataWithRef = &SumAndSetData; } private static readonly VirtualFunctionTable* VTablePointer; @@ -107,6 +113,34 @@ private static void ExchangeData(NativeObject* obj, int* value) obj->Data = *value; *value = temp; } + + [UnmanagedCallersOnly] + private static void SumAndSetData(NativeObject* obj, int** values, int numValues, int* oldValue) + { + *oldValue = obj->Data; + + Span arr = new(*values, numValues); + int sum = 0; + foreach (int value in arr) + { + sum += value; + } + obj->Data = sum; + } + + [UnmanagedCallersOnly] + private static void SumAndSetData(NativeObject* obj, int* values, int numValues, int* oldValue) + { + *oldValue = obj->Data; + + Span arr = new(values, numValues); + int sum = 0; + foreach (int value in arr) + { + sum += value; + } + obj->Data = sum; + } } [UnmanagedCallersOnly(EntryPoint = "new_native_object")] @@ -146,5 +180,19 @@ public static void ExchangeNativeObjectData([DNNE.C99Type("struct INativeObject* { obj->VTable->exchangeData(obj, x); } + + [UnmanagedCallersOnly(EntryPoint = "sum_and_set_native_object_data")] + [DNNE.C99DeclCode("struct INativeObject;")] + public static void SumAndSetData([DNNE.C99Type("struct INativeObject*")] NativeObjectInterface* obj, int* values, int numValues, int* oldValue) + { + obj->VTable->sumAndSetData(obj, values, numValues, oldValue); + } + + [UnmanagedCallersOnly(EntryPoint = "sum_and_set_native_object_data_with_ref")] + [DNNE.C99DeclCode("struct INativeObject;")] + public static void SumAndSetData([DNNE.C99Type("struct INativeObject*")] NativeObjectInterface* obj, int** values, int numValues, int* oldValue) + { + obj->VTable->sumAndSetDataWithRef(obj, values, numValues, oldValue); + } } } From e6c5b22827fd1932dc5dba75aa399ee402a40d40 Mon Sep 17 00:00:00 2001 From: Jeremy Koritzinsky Date: Fri, 19 May 2023 13:45:12 -0700 Subject: [PATCH 5/9] Fix by-ref collection ownership. --- ...ributedMarshallingModelGeneratorFactory.cs | 110 ++++++++++++------ .../StatelessMarshallingStrategy.cs | 105 ++++++++++++----- .../ImplicitThisTests.cs | 3 +- ...nmanagedToManagedCustomMarshallingTests.cs | 10 +- .../NativeExports/VirtualMethodTables.cs | 2 +- 5 files changed, 157 insertions(+), 73 deletions(-) diff --git a/src/libraries/System.Runtime.InteropServices/gen/Microsoft.Interop.SourceGeneration/Marshalling/AttributedMarshallingModelGeneratorFactory.cs b/src/libraries/System.Runtime.InteropServices/gen/Microsoft.Interop.SourceGeneration/Marshalling/AttributedMarshallingModelGeneratorFactory.cs index 79d6d50e5203f7..994515b87f8291 100644 --- a/src/libraries/System.Runtime.InteropServices/gen/Microsoft.Interop.SourceGeneration/Marshalling/AttributedMarshallingModelGeneratorFactory.cs +++ b/src/libraries/System.Runtime.InteropServices/gen/Microsoft.Interop.SourceGeneration/Marshalling/AttributedMarshallingModelGeneratorFactory.cs @@ -244,23 +244,21 @@ private IMarshallingGenerator CreateCustomNativeTypeMarshaller(TypePositionInfo if (marshallerData.Shape.HasFlag(MarshallerShape.CallerAllocatedBuffer)) marshallingStrategy = new StatelessCallerAllocatedBufferMarshalling(marshallingStrategy, marshallerData.MarshallerType.Syntax, marshallerData.BufferElementType.Syntax, isLinearCollectionMarshalling: false); - if (marshallerData.Shape.HasFlag(MarshallerShape.Free)) + StatelessFreeStrategy freeStrategy = GetStatelessFreeStrategy(info, marshallerData.Shape, context); + + if (freeStrategy == StatelessFreeStrategy.FreeOriginal) { - if (context.Direction == MarshalDirection.ManagedToUnmanaged) - { - marshallingStrategy = new StatelessFreeMarshalling(marshallingStrategy, marshallerData.MarshallerType.Syntax); - } - else if (info.RefKind == RefKind.Ref) - { - if (!context.AdditionalTemporaryStateLivesAcrossStages) - { - marshallingStrategy = new StatelessFreeMarshalling(marshallingStrategy, marshallerData.MarshallerType.Syntax); - } - else - { - marshallingStrategy = new StatelessByRefFreeMarshalling(marshallingStrategy, marshallerData.MarshallerType.Syntax); - } - } + marshallingStrategy = new StatelessUnmanagedToManagedOwnershipTracking(marshallingStrategy); + } + + if (freeStrategy != StatelessFreeStrategy.NoFree) + { + marshallingStrategy = new StatelessFreeMarshalling(marshallingStrategy, marshallerData.MarshallerType.Syntax); + } + + if (freeStrategy == StatelessFreeStrategy.FreeOriginal) + { + marshallingStrategy = new FreeOwnedOriginalValueMarshalling(marshallingStrategy); } } @@ -336,10 +334,17 @@ private IMarshallingGenerator CreateNativeCollectionMarshaller( { marshallingStrategy = new StatelessLinearCollectionSpaceAllocator(marshallerTypeSyntax, nativeType, marshallerData.Shape, numElementsExpression); + StatelessFreeStrategy freeStrategy = GetStatelessFreeStrategy(info, marshallerData.Shape, context); + IElementsMarshallingCollectionSource collectionSource = new StatelessLinearCollectionSource(marshallerTypeSyntax); + if (freeStrategy == StatelessFreeStrategy.FreeOriginal) + { + marshallingStrategy = new StatelessUnmanagedToManagedOwnershipTracking(marshallingStrategy); + } + IElementsMarshalling elementsMarshalling = CreateElementsMarshalling(marshallerData, elementInfo, elementMarshaller, unmanagedElementType, collectionSource); - marshallingStrategy = new StatelessLinearCollectionMarshalling(marshallingStrategy, elementsMarshalling, nativeType, marshallerData.Shape); + marshallingStrategy = new StatelessLinearCollectionMarshalling(marshallingStrategy, elementsMarshalling, nativeType, marshallerData.Shape, freeStrategy != StatelessFreeStrategy.NoFree); if (marshallerData.Shape.HasFlag(MarshallerShape.CallerAllocatedBuffer)) { @@ -350,23 +355,14 @@ private IMarshallingGenerator CreateNativeCollectionMarshaller( marshallingStrategy = new StatelessCallerAllocatedBufferMarshalling(marshallingStrategy, marshallerTypeSyntax, bufferElementTypeSyntax, isLinearCollectionMarshalling: true); } - if (marshallerData.Shape.HasFlag(MarshallerShape.Free)) + if (freeStrategy != StatelessFreeStrategy.NoFree) { - if (context.Direction == MarshalDirection.ManagedToUnmanaged) - { - marshallingStrategy = new StatelessFreeMarshalling(marshallingStrategy, marshallerTypeSyntax); - } - else if (info.RefKind == RefKind.Ref) - { - if (!context.AdditionalTemporaryStateLivesAcrossStages) - { - marshallingStrategy = new StatelessFreeMarshalling(marshallingStrategy, marshallerTypeSyntax); - } - else - { - marshallingStrategy = new StatelessByRefFreeMarshalling(marshallingStrategy, marshallerTypeSyntax); - } - } + marshallingStrategy = new StatelessFreeMarshalling(marshallingStrategy, marshallerTypeSyntax); + } + + if (freeStrategy == StatelessFreeStrategy.FreeOriginal) + { + marshallingStrategy = new FreeOwnedOriginalValueMarshalling(marshallingStrategy); } } @@ -383,6 +379,54 @@ private IMarshallingGenerator CreateNativeCollectionMarshaller( return marshallingGenerator; } + private enum StatelessFreeStrategy + { + /// + /// Free the unmanaged value stored in the native identifier. + /// + FreeNative, + /// + /// Free the unmanaged value originally passed into the stub. + /// + FreeOriginal, + /// + /// Do not free the unmanaged value, we don't own it. + /// + NoFree + } + + private static StatelessFreeStrategy GetStatelessFreeStrategy(TypePositionInfo info, MarshallerShape shape, StubCodeContext context) + { + // If the marshaller doesn't have the Free method, then we don't need to free anything. + if (!shape.HasFlag(MarshallerShape.Free)) + { + return StatelessFreeStrategy.NoFree; + } + + // When marshalling from managed to unmanaged, we always own the value in the native identifier. + if (context.Direction == MarshalDirection.ManagedToUnmanaged) + { + return StatelessFreeStrategy.FreeNative; + } + + // When we're in a case where we don't have state across stages, the parent stub context that can track the state + // will only call our Cleanup stage when we own the value in the native identifier. + if (!context.AdditionalTemporaryStateLivesAcrossStages) + { + return StatelessFreeStrategy.FreeNative; + } + + // In an unmanaged-to-managed stub where a value is passed by 'ref', + // we own the original value once we replace it with the new value we're passing out to the caller. + if (info.RefKind == RefKind.Ref) + { + return StatelessFreeStrategy.FreeOriginal; + } + + // In an unmanaged-to-managed stub, we don't take ownership of the value when it isn't passed by 'ref'. + return StatelessFreeStrategy.NoFree; + } + private static IElementsMarshalling CreateElementsMarshalling(CustomTypeMarshallerData marshallerData, TypePositionInfo elementInfo, IMarshallingGenerator elementMarshaller, TypeSyntax unmanagedElementType, IElementsMarshallingCollectionSource collectionSource) { IElementsMarshalling elementsMarshalling; diff --git a/src/libraries/System.Runtime.InteropServices/gen/Microsoft.Interop.SourceGeneration/Marshalling/StatelessMarshallingStrategy.cs b/src/libraries/System.Runtime.InteropServices/gen/Microsoft.Interop.SourceGeneration/Marshalling/StatelessMarshallingStrategy.cs index 3b322d7e311f6c..73dfc941ec6f77 100644 --- a/src/libraries/System.Runtime.InteropServices/gen/Microsoft.Interop.SourceGeneration/Marshalling/StatelessMarshallingStrategy.cs +++ b/src/libraries/System.Runtime.InteropServices/gen/Microsoft.Interop.SourceGeneration/Marshalling/StatelessMarshallingStrategy.cs @@ -227,7 +227,7 @@ IEnumerable GenerateCallerAllocatedBufferMarshalStatements() } else { - // = .ConvertToUnmanaged(, __buffer); + // = .ConvertToUnmanaged(, __buffer); yield return ExpressionStatement( AssignmentExpression( SyntaxKind.SimpleAssignmentExpression, @@ -295,40 +295,21 @@ public IEnumerable GenerateCleanupStatements(TypePositionInfo i public bool UsesNativeIdentifier(TypePositionInfo info, StubCodeContext context) => _innerMarshaller.UsesNativeIdentifier(info, context); } - internal sealed class StatelessByRefFreeMarshalling : ICustomTypeMarshallingStrategy + internal sealed class StatelessUnmanagedToManagedOwnershipTracking : ICustomTypeMarshallingStrategy { - private const string FreeUnmanagedIdentifier = "freeUnmanaged"; - private const string OriginalValueIdentifier = "original"; + internal const string OwnOriginalValueIdentifier = "ownOriginal"; + internal const string OriginalValueIdentifier = "original"; private readonly ICustomTypeMarshallingStrategy _innerMarshaller; - private readonly TypeSyntax _marshallerType; - public StatelessByRefFreeMarshalling(ICustomTypeMarshallingStrategy innerMarshaller, TypeSyntax marshallerType) + public StatelessUnmanagedToManagedOwnershipTracking(ICustomTypeMarshallingStrategy innerMarshaller) { _innerMarshaller = innerMarshaller; - _marshallerType = marshallerType; } public ManagedTypeInfo AsNativeType(TypePositionInfo info) => _innerMarshaller.AsNativeType(info); - public IEnumerable GenerateCleanupStatements(TypePositionInfo info, StubCodeContext context) - { - foreach (StatementSyntax statement in _innerMarshaller.GenerateCleanupStatements(info, context)) - { - yield return statement; - } - // if () - // .Free(); - yield return IfStatement( - IdentifierName(context.GetAdditionalIdentifier(info, FreeUnmanagedIdentifier)), - ExpressionStatement( - InvocationExpression( - MemberAccessExpression(SyntaxKind.SimpleMemberAccessExpression, - _marshallerType, - IdentifierName(ShapeMemberNames.Free)), - ArgumentList(SingletonSeparatedList( - Argument(IdentifierName(context.GetAdditionalIdentifier(info, OriginalValueIdentifier)))))))); - } + public IEnumerable GenerateCleanupStatements(TypePositionInfo info, StubCodeContext context) => _innerMarshaller.GenerateCleanupStatements(info, context); public IEnumerable GenerateGuaranteedUnmarshalStatements(TypePositionInfo info, StubCodeContext context) => _innerMarshaller.GenerateGuaranteedUnmarshalStatements(info, context); public IEnumerable GenerateMarshalStatements(TypePositionInfo info, StubCodeContext context) @@ -341,10 +322,10 @@ public IEnumerable GenerateMarshalStatements(TypePositionInfo i // Now that we've set the new value to pass to the caller on the identifier, we need to make sure that we free the old one. // The caller will not see the old one any more, so it won't be able to free it. - // = true; + // = true; yield return ExpressionStatement( AssignmentExpression(SyntaxKind.SimpleAssignmentExpression, - IdentifierName(context.GetAdditionalIdentifier(info, FreeUnmanagedIdentifier)), + IdentifierName(context.GetAdditionalIdentifier(info, OwnOriginalValueIdentifier)), LiteralExpression(SyntaxKind.TrueLiteralExpression))); } @@ -359,18 +340,18 @@ public IEnumerable GenerateSetupStatements(TypePositionInfo inf yield return statement; } - // bool = false; + // bool = false; yield return LocalDeclarationStatement( VariableDeclaration( PredefinedType(Token(SyntaxKind.BoolKeyword)), SingletonSeparatedList( VariableDeclarator( - Identifier(context.GetAdditionalIdentifier(info, FreeUnmanagedIdentifier)), + Identifier(context.GetAdditionalIdentifier(info, OwnOriginalValueIdentifier)), null, EqualsValueClause( LiteralExpression(SyntaxKind.FalseLiteralExpression)))))); - // = ; + // = ; yield return LocalDeclarationStatement( VariableDeclaration( AsNativeType(info).Syntax, @@ -388,6 +369,60 @@ public IEnumerable GenerateSetupStatements(TypePositionInfo inf public bool UsesNativeIdentifier(TypePositionInfo info, StubCodeContext context) => _innerMarshaller.UsesNativeIdentifier(info, context); } + internal sealed class FreeOwnedOriginalValueMarshalling : ICustomTypeMarshallingStrategy + { + private readonly ICustomTypeMarshallingStrategy _innerMarshaller; + + public FreeOwnedOriginalValueMarshalling(ICustomTypeMarshallingStrategy innerMarshaller) + { + _innerMarshaller = innerMarshaller; + } + + public ManagedTypeInfo AsNativeType(TypePositionInfo info) => _innerMarshaller.AsNativeType(info); + + public IEnumerable GenerateCleanupStatements(TypePositionInfo info, StubCodeContext context) + { + // if () + // { + // + // } + yield return IfStatement( + IdentifierName(context.GetAdditionalIdentifier(info, StatelessUnmanagedToManagedOwnershipTracking.OwnOriginalValueIdentifier)), + Block(_innerMarshaller.GenerateCleanupStatements(info, new OwnedValueCodeContext(context)))); + } + + public IEnumerable GenerateGuaranteedUnmarshalStatements(TypePositionInfo info, StubCodeContext context) => _innerMarshaller.GenerateGuaranteedUnmarshalStatements(info, context); + public IEnumerable GenerateMarshalStatements(TypePositionInfo info, StubCodeContext context) => _innerMarshaller.GenerateMarshalStatements(info, context); + + public IEnumerable GenerateNotifyForSuccessfulInvokeStatements(TypePositionInfo info, StubCodeContext context) => _innerMarshaller.GenerateNotifyForSuccessfulInvokeStatements(info, context); + public IEnumerable GeneratePinnedMarshalStatements(TypePositionInfo info, StubCodeContext context) => _innerMarshaller.GeneratePinnedMarshalStatements(info, context); + + public IEnumerable GeneratePinStatements(TypePositionInfo info, StubCodeContext context) => _innerMarshaller.GeneratePinStatements(info, context); + public IEnumerable GenerateSetupStatements(TypePositionInfo info, StubCodeContext context) => _innerMarshaller.GenerateSetupStatements(info, context); + + public IEnumerable GenerateUnmarshalCaptureStatements(TypePositionInfo info, StubCodeContext context) => _innerMarshaller.GenerateUnmarshalCaptureStatements(info, context); + + public IEnumerable GenerateUnmarshalStatements(TypePositionInfo info, StubCodeContext context) => _innerMarshaller.GenerateUnmarshalStatements(info, context); + public bool UsesNativeIdentifier(TypePositionInfo info, StubCodeContext context) => _innerMarshaller.UsesNativeIdentifier(info, context); + + private sealed record OwnedValueCodeContext(StubCodeContext InnerContext) : StubCodeContext + { + public override bool SingleFrameSpansNativeContext => InnerContext.SingleFrameSpansNativeContext; + + public override bool AdditionalTemporaryStateLivesAcrossStages => InnerContext.AdditionalTemporaryStateLivesAcrossStages; + + public override (TargetFramework framework, Version version) GetTargetFramework() => InnerContext.GetTargetFramework(); + + public override (string managed, string native) GetIdentifiers(TypePositionInfo info) + { + var (managed, _) = InnerContext.GetIdentifiers(info); + return (managed, InnerContext.GetAdditionalIdentifier(info, StatelessUnmanagedToManagedOwnershipTracking.OriginalValueIdentifier)); + } + + public override string GetAdditionalIdentifier(TypePositionInfo info, string name) => InnerContext.GetAdditionalIdentifier(info, name); + } + } + /// /// Marshaller type that enables allocating space for marshalling a linear collection using a marshaller that implements the LinearCollection marshalling spec. /// @@ -628,23 +663,31 @@ internal sealed class StatelessLinearCollectionMarshalling : ICustomTypeMarshall private readonly IElementsMarshalling _elementsMarshalling; private readonly ManagedTypeInfo _unmanagedType; private readonly MarshallerShape _shape; + private readonly bool _cleanupElementsAndSpace; public StatelessLinearCollectionMarshalling( ICustomTypeMarshallingStrategy spaceMarshallingStrategy, IElementsMarshalling elementsMarshalling, ManagedTypeInfo unmanagedType, - MarshallerShape shape) + MarshallerShape shape, + bool cleanupElementsAndSpace) { _spaceMarshallingStrategy = spaceMarshallingStrategy; _elementsMarshalling = elementsMarshalling; _unmanagedType = unmanagedType; _shape = shape; + _cleanupElementsAndSpace = cleanupElementsAndSpace; } public ManagedTypeInfo AsNativeType(TypePositionInfo info) => _unmanagedType; public IEnumerable GenerateCleanupStatements(TypePositionInfo info, StubCodeContext context) { + if (!_cleanupElementsAndSpace) + { + yield break; + } + StatementSyntax elementCleanup = _elementsMarshalling.GenerateElementCleanupStatement(info, context); if (!elementCleanup.IsKind(SyntaxKind.EmptyStatement)) diff --git a/src/libraries/System.Runtime.InteropServices/tests/ComInterfaceGenerator.Tests/ImplicitThisTests.cs b/src/libraries/System.Runtime.InteropServices/tests/ComInterfaceGenerator.Tests/ImplicitThisTests.cs index e4b68e0e407e65..445dd51a252082 100644 --- a/src/libraries/System.Runtime.InteropServices/tests/ComInterfaceGenerator.Tests/ImplicitThisTests.cs +++ b/src/libraries/System.Runtime.InteropServices/tests/ComInterfaceGenerator.Tests/ImplicitThisTests.cs @@ -18,8 +18,7 @@ internal partial class ImplicitThis [UnmanagedObjectUnwrapperAttribute>] internal partial interface INativeObject : IUnmanagedInterfaceType { - - private static void** s_vtable = (void**)RuntimeHelpers.AllocateTypeAssociatedMemory(typeof(INativeObject), sizeof(void*) * 2); + private static void** s_vtable = (void**)RuntimeHelpers.AllocateTypeAssociatedMemory(typeof(INativeObject), sizeof(void*) * 5); static void* IUnmanagedInterfaceType.VirtualMethodTableManagedImplementation { get diff --git a/src/libraries/System.Runtime.InteropServices/tests/ComInterfaceGenerator.Tests/UnmanagedToManagedCustomMarshallingTests.cs b/src/libraries/System.Runtime.InteropServices/tests/ComInterfaceGenerator.Tests/UnmanagedToManagedCustomMarshallingTests.cs index 6d9f30cdbbd822..4088dcd6db904e 100644 --- a/src/libraries/System.Runtime.InteropServices/tests/ComInterfaceGenerator.Tests/UnmanagedToManagedCustomMarshallingTests.cs +++ b/src/libraries/System.Runtime.InteropServices/tests/ComInterfaceGenerator.Tests/UnmanagedToManagedCustomMarshallingTests.cs @@ -23,8 +23,7 @@ internal partial class UnmanagedToManagedCustomMarshalling [UnmanagedObjectUnwrapper>] internal partial interface INativeObject : IUnmanagedInterfaceType { - - private static void** s_vtable = (void**)RuntimeHelpers.AllocateTypeAssociatedMemory(typeof(INativeObject), sizeof(void*) * 2); + private static void** s_vtable = (void**)RuntimeHelpers.AllocateTypeAssociatedMemory(typeof(INativeObject), sizeof(void*) * 5); static void* IUnmanagedInterfaceType.VirtualMethodTableManagedImplementation { get @@ -102,7 +101,7 @@ static class NativeObjectMarshaller [LibraryImport(NativeExportsNE_Binary, EntryPoint = "sum_and_set_native_object_data")] public static partial int SumAndSetNativeObjectData(void* obj, [MarshalUsing(CountElementName = nameof(numValues))] int[] arr, int numValues, out int oldValue); - [LibraryImport(NativeExportsNE_Binary, EntryPoint = "sum_and_set_native_object_data_wth_ref")] + [LibraryImport(NativeExportsNE_Binary, EntryPoint = "sum_and_set_native_object_data_with_ref")] public static partial int SumAndSetNativeObjectData(void* obj, [MarshalUsing(CountElementName = nameof(numValues))] ref int[] arr, int numValues, out int oldValue); } } @@ -201,14 +200,13 @@ public ManagedObjectImplementation(int value) public void ExchangeData(ref IntWrapper x) => x = Interlocked.Exchange(ref _data, x); public IntWrapper GetData() => _data; public void SetData(IntWrapper x) => _data = x; - public void SumAndSetData(ref IntWrapper[] values, int numValues, out IntWrapper oldValue) + public void SumAndSetData(ref IntWrapper[] values, int numValues, out IntWrapper oldValue) => SumAndSetData(values, numValues, out oldValue); + public void SumAndSetData(IntWrapper[] values, int numValues, out IntWrapper oldValue) { int value = values.Sum(value => value.i); oldValue = _data; _data = new() { i = value }; } - - public void SumAndSetData([MarshalUsing(CountElementName = "numValues"), MarshalUsing(typeof(IntWrapperMarshallerToIntWithFreeCounts), ElementIndirectionDepth = 1)] IntWrapper[] values123, int numValues, [MarshalUsing(typeof(IntWrapperMarshallerToIntWithFreeCounts))] out IntWrapper oldValue) => SumAndSetData(values123, numValues, out oldValue); } diff --git a/src/libraries/System.Runtime.InteropServices/tests/TestAssets/NativeExports/VirtualMethodTables.cs b/src/libraries/System.Runtime.InteropServices/tests/TestAssets/NativeExports/VirtualMethodTables.cs index 618f0c235f9613..dbcea3f908589f 100644 --- a/src/libraries/System.Runtime.InteropServices/tests/TestAssets/NativeExports/VirtualMethodTables.cs +++ b/src/libraries/System.Runtime.InteropServices/tests/TestAssets/NativeExports/VirtualMethodTables.cs @@ -190,7 +190,7 @@ public static void SumAndSetData([DNNE.C99Type("struct INativeObject*")] NativeO [UnmanagedCallersOnly(EntryPoint = "sum_and_set_native_object_data_with_ref")] [DNNE.C99DeclCode("struct INativeObject;")] - public static void SumAndSetData([DNNE.C99Type("struct INativeObject*")] NativeObjectInterface* obj, int** values, int numValues, int* oldValue) + public static void SumAndSetDataWithRef([DNNE.C99Type("struct INativeObject*")] NativeObjectInterface* obj, int** values, int numValues, int* oldValue) { obj->VTable->sumAndSetDataWithRef(obj, values, numValues, oldValue); } From fcefc5777f80ede4da6e044e731343196d1b37c1 Mon Sep 17 00:00:00 2001 From: Jeremy Koritzinsky Date: Fri, 19 May 2023 15:27:02 -0700 Subject: [PATCH 6/9] Fix compilation of stateless by-value out marshalling for unmanaged-to-managed. It still needs additional work to function, but this at least fixes the compilation problems, enabling usage in managed->unmanaged COM scenarios. --- .../CustomTypeMarshallingGenerator.cs | 2 +- .../Marshalling/ElementsMarshalling.cs | 16 ++-- .../StatefulMarshallingStrategy.cs | 9 +- .../StatelessMarshallingStrategy.cs | 10 +-- .../ImplicitThisTests.cs | 11 +++ ...nmanagedToManagedCustomMarshallingTests.cs | 83 +++++++++++-------- .../NativeExports/VirtualMethodTables.cs | 20 +++++ 7 files changed, 98 insertions(+), 53 deletions(-) diff --git a/src/libraries/System.Runtime.InteropServices/gen/Microsoft.Interop.SourceGeneration/Marshalling/CustomTypeMarshallingGenerator.cs b/src/libraries/System.Runtime.InteropServices/gen/Microsoft.Interop.SourceGeneration/Marshalling/CustomTypeMarshallingGenerator.cs index 2b54ccc5c72b70..d5b2da82134908 100644 --- a/src/libraries/System.Runtime.InteropServices/gen/Microsoft.Interop.SourceGeneration/Marshalling/CustomTypeMarshallingGenerator.cs +++ b/src/libraries/System.Runtime.InteropServices/gen/Microsoft.Interop.SourceGeneration/Marshalling/CustomTypeMarshallingGenerator.cs @@ -107,7 +107,7 @@ public IEnumerable Generate(TypePositionInfo info, StubCodeCont public bool SupportsByValueMarshalKind(ByValueContentsMarshalKind marshalKind, StubCodeContext context) { - return _enableByValueContentsMarshalling && context.Direction == MarshalDirection.ManagedToUnmanaged; + return _enableByValueContentsMarshalling; } public bool UsesNativeIdentifier(TypePositionInfo info, StubCodeContext context) diff --git a/src/libraries/System.Runtime.InteropServices/gen/Microsoft.Interop.SourceGeneration/Marshalling/ElementsMarshalling.cs b/src/libraries/System.Runtime.InteropServices/gen/Microsoft.Interop.SourceGeneration/Marshalling/ElementsMarshalling.cs index 02d0b6bf70a54c..3c1468fbfe58ca 100644 --- a/src/libraries/System.Runtime.InteropServices/gen/Microsoft.Interop.SourceGeneration/Marshalling/ElementsMarshalling.cs +++ b/src/libraries/System.Runtime.InteropServices/gen/Microsoft.Interop.SourceGeneration/Marshalling/ElementsMarshalling.cs @@ -22,9 +22,9 @@ internal interface IElementsMarshallingCollectionSource internal interface IElementsMarshalling { - StatementSyntax GenerateByValueOutMarshalStatement(TypePositionInfo info, StubCodeContext context); + StatementSyntax GenerateManagedToUnmanagedByValueOutMarshalStatement(TypePositionInfo info, StubCodeContext context); StatementSyntax GenerateMarshalStatement(TypePositionInfo info, StubCodeContext context); - StatementSyntax GenerateByValueOutUnmarshalStatement(TypePositionInfo info, StubCodeContext context); + StatementSyntax GenerateManagedToUnmanagedByValueOutUnmarshalStatement(TypePositionInfo info, StubCodeContext context); StatementSyntax GenerateUnmarshalStatement(TypePositionInfo info, StubCodeContext context); StatementSyntax GenerateElementCleanupStatement(TypePositionInfo info, StubCodeContext context); } @@ -45,7 +45,7 @@ public BlittableElementsMarshalling(TypeSyntax managedElementType, TypeSyntax un _collectionSource = collectionSource; } - public StatementSyntax GenerateByValueOutMarshalStatement(TypePositionInfo info, StubCodeContext context) + public StatementSyntax GenerateManagedToUnmanagedByValueOutMarshalStatement(TypePositionInfo info, StubCodeContext context) { // If the parameter is marshalled by-value [Out], then we don't marshal the contents of the collection. // We do clear the span, so that if the invoke target doesn't fill it, we aren't left with undefined content. @@ -73,7 +73,7 @@ public StatementSyntax GenerateMarshalStatement(TypePositionInfo info, StubCodeC Argument(destination))); } - public StatementSyntax GenerateByValueOutUnmarshalStatement(TypePositionInfo info, StubCodeContext context) + public StatementSyntax GenerateManagedToUnmanagedByValueOutUnmarshalStatement(TypePositionInfo info, StubCodeContext context) { ExpressionSyntax source = CastToManagedIfNecessary(_collectionSource.GetUnmanagedValuesDestination(info, context)); @@ -175,7 +175,7 @@ public NonBlittableElementsMarshalling( _collectionSource = collectionSource; } - public StatementSyntax GenerateByValueOutMarshalStatement(TypePositionInfo info, StubCodeContext context) + public StatementSyntax GenerateManagedToUnmanagedByValueOutMarshalStatement(TypePositionInfo info, StubCodeContext context) { // If the parameter is marshalled by-value [Out], then we don't marshal the contents of the collection. // We do clear the span, so that if the invoke target doesn't fill it, we aren't left with undefined content. @@ -259,7 +259,7 @@ public StatementSyntax GenerateUnmarshalStatement(TypePositionInfo info, StubCod StubCodeContext.Stage.Unmarshal)); } - public StatementSyntax GenerateByValueOutUnmarshalStatement(TypePositionInfo info, StubCodeContext context) + public StatementSyntax GenerateManagedToUnmanagedByValueOutUnmarshalStatement(TypePositionInfo info, StubCodeContext context) { // Use ManagedSource and NativeDestination spans for by-value marshalling since we're just marshalling back the contents, // not the array itself. @@ -356,7 +356,9 @@ public StatementSyntax GenerateElementCleanupStatement(TypePositionInfo info, St VariableDeclarator( Identifier(nativeSpanIdentifier)) .WithInitializer(EqualsValueClause( - _collectionSource.GetUnmanagedValuesDestination(info, context)))))), + context.Direction == MarshalDirection.ManagedToUnmanaged + ? _collectionSource.GetUnmanagedValuesDestination(info, context) + : _collectionSource.GetUnmanagedValuesSource(info, context)))))), contentsCleanupStatements); } diff --git a/src/libraries/System.Runtime.InteropServices/gen/Microsoft.Interop.SourceGeneration/Marshalling/StatefulMarshallingStrategy.cs b/src/libraries/System.Runtime.InteropServices/gen/Microsoft.Interop.SourceGeneration/Marshalling/StatefulMarshallingStrategy.cs index 8c8aa0db66ccb7..f5fce73a953578 100644 --- a/src/libraries/System.Runtime.InteropServices/gen/Microsoft.Interop.SourceGeneration/Marshalling/StatefulMarshallingStrategy.cs +++ b/src/libraries/System.Runtime.InteropServices/gen/Microsoft.Interop.SourceGeneration/Marshalling/StatefulMarshallingStrategy.cs @@ -419,9 +419,9 @@ public IEnumerable GenerateMarshalStatements(TypePositionInfo i yield return statement; } - if (!info.IsByRef && info.ByValueContentsMarshalKind == ByValueContentsMarshalKind.Out) + if (context.Direction == MarshalDirection.ManagedToUnmanaged && !info.IsByRef && info.ByValueContentsMarshalKind == ByValueContentsMarshalKind.Out) { - yield return _elementsMarshalling.GenerateByValueOutMarshalStatement(info, context); + yield return _elementsMarshalling.GenerateManagedToUnmanagedByValueOutMarshalStatement(info, context); yield break; } @@ -437,9 +437,10 @@ public IEnumerable GenerateUnmarshalStatements(TypePositionInfo { string numElementsIdentifier = MarshallerHelpers.GetNumElementsIdentifier(info, context); - if (!info.IsByRef && info.ByValueContentsMarshalKind.HasFlag(ByValueContentsMarshalKind.Out)) + if (context.Direction == MarshalDirection.ManagedToUnmanaged && !info.IsByRef && info.ByValueContentsMarshalKind.HasFlag(ByValueContentsMarshalKind.Out)) { - yield return _elementsMarshalling.GenerateByValueOutUnmarshalStatement(info, context); + yield return _elementsMarshalling.GenerateManagedToUnmanagedByValueOutUnmarshalStatement(info, context); + yield break; } if (!_shape.HasFlag(MarshallerShape.ToManaged)) diff --git a/src/libraries/System.Runtime.InteropServices/gen/Microsoft.Interop.SourceGeneration/Marshalling/StatelessMarshallingStrategy.cs b/src/libraries/System.Runtime.InteropServices/gen/Microsoft.Interop.SourceGeneration/Marshalling/StatelessMarshallingStrategy.cs index 73dfc941ec6f77..fd9459ca794591 100644 --- a/src/libraries/System.Runtime.InteropServices/gen/Microsoft.Interop.SourceGeneration/Marshalling/StatelessMarshallingStrategy.cs +++ b/src/libraries/System.Runtime.InteropServices/gen/Microsoft.Interop.SourceGeneration/Marshalling/StatelessMarshallingStrategy.cs @@ -532,7 +532,7 @@ public IEnumerable GenerateSetupStatements(TypePositionInfo inf public IEnumerable GenerateUnmarshalStatements(TypePositionInfo info, StubCodeContext context) { - if (!info.IsByRef && info.ByValueContentsMarshalKind.HasFlag(ByValueContentsMarshalKind.Out)) + if (context.Direction == MarshalDirection.ManagedToUnmanaged && !info.IsByRef && info.ByValueContentsMarshalKind.HasFlag(ByValueContentsMarshalKind.Out)) { yield break; } @@ -712,9 +712,9 @@ public IEnumerable GenerateMarshalStatements(TypePositionInfo i if (!_shape.HasFlag(MarshallerShape.ToUnmanaged) && !_shape.HasFlag(MarshallerShape.CallerAllocatedBuffer)) yield break; - if (!info.IsByRef && info.ByValueContentsMarshalKind == ByValueContentsMarshalKind.Out) + if (context.Direction == MarshalDirection.ManagedToUnmanaged && !info.IsByRef && info.ByValueContentsMarshalKind == ByValueContentsMarshalKind.Out) { - yield return _elementsMarshalling.GenerateByValueOutMarshalStatement(info, context); + yield return _elementsMarshalling.GenerateManagedToUnmanagedByValueOutMarshalStatement(info, context); } else { @@ -732,9 +732,9 @@ public IEnumerable GenerateMarshalStatements(TypePositionInfo i public IEnumerable GenerateUnmarshalStatements(TypePositionInfo info, StubCodeContext context) { - if (!info.IsByRef && info.ByValueContentsMarshalKind.HasFlag(ByValueContentsMarshalKind.Out)) + if (context.Direction == MarshalDirection.ManagedToUnmanaged && !info.IsByRef && info.ByValueContentsMarshalKind.HasFlag(ByValueContentsMarshalKind.Out)) { - yield return _elementsMarshalling.GenerateByValueOutUnmarshalStatement(info, context); + yield return _elementsMarshalling.GenerateManagedToUnmanagedByValueOutUnmarshalStatement(info, context); yield break; } diff --git a/src/libraries/System.Runtime.InteropServices/tests/ComInterfaceGenerator.Tests/ImplicitThisTests.cs b/src/libraries/System.Runtime.InteropServices/tests/ComInterfaceGenerator.Tests/ImplicitThisTests.cs index 445dd51a252082..9cd3ecb3bd1b1b 100644 --- a/src/libraries/System.Runtime.InteropServices/tests/ComInterfaceGenerator.Tests/ImplicitThisTests.cs +++ b/src/libraries/System.Runtime.InteropServices/tests/ComInterfaceGenerator.Tests/ImplicitThisTests.cs @@ -47,6 +47,10 @@ void SumAndSetData( [MarshalUsing(CountElementName = nameof(numValues))] ref int[] values, int numValues, out int oldValue); + [VirtualMethodIndex(5, ImplicitThisParameter = true)] + void MultiplyWithData( + [MarshalUsing(CountElementName = nameof(numValues))] int[] values, + int numValues); } [NativeMarshalling(typeof(NativeObjectMarshaller))] @@ -142,6 +146,13 @@ public ManagedObjectImplementation(int value) public void ExchangeData(ref int x) => x = Interlocked.Exchange(ref _data, x); public int GetData() => _data; + public void MultiplyWithData([MarshalUsing(CountElementName = "numValues")] int[] values, int numValues) + { + for (int i = 0; i < values.Length; i++) + { + values[i] *= _data; + } + } public void SetData(int x) => _data = x; public void SumAndSetData([MarshalUsing(CountElementName = "numValues")] int[] values, int numValues, out int oldValue) { diff --git a/src/libraries/System.Runtime.InteropServices/tests/ComInterfaceGenerator.Tests/UnmanagedToManagedCustomMarshallingTests.cs b/src/libraries/System.Runtime.InteropServices/tests/ComInterfaceGenerator.Tests/UnmanagedToManagedCustomMarshallingTests.cs index 4088dcd6db904e..d4eb80c7663630 100644 --- a/src/libraries/System.Runtime.InteropServices/tests/ComInterfaceGenerator.Tests/UnmanagedToManagedCustomMarshallingTests.cs +++ b/src/libraries/System.Runtime.InteropServices/tests/ComInterfaceGenerator.Tests/UnmanagedToManagedCustomMarshallingTests.cs @@ -45,42 +45,18 @@ static void* IUnmanagedInterfaceType.VirtualMethodTableManagedImplementation void ExchangeData([MarshalUsing(typeof(IntWrapperMarshallerToIntWithFreeCounts))] ref IntWrapper data); [VirtualMethodIndex(3, ImplicitThisParameter = true)] void SumAndSetData( - [MarshalUsing(CountElementName = nameof(numValues)), MarshalUsing(typeof(IntWrapperMarshallerToIntWithFreeCounts), ElementIndirectionDepth = 1)] IntWrapper[] values123, + [MarshalUsing(CountElementName = nameof(numValues)), MarshalUsing(typeof(IntWrapperMarshallerToIntWithFreeCounts), ElementIndirectionDepth = 1)] IntWrapper[] values, int numValues, [MarshalUsing(typeof(IntWrapperMarshallerToIntWithFreeCounts))] out IntWrapper oldValue); [VirtualMethodIndex(4, ImplicitThisParameter = true)] void SumAndSetData( - [MarshalUsing(CountElementName = nameof(numValues)), MarshalUsing(typeof(IntWrapperMarshallerToIntWithFreeCounts), ElementIndirectionDepth = 1)] ref IntWrapper[] values123, + [MarshalUsing(CountElementName = nameof(numValues)), MarshalUsing(typeof(IntWrapperMarshallerToIntWithFreeCounts), ElementIndirectionDepth = 1)] ref IntWrapper[] values, int numValues, [MarshalUsing(typeof(IntWrapperMarshallerToIntWithFreeCounts))] out IntWrapper oldValue); - } - - [NativeMarshalling(typeof(NativeObjectMarshaller))] - public class NativeObject : INativeObject.Native, IUnmanagedVirtualMethodTableProvider, IDisposable - { - private readonly void* _pointer; - - public NativeObject(void* pointer) - { - _pointer = pointer; - } - - public VirtualMethodTableInfo GetVirtualMethodTableInfoForKey(Type type) - { - Assert.Equal(typeof(INativeObject), type); - return new VirtualMethodTableInfo(_pointer, *(void***)_pointer); - } - - public void Dispose() - { - DeleteNativeObject(_pointer); - } - } - - [CustomMarshaller(typeof(NativeObject), MarshalMode.ManagedToUnmanagedOut, typeof(NativeObjectMarshaller))] - static class NativeObjectMarshaller - { - public static NativeObject ConvertToManaged(void* value) => new NativeObject(value); + [VirtualMethodIndex(5, ImplicitThisParameter = true)] + void MultiplyWithData( + [MarshalUsing(CountElementName = nameof(numValues)), MarshalUsing(typeof(IntWrapperMarshallerToIntWithFreeCounts), ElementIndirectionDepth = 1), In, Out] IntWrapper[] values123, + int numValues); } [LibraryImport(NativeExportsNE_Binary, EntryPoint = "new_native_object")] @@ -103,12 +79,15 @@ static class NativeObjectMarshaller [LibraryImport(NativeExportsNE_Binary, EntryPoint = "sum_and_set_native_object_data_with_ref")] public static partial int SumAndSetNativeObjectData(void* obj, [MarshalUsing(CountElementName = nameof(numValues))] ref int[] arr, int numValues, out int oldValue); + + [LibraryImport(NativeExportsNE_Binary, EntryPoint = "multiply_with_native_object_data")] + public static partial int MultiplyWithNativeObjectData(void* obj, [MarshalUsing(CountElementName = nameof(numValues))] int[] arr, int numValues); } } public class UnmanagedToManagedCustomMarshallingTests { [Fact] - public unsafe void ValidateOnlyByRefStatelessFreed() + public unsafe void ValidateOnlyByRefFreed_Stateless() { const int startingValue = 13; const int newValue = 42; @@ -134,12 +113,12 @@ public unsafe void ValidateOnlyByRefStatelessFreed() } finally { - VTableGCHandlePair.Free(wrapper); + VTableGCHandlePair.Free(wrapper); } } [Fact] - public unsafe void ValidateArrayElementsAndOutParameterNotFreed() + public unsafe void ValidateArrayElementsAndOutParameterNotFreed_Stateless() { const int startingValue = 13; @@ -159,12 +138,12 @@ public unsafe void ValidateArrayElementsAndOutParameterNotFreed() } finally { - VTableGCHandlePair.Free(wrapper); + VTableGCHandlePair.Free(wrapper); } } [Fact] - public unsafe void ValidateArrayElementsByRefFreed() + public unsafe void ValidateArrayElementsByRefFreed_Stateless() { const int startingValue = 13; @@ -184,7 +163,32 @@ public unsafe void ValidateArrayElementsByRefFreed() } finally { - VTableGCHandlePair.Free(wrapper); + VTableGCHandlePair.Free(wrapper); + } + } + + [Fact] + public unsafe void ValidateArrayElementsByValueOutFreed_Stateless() + { + const int startingValue = 13; + + ManagedObjectImplementation impl = new ManagedObjectImplementation(startingValue); + + void* wrapper = VTableGCHandlePair.Allocate(impl); + + try + { + var values = new int[] { 1, 32, 63, 124, 255 }; + + int freeCalls = IntWrapperMarshallerToIntWithFreeCounts.NumCallsToFree; + + NativeExportsNE.UnmanagedToManagedCustomMarshalling.MultiplyWithNativeObjectData(wrapper, values, values.Length); + + Assert.Equal(freeCalls + values.Length, IntWrapperMarshallerToIntWithFreeCounts.NumCallsToFree); + } + finally + { + VTableGCHandlePair.Free(wrapper); } } @@ -199,6 +203,13 @@ public ManagedObjectImplementation(int value) public void ExchangeData(ref IntWrapper x) => x = Interlocked.Exchange(ref _data, x); public IntWrapper GetData() => _data; + public void MultiplyWithData(IntWrapper[] values, int numValues) + { + for (int i = 0; i < values.Length; i++) + { + values[i].i *= _data.i; + } + } public void SetData(IntWrapper x) => _data = x; public void SumAndSetData(ref IntWrapper[] values, int numValues, out IntWrapper oldValue) => SumAndSetData(values, numValues, out oldValue); public void SumAndSetData(IntWrapper[] values, int numValues, out IntWrapper oldValue) diff --git a/src/libraries/System.Runtime.InteropServices/tests/TestAssets/NativeExports/VirtualMethodTables.cs b/src/libraries/System.Runtime.InteropServices/tests/TestAssets/NativeExports/VirtualMethodTables.cs index dbcea3f908589f..b435fa7951a40e 100644 --- a/src/libraries/System.Runtime.InteropServices/tests/TestAssets/NativeExports/VirtualMethodTables.cs +++ b/src/libraries/System.Runtime.InteropServices/tests/TestAssets/NativeExports/VirtualMethodTables.cs @@ -52,6 +52,7 @@ public struct VirtualFunctionTable public delegate* unmanaged exchangeData; public delegate* unmanaged sumAndSetData; public delegate* unmanaged sumAndSetDataWithRef; + public delegate* unmanaged multiplyWithData; } public readonly VirtualFunctionTable* VTable; @@ -73,6 +74,7 @@ public struct VirtualFunctionTable public delegate* unmanaged exchangeData; public delegate* unmanaged sumAndSetData; public delegate* unmanaged sumAndSetDataWithRef; + public delegate* unmanaged multiplyWithData; } static NativeObject() { @@ -82,6 +84,7 @@ static NativeObject() VTablePointer->exchangeData = &ExchangeData; VTablePointer->sumAndSetData = &SumAndSetData; VTablePointer->sumAndSetDataWithRef = &SumAndSetData; + VTablePointer->multiplyWithData = &MultiplyWithData; } private static readonly VirtualFunctionTable* VTablePointer; @@ -141,6 +144,16 @@ private static void SumAndSetData(NativeObject* obj, int* values, int numValues, } obj->Data = sum; } + + [UnmanagedCallersOnly] + private static void MultiplyWithData(NativeObject* obj, int* values, int numValues) + { + Span arr = new(values, numValues); + foreach (ref int value in arr) + { + value *= obj->Data; + } + } } [UnmanagedCallersOnly(EntryPoint = "new_native_object")] @@ -194,5 +207,12 @@ public static void SumAndSetDataWithRef([DNNE.C99Type("struct INativeObject*")] { obj->VTable->sumAndSetDataWithRef(obj, values, numValues, oldValue); } + + [UnmanagedCallersOnly(EntryPoint = "multiply_with_native_object_data")] + [DNNE.C99DeclCode("struct INativeObject;")] + public static void SumAndSetDataWithRef([DNNE.C99Type("struct INativeObject*")] NativeObjectInterface* obj, int* values, int numValues) + { + obj->VTable->multiplyWithData(obj, values, numValues); + } } } From 6bb51e1bc656f0e2ff4d62b2f961867fdab2122a Mon Sep 17 00:00:00 2001 From: Jeremy Koritzinsky Date: Mon, 22 May 2023 13:43:32 -0700 Subject: [PATCH 7/9] Add tests for stateful marshaller shapes and update behavior to match the expectation (always free stateful marshaller state, only free elements when we're supposed to). --- ...ributedMarshallingModelGeneratorFactory.cs | 51 ++-- .../StatefulMarshallingStrategy.cs | 68 ++++- .../ImplicitThisTests.cs | 2 +- ...nmanagedToManagedCustomMarshallingTests.cs | 247 +++++++++++++++++- .../NativeExports/VirtualMethodTables.cs | 2 +- 5 files changed, 322 insertions(+), 48 deletions(-) diff --git a/src/libraries/System.Runtime.InteropServices/gen/Microsoft.Interop.SourceGeneration/Marshalling/AttributedMarshallingModelGeneratorFactory.cs b/src/libraries/System.Runtime.InteropServices/gen/Microsoft.Interop.SourceGeneration/Marshalling/AttributedMarshallingModelGeneratorFactory.cs index 994515b87f8291..3a90412e039231 100644 --- a/src/libraries/System.Runtime.InteropServices/gen/Microsoft.Interop.SourceGeneration/Marshalling/AttributedMarshallingModelGeneratorFactory.cs +++ b/src/libraries/System.Runtime.InteropServices/gen/Microsoft.Interop.SourceGeneration/Marshalling/AttributedMarshallingModelGeneratorFactory.cs @@ -244,19 +244,19 @@ private IMarshallingGenerator CreateCustomNativeTypeMarshaller(TypePositionInfo if (marshallerData.Shape.HasFlag(MarshallerShape.CallerAllocatedBuffer)) marshallingStrategy = new StatelessCallerAllocatedBufferMarshalling(marshallingStrategy, marshallerData.MarshallerType.Syntax, marshallerData.BufferElementType.Syntax, isLinearCollectionMarshalling: false); - StatelessFreeStrategy freeStrategy = GetStatelessFreeStrategy(info, marshallerData.Shape, context); + FreeStrategy freeStrategy = GetFreeStrategy(info, context); - if (freeStrategy == StatelessFreeStrategy.FreeOriginal) + if (freeStrategy == FreeStrategy.FreeOriginal) { marshallingStrategy = new StatelessUnmanagedToManagedOwnershipTracking(marshallingStrategy); } - if (freeStrategy != StatelessFreeStrategy.NoFree) + if (freeStrategy != FreeStrategy.NoFree && marshallerData.Shape.HasFlag(MarshallerShape.Free)) { marshallingStrategy = new StatelessFreeMarshalling(marshallingStrategy, marshallerData.MarshallerType.Syntax); } - if (freeStrategy == StatelessFreeStrategy.FreeOriginal) + if (freeStrategy == FreeStrategy.FreeOriginal) { marshallingStrategy = new FreeOwnedOriginalValueMarshalling(marshallingStrategy); } @@ -325,26 +325,39 @@ private IMarshallingGenerator CreateNativeCollectionMarshaller( marshallingStrategy = new StatefulCallerAllocatedBufferMarshalling(marshallingStrategy, marshallerTypeSyntax, bufferElementTypeSyntax); } + FreeStrategy freeStrategy = GetFreeStrategy(info, context); IElementsMarshallingCollectionSource collectionSource = new StatefulLinearCollectionSource(); IElementsMarshalling elementsMarshalling = CreateElementsMarshalling(marshallerData, elementInfo, elementMarshaller, unmanagedElementType, collectionSource); - marshallingStrategy = new StatefulLinearCollectionMarshalling(marshallingStrategy, marshallerData.Shape, numElementsExpression, elementsMarshalling); + if (freeStrategy == FreeStrategy.FreeOriginal) + { + marshallingStrategy = new StatelessUnmanagedToManagedOwnershipTracking(marshallingStrategy); + } + + marshallingStrategy = new StatefulLinearCollectionMarshalling(marshallingStrategy, marshallerData.Shape, numElementsExpression, elementsMarshalling, freeStrategy != FreeStrategy.NoFree); + + if (freeStrategy == FreeStrategy.FreeOriginal) + { + marshallingStrategy = new FreeOwnedOriginalValueMarshalling(marshallingStrategy); + } + + marshallingStrategy = new StatefulFreeMarshalling(marshallingStrategy); } else { marshallingStrategy = new StatelessLinearCollectionSpaceAllocator(marshallerTypeSyntax, nativeType, marshallerData.Shape, numElementsExpression); - StatelessFreeStrategy freeStrategy = GetStatelessFreeStrategy(info, marshallerData.Shape, context); + FreeStrategy freeStrategy = GetFreeStrategy(info, context); IElementsMarshallingCollectionSource collectionSource = new StatelessLinearCollectionSource(marshallerTypeSyntax); - if (freeStrategy == StatelessFreeStrategy.FreeOriginal) + if (freeStrategy == FreeStrategy.FreeOriginal) { marshallingStrategy = new StatelessUnmanagedToManagedOwnershipTracking(marshallingStrategy); } IElementsMarshalling elementsMarshalling = CreateElementsMarshalling(marshallerData, elementInfo, elementMarshaller, unmanagedElementType, collectionSource); - marshallingStrategy = new StatelessLinearCollectionMarshalling(marshallingStrategy, elementsMarshalling, nativeType, marshallerData.Shape, freeStrategy != StatelessFreeStrategy.NoFree); + marshallingStrategy = new StatelessLinearCollectionMarshalling(marshallingStrategy, elementsMarshalling, nativeType, marshallerData.Shape, freeStrategy != FreeStrategy.NoFree); if (marshallerData.Shape.HasFlag(MarshallerShape.CallerAllocatedBuffer)) { @@ -355,12 +368,12 @@ private IMarshallingGenerator CreateNativeCollectionMarshaller( marshallingStrategy = new StatelessCallerAllocatedBufferMarshalling(marshallingStrategy, marshallerTypeSyntax, bufferElementTypeSyntax, isLinearCollectionMarshalling: true); } - if (freeStrategy != StatelessFreeStrategy.NoFree) + if (freeStrategy != FreeStrategy.NoFree && marshallerData.Shape.HasFlag(MarshallerShape.Free)) { marshallingStrategy = new StatelessFreeMarshalling(marshallingStrategy, marshallerTypeSyntax); } - if (freeStrategy == StatelessFreeStrategy.FreeOriginal) + if (freeStrategy == FreeStrategy.FreeOriginal) { marshallingStrategy = new FreeOwnedOriginalValueMarshalling(marshallingStrategy); } @@ -379,7 +392,7 @@ private IMarshallingGenerator CreateNativeCollectionMarshaller( return marshallingGenerator; } - private enum StatelessFreeStrategy + private enum FreeStrategy { /// /// Free the unmanaged value stored in the native identifier. @@ -395,36 +408,30 @@ private enum StatelessFreeStrategy NoFree } - private static StatelessFreeStrategy GetStatelessFreeStrategy(TypePositionInfo info, MarshallerShape shape, StubCodeContext context) + private static FreeStrategy GetFreeStrategy(TypePositionInfo info, StubCodeContext context) { - // If the marshaller doesn't have the Free method, then we don't need to free anything. - if (!shape.HasFlag(MarshallerShape.Free)) - { - return StatelessFreeStrategy.NoFree; - } - // When marshalling from managed to unmanaged, we always own the value in the native identifier. if (context.Direction == MarshalDirection.ManagedToUnmanaged) { - return StatelessFreeStrategy.FreeNative; + return FreeStrategy.FreeNative; } // When we're in a case where we don't have state across stages, the parent stub context that can track the state // will only call our Cleanup stage when we own the value in the native identifier. if (!context.AdditionalTemporaryStateLivesAcrossStages) { - return StatelessFreeStrategy.FreeNative; + return FreeStrategy.FreeNative; } // In an unmanaged-to-managed stub where a value is passed by 'ref', // we own the original value once we replace it with the new value we're passing out to the caller. if (info.RefKind == RefKind.Ref) { - return StatelessFreeStrategy.FreeOriginal; + return FreeStrategy.FreeOriginal; } // In an unmanaged-to-managed stub, we don't take ownership of the value when it isn't passed by 'ref'. - return StatelessFreeStrategy.NoFree; + return FreeStrategy.NoFree; } private static IElementsMarshalling CreateElementsMarshalling(CustomTypeMarshallerData marshallerData, TypePositionInfo elementInfo, IMarshallingGenerator elementMarshaller, TypeSyntax unmanagedElementType, IElementsMarshallingCollectionSource collectionSource) diff --git a/src/libraries/System.Runtime.InteropServices/gen/Microsoft.Interop.SourceGeneration/Marshalling/StatefulMarshallingStrategy.cs b/src/libraries/System.Runtime.InteropServices/gen/Microsoft.Interop.SourceGeneration/Marshalling/StatefulMarshallingStrategy.cs index f5fce73a953578..2163c71047f875 100644 --- a/src/libraries/System.Runtime.InteropServices/gen/Microsoft.Interop.SourceGeneration/Marshalling/StatefulMarshallingStrategy.cs +++ b/src/libraries/System.Runtime.InteropServices/gen/Microsoft.Interop.SourceGeneration/Marshalling/StatefulMarshallingStrategy.cs @@ -372,40 +372,36 @@ internal sealed class StatefulLinearCollectionMarshalling : ICustomTypeMarshalli private readonly MarshallerShape _shape; private readonly ExpressionSyntax _numElementsExpression; private readonly IElementsMarshalling _elementsMarshalling; + private readonly bool _cleanupElements; public StatefulLinearCollectionMarshalling( ICustomTypeMarshallingStrategy innerMarshaller, MarshallerShape shape, ExpressionSyntax numElementsExpression, - IElementsMarshalling elementsMarshalling) + IElementsMarshalling elementsMarshalling, + bool cleanupElements) { _innerMarshaller = innerMarshaller; _shape = shape; _numElementsExpression = numElementsExpression; _elementsMarshalling = elementsMarshalling; + _cleanupElements = cleanupElements; } public ManagedTypeInfo AsNativeType(TypePositionInfo info) => _innerMarshaller.AsNativeType(info); public IEnumerable GenerateCleanupStatements(TypePositionInfo info, StubCodeContext context) { + if (!_cleanupElements) + { + yield break; + } + StatementSyntax elementCleanup = _elementsMarshalling.GenerateElementCleanupStatement(info, context); if (!elementCleanup.IsKind(SyntaxKind.EmptyStatement)) { yield return elementCleanup; } - - if (!_shape.HasFlag(MarshallerShape.Free)) - yield break; - - string marshaller = StatefulValueMarshalling.GetMarshallerIdentifier(info, context); - // .Free(); - yield return ExpressionStatement( - InvocationExpression( - MemberAccessExpression(SyntaxKind.SimpleMemberAccessExpression, - IdentifierName(marshaller), - IdentifierName(ShapeMemberNames.Free)), - ArgumentList())); } public IEnumerable GenerateGuaranteedUnmarshalStatements(TypePositionInfo info, StubCodeContext context) => _innerMarshaller.GenerateGuaranteedUnmarshalStatements(info, context); @@ -470,4 +466,50 @@ public IEnumerable GenerateUnmarshalStatements(TypePositionInfo public bool UsesNativeIdentifier(TypePositionInfo info, StubCodeContext context) => true; } + + /// + /// Marshaller that enables calling the Free method on a stateful marshaller. + /// + internal sealed class StatefulFreeMarshalling : ICustomTypeMarshallingStrategy + { + private readonly ICustomTypeMarshallingStrategy _innerMarshaller; + + public StatefulFreeMarshalling(ICustomTypeMarshallingStrategy innerMarshaller) + { + _innerMarshaller = innerMarshaller; + } + + public ManagedTypeInfo AsNativeType(TypePositionInfo info) => _innerMarshaller.AsNativeType(info); + + public IEnumerable GenerateCleanupStatements(TypePositionInfo info, StubCodeContext context) + { + foreach (var statement in _innerMarshaller.GenerateCleanupStatements(info, context)) + { + yield return statement; + } + + string marshaller = StatefulValueMarshalling.GetMarshallerIdentifier(info, context); + // .Free(); + yield return ExpressionStatement( + InvocationExpression( + MemberAccessExpression(SyntaxKind.SimpleMemberAccessExpression, + IdentifierName(marshaller), + IdentifierName(ShapeMemberNames.Free)), + ArgumentList())); + } + public IEnumerable GenerateGuaranteedUnmarshalStatements(TypePositionInfo info, StubCodeContext context) => _innerMarshaller.GenerateGuaranteedUnmarshalStatements(info, context); + + public IEnumerable GenerateMarshalStatements(TypePositionInfo info, StubCodeContext context) => _innerMarshaller.GenerateMarshalStatements(info, context); + + public IEnumerable GenerateNotifyForSuccessfulInvokeStatements(TypePositionInfo info, StubCodeContext context) => _innerMarshaller.GenerateNotifyForSuccessfulInvokeStatements(info, context); + public IEnumerable GeneratePinnedMarshalStatements(TypePositionInfo info, StubCodeContext context) => _innerMarshaller.GeneratePinnedMarshalStatements(info, context); + public IEnumerable GeneratePinStatements(TypePositionInfo info, StubCodeContext context) => _innerMarshaller.GeneratePinStatements(info, context); + public IEnumerable GenerateSetupStatements(TypePositionInfo info, StubCodeContext context) => _innerMarshaller.GenerateSetupStatements(info, context); + + public IEnumerable GenerateUnmarshalStatements(TypePositionInfo info, StubCodeContext context) => _innerMarshaller.GenerateUnmarshalStatements(info, context); + + public IEnumerable GenerateUnmarshalCaptureStatements(TypePositionInfo info, StubCodeContext context) => _innerMarshaller.GenerateUnmarshalCaptureStatements(info, context); + + public bool UsesNativeIdentifier(TypePositionInfo info, StubCodeContext context) => _innerMarshaller.UsesNativeIdentifier(info, context); + } } diff --git a/src/libraries/System.Runtime.InteropServices/tests/ComInterfaceGenerator.Tests/ImplicitThisTests.cs b/src/libraries/System.Runtime.InteropServices/tests/ComInterfaceGenerator.Tests/ImplicitThisTests.cs index 9cd3ecb3bd1b1b..20ee9b4a27668f 100644 --- a/src/libraries/System.Runtime.InteropServices/tests/ComInterfaceGenerator.Tests/ImplicitThisTests.cs +++ b/src/libraries/System.Runtime.InteropServices/tests/ComInterfaceGenerator.Tests/ImplicitThisTests.cs @@ -18,7 +18,7 @@ internal partial class ImplicitThis [UnmanagedObjectUnwrapperAttribute>] internal partial interface INativeObject : IUnmanagedInterfaceType { - private static void** s_vtable = (void**)RuntimeHelpers.AllocateTypeAssociatedMemory(typeof(INativeObject), sizeof(void*) * 5); + private static void** s_vtable = (void**)RuntimeHelpers.AllocateTypeAssociatedMemory(typeof(INativeObject), sizeof(void*) * 6); static void* IUnmanagedInterfaceType.VirtualMethodTableManagedImplementation { get diff --git a/src/libraries/System.Runtime.InteropServices/tests/ComInterfaceGenerator.Tests/UnmanagedToManagedCustomMarshallingTests.cs b/src/libraries/System.Runtime.InteropServices/tests/ComInterfaceGenerator.Tests/UnmanagedToManagedCustomMarshallingTests.cs index d4eb80c7663630..e1e8617b890577 100644 --- a/src/libraries/System.Runtime.InteropServices/tests/ComInterfaceGenerator.Tests/UnmanagedToManagedCustomMarshallingTests.cs +++ b/src/libraries/System.Runtime.InteropServices/tests/ComInterfaceGenerator.Tests/UnmanagedToManagedCustomMarshallingTests.cs @@ -23,7 +23,7 @@ internal partial class UnmanagedToManagedCustomMarshalling [UnmanagedObjectUnwrapper>] internal partial interface INativeObject : IUnmanagedInterfaceType { - private static void** s_vtable = (void**)RuntimeHelpers.AllocateTypeAssociatedMemory(typeof(INativeObject), sizeof(void*) * 5); + private static void** s_vtable = (void**)RuntimeHelpers.AllocateTypeAssociatedMemory(typeof(INativeObject), sizeof(void*) * 6); static void* IUnmanagedInterfaceType.VirtualMethodTableManagedImplementation { get @@ -59,11 +59,41 @@ void MultiplyWithData( int numValues); } - [LibraryImport(NativeExportsNE_Binary, EntryPoint = "new_native_object")] - public static partial NativeObject NewNativeObject(); + [UnmanagedObjectUnwrapper>] + internal partial interface INativeObjectStateful : IUnmanagedInterfaceType + { + private static void** s_vtable = (void**)RuntimeHelpers.AllocateTypeAssociatedMemory(typeof(INativeObjectStateful), sizeof(void*) * 6); + static void* IUnmanagedInterfaceType.VirtualMethodTableManagedImplementation + { + get + { + if (s_vtable[0] == null) + { + Native.PopulateUnmanagedVirtualMethodTable(s_vtable); + } + return s_vtable; + } + } - [LibraryImport(NativeExportsNE_Binary, EntryPoint = "delete_native_object")] - public static partial void DeleteNativeObject(void* obj); + [VirtualMethodIndex(3, ImplicitThisParameter = true, Direction = MarshalDirection.UnmanagedToManaged)] + void SumAndSetData( + [MarshalUsing(typeof(StatefulUnmanagedToManagedCollectionMarshaller<,>), CountElementName = nameof(numValues))] + [MarshalUsing(typeof(IntWrapperMarshallerToIntWithFreeCounts), ElementIndirectionDepth = 1)] IntWrapper[] values, + int numValues, + [MarshalUsing(typeof(IntWrapperMarshallerToIntWithFreeCounts))] out IntWrapper oldValue); + [VirtualMethodIndex(4, ImplicitThisParameter = true, Direction = MarshalDirection.UnmanagedToManaged)] + void SumAndSetData( + [MarshalUsing(typeof(StatefulUnmanagedToManagedCollectionMarshaller<,>), CountElementName = nameof(numValues))] + [MarshalUsing(typeof(IntWrapperMarshallerToIntWithFreeCounts), ElementIndirectionDepth = 1)] ref IntWrapper[] values, + int numValues, + [MarshalUsing(typeof(IntWrapperMarshallerToIntWithFreeCounts))] out IntWrapper oldValue); + + [VirtualMethodIndex(5, ImplicitThisParameter = true, Direction = MarshalDirection.UnmanagedToManaged)] + void MultiplyWithData( + [MarshalUsing(typeof(StatefulUnmanagedToManagedCollectionMarshaller<,>), CountElementName = nameof(numValues))] + [MarshalUsing(typeof(IntWrapperMarshallerToIntWithFreeCounts), ElementIndirectionDepth = 1), In, Out] IntWrapper[] values123, + int numValues); + } [LibraryImport(NativeExportsNE_Binary, EntryPoint = "set_native_object_data")] public static partial void SetNativeObjectData(void* obj, int data); @@ -168,31 +198,121 @@ public unsafe void ValidateArrayElementsByRefFreed_Stateless() } [Fact] + [ActiveIssue("https://github.com/dotnet/runtime/issues/86608")] public unsafe void ValidateArrayElementsByValueOutFreed_Stateless() { const int startingValue = 13; ManagedObjectImplementation impl = new ManagedObjectImplementation(startingValue); - void* wrapper = VTableGCHandlePair.Allocate(impl); + void* wrapper = VTableGCHandlePair.Allocate(impl); try { var values = new int[] { 1, 32, 63, 124, 255 }; + var expected = values.Select(x => x * startingValue).ToArray(); - int freeCalls = IntWrapperMarshallerToIntWithFreeCounts.NumCallsToFree; + int elementFreeCalls = IntWrapperMarshallerToIntWithFreeCounts.NumCallsToFree; NativeExportsNE.UnmanagedToManagedCustomMarshalling.MultiplyWithNativeObjectData(wrapper, values, values.Length); - Assert.Equal(freeCalls + values.Length, IntWrapperMarshallerToIntWithFreeCounts.NumCallsToFree); + //Assert.Equal(expected, values); + + Assert.Equal(elementFreeCalls + values.Length, IntWrapperMarshallerToIntWithFreeCounts.NumCallsToFree); } finally { - VTableGCHandlePair.Free(wrapper); + VTableGCHandlePair.Free(wrapper); + } + } + + [Fact] + public unsafe void ValidateArrayElementsAndOutParameterNotFreed_Stateful() + { + const int startingValue = 13; + + ManagedObjectImplementation impl = new ManagedObjectImplementation(startingValue); + + void* wrapper = VTableGCHandlePair.Allocate(impl); + + try + { + var values = new int[] { 1, 32, 63, 124, 255 }; + + int elementFreeCalls = IntWrapperMarshallerToIntWithFreeCounts.NumCallsToFree; + int marshallerFreeCalls = StatefulUnmanagedToManagedCollectionMarshaller.In.NumCallsToFree; + + NativeExportsNE.UnmanagedToManagedCustomMarshalling.SumAndSetNativeObjectData(wrapper, values, values.Length, out int _); + + // We shouldn't free the elements, but we always free the stateful marshaller. + Assert.Equal(elementFreeCalls, IntWrapperMarshallerToIntWithFreeCounts.NumCallsToFree); + Assert.Equal(marshallerFreeCalls + 1, StatefulUnmanagedToManagedCollectionMarshaller.In.NumCallsToFree); + } + finally + { + VTableGCHandlePair.Free(wrapper); + } + } + + [Fact] + public unsafe void ValidateArrayElementsByRefFreed_Stateful() + { + const int startingValue = 13; + + ManagedObjectImplementation impl = new ManagedObjectImplementation(startingValue); + + void* wrapper = VTableGCHandlePair.Allocate(impl); + + try + { + var values = new int[] { 1, 32, 63, 124, 255 }; + + int elementFreeCalls = IntWrapperMarshallerToIntWithFreeCounts.NumCallsToFree; + int marshallerFreeCalls = StatefulUnmanagedToManagedCollectionMarshaller.Ref.NumCallsToFree; + + NativeExportsNE.UnmanagedToManagedCustomMarshalling.SumAndSetNativeObjectData(wrapper, ref values, values.Length, out int _); + + Assert.Equal(elementFreeCalls + values.Length, IntWrapperMarshallerToIntWithFreeCounts.NumCallsToFree); + Assert.Equal(marshallerFreeCalls + 1, StatefulUnmanagedToManagedCollectionMarshaller.Ref.NumCallsToFree); + } + finally + { + VTableGCHandlePair.Free(wrapper); + } + } + + [Fact] + [ActiveIssue("https://github.com/dotnet/runtime/issues/86608")] + public unsafe void ValidateArrayElementsByValueOutFreed_Stateful() + { + const int startingValue = 13; + + ManagedObjectImplementation impl = new ManagedObjectImplementation(startingValue); + + void* wrapper = VTableGCHandlePair.Allocate(impl); + + try + { + var values = new int[] { 1, 32, 63, 124, 255 }; + var expected = values.Select(x => x * startingValue).ToArray(); + + int elementFreeCalls = IntWrapperMarshallerToIntWithFreeCounts.NumCallsToFree; + int marshallerFreeCalls = StatefulUnmanagedToManagedCollectionMarshaller.In.NumCallsToFree; + + NativeExportsNE.UnmanagedToManagedCustomMarshalling.MultiplyWithNativeObjectData(wrapper, values, values.Length); + + Assert.Equal(expected, values); + + Assert.Equal(elementFreeCalls + values.Length, IntWrapperMarshallerToIntWithFreeCounts.NumCallsToFree); + Assert.Equal(marshallerFreeCalls + 1, StatefulUnmanagedToManagedCollectionMarshaller.In.NumCallsToFree); + } + finally + { + VTableGCHandlePair.Free(wrapper); } } - sealed class ManagedObjectImplementation : NativeExportsNE.UnmanagedToManagedCustomMarshalling.INativeObject + sealed unsafe class ManagedObjectImplementation : NativeExportsNE.UnmanagedToManagedCustomMarshalling.INativeObject, NativeExportsNE.UnmanagedToManagedCustomMarshalling.INativeObjectStateful { private IntWrapper _data; @@ -218,8 +338,16 @@ public void SumAndSetData(IntWrapper[] values, int numValues, out IntWrapper old oldValue = _data; _data = new() { i = value }; } - } + static void* IUnmanagedInterfaceType.VirtualMethodTableManagedImplementation + { + get + { + Assert.Fail("The VirtualMethodTableManagedImplementation property should not be called on implementing class types"); + return null; + } + } + } [CustomMarshaller(typeof(IntWrapper), MarshalMode.Default, typeof(IntWrapperMarshallerToIntWithFreeCounts))] public static unsafe class IntWrapperMarshallerToIntWithFreeCounts @@ -242,5 +370,102 @@ public static void Free(int _) NumCallsToFree++; } } + + [CustomMarshaller(typeof(CustomMarshallerAttribute.GenericPlaceholder[]), MarshalMode.UnmanagedToManagedIn, typeof(StatefulUnmanagedToManagedCollectionMarshaller<,>.In))] + [CustomMarshaller(typeof(CustomMarshallerAttribute.GenericPlaceholder[]), MarshalMode.UnmanagedToManagedRef, typeof(StatefulUnmanagedToManagedCollectionMarshaller<,>.Ref))] + [ContiguousCollectionMarshaller] + public unsafe static class StatefulUnmanagedToManagedCollectionMarshaller + where TUnmanaged : unmanaged + { + public struct In + { + [ThreadStatic] + public static int NumCallsToFree = 0; + + private TUnmanaged* _unmanaged; + private TManaged[] _managed; + + public void FromUnmanaged(TUnmanaged* unmanaged) + { + _unmanaged = unmanaged; + } + + public Span GetManagedValuesDestination(int numElements) + { + return _managed = new TManaged[numElements]; + } + + public ReadOnlySpan GetUnmanagedValuesSource(int numElements) + { + return new(_unmanaged, numElements); + } + + public TManaged[] ToManaged() + { + return _managed; + } + + public void Free() + { + NumCallsToFree++; + } + } + + public struct Ref + { + [ThreadStatic] + public static int NumCallsToFree = 0; + + private TUnmanaged* _originalUnmanaged; + private TUnmanaged* _unmanaged; + private TManaged[] _managed; + + public void FromUnmanaged(TUnmanaged* unmanaged) + { + _originalUnmanaged = unmanaged; + } + + public Span GetManagedValuesDestination(int numElements) + { + return _managed = new TManaged[numElements]; + } + + public ReadOnlySpan GetUnmanagedValuesSource(int numElements) + { + return new(_originalUnmanaged, numElements); + } + + public TManaged[] ToManaged() + { + return _managed; + } + + public void Free() + { + Marshal.FreeCoTaskMem((nint)_originalUnmanaged); + NumCallsToFree++; + } + + public void FromManaged(TManaged[] managed) + { + _managed = managed; + } + + public TUnmanaged* ToUnmanaged() + { + return _unmanaged = (TUnmanaged*)Marshal.AllocCoTaskMem(sizeof(TUnmanaged) * _managed.Length); + } + + public ReadOnlySpan GetManagedValuesSource() + { + return _managed; + } + + public Span GetUnmanagedValuesDestination() + { + return new(_unmanaged, _managed.Length); + } + } + } } } diff --git a/src/libraries/System.Runtime.InteropServices/tests/TestAssets/NativeExports/VirtualMethodTables.cs b/src/libraries/System.Runtime.InteropServices/tests/TestAssets/NativeExports/VirtualMethodTables.cs index b435fa7951a40e..a8e6458fb0ab5a 100644 --- a/src/libraries/System.Runtime.InteropServices/tests/TestAssets/NativeExports/VirtualMethodTables.cs +++ b/src/libraries/System.Runtime.InteropServices/tests/TestAssets/NativeExports/VirtualMethodTables.cs @@ -210,7 +210,7 @@ public static void SumAndSetDataWithRef([DNNE.C99Type("struct INativeObject*")] [UnmanagedCallersOnly(EntryPoint = "multiply_with_native_object_data")] [DNNE.C99DeclCode("struct INativeObject;")] - public static void SumAndSetDataWithRef([DNNE.C99Type("struct INativeObject*")] NativeObjectInterface* obj, int* values, int numValues) + public static void MultiplyWithData([DNNE.C99Type("struct INativeObject*")] NativeObjectInterface* obj, int* values, int numValues) { obj->VTable->multiplyWithData(obj, values, numValues); } From 32ff4473c230c1c7355c0e302cde490ad7281605 Mon Sep 17 00:00:00 2001 From: Jeremy Koritzinsky Date: Mon, 22 May 2023 14:50:45 -0700 Subject: [PATCH 8/9] Fix test failures and remove duplicate test runs. --- .../AttributedMarshallingModelGeneratorFactory.cs | 5 ++++- .../tests/ComInterfaceGenerator.Unit.Tests/Compiles.cs | 1 - 2 files changed, 4 insertions(+), 2 deletions(-) diff --git a/src/libraries/System.Runtime.InteropServices/gen/Microsoft.Interop.SourceGeneration/Marshalling/AttributedMarshallingModelGeneratorFactory.cs b/src/libraries/System.Runtime.InteropServices/gen/Microsoft.Interop.SourceGeneration/Marshalling/AttributedMarshallingModelGeneratorFactory.cs index 3a90412e039231..ee9809770583e3 100644 --- a/src/libraries/System.Runtime.InteropServices/gen/Microsoft.Interop.SourceGeneration/Marshalling/AttributedMarshallingModelGeneratorFactory.cs +++ b/src/libraries/System.Runtime.InteropServices/gen/Microsoft.Interop.SourceGeneration/Marshalling/AttributedMarshallingModelGeneratorFactory.cs @@ -341,7 +341,10 @@ private IMarshallingGenerator CreateNativeCollectionMarshaller( marshallingStrategy = new FreeOwnedOriginalValueMarshalling(marshallingStrategy); } - marshallingStrategy = new StatefulFreeMarshalling(marshallingStrategy); + if (marshallerData.Shape.HasFlag(MarshallerShape.Free)) + { + marshallingStrategy = new StatefulFreeMarshalling(marshallingStrategy); + } } else { diff --git a/src/libraries/System.Runtime.InteropServices/tests/ComInterfaceGenerator.Unit.Tests/Compiles.cs b/src/libraries/System.Runtime.InteropServices/tests/ComInterfaceGenerator.Unit.Tests/Compiles.cs index 89c1e84fd5ab80..628db426187198 100644 --- a/src/libraries/System.Runtime.InteropServices/tests/ComInterfaceGenerator.Unit.Tests/Compiles.cs +++ b/src/libraries/System.Runtime.InteropServices/tests/ComInterfaceGenerator.Unit.Tests/Compiles.cs @@ -327,7 +327,6 @@ public static IEnumerable CustomCollectionsManagedToUnmanaged(Generato [MemberData(nameof(UnmanagedToManagedCodeSnippetsToCompile), GeneratorKind.VTableIndexStubGenerator)] [MemberData(nameof(CustomCollectionsManagedToUnmanaged), GeneratorKind.VTableIndexStubGenerator)] [MemberData(nameof(CustomCollections), GeneratorKind.VTableIndexStubGenerator)] - [MemberData(nameof(CustomCollections), GeneratorKind.VTableIndexStubGenerator)] public async Task ValidateVTableIndexSnippets(string id, string source) { _ = id; From 14c023b4a6ff053f4460165af8fdafd58f7e62c1 Mon Sep 17 00:00:00 2001 From: Jeremy Koritzinsky Date: Mon, 22 May 2023 14:53:18 -0700 Subject: [PATCH 9/9] Uncomment code --- .../UnmanagedToManagedCustomMarshallingTests.cs | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/libraries/System.Runtime.InteropServices/tests/ComInterfaceGenerator.Tests/UnmanagedToManagedCustomMarshallingTests.cs b/src/libraries/System.Runtime.InteropServices/tests/ComInterfaceGenerator.Tests/UnmanagedToManagedCustomMarshallingTests.cs index e1e8617b890577..224ea68c41477f 100644 --- a/src/libraries/System.Runtime.InteropServices/tests/ComInterfaceGenerator.Tests/UnmanagedToManagedCustomMarshallingTests.cs +++ b/src/libraries/System.Runtime.InteropServices/tests/ComInterfaceGenerator.Tests/UnmanagedToManagedCustomMarshallingTests.cs @@ -216,7 +216,7 @@ public unsafe void ValidateArrayElementsByValueOutFreed_Stateless() NativeExportsNE.UnmanagedToManagedCustomMarshalling.MultiplyWithNativeObjectData(wrapper, values, values.Length); - //Assert.Equal(expected, values); + Assert.Equal(expected, values); Assert.Equal(elementFreeCalls + values.Length, IntWrapperMarshallerToIntWithFreeCounts.NumCallsToFree); }