diff --git a/src/MSBuildLocator/MSBuildLocator.cs b/src/MSBuildLocator/MSBuildLocator.cs index 0ed30ef6..dc35aa89 100644 --- a/src/MSBuildLocator/MSBuildLocator.cs +++ b/src/MSBuildLocator/MSBuildLocator.cs @@ -18,6 +18,7 @@ namespace Microsoft.Build.Locator public static class MSBuildLocator { private const string MSBuildPublicKeyToken = "b03f5f7f11d50a3a"; + private const string NuGetPublicKeyToken = "31bf3856ad364e35"; private static readonly string[] s_msBuildAssemblies = { @@ -26,6 +27,13 @@ public static class MSBuildLocator "Microsoft.Build.Tasks.Core", "Microsoft.Build.Utilities.Core" }; + private static readonly string[] s_nuGetAssemblyNames = { + "NuGet.Common", + "NuGet.Frameworks", + "NuGet.Packaging", + "NuGet.ProjectModel", + "NuGet.Versioning" + }; #if NET46 private static ResolveEventHandler s_registeredHandler; @@ -209,7 +217,7 @@ Assembly TryLoadAssembly(AssemblyName assemblyName) if (File.Exists(targetAssembly)) { // Automatically unregister the handler once all supported assemblies have been loaded. - if (Interlocked.Increment(ref numResolvedAssemblies) == s_msBuildAssemblies.Length) + if (Interlocked.Increment(ref numResolvedAssemblies) == (s_msBuildAssemblies.Length + s_nuGetAssemblyNames.Length)) { Unregister(); } @@ -287,11 +295,20 @@ private static void ApplyDotNetSdkEnvironmentVariables(string dotNetSdkPath) private static bool IsMSBuildAssembly(AssemblyName assemblyName) { - if (!s_msBuildAssemblies.Contains(assemblyName.Name, StringComparer.OrdinalIgnoreCase)) + if (s_msBuildAssemblies.Contains(assemblyName.Name, StringComparer.OrdinalIgnoreCase)) { - return false; + return HasPublicKeyToken(assemblyName, MSBuildPublicKeyToken); + } + if (s_nuGetAssemblyNames.Contains(assemblyName.Name, StringComparer.OrdinalIgnoreCase)) + { + return HasPublicKeyToken(assemblyName, NuGetPublicKeyToken); } + return false; + } + + private static bool HasPublicKeyToken(AssemblyName assemblyName, string expectedPublicKeyToken) + { var publicKeyToken = assemblyName.GetPublicKeyToken(); if (publicKeyToken == null || publicKeyToken.Length == 0) { @@ -304,7 +321,7 @@ private static bool IsMSBuildAssembly(AssemblyName assemblyName) sb.Append($"{b:x2}"); } - return sb.ToString().Equals(MSBuildPublicKeyToken, StringComparison.OrdinalIgnoreCase); + return sb.ToString().Equals(expectedPublicKeyToken, StringComparison.OrdinalIgnoreCase); } private static IEnumerable GetInstances(VisualStudioInstanceQueryOptions options)