diff --git a/Source/Mockolate.Analyzers/Mockolate.Analyzers.csproj b/Source/Mockolate.Analyzers/Mockolate.Analyzers.csproj
index 7165d003..ae7c3f2c 100644
--- a/Source/Mockolate.Analyzers/Mockolate.Analyzers.csproj
+++ b/Source/Mockolate.Analyzers/Mockolate.Analyzers.csproj
@@ -14,6 +14,7 @@
+
@@ -40,4 +41,8 @@
+
+
+
+
diff --git a/Source/Mockolate.Analyzers/Polyfills/NotNullWhenAttribute.cs b/Source/Mockolate.Analyzers/Polyfills/NotNullWhenAttribute.cs
new file mode 100644
index 00000000..a39569f5
--- /dev/null
+++ b/Source/Mockolate.Analyzers/Polyfills/NotNullWhenAttribute.cs
@@ -0,0 +1,71 @@
+#region License
+// MIT License
+//
+// Copyright (c) Manuel Römer
+//
+// Permission is hereby granted, free of charge, to any person obtaining a copy
+// of this software and associated documentation files (the "Software"), to deal
+// in the Software without restriction, including without limitation the rights
+// to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
+// copies of the Software, and to permit persons to whom the Software is
+// furnished to do so, subject to the following conditions:
+//
+// The above copyright notice and this permission notice shall be included in all
+// copies or substantial portions of the Software.
+//
+// THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+// IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+// FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+// AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+// LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+// OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
+// SOFTWARE.
+#endregion
+
+#if !NULLABLE_ATTRIBUTES_DISABLE
+#nullable enable
+#pragma warning disable
+
+namespace System.Diagnostics.CodeAnalysis
+{
+ using global::System;
+
+#if DEBUG
+ ///
+ /// Specifies that when a method returns ,
+ /// the parameter will not be even if the corresponding type allows it.
+ ///
+#endif
+ [AttributeUsage(AttributeTargets.Parameter, Inherited = false)]
+#if !NULLABLE_ATTRIBUTES_INCLUDE_IN_CODE_COVERAGE
+ [ExcludeFromCodeCoverage, DebuggerNonUserCode]
+#endif
+ internal sealed class NotNullWhenAttribute : Attribute
+ {
+#if DEBUG
+ ///
+ /// Gets the return value condition.
+ /// If the method returns this value, the associated parameter will not be .
+ ///
+#endif
+ public bool ReturnValue { get; }
+
+#if DEBUG
+ ///
+ /// Initializes the attribute with the specified return value condition.
+ ///
+ ///
+ /// The return value condition.
+ /// If the method returns this value, the associated parameter will not be .
+ ///
+#endif
+ public NotNullWhenAttribute(bool returnValue)
+ {
+ ReturnValue = returnValue;
+ }
+ }
+}
+
+#pragma warning restore
+#nullable restore
+#endif // NULLABLE_ATTRIBUTES_DISABLE
diff --git a/Source/Mockolate.SourceGenerators/Entities/Method.cs b/Source/Mockolate.SourceGenerators/Entities/Method.cs
index a400f278..6b169e38 100644
--- a/Source/Mockolate.SourceGenerators/Entities/Method.cs
+++ b/Source/Mockolate.SourceGenerators/Entities/Method.cs
@@ -131,13 +131,8 @@ public bool Equals(Method? x, Method? y)
}
// Compare parameters ignoring nullability annotations
- MethodParameter[]? xParams = x.Parameters.AsArray();
- MethodParameter[]? yParams = y.Parameters.AsArray();
-
- if (xParams is null || yParams is null)
- {
- return xParams is null && yParams is null;
- }
+ MethodParameter[] xParams = x.Parameters.AsArray()!;
+ MethodParameter[] yParams = y.Parameters.AsArray()!;
for (int i = 0; i < xParams.Length; i++)
{
diff --git a/Source/Mockolate.SourceGenerators/Mockolate.SourceGenerators.csproj b/Source/Mockolate.SourceGenerators/Mockolate.SourceGenerators.csproj
index b2d403c3..454c4e27 100644
--- a/Source/Mockolate.SourceGenerators/Mockolate.SourceGenerators.csproj
+++ b/Source/Mockolate.SourceGenerators/Mockolate.SourceGenerators.csproj
@@ -11,6 +11,10 @@
S3776
+
+
+
+
@@ -20,4 +24,8 @@
+
+
+
+
diff --git a/Source/Mockolate.SourceGenerators/Polyfills/NotNullWhenAttribute.cs b/Source/Mockolate.SourceGenerators/Polyfills/NotNullWhenAttribute.cs
new file mode 100644
index 00000000..a39569f5
--- /dev/null
+++ b/Source/Mockolate.SourceGenerators/Polyfills/NotNullWhenAttribute.cs
@@ -0,0 +1,71 @@
+#region License
+// MIT License
+//
+// Copyright (c) Manuel Römer
+//
+// Permission is hereby granted, free of charge, to any person obtaining a copy
+// of this software and associated documentation files (the "Software"), to deal
+// in the Software without restriction, including without limitation the rights
+// to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
+// copies of the Software, and to permit persons to whom the Software is
+// furnished to do so, subject to the following conditions:
+//
+// The above copyright notice and this permission notice shall be included in all
+// copies or substantial portions of the Software.
+//
+// THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+// IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+// FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+// AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+// LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+// OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
+// SOFTWARE.
+#endregion
+
+#if !NULLABLE_ATTRIBUTES_DISABLE
+#nullable enable
+#pragma warning disable
+
+namespace System.Diagnostics.CodeAnalysis
+{
+ using global::System;
+
+#if DEBUG
+ ///
+ /// Specifies that when a method returns ,
+ /// the parameter will not be even if the corresponding type allows it.
+ ///
+#endif
+ [AttributeUsage(AttributeTargets.Parameter, Inherited = false)]
+#if !NULLABLE_ATTRIBUTES_INCLUDE_IN_CODE_COVERAGE
+ [ExcludeFromCodeCoverage, DebuggerNonUserCode]
+#endif
+ internal sealed class NotNullWhenAttribute : Attribute
+ {
+#if DEBUG
+ ///
+ /// Gets the return value condition.
+ /// If the method returns this value, the associated parameter will not be .
+ ///
+#endif
+ public bool ReturnValue { get; }
+
+#if DEBUG
+ ///
+ /// Initializes the attribute with the specified return value condition.
+ ///
+ ///
+ /// The return value condition.
+ /// If the method returns this value, the associated parameter will not be .
+ ///
+#endif
+ public NotNullWhenAttribute(bool returnValue)
+ {
+ ReturnValue = returnValue;
+ }
+ }
+}
+
+#pragma warning restore
+#nullable restore
+#endif // NULLABLE_ATTRIBUTES_DISABLE
diff --git a/Tests/Mockolate.Analyzers.Tests/AnalyzerHelpersTests.cs b/Tests/Mockolate.Analyzers.Tests/AnalyzerHelpersTests.cs
new file mode 100644
index 00000000..48a8ed7d
--- /dev/null
+++ b/Tests/Mockolate.Analyzers.Tests/AnalyzerHelpersTests.cs
@@ -0,0 +1,123 @@
+using System.Linq;
+using System.Threading.Tasks;
+using aweXpect;
+using Microsoft.CodeAnalysis;
+using Microsoft.CodeAnalysis.CSharp;
+using Microsoft.CodeAnalysis.CSharp.Syntax;
+using Xunit;
+using static aweXpect.Expect;
+
+namespace Mockolate.Analyzers.Tests;
+
+public class AnalyzerHelpersTests
+{
+ [Fact]
+ public async Task WhenInvokedMethodIsNotGeneric_ShouldNotReturnAnyTypeArgument()
+ {
+ const string source = """
+ public class C
+ {
+ public void Foo() { }
+ public void Bar() { Foo(); }
+ }
+ """;
+ IMethodSymbol method = GetInvokedMethod(source, "Foo");
+
+ ITypeSymbol? result = AnalyzerHelpers.GetSingleInvocationTypeArgumentOrNull(method);
+
+ await That(result).IsNull();
+ }
+
+ [Fact]
+ public async Task WhenInvokedMethodIsGeneric_ShouldReturnFirstTypeArgument()
+ {
+ const string source = """
+ public class C
+ {
+ public T Foo() => default!;
+ public void Bar() { Foo(); }
+ }
+ """;
+ IMethodSymbol method = GetInvokedMethod(source, "Foo");
+
+ ITypeSymbol? result = AnalyzerHelpers.GetSingleInvocationTypeArgumentOrNull(method);
+
+ await That(result).IsNotNull();
+ await That(result!.SpecialType).IsEqualTo(SpecialType.System_Int32);
+ }
+
+ [Fact]
+ public async Task WhenSyntaxIsNotInvocationExpression_ShouldNotReturnAnyLocation()
+ {
+ const string source = """
+ public class C
+ {
+ public int Foo() => 0;
+ }
+ """;
+ SyntaxTree tree = CSharpSyntaxTree.ParseText(source);
+ CSharpCompilation compilation = CreateCompilation(tree);
+ SemanticModel model = compilation.GetSemanticModel(tree);
+ MethodDeclarationSyntax declaration = tree.GetRoot().DescendantNodes()
+ .OfType()
+ .Single();
+ IMethodSymbol symbol = (IMethodSymbol)model.GetDeclaredSymbol(declaration)!;
+
+ Location? result = AnalyzerHelpers.GetTypeArgumentLocation(declaration, symbol.ReturnType);
+
+ await That(result).IsNull();
+ }
+
+ [Fact]
+ public async Task WhenInvocationHasGenericNameSyntax_ShouldReturnTypeArgumentLocation()
+ {
+ const string source = """
+ public static class S
+ {
+ public static T Make() => default!;
+ }
+
+ public class C
+ {
+ public void Foo() { S.Make(); }
+ }
+ """;
+ SyntaxTree tree = CSharpSyntaxTree.ParseText(source);
+ CSharpCompilation compilation = CreateCompilation(tree);
+ SemanticModel model = compilation.GetSemanticModel(tree);
+ InvocationExpressionSyntax invocation = tree.GetRoot().DescendantNodes()
+ .OfType()
+ .Single(i => i.Expression is MemberAccessExpressionSyntax { Name: GenericNameSyntax, });
+ IMethodSymbol method = (IMethodSymbol)model.GetSymbolInfo(invocation).Symbol!;
+ ITypeSymbol typeArgument = method.TypeArguments[0];
+
+ Location? result = AnalyzerHelpers.GetTypeArgumentLocation(invocation, typeArgument);
+
+ await That(result).IsNotNull();
+ }
+
+ private static IMethodSymbol GetInvokedMethod(string source, string methodName)
+ {
+ SyntaxTree tree = CSharpSyntaxTree.ParseText(source);
+ CSharpCompilation compilation = CreateCompilation(tree);
+ SemanticModel model = compilation.GetSemanticModel(tree);
+ InvocationExpressionSyntax invocation = tree.GetRoot().DescendantNodes()
+ .OfType()
+ .Single(i => InvocationName(i) == methodName);
+ return (IMethodSymbol)model.GetSymbolInfo(invocation).Symbol!;
+ }
+
+ private static string? InvocationName(InvocationExpressionSyntax invocation) => invocation.Expression switch
+ {
+ IdentifierNameSyntax id => id.Identifier.Text,
+ GenericNameSyntax generic => generic.Identifier.Text,
+ MemberAccessExpressionSyntax member => member.Name.Identifier.Text,
+ _ => null,
+ };
+
+ private static CSharpCompilation CreateCompilation(SyntaxTree tree) => CSharpCompilation.Create(
+ "TestAssembly",
+ [tree,],
+ [MetadataReference.CreateFromFile(typeof(object).Assembly.Location),],
+ new CSharpCompilationOptions(OutputKind.DynamicallyLinkedLibrary));
+}
diff --git a/Tests/Mockolate.SourceGenerators.Tests/Entities/MethodEqualityComparerTests.cs b/Tests/Mockolate.SourceGenerators.Tests/Entities/MethodEqualityComparerTests.cs
new file mode 100644
index 00000000..b7bbb438
--- /dev/null
+++ b/Tests/Mockolate.SourceGenerators.Tests/Entities/MethodEqualityComparerTests.cs
@@ -0,0 +1,83 @@
+using System.Collections.Generic;
+using System.Linq;
+using Microsoft.CodeAnalysis;
+using Microsoft.CodeAnalysis.CSharp;
+using Microsoft.CodeAnalysis.CSharp.Syntax;
+using Mockolate.SourceGenerators.Entities;
+
+namespace Mockolate.SourceGenerators.Tests.Entities;
+
+public class MethodEqualityComparerTests
+{
+ [Fact]
+ public async Task WhenBothMethodsAreNull_ShouldReturnTrue()
+ {
+ IEqualityComparer comparer = Method.ContainingTypeIndependentEqualityComparer;
+
+ bool result = comparer.Equals(null, null);
+
+ await That(result).IsTrue();
+ }
+
+ [Fact]
+ public async Task WhenLeftMethodIsNull_ShouldReturnFalse()
+ {
+ IEqualityComparer comparer = Method.ContainingTypeIndependentEqualityComparer;
+ Method right = CreateMethod("public class C { public void Foo() {} }", "Foo");
+
+ bool result = comparer.Equals(null, right);
+
+ await That(result).IsFalse();
+ }
+
+ [Fact]
+ public async Task WhenRightMethodIsNull_ShouldReturnFalse()
+ {
+ IEqualityComparer comparer = Method.ContainingTypeIndependentEqualityComparer;
+ Method left = CreateMethod("public class C { public void Foo() {} }", "Foo");
+
+ bool result = comparer.Equals(left, null);
+
+ await That(result).IsFalse();
+ }
+
+ [Fact]
+ public async Task WhenMethodsHaveDifferentNames_ShouldReturnFalse()
+ {
+ IEqualityComparer comparer = Method.ContainingTypeIndependentEqualityComparer;
+ Method left = CreateMethod("public class C { public void Foo() {} }", "Foo");
+ Method right = CreateMethod("public class C { public void Bar() {} }", "Bar");
+
+ bool result = comparer.Equals(left, right);
+
+ await That(result).IsFalse();
+ }
+
+ [Fact]
+ public async Task WhenMethodsHaveDifferentParameterCount_ShouldReturnFalse()
+ {
+ IEqualityComparer comparer = Method.ContainingTypeIndependentEqualityComparer;
+ Method left = CreateMethod("public class C { public void Foo() {} }", "Foo");
+ Method right = CreateMethod("public class C { public void Foo(int x) {} }", "Foo");
+
+ bool result = comparer.Equals(left, right);
+
+ await That(result).IsFalse();
+ }
+
+ private static Method CreateMethod(string source, string methodName)
+ {
+ SyntaxTree tree = CSharpSyntaxTree.ParseText(source);
+ CSharpCompilation compilation = CSharpCompilation.Create(
+ "TestAssembly",
+ [tree,],
+ [MetadataReference.CreateFromFile(typeof(object).Assembly.Location),],
+ new CSharpCompilationOptions(OutputKind.DynamicallyLinkedLibrary));
+ SemanticModel model = compilation.GetSemanticModel(tree);
+ MethodDeclarationSyntax declaration = tree.GetRoot().DescendantNodes()
+ .OfType()
+ .First(m => m.Identifier.Text == methodName);
+ IMethodSymbol symbol = (IMethodSymbol)model.GetDeclaredSymbol(declaration)!;
+ return new Method(symbol, null);
+ }
+}