diff --git a/src/Core/gen/Eventuous.Shared.Generators/Constants.cs b/src/Core/gen/Eventuous.Shared.Generators/Constants.cs index 42b3fd4d..2c8db157 100644 --- a/src/Core/gen/Eventuous.Shared.Generators/Constants.cs +++ b/src/Core/gen/Eventuous.Shared.Generators/Constants.cs @@ -3,8 +3,20 @@ namespace Eventuous.Shared.Generators; +/// +/// Constants used for type and member lookups. +/// These are primarily used for symbol resolution via Compilation.GetTypeByMetadataName() +/// and as fallback when symbol-based comparison is not available. +/// The generators now prefer symbol-based comparisons using SymbolEqualityComparer, +/// which are refactoring-safe and won't break when types are renamed. +/// internal static class Constants { - public const string BaseNamespace = "Eventuous"; + /// Base namespace for Eventuous types. + public const string BaseNamespace = "Eventuous"; + + /// Name of the EventType attribute class (without namespace). public const string EventTypeAttribute = "EventTypeAttribute"; - public const string EventTypeAttrFqcn = $"{BaseNamespace}.{EventTypeAttribute}"; + + /// Fully qualified name of the EventType attribute for GetTypeByMetadataName(). + public const string EventTypeAttrFqcn = $"{BaseNamespace}.{EventTypeAttribute}"; } diff --git a/src/Core/gen/Eventuous.Shared.Generators/EventUsageAnalyzer.cs b/src/Core/gen/Eventuous.Shared.Generators/EventUsageAnalyzer.cs index 02c22dea..ae5e16b3 100644 --- a/src/Core/gen/Eventuous.Shared.Generators/EventUsageAnalyzer.cs +++ b/src/Core/gen/Eventuous.Shared.Generators/EventUsageAnalyzer.cs @@ -32,11 +32,43 @@ public override void Initialize(AnalysisContext context) { context.ConfigureGeneratedCodeAnalysis(GeneratedCodeAnalysisFlags.None); context.EnableConcurrentExecution(); - context.RegisterOperationAction(AnalyzeInvocation, OperationKind.Invocation); - context.RegisterOperationAction(AnalyzeObjectCreation, OperationKind.ObjectCreation); + context.RegisterCompilationStartAction(compilationContext => { + // Resolve well-known type symbols once per compilation + var compilation = compilationContext.Compilation; + var knownTypes = new KnownTypeSymbols(compilation); + + compilationContext.RegisterOperationAction(ctx => AnalyzeInvocation(ctx, knownTypes), OperationKind.Invocation); + compilationContext.RegisterOperationAction(ctx => AnalyzeObjectCreation(ctx, knownTypes), OperationKind.ObjectCreation); + }); } - static ImmutableHashSet GetExplicitRegistrations(OperationAnalysisContext ctx) { + /// + /// Cache of well-known type symbols resolved from the compilation. + /// This makes the analyzer refactoring-safe by using symbol comparison instead of string matching. + /// + sealed class KnownTypeSymbols { + public INamedTypeSymbol? EventTypeAttribute { get; } + public INamedTypeSymbol? TypeMapper { get; } + public INamedTypeSymbol? Aggregate { get; } + public INamedTypeSymbol? State { get; } + public INamedTypeSymbol? CommandHandlerBuilder { get; } + public INamedTypeSymbol? IDefineExecution { get; } + public INamedTypeSymbol? ICommandHandlerBuilder { get; } + public INamedTypeSymbol? IDefineStoreOrExecution { get; } + + public KnownTypeSymbols(Compilation compilation) { + EventTypeAttribute = compilation.GetTypeByMetadataName(EventTypeAttrFqcn); + TypeMapper = compilation.GetTypeByMetadataName($"{BaseNamespace}.TypeMapper"); + Aggregate = compilation.GetTypeByMetadataName($"{BaseNamespace}.Aggregate`1"); + State = compilation.GetTypeByMetadataName($"{BaseNamespace}.State`1"); + CommandHandlerBuilder = compilation.GetTypeByMetadataName($"{BaseNamespace}.CommandHandlerBuilder"); + IDefineExecution = compilation.GetTypeByMetadataName($"{BaseNamespace}.IDefineExecution"); + ICommandHandlerBuilder = compilation.GetTypeByMetadataName($"{BaseNamespace}.ICommandHandlerBuilder"); + IDefineStoreOrExecution = compilation.GetTypeByMetadataName($"{BaseNamespace}.IDefineStoreOrExecution"); + } + } + + static ImmutableHashSet GetExplicitRegistrations(OperationAnalysisContext ctx, KnownTypeSymbols knownTypes) { var model = ctx.Operation.SemanticModel; if (model == null) return ImmutableHashSet.Empty; var root = ctx.Operation.Syntax.SyntaxTree.GetRoot(); @@ -45,12 +77,18 @@ static ImmutableHashSet GetExplicitRegistrations(OperationAnalysisC foreach (var invSyntax in root.DescendantNodes().OfType()) { if (model.GetOperation(invSyntax) is not IInvocationOperation op) continue; var m = op.TargetMethod; + + // Use symbol comparison when available, fall back to string comparison if (m.Name != "AddType") continue; var ct = m.ContainingType; if (ct == null) continue; - if (ct.Name != "TypeMapper") continue; - var ns = ct.ContainingNamespace?.ToDisplayString(); - if (ns != BaseNamespace) continue; + + // Prefer symbol comparison (refactoring-safe) + var isTypeMapper = knownTypes.TypeMapper != null + ? SymbolEqualityComparer.Default.Equals(ct, knownTypes.TypeMapper) + : ct.Name == "TypeMapper" && ct.ContainingNamespace?.ToDisplayString() == BaseNamespace; + + if (!isTypeMapper) continue; if (m.TypeArguments.Length == 1) { set.Add(m.TypeArguments[0]); @@ -65,12 +103,12 @@ static ImmutableHashSet GetExplicitRegistrations(OperationAnalysisC return set.ToImmutable(); } - static bool IsExplicitlyRegistered(ITypeSymbol type, OperationAnalysisContext ctx) { - var set = GetExplicitRegistrations(ctx); + static bool IsExplicitlyRegistered(ITypeSymbol type, OperationAnalysisContext ctx, KnownTypeSymbols knownTypes) { + var set = GetExplicitRegistrations(ctx, knownTypes); return set.Contains(type); } - static void AnalyzeInvocation(OperationAnalysisContext ctx) { + static void AnalyzeInvocation(OperationAnalysisContext ctx, KnownTypeSymbols knownTypes) { if (ctx.Operation is not IInvocationOperation inv) return; var method = inv.TargetMethod; @@ -80,10 +118,10 @@ static void AnalyzeInvocation(OperationAnalysisContext ctx) { case { Name: "Apply", TypeArguments.Length: 1, Parameters.Length: 1 }: { var containing = method.ContainingType; - if (IsAggregate(containing)) { + if (IsAggregate(containing, knownTypes)) { var eventType = method.TypeArguments[0]; - if (IsConcreteEvent(eventType) && !HasEventTypeAttribute(eventType) && !IsExplicitlyRegistered(eventType, ctx)) { + if (IsConcreteEvent(eventType) && !HasEventTypeAttribute(eventType, knownTypes) && !IsExplicitlyRegistered(eventType, ctx, knownTypes)) { ctx.ReportDiagnostic(Diagnostic.Create(MissingEventTypeAttribute, inv.Syntax.GetLocation(), eventType.ToDisplayString())); } } @@ -91,7 +129,7 @@ static void AnalyzeInvocation(OperationAnalysisContext ctx) { return; } // Case 1b: State.When(...) invocations where an event instance is passed - case { Name: "When", Parameters.Length: 1 } when IsState(method.ContainingType): { + case { Name: "When", Parameters.Length: 1 } when IsState(method.ContainingType, knownTypes): { var arg = inv.Arguments.Length > 0 ? inv.Arguments[0].Value : null; ITypeSymbol? eventType = null; @@ -105,7 +143,7 @@ static void AnalyzeInvocation(OperationAnalysisContext ctx) { _ => arg?.Type }; - if (eventType != null && IsConcreteEvent(eventType) && !HasEventTypeAttribute(eventType) && !IsExplicitlyRegistered(eventType, ctx)) { + if (eventType != null && IsConcreteEvent(eventType) && !HasEventTypeAttribute(eventType, knownTypes) && !IsExplicitlyRegistered(eventType, ctx, knownTypes)) { var location = arg?.Syntax.GetLocation() ?? inv.Syntax.GetLocation(); ctx.ReportDiagnostic(Diagnostic.Create(MissingEventTypeAttribute, location, eventType.ToDisplayString())); } @@ -113,10 +151,10 @@ static void AnalyzeInvocation(OperationAnalysisContext ctx) { return; } // Case 1c: State.On(...) handler registrations - case { Name: "On", TypeArguments.Length: 1 } when IsState(method.ContainingType): { + case { Name: "On", TypeArguments.Length: 1 } when IsState(method.ContainingType, knownTypes): { var eventType = method.TypeArguments[0]; - if (IsConcreteEvent(eventType) && !HasEventTypeAttribute(eventType) && !IsExplicitlyRegistered(eventType, ctx)) { + if (IsConcreteEvent(eventType) && !HasEventTypeAttribute(eventType, knownTypes) && !IsExplicitlyRegistered(eventType, ctx, knownTypes)) { ctx.ReportDiagnostic(Diagnostic.Create(MissingEventTypeAttribute, inv.Syntax.GetLocation(), eventType.ToDisplayString())); } @@ -127,7 +165,7 @@ static void AnalyzeInvocation(OperationAnalysisContext ctx) { // Case 2: Functional service: Act/ActAsync handlers if (method.Name is "Act" or "ActAsync") { // Heuristic: only consider the overloads that accept a delegate and are defined in CommandHandlerBuilder interfaces/classes - if (!IsFunctionalServiceAct(method)) return; + if (!IsFunctionalServiceAct(method, knownTypes)) return; foreach (var value in inv.Arguments.Select(arg => arg.Value)) { switch (value) { @@ -135,11 +173,11 @@ static void AnalyzeInvocation(OperationAnalysisContext ctx) { continue; // If the argument is a lambda, analyze its body for created event instances case IAnonymousFunctionOperation lambda: - AnalyzeDelegateBodyForEventCreations(ctx, lambda.Body); + AnalyzeDelegateBodyForEventCreations(ctx, lambda.Body, knownTypes); break; case IConversionOperation { Operand: IAnonymousFunctionOperation lambdaConv }: - AnalyzeDelegateBodyForEventCreations(ctx, lambdaConv.Body); + AnalyzeDelegateBodyForEventCreations(ctx, lambdaConv.Body, knownTypes); break; } @@ -147,21 +185,21 @@ static void AnalyzeInvocation(OperationAnalysisContext ctx) { } } - static void AnalyzeDelegateBodyForEventCreations(OperationAnalysisContext ctx, IBlockOperation? body) { + static void AnalyzeDelegateBodyForEventCreations(OperationAnalysisContext ctx, IBlockOperation? body, KnownTypeSymbols knownTypes) { if (body is null) return; foreach (var op in body.Descendants()) { if (op is IObjectCreationOperation create) { var created = create.Type; - if (created != null && IsConcreteEvent(created) && !HasEventTypeAttribute(created) && !IsExplicitlyRegistered(created, ctx)) { + if (created != null && IsConcreteEvent(created) && !HasEventTypeAttribute(created, knownTypes) && !IsExplicitlyRegistered(created, ctx, knownTypes)) { ctx.ReportDiagnostic(Diagnostic.Create(MissingEventTypeAttribute, create.Syntax.GetLocation(), created.ToDisplayString())); } } } } - static void AnalyzeObjectCreation(OperationAnalysisContext ctx) { + static void AnalyzeObjectCreation(OperationAnalysisContext ctx, KnownTypeSymbols knownTypes) { // Global safety net for method groups passed into Act where we couldn't traverse the body via the invocation site. // If the object creation is within a method that appears to be an Act handler (returns NewEvents/ IEnumerable), warn. if (ctx.Operation is not IObjectCreationOperation create) return; @@ -175,7 +213,7 @@ static void AnalyzeObjectCreation(OperationAnalysisContext ctx) { if (method == null) return; if (ReturnsNewEvents(method)) { - if (!HasEventTypeAttribute(created) && !IsExplicitlyRegistered(created, ctx)) { + if (!HasEventTypeAttribute(created, knownTypes) && !IsExplicitlyRegistered(created, ctx, knownTypes)) { ctx.ReportDiagnostic(Diagnostic.Create(MissingEventTypeAttribute, create.Syntax.GetLocation(), created.ToDisplayString())); } } @@ -218,29 +256,51 @@ static bool IsIEnumerableOfObject(INamedTypeSymbol type) { return false; } - static bool IsAggregate(INamedTypeSymbol? type) { + static bool IsAggregate(INamedTypeSymbol? type, KnownTypeSymbols knownTypes) { if (type == null) return false; // Walk base types to check if it derives from Eventuous.Aggregate<> for (var t = type; t != null; t = t.BaseType) { - if (t is { Name: "Aggregate", Arity: 1 } && t.ContainingNamespace.ToDisplayString() == BaseNamespace) return true; + // Prefer symbol comparison (refactoring-safe) + if (knownTypes.Aggregate != null) { + if (SymbolEqualityComparer.Default.Equals(t.OriginalDefinition, knownTypes.Aggregate)) { + return true; + } + } + else { + // Fallback to string comparison + if (t is { Name: "Aggregate", Arity: 1 } && t.ContainingNamespace.ToDisplayString() == BaseNamespace) { + return true; + } + } } return false; } - static bool IsState(INamedTypeSymbol? type) { + static bool IsState(INamedTypeSymbol? type, KnownTypeSymbols knownTypes) { if (type == null) return false; // Walk base types to check if it derives from Eventuous.State<> for (var t = type; t != null; t = t.BaseType) { - if (t is { Name: "State", Arity: 1 } && t.ContainingNamespace.ToDisplayString() == BaseNamespace) return true; + // Prefer symbol comparison (refactoring-safe) + if (knownTypes.State != null) { + if (SymbolEqualityComparer.Default.Equals(t.OriginalDefinition, knownTypes.State)) { + return true; + } + } + else { + // Fallback to string comparison + if (t is { Name: "State", Arity: 1 } && t.ContainingNamespace.ToDisplayString() == BaseNamespace) { + return true; + } + } } return false; } - static bool IsFunctionalServiceAct(IMethodSymbol method) { + static bool IsFunctionalServiceAct(IMethodSymbol method, KnownTypeSymbols knownTypes) { // We only care about the Act methods from CommandHandlerBuilder and the related interfaces in Eventuous namespace if (method.Name is not ("Act" or "ActAsync")) return false; @@ -248,19 +308,34 @@ static bool IsFunctionalServiceAct(IMethodSymbol method) { if (containing == null) return false; - var ns = containing.ContainingNamespace?.ToDisplayString(); + // Prefer symbol comparison (refactoring-safe) + if (knownTypes.CommandHandlerBuilder != null || knownTypes.IDefineExecution != null || + knownTypes.ICommandHandlerBuilder != null || knownTypes.IDefineStoreOrExecution != null) { + return SymbolEqualityComparer.Default.Equals(containing, knownTypes.CommandHandlerBuilder) || + SymbolEqualityComparer.Default.Equals(containing, knownTypes.IDefineExecution) || + SymbolEqualityComparer.Default.Equals(containing, knownTypes.ICommandHandlerBuilder) || + SymbolEqualityComparer.Default.Equals(containing, knownTypes.IDefineStoreOrExecution); + } + // Fallback to string comparison + var ns = containing.ContainingNamespace?.ToDisplayString(); if (ns != BaseNamespace) return false; - // Simple name checks return containing.Name is "CommandHandlerBuilder" or "IDefineExecution" or "ICommandHandlerBuilder" or "IDefineStoreOrExecution"; } static bool IsConcreteEvent(ITypeSymbol type) => type.TypeKind is TypeKind.Class or TypeKind.Struct; - static bool HasEventTypeAttribute(ITypeSymbol type) - => (from attrClass in type.GetAttributes().Select(a => a.AttributeClass).OfType() - let name = attrClass.ToDisplayString() - where name == EventTypeAttrFqcn || attrClass.Name is EventTypeAttribute - select attrClass).Any(); + static bool HasEventTypeAttribute(ITypeSymbol type, KnownTypeSymbols knownTypes) { + // Prefer symbol comparison (refactoring-safe) + if (knownTypes.EventTypeAttribute != null) { + return type.GetAttributes().Any(a => SymbolEqualityComparer.Default.Equals(a.AttributeClass, knownTypes.EventTypeAttribute)); + } + + // Fallback to string comparison + return (from attrClass in type.GetAttributes().Select(a => a.AttributeClass).OfType() + let name = attrClass.ToDisplayString() + where name == EventTypeAttrFqcn || attrClass.Name is EventTypeAttribute + select attrClass).Any(); + } } diff --git a/src/Core/gen/Eventuous.Shared.Generators/TypeMappingsGenerator.cs b/src/Core/gen/Eventuous.Shared.Generators/TypeMappingsGenerator.cs index d68b13e8..b9d52b4c 100644 --- a/src/Core/gen/Eventuous.Shared.Generators/TypeMappingsGenerator.cs +++ b/src/Core/gen/Eventuous.Shared.Generators/TypeMappingsGenerator.cs @@ -17,14 +17,24 @@ namespace Eventuous.Shared.Generators; [Generator(LanguageNames.CSharp)] public sealed class TypeMappingsGenerator : IIncrementalGenerator { public void Initialize(IncrementalGeneratorInitializationContext context) { + // Resolve the EventTypeAttribute symbol from the compilation + var eventTypeAttributeSymbol = context.CompilationProvider + .Select(static (c, _) => c.GetTypeByMetadataName(EventTypeAttrFqcn)); + var syntaxCandidates = context.SyntaxProvider .CreateSyntaxProvider(IsCandidate, Transform) .Where(static t => t is not null) + .Combine(eventTypeAttributeSymbol) + .Select(static (pair, _) => pair.Left.HasValue ? TransformWithSymbol(pair.Left.Value, pair.Right) : null) + .Where(static t => t is not null) .Select(static (t, _) => t!) .Collect(); // Additionally, discover [EventType] on symbols from referenced assemblies (metadata) via the Compilation model - var symbolCandidates = context.CompilationProvider.Select(static (c, _) => DiscoverFromCompilation(c)); + var symbolCandidates = eventTypeAttributeSymbol + .Select(static (symbol, _) => (Symbol: symbol, Compilation: (Compilation?)null)) + .Combine(context.CompilationProvider) + .Select(static (pair, _) => DiscoverFromCompilation(pair.Right, pair.Left.Symbol)); var mergedCandidates = syntaxCandidates .Combine(symbolCandidates) @@ -46,7 +56,17 @@ sealed record Mapping { public string EventTypeName { get; set; } = null!; } - static Mapping? Transform(GeneratorSyntaxContext ctx, CancellationToken _) { + readonly struct TransformInput { + public GeneratorSyntaxContext Context { get; } + public INamedTypeSymbol? Symbol { get; } + + public TransformInput(GeneratorSyntaxContext context, INamedTypeSymbol? symbol) { + Context = context; + Symbol = symbol; + } + } + + static TransformInput? Transform(GeneratorSyntaxContext ctx, CancellationToken _) { // Get the declared symbol if (ctx.Node is not TypeDeclarationSyntax tds) return null; @@ -55,13 +75,20 @@ sealed record Mapping { // Only concrete classes/records if (symbol?.TypeKind is not (TypeKind.Class or TypeKind.Struct)) return null; - // Look for EventTypeAttribute - var attr = GetEventTypeAttribute(symbol); + return new TransformInput(ctx, symbol); + } + + static Mapping? TransformWithSymbol(TransformInput input, INamedTypeSymbol? eventTypeAttributeSymbol) { + var symbol = input.Symbol; + if (symbol is null) return null; + + // Look for EventTypeAttribute using symbol comparison + var attr = GetEventTypeAttribute(symbol, eventTypeAttributeSymbol); if (attr is null) return null; // Try to get the constructor argument (event type name) - var evtName = TryGetEventTypeName(attr) ?? string.Empty; + var evtName = TryGetEventTypeName(attr, eventTypeAttributeSymbol) ?? string.Empty; // Use fully-qualified global:: name for the type var typeName = MakeGlobal(symbol.ToDisplayString(SymbolDisplayFormat.FullyQualifiedFormat)); @@ -69,22 +96,33 @@ sealed record Mapping { return new() { FullyQualifiedType = typeName, EventTypeName = evtName }; } - static AttributeData? GetEventTypeAttribute(ISymbol symbol) { - // ReSharper disable once ForeachCanBeConvertedToQueryUsingAnotherGetEnumerator + static AttributeData? GetEventTypeAttribute(ISymbol symbol, INamedTypeSymbol? eventTypeAttributeSymbol) { + // If we have the resolved symbol, use symbol comparison (refactoring-safe) + if (eventTypeAttributeSymbol is not null) { + foreach (var a in symbol.GetAttributes()) { + if (SymbolEqualityComparer.Default.Equals(a.AttributeClass, eventTypeAttributeSymbol)) { + return a; + } + } + } + + // Fallback to string-based comparison if symbol resolution failed + // This handles cases where EventTypeAttribute is in a different assembly not yet compiled foreach (var a in symbol.GetAttributes()) { var attrClass = a.AttributeClass; - if (attrClass == null) continue; - var name = attrClass.ToDisplayString(); - - if (name == EventTypeAttrFqcn || attrClass.Name is EventTypeAttribute) return a; + var fullName = attrClass.ToDisplayString(); + if (fullName == EventTypeAttrFqcn || + (attrClass.Name == EventTypeAttribute && attrClass.ContainingNamespace?.ToDisplayString() == BaseNamespace)) { + return a; + } } return null; } - static string? TryGetEventTypeName(AttributeData attr) { + static string? TryGetEventTypeName(AttributeData attr, INamedTypeSymbol? eventTypeAttributeSymbol) { // Prefer the first constructor argument if it is a constant string if (attr.ConstructorArguments.Length > 0) { var arg = attr.ConstructorArguments[0]; @@ -93,6 +131,23 @@ sealed record Mapping { } // Also check named argument "EventType" + // Note: NamedArguments uses string keys, not symbols, so we still use string comparison here + // However, we verify the property exists on the attribute type when we have the symbol + if (eventTypeAttributeSymbol is not null) { + var hasProperty = eventTypeAttributeSymbol.GetMembers(EventTypeAttribute) + .OfType() + .Any(); + + if (hasProperty) { + foreach (var kv in attr.NamedArguments) { + if (kv.Key == EventTypeAttribute && kv.Value.Value is string s) { + return s; + } + } + } + } + + // Fallback to string-based comparison when symbol is not available foreach (var kv in attr.NamedArguments) { if (kv is { Key: EventTypeAttribute, Value.Value: string s }) return s; } @@ -100,7 +155,7 @@ sealed record Mapping { return null; } - static ImmutableArray DiscoverFromCompilation(Compilation compilation) { + static ImmutableArray DiscoverFromCompilation(Compilation compilation, INamedTypeSymbol? eventTypeAttributeSymbol) { var builder = ImmutableArray.CreateBuilder(); // Current assembly @@ -114,10 +169,10 @@ static ImmutableArray DiscoverFromCompilation(Compilation compilation) return builder.ToImmutable(); void ProcessType(INamedTypeSymbol type) { - var attr = GetEventTypeAttribute(type); + var attr = GetEventTypeAttribute(type, eventTypeAttributeSymbol); if (attr is not null) { - var evtName = TryGetEventTypeName(attr) ?? string.Empty; + var evtName = TryGetEventTypeName(attr, eventTypeAttributeSymbol) ?? string.Empty; var typeName = MakeGlobal(type.ToDisplayString(SymbolDisplayFormat.FullyQualifiedFormat)); builder.Add(new() { FullyQualifiedType = typeName, EventTypeName = evtName }); } diff --git a/src/Core/gen/Eventuous.Subscriptions.Generators/ConsumeContextConverterGenerator.cs b/src/Core/gen/Eventuous.Subscriptions.Generators/ConsumeContextConverterGenerator.cs index 194d9a0f..1e7ce90c 100644 --- a/src/Core/gen/Eventuous.Subscriptions.Generators/ConsumeContextConverterGenerator.cs +++ b/src/Core/gen/Eventuous.Subscriptions.Generators/ConsumeContextConverterGenerator.cs @@ -12,11 +12,19 @@ namespace Eventuous.Subscriptions.Generators; public sealed class ConsumeContextConverterGenerator : IIncrementalGenerator { const string InterfaceNamespace = "Eventuous.Subscriptions.Context"; const string InterfaceName = "IMessageConsumeContext"; + const string InterfaceFqn = $"{InterfaceNamespace}.{InterfaceName}`1"; public void Initialize(IncrementalGeneratorInitializationContext context) { + // Resolve the IMessageConsumeContext<> symbol from the compilation + var messageConsumeContextSymbol = context.CompilationProvider + .Select(static (c, _) => c.GetTypeByMetadataName(InterfaceFqn)); + var candidateTypes = context.SyntaxProvider .CreateSyntaxProvider(IsPotentialUsage, Transform) .Where(static t => t is not null) + .Combine(messageConsumeContextSymbol) + .Select(static (pair, _) => TransformWithSymbol(pair.Left, pair.Right)) + .Where(static t => t is not null) .Select(static (t, _) => t!) .Collect(); @@ -35,16 +43,23 @@ static bool IsPotentialUsage(SyntaxNode node, CancellationToken _) { }; } - static string? Transform(GeneratorSyntaxContext ctx, CancellationToken _) { + static GeneratorSyntaxContext? Transform(GeneratorSyntaxContext ctx, CancellationToken _) { + // Just return the context for further processing + return ctx; + } + + static string? TransformWithSymbol(GeneratorSyntaxContext? ctx, INamedTypeSymbol? messageConsumeContextSymbol) { + if (ctx is not { } context) return null; + // Explicit generic type usage: IMessageConsumeContext - if (ctx.Node is GenericNameSyntax g) { + if (context.Node is GenericNameSyntax g) { // Case 1: explicit IMessageConsumeContext - var symbol = ctx.SemanticModel.GetSymbolInfo(g).Symbol as INamedTypeSymbol - ?? ctx.SemanticModel.GetTypeInfo(g).Type as INamedTypeSymbol; + var symbol = context.SemanticModel.GetSymbolInfo(g).Symbol as INamedTypeSymbol + ?? context.SemanticModel.GetTypeInfo(g).Type as INamedTypeSymbol; if (symbol != null) { var def = symbol.OriginalDefinition; - if (IsTargetInterface(def) && symbol.TypeArguments.Length == 1) { + if (IsTargetInterface(def, messageConsumeContextSymbol) && symbol.TypeArguments.Length == 1) { var arg = symbol.TypeArguments[0]; return GetTypeSyntax(arg); } @@ -55,7 +70,7 @@ static bool IsPotentialUsage(SyntaxNode node, CancellationToken _) { // Try to get T from the generic method symbol On(...) var inv = g.Parent as InvocationExpressionSyntax ?? g.Parent?.Parent as InvocationExpressionSyntax; if (inv != null) { - var symbolInfo = ctx.SemanticModel.GetSymbolInfo(inv).Symbol; + var symbolInfo = context.SemanticModel.GetSymbolInfo(inv).Symbol; var method = symbolInfo as IMethodSymbol; if (method?.TypeArguments.Length == 1 && ShouldTreatGenericOnAsEvent(method)) { var tArg = method.TypeArguments[0]; @@ -67,13 +82,13 @@ static bool IsPotentialUsage(SyntaxNode node, CancellationToken _) { } // Qualified explicit usage: Namespace.IMessageConsumeContext - if (ctx.Node is QualifiedNameSyntax { Right: GenericNameSyntax g2 }) { - var symbol = ctx.SemanticModel.GetSymbolInfo(g2).Symbol as INamedTypeSymbol - ?? ctx.SemanticModel.GetTypeInfo(g2).Type as INamedTypeSymbol; + if (context.Node is QualifiedNameSyntax { Right: GenericNameSyntax g2 }) { + var symbol = context.SemanticModel.GetSymbolInfo(g2).Symbol as INamedTypeSymbol + ?? context.SemanticModel.GetTypeInfo(g2).Type as INamedTypeSymbol; if (symbol != null) { var def = symbol.OriginalDefinition; - if (IsTargetInterface(def) && symbol.TypeArguments.Length == 1) { + if (IsTargetInterface(def, messageConsumeContextSymbol) && symbol.TypeArguments.Length == 1) { var arg = symbol.TypeArguments[0]; return GetTypeSyntax(arg); } @@ -81,13 +96,13 @@ static bool IsPotentialUsage(SyntaxNode node, CancellationToken _) { } // Implicit usage via lambda parameter type inference - if (ctx.Node is LambdaExpressionSyntax lambda) { - var typeInfo = ctx.SemanticModel.GetTypeInfo(lambda); + if (context.Node is LambdaExpressionSyntax lambda) { + var typeInfo = context.SemanticModel.GetTypeInfo(lambda); var delegateType = typeInfo.ConvertedType as INamedTypeSymbol; var invoke = delegateType?.DelegateInvokeMethod; if (invoke is not null) { foreach (var p in invoke.Parameters) { - if (TryExtractTypeArgFromIMessageConsumeContext(p.Type, out var typeArg)) { + if (TryExtractTypeArgFromIMessageConsumeContext(p.Type, messageConsumeContextSymbol, out var typeArg)) { return GetTypeSyntax(typeArg); } } @@ -103,8 +118,16 @@ static string GetTypeSyntax(ITypeSymbol symbol) { return name.StartsWith("global::", StringComparison.Ordinal) ? name : $"global::{name}"; } - static bool IsTargetInterface(INamedTypeSymbol def) => - def is { Arity: 1, Name: InterfaceName } && def.ContainingNamespace?.ToDisplayString() == InterfaceNamespace; + static bool IsTargetInterface(INamedTypeSymbol def, INamedTypeSymbol? messageConsumeContextSymbol) { + // Prefer symbol comparison (refactoring-safe) + if (messageConsumeContextSymbol is not null) { + return SymbolEqualityComparer.Default.Equals(def, messageConsumeContextSymbol); + } + + // Fallback to string-based comparison + return def is { Arity: 1, Name: InterfaceName } && + def.ContainingNamespace?.ToDisplayString() == InterfaceNamespace; + } static bool ShouldTreatGenericOnAsEvent(IMethodSymbol method) { if (method is not { Name: "On" }) return false; @@ -116,10 +139,13 @@ static bool ShouldTreatGenericOnAsEvent(IMethodSymbol method) { return paramName.IndexOf("Event", StringComparison.OrdinalIgnoreCase) >= 0; } - static bool TryExtractTypeArgFromIMessageConsumeContext(ITypeSymbol type, out ITypeSymbol typeArg) { + static bool TryExtractTypeArgFromIMessageConsumeContext( + ITypeSymbol type, + INamedTypeSymbol? messageConsumeContextSymbol, + out ITypeSymbol typeArg) { if (type is INamedTypeSymbol named) { var def = named.OriginalDefinition; - if (IsTargetInterface(def) && named.TypeArguments.Length == 1) { + if (IsTargetInterface(def, messageConsumeContextSymbol) && named.TypeArguments.Length == 1) { typeArg = named.TypeArguments[0]; return true; }