diff --git a/src/libraries/System.Private.CoreLib/src/System/Diagnostics/CodeAnalysis/RequiresUnsafeAttribute.cs b/src/libraries/System.Private.CoreLib/src/System/Diagnostics/CodeAnalysis/RequiresUnsafeAttribute.cs index 6ecf050781400a..f4641cc0e0eb0f 100644 --- a/src/libraries/System.Private.CoreLib/src/System/Diagnostics/CodeAnalysis/RequiresUnsafeAttribute.cs +++ b/src/libraries/System.Private.CoreLib/src/System/Diagnostics/CodeAnalysis/RequiresUnsafeAttribute.cs @@ -13,7 +13,7 @@ namespace System.Diagnostics.CodeAnalysis /// [AttributeUsage(AttributeTargets.Method | AttributeTargets.Constructor | AttributeTargets.Property, Inherited = false)] [Conditional("DEBUG")] - internal sealed class RequiresUnsafeAttribute : Attribute + public sealed class RequiresUnsafeAttribute : Attribute { } } diff --git a/src/libraries/System.Runtime.InteropServices/gen/LibraryImportGenerator/LibraryImportGenerator.cs b/src/libraries/System.Runtime.InteropServices/gen/LibraryImportGenerator/LibraryImportGenerator.cs index aee232c0d9d7b0..7642bc9d49b5a1 100644 --- a/src/libraries/System.Runtime.InteropServices/gen/LibraryImportGenerator/LibraryImportGenerator.cs +++ b/src/libraries/System.Runtime.InteropServices/gen/LibraryImportGenerator/LibraryImportGenerator.cs @@ -267,6 +267,15 @@ private static IncrementalStubGenerationContext CalculateStubInformation( var methodSyntaxTemplate = new ContainingSyntax(originalSyntax.Modifiers, SyntaxKind.MethodDeclaration, originalSyntax.Identifier, originalSyntax.TypeParameterList); + // If [RequiresUnsafe] is available, set the flag so it can be added to the stub later. + // Don't add if the user's declaration already has it (to avoid duplicate attribute error). + var environmentFlags = environment.EnvironmentFlags; + if (environment.RequiresUnsafeAttrType is not null + && !symbol.GetAttributes().Any(a => SymbolEqualityComparer.Default.Equals(a.AttributeClass, environment.RequiresUnsafeAttrType))) + { + environmentFlags |= EnvironmentFlags.RequiresUnsafeAvailable; + } + List additionalAttributes = GenerateSyntaxForForwardedAttributes(suppressGCTransitionAttribute, unmanagedCallConvAttribute, defaultDllImportSearchPathsAttribute, wasmImportLinkageAttribute, stackTraceHiddenAttribute); return new IncrementalStubGenerationContext( signatureContext, @@ -276,7 +285,7 @@ private static IncrementalStubGenerationContext CalculateStubInformation( new SequenceEqualImmutableArray(additionalAttributes.ToImmutableArray(), SyntaxEquivalentComparer.Instance), LibraryImportData.From(libraryImportData), options, - environment.EnvironmentFlags); + environmentFlags); } private static MemberDeclarationSyntax GenerateSource( @@ -330,7 +339,17 @@ private static MemberDeclarationSyntax GenerateSource( dllImport = dllImport.WithLeadingTrivia(Comment("// Local P/Invoke")); code = code.AddStatements(dllImport); - return pinvokeStub.ContainingSyntaxContext.WrapMemberInContainingSyntaxWithUnsafeModifier(PrintGeneratedSource(pinvokeStub.StubMethodSyntaxTemplate, pinvokeStub.SignatureContext, code)); + var signatureContext = pinvokeStub.SignatureContext; + if (pinvokeStub.EnvironmentFlags.HasFlag(EnvironmentFlags.RequiresUnsafeAvailable)) + { + signatureContext = signatureContext with + { + AdditionalAttributes = signatureContext.AdditionalAttributes.Add( + AttributeList(SingletonSeparatedList(Attribute(NameSyntaxes.System_Diagnostics_CodeAnalysis_RequiresUnsafeAttribute)))) + }; + } + + return pinvokeStub.ContainingSyntaxContext.WrapMemberInContainingSyntaxWithUnsafeModifier(PrintGeneratedSource(pinvokeStub.StubMethodSyntaxTemplate, signatureContext, code)); } private static MemberDeclarationSyntax PrintForwarderStub(ContainingSyntax userDeclaredMethod, IncrementalStubGenerationContext stub) @@ -361,6 +380,12 @@ private static MemberDeclarationSyntax PrintForwarderStub(ContainingSyntax userD SingletonSeparatedList( CreateForwarderDllImport(pinvokeData)))); + if (stub.EnvironmentFlags.HasFlag(EnvironmentFlags.RequiresUnsafeAvailable)) + { + stubMethod = stubMethod.AddAttributeLists( + AttributeList(SingletonSeparatedList(Attribute(NameSyntaxes.System_Diagnostics_CodeAnalysis_RequiresUnsafeAttribute)))); + } + MemberDeclarationSyntax toPrint = stub.ContainingSyntaxContext.WrapMemberInContainingSyntaxWithUnsafeModifier(stubMethod); return toPrint; diff --git a/src/libraries/System.Runtime.InteropServices/gen/Microsoft.Interop.SourceGeneration/StubEnvironment.cs b/src/libraries/System.Runtime.InteropServices/gen/Microsoft.Interop.SourceGeneration/StubEnvironment.cs index 5f87229e9e5752..9c9ae1b5afcb67 100644 --- a/src/libraries/System.Runtime.InteropServices/gen/Microsoft.Interop.SourceGeneration/StubEnvironment.cs +++ b/src/libraries/System.Runtime.InteropServices/gen/Microsoft.Interop.SourceGeneration/StubEnvironment.cs @@ -12,6 +12,7 @@ public enum EnvironmentFlags None = 0, SkipLocalsInit = 0x1, DisableRuntimeMarshalling = 0x2, + RequiresUnsafeAvailable = 0x4, } public sealed record StubEnvironment( @@ -101,5 +102,19 @@ public INamedTypeSymbol? StackTraceHiddenAttrType return _stackTraceHiddenAttrType.Value; } } + + private Optional _requiresUnsafeAttrType; + public INamedTypeSymbol? RequiresUnsafeAttrType + { + get + { + if (_requiresUnsafeAttrType.HasValue) + { + return _requiresUnsafeAttrType.Value; + } + _requiresUnsafeAttrType = new Optional(Compilation.GetTypeByMetadataName(TypeNames.System_Diagnostics_CodeAnalysis_RequiresUnsafeAttribute)); + return _requiresUnsafeAttrType.Value; + } + } } } diff --git a/src/libraries/System.Runtime.InteropServices/gen/Microsoft.Interop.SourceGeneration/TypeNames.cs b/src/libraries/System.Runtime.InteropServices/gen/Microsoft.Interop.SourceGeneration/TypeNames.cs index f906e506bdec31..7d8db5ddc2b7eb 100644 --- a/src/libraries/System.Runtime.InteropServices/gen/Microsoft.Interop.SourceGeneration/TypeNames.cs +++ b/src/libraries/System.Runtime.InteropServices/gen/Microsoft.Interop.SourceGeneration/TypeNames.cs @@ -44,6 +44,9 @@ public static class NameSyntaxes public static NameSyntax System_Runtime_InteropServices_StructLayoutAttribute => _System_Runtime_InteropServices_StructLayoutAttribute ??= ParseName(TypeNames.GlobalAlias + TypeNames.System_Runtime_InteropServices_StructLayoutAttribute); private static NameSyntax? _System_Diagnostics_StackTraceHiddenAttribute; public static NameSyntax System_Diagnostics_StackTraceHiddenAttribute => _System_Diagnostics_StackTraceHiddenAttribute ??= ParseName(TypeNames.GlobalAlias + TypeNames.System_Diagnostics_StackTraceHiddenAttribute); + + private static NameSyntax? _System_Diagnostics_CodeAnalysis_RequiresUnsafeAttribute; + public static NameSyntax System_Diagnostics_CodeAnalysis_RequiresUnsafeAttribute => _System_Diagnostics_CodeAnalysis_RequiresUnsafeAttribute ??= ParseName(TypeNames.GlobalAlias + TypeNames.System_Diagnostics_CodeAnalysis_RequiresUnsafeAttribute); } public static class TypeSyntaxes @@ -202,6 +205,8 @@ public static class TypeNames public const string System_Diagnostics_StackTraceHiddenAttribute = "System.Diagnostics.StackTraceHiddenAttribute"; + public const string System_Diagnostics_CodeAnalysis_RequiresUnsafeAttribute = "System.Diagnostics.CodeAnalysis.RequiresUnsafeAttribute"; + public static string MarshalEx(InteropGenerationOptions options) { return options.UseMarshalType ? System_Runtime_InteropServices_Marshal : System_Runtime_InteropServices_MarshalEx; diff --git a/src/libraries/System.Runtime.InteropServices/tests/LibraryImportGenerator.UnitTests/AdditionalAttributesOnStub.cs b/src/libraries/System.Runtime.InteropServices/tests/LibraryImportGenerator.UnitTests/AdditionalAttributesOnStub.cs index ebb295f6d3fe79..c73e76dff33985 100644 --- a/src/libraries/System.Runtime.InteropServices/tests/LibraryImportGenerator.UnitTests/AdditionalAttributesOnStub.cs +++ b/src/libraries/System.Runtime.InteropServices/tests/LibraryImportGenerator.UnitTests/AdditionalAttributesOnStub.cs @@ -115,6 +115,54 @@ partial class C await VerifySourceGeneratorAsync(source, "C", "Method", typeof(System.CodeDom.Compiler.GeneratedCodeAttribute).FullName, attributeAdded: false); } + [Fact] + public async Task RequiresUnsafeAdded() + { + string source = """ + using System.Runtime.CompilerServices; + using System.Runtime.InteropServices; + using System.Runtime.InteropServices.Marshalling; + [assembly:DisableRuntimeMarshalling] + partial class C + { + [LibraryImportAttribute("DoesNotExist")] + public static partial S Method(); + } + + [NativeMarshalling(typeof(Marshaller))] + struct S + { + } + + struct Native + { + } + + [CustomMarshaller(typeof(S), MarshalMode.Default, typeof(Marshaller))] + static class Marshaller + { + public static Native ConvertToUnmanaged(S s) => default; + + public static S ConvertToManaged(Native n) => default; + } + """; + await VerifySourceGeneratorAsync(source, "C", "Method", "System.Diagnostics.CodeAnalysis.RequiresUnsafeAttribute", attributeAdded: true); + } + + [Fact] + public async Task RequiresUnsafeAddedOnForwardingStub() + { + string source = """ + using System.Runtime.InteropServices; + partial class C + { + [LibraryImportAttribute("DoesNotExist")] + public static partial void Method(); + } + """; + await VerifySourceGeneratorAsync(source, "C", "Method", "System.Diagnostics.CodeAnalysis.RequiresUnsafeAttribute", attributeAdded: true); + } + public static IEnumerable GetDownlevelTargetFrameworks() { yield return new object[] { TestTargetFramework.Standard2_0, false }; diff --git a/src/libraries/System.Runtime/ref/System.Runtime.cs b/src/libraries/System.Runtime/ref/System.Runtime.cs index ccc217672ebad4..88d019f7a86b96 100644 --- a/src/libraries/System.Runtime/ref/System.Runtime.cs +++ b/src/libraries/System.Runtime/ref/System.Runtime.cs @@ -9110,6 +9110,12 @@ public RequiresUnreferencedCodeAttribute(string message) { } public string Message { get { throw null; } } public string? Url { get { throw null; } set { } } } + [System.AttributeUsageAttribute(System.AttributeTargets.Constructor | System.AttributeTargets.Method | System.AttributeTargets.Property, Inherited=false)] + [System.Diagnostics.ConditionalAttribute("DEBUG")] + public sealed partial class RequiresUnsafeAttribute : System.Attribute + { + public RequiresUnsafeAttribute() { } + } [System.AttributeUsageAttribute(System.AttributeTargets.Constructor, AllowMultiple=false, Inherited=false)] public sealed partial class SetsRequiredMembersAttribute : System.Attribute {