Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
@@ -1,36 +1,35 @@
// Licensed to the .NET Foundation under one or more agreements.
// The .NET Foundation licenses this file to you under the MIT license.

using System.Collections.Generic;
using System.Collections.Immutable;
using System.CodeDom.Compiler;
using System.IO;
using System.Linq;
using Microsoft.CodeAnalysis;
using Microsoft.CodeAnalysis.CSharp;
using Microsoft.CodeAnalysis.CSharp.Syntax;
using Microsoft.CodeAnalysis.DotnetRuntime.Extensions;
using Microsoft.Interop.Analyzers;
using static Microsoft.CodeAnalysis.CSharp.SyntaxFactory;
using static Microsoft.Interop.SyntaxFactoryExtensions;

namespace Microsoft.Interop
{
[Generator]
public class ComClassGenerator : IIncrementalGenerator
{
private const string ClassInfoTypeName = "ComClassInformation";

public void Initialize(IncrementalGeneratorInitializationContext context)
{
// Get all types with the [GeneratedComClassAttribute] attribute.
var attributedClasses = context.SyntaxProvider
IncrementalValuesProvider<ComClassInfo> attributedClasses = context.SyntaxProvider
.ForAttributeWithMetadataName(
TypeNames.GeneratedComClassAttribute,
static (node, ct) => node is ClassDeclarationSyntax,
static (context, _) =>
{
var type = (INamedTypeSymbol)context.TargetSymbol;
var syntax = (ClassDeclarationSyntax)context.TargetNode;
var compilation = context.SemanticModel.Compilation;
var unsafeCodeIsEnabled = compilation.Options is CSharpCompilationOptions { AllowUnsafe: true };
Compilation compilation = context.SemanticModel.Compilation;
bool unsafeCodeIsEnabled = compilation.Options is CSharpCompilationOptions { AllowUnsafe: true };
INamedTypeSymbol? generatedComInterfaceAttributeType = compilation.GetBestTypeByMetadataName(TypeNames.GeneratedComInterfaceAttribute);

// Currently all reported diagnostics are fatal to the generator
Expand All @@ -43,169 +42,61 @@ public void Initialize(IncrementalGeneratorInitializationContext context)
})
.Where(static info => info is not null);

var classInfoType = attributedClasses
.Select(static (info, ct) => new ItemAndSyntaxes<ComClassInfo>(info,
[
GenerateClassInfoType(info.ImplementedInterfacesNames.Array).NormalizeWhitespace(),
GenerateClassInfoAttributeOnUserType(info.ContainingSyntaxContext, info.ClassSyntax).NormalizeWhitespace()
]));

context.RegisterSourceOutput(classInfoType, static (context, data) =>
context.RegisterSourceOutput(attributedClasses, static (context, data) =>
{
var className = data.Context.ClassName;
var classInfoType = data[0];
var attribute = data[1];
string className = data.ClassName;
SequenceEqualImmutableArray<string> implementedInterfaces = data.ImplementedInterfacesNames;

StringWriter writer = new();
using StringWriter sw = new();
using IndentedTextWriter writer = new(sw);
writer.WriteLine("// <auto-generated />");
writer.WriteLine(classInfoType.ToFullString());
writer.WriteLine();
writer.WriteLine(attribute);
// Replace < and > with { and } to make valid hint names for generic types
string hintName = className.Replace('<', '{').Replace('>', '}');
context.AddSource(hintName, writer.ToString());
});
}

private const string ClassInfoTypeName = "ComClassInformation";

private static readonly AttributeSyntax s_comExposedClassAttributeTemplate =
Attribute(
GenericName(TypeNames.GlobalAlias + TypeNames.ComExposedClassAttribute)
.AddTypeArgumentListArguments(
IdentifierName(ClassInfoTypeName)));
private static MemberDeclarationSyntax GenerateClassInfoAttributeOnUserType(ContainingSyntaxContext containingSyntaxContext, ContainingSyntax classSyntax) =>
containingSyntaxContext.WrapMemberInContainingSyntaxWithUnsafeModifier(
TypeDeclaration(classSyntax.TypeKind, classSyntax.Identifier)
.WithModifiers(classSyntax.Modifiers)
.WithTypeParameterList(classSyntax.TypeParameters)
.AddAttributeLists(AttributeList(SingletonSeparatedList(s_comExposedClassAttributeTemplate))));
private static ClassDeclarationSyntax GenerateClassInfoType(ImmutableArray<string> implementedInterfaces)
{
const string vtablesField = "s_vtables";
const string vtablesLocal = "vtables";
const string detailsTempLocal = "details";
const string countIdentifier = "count";
var typeDeclaration = ClassDeclaration(ClassInfoTypeName)
.AddModifiers(
Token(SyntaxKind.FileKeyword),
Token(SyntaxKind.SealedKeyword),
Token(SyntaxKind.UnsafeKeyword))
.AddBaseListTypes(SimpleBaseType(TypeSyntaxes.IComExposedClass))
.AddMembers(
FieldDeclaration(
VariableDeclaration(
PointerType(TypeSyntaxes.System_Runtime_InteropServices_ComWrappers_ComInterfaceEntry),
SingletonSeparatedList(VariableDeclarator(vtablesField))))
.AddModifiers(
Token(SyntaxKind.PrivateKeyword),
Token(SyntaxKind.StaticKeyword),
Token(SyntaxKind.VolatileKeyword)));
List<StatementSyntax> vtableInitializationBlock = new()
{
// ComInterfaceEntry* vtables = (ComInterfaceEntry*)RuntimeHelpers.AllocateTypeAssociatedMemory(typeof(<ClassInfoTypeName>), sizeof(ComInterfaceEntry) * <numInterfaces>);
Declare(
PointerType(TypeSyntaxes.System_Runtime_InteropServices_ComWrappers_ComInterfaceEntry),
vtablesLocal,
CastExpression(
PointerType(TypeSyntaxes.System_Runtime_InteropServices_ComWrappers_ComInterfaceEntry),
MethodInvocation(
TypeSyntaxes.System_Runtime_CompilerServices_RuntimeHelpers,
IdentifierName("AllocateTypeAssociatedMemory"),
Argument(TypeOfExpression(IdentifierName(ClassInfoTypeName))),
Argument(
BinaryExpression(
SyntaxKind.MultiplyExpression,
SizeOfExpression(TypeSyntaxes.System_Runtime_InteropServices_ComWrappers_ComInterfaceEntry),
LiteralExpression(
SyntaxKind.NumericLiteralExpression,
Literal(implementedInterfaces.Length))))))),

// IIUnknownDerivedDetails details;
Declare(TypeSyntaxes.IIUnknownDerivedDetails, detailsTempLocal, initializeToDefault: false)
};
for (int i = 0; i < implementedInterfaces.Length; i++)
{
string ifaceName = implementedInterfaces[i];
writer.WriteLine($"file sealed unsafe class {ClassInfoTypeName} : global::System.Runtime.InteropServices.Marshalling.IComExposedClass");
writer.WriteLine('{');
writer.Indent++;
writer.WriteLine("private static volatile global::System.Runtime.InteropServices.ComWrappers.ComInterfaceEntry* s_vtables;");
sw.WriteLine();
writer.WriteLine("public static global::System.Runtime.InteropServices.ComWrappers.ComInterfaceEntry* GetComInterfaceEntries(out int count)");
writer.WriteLine('{');
writer.Indent++;
writer.WriteLine($"count = {implementedInterfaces.Length};");
writer.WriteLine("if (s_vtables == null)");
writer.WriteLine('{');
writer.Indent++;
writer.WriteLine($"global::System.Runtime.InteropServices.ComWrappers.ComInterfaceEntry* vtables = (global::System.Runtime.InteropServices.ComWrappers.ComInterfaceEntry*)global::System.Runtime.CompilerServices.RuntimeHelpers.AllocateTypeAssociatedMemory(typeof({ClassInfoTypeName}), sizeof(global::System.Runtime.InteropServices.ComWrappers.ComInterfaceEntry) * {implementedInterfaces.Length});");
writer.WriteLine("global::System.Runtime.InteropServices.Marshalling.IIUnknownDerivedDetails details;");
sw.WriteLine();
for (int i = 0; i < implementedInterfaces.Length; i++)
{
string ifaceName = implementedInterfaces[i];

// details = StrategyBasedComWrappers.DefaultIUnknownInterfaceDetailsStrategy.GetIUnknownDerivedDetails(typeof(<ifaceName>).TypeHandle);
vtableInitializationBlock.Add(
AssignmentStatement(
IdentifierName(detailsTempLocal),
MethodInvocation(
TypeSyntaxes.StrategyBasedComWrappers
.Dot(IdentifierName("DefaultIUnknownInterfaceDetailsStrategy")),
IdentifierName("GetIUnknownDerivedDetails"),
Argument(
MemberAccessExpression(SyntaxKind.SimpleMemberAccessExpression,
TypeOfExpression(ParseName(ifaceName)),
IdentifierName("TypeHandle"))))));
// vtable[i] = new() { IID = details.Iid, Vtable = details.ManagedVirtualMethodTable };
vtableInitializationBlock.Add(
AssignmentStatement(
IndexExpression(
IdentifierName(vtablesLocal),
Argument(IntLiteral(i))),
ImplicitObjectCreationExpression(
ArgumentList(),
InitializerExpression(SyntaxKind.ObjectInitializerExpression,
SeparatedList(
new ExpressionSyntax[]
{
AssignmentExpression(
SyntaxKind.SimpleAssignmentExpression,
IdentifierName("IID"),
IdentifierName(detailsTempLocal)
.Dot(IdentifierName("Iid"))),
AssignmentExpression(
SyntaxKind.SimpleAssignmentExpression,
IdentifierName("Vtable"),
CastExpression(
IdentifierName("nint"),
IdentifierName(detailsTempLocal)
.Dot(IdentifierName("ManagedVirtualMethodTable"))))
})))));
}
writer.WriteLine($"details = global::System.Runtime.InteropServices.Marshalling.StrategyBasedComWrappers.DefaultIUnknownInterfaceDetailsStrategy.GetIUnknownDerivedDetails(typeof({ifaceName}).TypeHandle);");
writer.WriteLine($"vtables[{i}] = new() {{ IID = details.Iid, Vtable = (nint)details.ManagedVirtualMethodTable }};");
sw.WriteLine();
}

// s_vtable = vtable;
vtableInitializationBlock.Add(
ExpressionStatement(
AssignmentExpression(SyntaxKind.SimpleAssignmentExpression,
IdentifierName(vtablesField),
IdentifierName(vtablesLocal))));
writer.WriteLine("s_vtables = vtables;");
writer.Indent--;
writer.WriteLine('}');
sw.WriteLine();
writer.WriteLine("return s_vtables;");
writer.Indent--;
writer.WriteLine('}');
writer.Indent--;
writer.WriteLine('}');

BlockSyntax getComInterfaceEntriesMethodBody = Block(
// count = <count>;
ExpressionStatement(
AssignmentExpression(SyntaxKind.SimpleAssignmentExpression,
IdentifierName(countIdentifier),
LiteralExpression(SyntaxKind.NumericLiteralExpression,
Literal(implementedInterfaces.Length)))),
// if (s_vtable == null)
// { initializer block }
IfStatement(
BinaryExpression(SyntaxKind.EqualsExpression,
IdentifierName(vtablesField),
LiteralExpression(SyntaxKind.NullLiteralExpression)),
Block(vtableInitializationBlock)),
// return s_vtable;
ReturnStatement(IdentifierName(vtablesField)));
sw.WriteLine();

typeDeclaration = typeDeclaration.AddMembers(
// public static unsafe ComWrappers.ComInterfaceDispatch* GetComInterfaceEntries(out int count)
// { body }
MethodDeclaration(
PointerType(
TypeSyntaxes.System_Runtime_InteropServices_ComWrappers_ComInterfaceEntry),
"GetComInterfaceEntries")
.AddParameterListParameters(
Parameter(Identifier(countIdentifier))
.WithType(PredefinedType(Token(SyntaxKind.IntKeyword)))
.AddModifiers(Token(SyntaxKind.OutKeyword)))
.WithBody(getComInterfaceEntriesMethodBody)
.AddModifiers(Token(SyntaxKind.PublicKeyword), Token(SyntaxKind.StaticKeyword)));
data.ContainingSyntaxContext.WriteToWithUnsafeModifier(writer, data.ClassSyntax, static (writer, classSyntax) =>
{
writer.WriteLine($"[global::System.Runtime.InteropServices.Marshalling.ComExposedClassAttribute<{ClassInfoTypeName}>]");
writer.WriteLine($"{string.Join(" ", classSyntax.Modifiers)} class {classSyntax.Identifier}{classSyntax.TypeParameters} {{ }}");
});

return typeDeclaration;
// Replace < and > with { and } to make valid hint names for generic types
string hintName = className.Replace('<', '{').Replace('>', '}');
context.AddSource(hintName, sw.ToString());
});
}
}
}
Original file line number Diff line number Diff line change
@@ -1,14 +1,15 @@
// Licensed to the .NET Foundation under one or more agreements.
// The .NET Foundation licenses this file to you under the MIT license.

using System;
using System.Collections.Immutable;
using System.Linq;
using Microsoft.CodeAnalysis;
using Microsoft.CodeAnalysis.CSharp.Syntax;

namespace Microsoft.Interop
{
internal sealed record ComClassInfo
internal sealed class ComClassInfo : IEquatable<ComClassInfo>
{
public string ClassName { get; init; }
public ContainingSyntaxContext ContainingSyntaxContext { get; init; }
Expand Down Expand Up @@ -54,6 +55,11 @@ public bool Equals(ComClassInfo? other)
&& ImplementedInterfacesNames.SequenceEqual(other.ImplementedInterfacesNames);
}

public override bool Equals(object obj)
{
return Equals(obj as ComClassInfo);
}

public override int GetHashCode()
{
return HashCode.Combine(ClassName, ContainingSyntaxContext, ImplementedInterfacesNames);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
// The .NET Foundation licenses this file to you under the MIT license.

using System;
using System.CodeDom.Compiler;
using System.Collections.Immutable;
using System.Linq;
using System.Text;
Expand Down Expand Up @@ -149,5 +150,53 @@ public MemberDeclarationSyntax WrapMembersInContainingSyntaxWithUnsafeModifier(p
}
return wrappedMember;
}

public void WriteToWithUnsafeModifier<TState>(IndentedTextWriter writer, TState writeMembersState, Action<IndentedTextWriter, TState> writeMembers)
{
if (ContainingNamespace is not null)
{
writer.WriteLine($"namespace {ContainingNamespace}");
writer.WriteLine('{');
writer.Indent++;
}

// When creating syntax we walk from most nested type to least nested and then enclose this chain in a namespace.
// With string writing things are exactly opposite: we are starting with a namespace and then print headers of types
// from least nested to most nested one. Since syntax model was the original one we have containing syntaxes stored as
// most convenient for it, so for string writing we have to walk them in the reverse order. When we eventually port
// our source generation to string writing we should reverse the order of elements for the convenience of that model instead.
for (int i = ContainingSyntax.Length - 1; i >= 0; i--)
{
ContainingSyntax syntax = ContainingSyntax[i];

string declarationKeyword = syntax.TypeKind switch
{
SyntaxKind.ClassDeclaration => "class",
SyntaxKind.StructDeclaration => "struct",
SyntaxKind.InterfaceDeclaration => "interface",
SyntaxKind.RecordDeclaration => "record",
SyntaxKind.RecordStructDeclaration => "record struct",
_ => throw new UnreachableException(),
};

writer.WriteLine($"{string.Join(" ", syntax.Modifiers.AddToModifiers(SyntaxKind.UnsafeKeyword))} {declarationKeyword} {syntax.Identifier}{syntax.TypeParameters}");
writer.WriteLine('{');
writer.Indent++;
}

writeMembers(writer, writeMembersState);

for (int i = 0; i < ContainingSyntax.Length; i++)
{
writer.Indent--;
writer.WriteLine('}');
}

if (ContainingNamespace is not null)
{
writer.Indent--;
writer.WriteLine('}');
}
}
}
}
Loading
Loading