From c2248ffe75f142fd04a1169dbd19918874ade8b1 Mon Sep 17 00:00:00 2001 From: "copilot-swe-agent[bot]" <198982749+Copilot@users.noreply.github.com> Date: Sun, 18 Jan 2026 16:29:18 +0000 Subject: [PATCH 1/5] Initial plan From 920456f0786943fd2dffac82a26c950cd6ec134a Mon Sep 17 00:00:00 2001 From: "copilot-swe-agent[bot]" <198982749+Copilot@users.noreply.github.com> Date: Sun, 18 Jan 2026 16:36:22 +0000 Subject: [PATCH 2/5] Fix SSL certificate validation for integration tests - Updated HttpClientHandler to bypass SSL certificate validation in test environment - Added ServerCertificateCustomValidationCallback to both GetHttpClient methods - Fixes SSL errors on Ubuntu when using self-signed certificates in tests Co-authored-by: Keboo <952248+Keboo@users.noreply.github.com> --- .../Helpers/TestHelper.cs | 13 +++++++++++-- 1 file changed, 11 insertions(+), 2 deletions(-) diff --git a/Keboo.Web.Proxy.IntegrationTests/Helpers/TestHelper.cs b/Keboo.Web.Proxy.IntegrationTests/Helpers/TestHelper.cs index e6e9686..25c1e0f 100644 --- a/Keboo.Web.Proxy.IntegrationTests/Helpers/TestHelper.cs +++ b/Keboo.Web.Proxy.IntegrationTests/Helpers/TestHelper.cs @@ -11,14 +11,23 @@ public static HttpClient GetHttpClient(int localProxyPort, { var proxy = new TestProxy($"http://localhost:{localProxyPort}", enableBasicProxyAuthorization); - var handler = new HttpClientHandler { Proxy = proxy, UseProxy = true }; + var handler = new HttpClientHandler + { + Proxy = proxy, + UseProxy = true, + ServerCertificateCustomValidationCallback = HttpClientHandler.DangerousAcceptAnyServerCertificateValidator + }; return new HttpClient(handler); } public static HttpClient GetHttpClient() { - return new HttpClient(new HttpClientHandler()); + var handler = new HttpClientHandler + { + ServerCertificateCustomValidationCallback = HttpClientHandler.DangerousAcceptAnyServerCertificateValidator + }; + return new HttpClient(handler); } public class TestProxy : IWebProxy From e6b942066791de6d16dff4a0964e36eea88947b3 Mon Sep 17 00:00:00 2001 From: "copilot-swe-agent[bot]" <198982749+Copilot@users.noreply.github.com> Date: Sun, 18 Jan 2026 16:37:21 +0000 Subject: [PATCH 3/5] Add comments explaining SSL certificate bypass in tests - Added explanatory comments for ServerCertificateCustomValidationCallback usage - Clarifies this is safe for integration tests but not production code - Addresses code review feedback Co-authored-by: Keboo <952248+Keboo@users.noreply.github.com> --- Keboo.Web.Proxy.IntegrationTests/Helpers/TestHelper.cs | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/Keboo.Web.Proxy.IntegrationTests/Helpers/TestHelper.cs b/Keboo.Web.Proxy.IntegrationTests/Helpers/TestHelper.cs index 25c1e0f..30d8c8c 100644 --- a/Keboo.Web.Proxy.IntegrationTests/Helpers/TestHelper.cs +++ b/Keboo.Web.Proxy.IntegrationTests/Helpers/TestHelper.cs @@ -15,6 +15,8 @@ public static HttpClient GetHttpClient(int localProxyPort, { Proxy = proxy, UseProxy = true, + // Accept self-signed certificates generated for testing. This is safe in integration tests + // where we control both the client and server, but should not be used in production code. ServerCertificateCustomValidationCallback = HttpClientHandler.DangerousAcceptAnyServerCertificateValidator }; @@ -25,6 +27,8 @@ public static HttpClient GetHttpClient() { var handler = new HttpClientHandler { + // Accept self-signed certificates generated for testing. This is safe in integration tests + // where we control both the client and server, but should not be used in production code. ServerCertificateCustomValidationCallback = HttpClientHandler.DangerousAcceptAnyServerCertificateValidator }; return new HttpClient(handler); From 929d3b82c9ed962e0d2e0a8d2274715ec7c064d4 Mon Sep 17 00:00:00 2001 From: Kevin Bost Date: Mon, 19 Jan 2026 08:46:43 -0800 Subject: [PATCH 4/5] Refactors certificate management and creation Adopts consistent private field naming conventions across `CertificateManager` and certificate maker implementations. Removes `ExceptionHandler` dependency from certificate maker constructors, simplifying their design and shifting error handling. Updates default root certificate names to explicitly indicate their non-trustworthy nature, enhancing security clarity. Cleans up unit tests by leveraging simplified assertion methods and optimizing `Task.WhenAll` usage. --- Directory.Packages.props | 8 +- .../CertificateManagerTests.cs | 163 ++++++------- .../Certificates/CertificateManager.cs | 107 ++++---- .../Certificates/Makers/BCCertificateMaker.cs | 14 +- .../Makers/BCCertificateMakerFast.cs | 19 +- .../Makers/WinCertificateMaker.cs | 228 +++++++++--------- 6 files changed, 261 insertions(+), 278 deletions(-) diff --git a/Directory.Packages.props b/Directory.Packages.props index 4f011a5..cf3517e 100644 --- a/Directory.Packages.props +++ b/Directory.Packages.props @@ -29,7 +29,7 @@ all - + @@ -45,9 +45,9 @@ - - - + + + runtime; build; native; contentfiles; analyzers; buildtransitive diff --git a/Keboo.Web.Proxy.UnitTests/CertificateManagerTests.cs b/Keboo.Web.Proxy.UnitTests/CertificateManagerTests.cs index 69db816..d2941e8 100644 --- a/Keboo.Web.Proxy.UnitTests/CertificateManagerTests.cs +++ b/Keboo.Web.Proxy.UnitTests/CertificateManagerTests.cs @@ -1,100 +1,93 @@ -using System; -using System.Collections.Generic; using System.Diagnostics; -using System.Linq; -using System.Threading.Tasks; -using TUnit.Core; -using TUnit.Assertions; -using TUnit.Assertions.Extensions; + using Keboo.Web.Proxy.Network; -namespace Keboo.Web.Proxy.UnitTests +namespace Keboo.Web.Proxy.UnitTests; + +public class CertificateManagerTests { - public class CertificateManagerTests - { - private static readonly string[] hostNames - = { "facebook.com", "youtube.com", "google.com", "bing.com", "yahoo.com" }; + private static readonly string[] hostNames + = { "facebook.com", "youtube.com", "google.com", "bing.com", "yahoo.com" }; - [Test] - public async Task Simple_BC_Create_Certificate_Test() + [Test] + public async Task Simple_BC_Create_Certificate_Test() + { + var tasks = new List(); + + CertificateManager mgr = new (null, null, false, false, false, new Lazy(() => e => + { + Debug.WriteLine(e.ToString()); + Debug.WriteLine(e.InnerException?.ToString()); + }).Value) { - var tasks = new List(); + CertificateEngine = CertificateEngine.BouncyCastle + }; + mgr.ClearIdleCertificates(); + for (var i = 0; i < 5; i++) + tasks.AddRange(hostNames.Select(host => Task.Run(async () => + { + // get the connection + var certificate = mgr.CreateCertificate(host, false); + Assert.NotNull(certificate); + }))); + + await Task.WhenAll(tasks); + + mgr.StopClearIdleCertificates(); + } + + // uncomment this to compare WinCert maker performance with BC (BC takes more time for same test above) + //[TestMethod] + public static async Task Simple_Create_Win_Certificate_Test() + { + var tasks = new List(); + + CertificateManager mgr = new(null, null, false, false, false, new Lazy(() => e => + { + Debug.WriteLine(e.ToString()); + Debug.WriteLine(e.InnerException?.ToString()); + }).Value) + { CertificateEngine = CertificateEngine.DefaultWindows }; + + mgr.CreateRootCertificate(); + mgr.TrustRootCertificate(true); + mgr.ClearIdleCertificates(); + + for (var i = 0; i < 5; i++) + tasks.AddRange(hostNames.Select(host => Task.Run(async () => + { + // get the connection + var certificate = mgr.CreateCertificate(host, false); + Assert.NotNull(certificate); + }))); + + await Task.WhenAll(tasks.ToArray()); + mgr.RemoveTrustedRootCertificate(true); + mgr.StopClearIdleCertificates(); + } - var mgr = new CertificateManager(null, null, false, false, false, new Lazy(() => e => + [Test] + public async Task Create_Server_Certificate_Test() + { + var tasks = new List(); + + CertificateManager mgr = new(null, null, false, false, false, new Lazy(() => e => { Debug.WriteLine(e.ToString()); Debug.WriteLine(e.InnerException?.ToString()); }).Value) + { CertificateEngine = CertificateEngine.BouncyCastleFast }; + + mgr.SaveFakeCertificates = true; + + for (var i = 0; i < 500; i++) + tasks.AddRange(hostNames.Select(host => Task.Run(async () => { - CertificateEngine = CertificateEngine.BouncyCastle - }; - mgr.ClearIdleCertificates(); - for (var i = 0; i < 5; i++) - tasks.AddRange(hostNames.Select(host => Task.Run(async () => - { - // get the connection - var certificate = mgr.CreateCertificate(host, false); - await Assert.That(certificate is not null).IsTrue(); - }))); - - await Task.WhenAll(tasks.ToArray()); - - mgr.StopClearIdleCertificates(); - } - - // uncomment this to compare WinCert maker performance with BC (BC takes more time for same test above) - //[TestMethod] - public static async Task Simple_Create_Win_Certificate_Test() - { - var tasks = new List(); - - var mgr = new CertificateManager(null, null, false, false, false, new Lazy(() => e => - { - Debug.WriteLine(e.ToString()); - Debug.WriteLine(e.InnerException?.ToString()); - }).Value) - { CertificateEngine = CertificateEngine.DefaultWindows }; - - mgr.CreateRootCertificate(); - mgr.TrustRootCertificate(true); - mgr.ClearIdleCertificates(); - - for (var i = 0; i < 5; i++) - tasks.AddRange(hostNames.Select(host => Task.Run(async () => - { - // get the connection - var certificate = mgr.CreateCertificate(host, false); - await Assert.That(certificate is not null).IsTrue(); - }))); - - await Task.WhenAll(tasks.ToArray()); - mgr.RemoveTrustedRootCertificate(true); - mgr.StopClearIdleCertificates(); - } - - [Test] - public async Task Create_Server_Certificate_Test() - { - var tasks = new List(); - - var mgr = new CertificateManager(null, null, false, false, false, new Lazy(() => e => - { - Debug.WriteLine(e.ToString()); - Debug.WriteLine(e.InnerException?.ToString()); - }).Value) - { CertificateEngine = CertificateEngine.BouncyCastleFast }; - - mgr.SaveFakeCertificates = true; - - for (var i = 0; i < 500; i++) - tasks.AddRange(hostNames.Select(host => Task.Run(async () => - { - var certificate = mgr.CreateServerCertificate(host); - await Assert.That(certificate is not null).IsTrue(); - }))); - - await Task.WhenAll(tasks.ToArray()); - } + var certificate = mgr.CreateServerCertificate(host); + Assert.NotNull(certificate); + }))); + + await Task.WhenAll(tasks); } } \ No newline at end of file diff --git a/Keboo.Web.Proxy/Certificates/CertificateManager.cs b/Keboo.Web.Proxy/Certificates/CertificateManager.cs index 0d12786..cb054e9 100644 --- a/Keboo.Web.Proxy/Certificates/CertificateManager.cs +++ b/Keboo.Web.Proxy/Certificates/CertificateManager.cs @@ -1,12 +1,7 @@ -using System; using System.Collections.Concurrent; -using System.Collections.Generic; using System.Diagnostics; -using System.IO; -using System.Linq; using System.Security.Cryptography.X509Certificates; -using System.Threading; -using System.Threading.Tasks; + using Keboo.Web.Proxy.Helpers; using Keboo.Web.Proxy.Network.Certificate; using Keboo.Web.Proxy.Shared; @@ -39,45 +34,45 @@ public enum CertificateEngine /// public sealed class CertificateManager : IDisposable { - private const string DefaultRootCertificateIssuer = "Keboo"; + private const string DefaultRootCertificateIssuer = "DO_NOT_TRUST_Keboo.Web.Proxy"; - private const string DefaultRootRootCertificateName = "Keboo Root Certificate Authority"; + private const string DefaultRootRootCertificateName = "DO_NOT_TRUST_Keboo.Web.Proxy"; private static readonly ConcurrentDictionary _saveCertificateLocks = new(); /// /// Cache dictionary /// - private readonly ConcurrentDictionary cachedCertificates = new(); + private readonly ConcurrentDictionary _cachedCertificates = []; - private readonly CancellationTokenSource clearCertificatesTokenSource = new(); + private readonly CancellationTokenSource _clearCertificatesTokenSource = new(); /// /// Used to prevent multiple threads working on same certificate generation /// when burst certificate generation requests happen for same certificate. /// - private readonly SemaphoreSlim pendingCertificateCreationTaskLock = new(1); + private readonly SemaphoreSlim _pendingCertificateCreationTaskLock = new(1); /// /// A list of pending certificate creation tasks. /// - private readonly Dictionary> pendingCertificateCreationTasks = new(); + private readonly Dictionary> _pendingCertificateCreationTasks = []; - private readonly object rootCertCreationLock = new(); + private readonly Lock _rootCertCreationLock = new(); private ICertificateMaker? certEngineValue; - private ICertificateCache certificateCache = new DefaultCertificateDiskCache(); + private ICertificateCache _certificateCache = new DefaultCertificateDiskCache(); - private bool disposed; + private bool _disposed; private CertificateEngine engine; - private string? issuer; + private string? _issuer; - private X509Certificate2? rootCertificate; + private X509Certificate2? _rootCertificate; - private string? rootCertificateName; + private string? _rootCertificateName; /// /// Initializes a new instance of the class. @@ -120,14 +115,14 @@ private ICertificateMaker CertEngine switch (engine) { case CertificateEngine.BouncyCastle: - certEngineValue = new BcCertificateMaker(ExceptionFunc, CertificateValidDays); + certEngineValue = new BcCertificateMaker(CertificateValidDays); break; case CertificateEngine.BouncyCastleFast: - certEngineValue = new BcCertificateMakerFast(ExceptionFunc, CertificateValidDays); + certEngineValue = new BcCertificateMakerFast(CertificateValidDays); break; case CertificateEngine.DefaultWindows: default: - certEngineValue = new WinCertificateMaker(ExceptionFunc, CertificateValidDays); + certEngineValue = new WinCertificateMaker(CertificateValidDays); break; } @@ -210,8 +205,8 @@ public CertificateEngine CertificateEngine /// public string RootCertificateIssuerName { - get => issuer ?? DefaultRootCertificateIssuer; - set => issuer = value; + get => _issuer ?? DefaultRootCertificateIssuer; + set => _issuer = value; } /// @@ -223,8 +218,8 @@ public string RootCertificateIssuerName /// public string RootCertificateName { - get => rootCertificateName ?? DefaultRootRootCertificateName; - set => rootCertificateName = value; + get => _rootCertificateName ?? DefaultRootRootCertificateName; + set => _rootCertificateName = value; } /// @@ -232,11 +227,11 @@ public string RootCertificateName /// public X509Certificate2? RootCertificate { - get => rootCertificate; + get => _rootCertificate; set { ClearRootCertificate(); - rootCertificate = value; + _rootCertificate = value; } } @@ -254,8 +249,8 @@ public X509Certificate2? RootCertificate /// public ICertificateCache CertificateStorage { - get => certificateCache; - set => certificateCache = value ?? new DefaultCertificateDiskCache(); + get => _certificateCache; + set => _certificateCache = value ?? new DefaultCertificateDiskCache(); } /// @@ -422,7 +417,7 @@ private void OnException(Exception exception) try { - certificate = certificateCache.LoadCertificate(subjectName, StorageFlag); + certificate = _certificateCache.LoadCertificate(subjectName, StorageFlag); if (certificate != null && certificate.NotAfter <= DateTime.Now) { @@ -453,7 +448,7 @@ private void OnException(Exception exception) try { //no two tasks with same subject name should together enter here - certificateCache.SaveCertificate(subjectName, certificate); + _certificateCache.SaveCertificate(subjectName, certificate); } finally { @@ -491,7 +486,7 @@ private void OnException(Exception exception) public async Task CreateServerCertificate(string certificateName) { // check in cache first - if (cachedCertificates.TryGetValue(certificateName, out var cached)) + if (_cachedCertificates.TryGetValue(certificateName, out var cached)) { cached.LastAccess = DateTime.UtcNow; return cached.Certificate; @@ -499,11 +494,11 @@ private void OnException(Exception exception) var createdTask = false; Task? createCertificateTask; - await pendingCertificateCreationTaskLock.WaitAsync(); + await _pendingCertificateCreationTaskLock.WaitAsync(); try { // check in cache first - if (cachedCertificates.TryGetValue(certificateName, out cached)) + if (_cachedCertificates.TryGetValue(certificateName, out cached)) { cached.LastAccess = DateTime.UtcNow; return cached.Certificate; @@ -511,24 +506,24 @@ private void OnException(Exception exception) // handle burst requests with same certificate name // by checking for existing task for same certificate name - if (!pendingCertificateCreationTasks.TryGetValue(certificateName, out createCertificateTask)) + if (!_pendingCertificateCreationTasks.TryGetValue(certificateName, out createCertificateTask)) { // run certificate creation task & add it to pending tasks createCertificateTask = Task.Run(() => { var result = CreateCertificate(certificateName, false); - if (result != null) cachedCertificates.TryAdd(certificateName, new CachedCertificate(result)); + if (result != null) _cachedCertificates.TryAdd(certificateName, new CachedCertificate(result)); return result; }); - pendingCertificateCreationTasks[certificateName] = createCertificateTask; + _pendingCertificateCreationTasks[certificateName] = createCertificateTask; createdTask = true; } } finally { - pendingCertificateCreationTaskLock.Release(); + _pendingCertificateCreationTaskLock.Release(); } var certificate = await createCertificateTask; @@ -536,14 +531,14 @@ private void OnException(Exception exception) if (createdTask) { // cleanup pending task - await pendingCertificateCreationTaskLock.WaitAsync(); + await _pendingCertificateCreationTaskLock.WaitAsync(); try { - pendingCertificateCreationTasks.Remove(certificateName); + _pendingCertificateCreationTasks.Remove(certificateName); } finally { - pendingCertificateCreationTaskLock.Release(); + _pendingCertificateCreationTaskLock.Release(); } } @@ -555,14 +550,14 @@ private void OnException(Exception exception) /// internal async void ClearIdleCertificates() { - var cancellationToken = clearCertificatesTokenSource.Token; + var cancellationToken = _clearCertificatesTokenSource.Token; while (!cancellationToken.IsCancellationRequested) { var cutOff = DateTime.UtcNow.AddMinutes(-CertificateCacheTimeOutMinutes); - var outdated = cachedCertificates.Where(x => x.Value.LastAccess < cutOff).ToList(); + var outdated = _cachedCertificates.Where(x => x.Value.LastAccess < cutOff).ToList(); - foreach (var cache in outdated) cachedCertificates.TryRemove(cache.Key, out _); + foreach (var cache in outdated) _cachedCertificates.TryRemove(cache.Key, out _); // after a minute come back to check for outdated certificates in cache try @@ -581,7 +576,7 @@ internal async void ClearIdleCertificates() /// internal void StopClearIdleCertificates() { - clearCertificatesTokenSource.Cancel(); + _clearCertificatesTokenSource.Cancel(); } /// @@ -593,7 +588,7 @@ internal void StopClearIdleCertificates() /// public bool CreateRootCertificate(bool persistToFile = true) { - lock (rootCertCreationLock) + lock (_rootCertCreationLock) { if (persistToFile && RootCertificate == null) RootCertificate = LoadRootCertificate(); @@ -602,7 +597,7 @@ public bool CreateRootCertificate(bool persistToFile = true) if (!OverwritePfxFile) try { - var rootCert = certificateCache.LoadRootCertificate(PfxFilePath, PfxPassword, + var rootCert = _certificateCache.LoadRootCertificate(PfxFilePath, PfxPassword, X509KeyStorageFlags.Exportable); if (rootCert != null && rootCert.NotAfter <= DateTime.Now) @@ -637,14 +632,14 @@ public bool CreateRootCertificate(bool persistToFile = true) { try { - certificateCache.Clear(); + _certificateCache.Clear(); } catch (Exception e) { OnException(new Exception("An error happened when clearing certificate cache.", e)); } - certificateCache.SaveRootCertificate(PfxFilePath, PfxPassword, RootCertificate); + _certificateCache.SaveRootCertificate(PfxFilePath, PfxPassword, RootCertificate); } catch (Exception e) { @@ -664,7 +659,7 @@ public bool CreateRootCertificate(bool persistToFile = true) try { var rootCert = - certificateCache.LoadRootCertificate(PfxFilePath, PfxPassword, X509KeyStorageFlags.Exportable); + _certificateCache.LoadRootCertificate(PfxFilePath, PfxPassword, X509KeyStorageFlags.Exportable); if (rootCert != null && rootCert.NotAfter <= DateTime.Now) { @@ -937,18 +932,18 @@ public bool RemoveTrustedRootCertificateAsAdmin(bool machineTrusted = false) /// public void ClearRootCertificate() { - certificateCache.Clear(); - cachedCertificates.Clear(); - rootCertificate = null; + _certificateCache.Clear(); + _cachedCertificates.Clear(); + _rootCertificate = null; } private void Dispose(bool disposing) { - if (disposed) return; + if (_disposed) return; - if (disposing) clearCertificatesTokenSource.Dispose(); + if (disposing) _clearCertificatesTokenSource.Dispose(); - disposed = true; + _disposed = true; } ~CertificateManager() diff --git a/Keboo.Web.Proxy/Certificates/Makers/BCCertificateMaker.cs b/Keboo.Web.Proxy/Certificates/Makers/BCCertificateMaker.cs index 634124e..ef0ed65 100644 --- a/Keboo.Web.Proxy/Certificates/Makers/BCCertificateMaker.cs +++ b/Keboo.Web.Proxy/Certificates/Makers/BCCertificateMaker.cs @@ -1,7 +1,9 @@ -using System; -using System.IO; using System.Net; using System.Security.Cryptography.X509Certificates; + +using Keboo.Web.Proxy.Helpers; +using Keboo.Web.Proxy.Shared; + using Org.BouncyCastle.Asn1; using Org.BouncyCastle.Asn1.Pkcs; using Org.BouncyCastle.Asn1.X509; @@ -16,8 +18,7 @@ using Org.BouncyCastle.Security; using Org.BouncyCastle.Utilities; using Org.BouncyCastle.X509; -using Keboo.Web.Proxy.Helpers; -using Keboo.Web.Proxy.Shared; + using X509Certificate = Org.BouncyCastle.X509.X509Certificate; namespace Keboo.Web.Proxy.Network.Certificate; @@ -34,12 +35,9 @@ internal class BcCertificateMaker : ICertificateMaker private static bool _doNotSetFriendlyName; private readonly int certificateValidDays; - private readonly ExceptionHandler? exceptionFunc; - - internal BcCertificateMaker(ExceptionHandler? exceptionFunc, int certificateValidDays) + internal BcCertificateMaker(int certificateValidDays) { this.certificateValidDays = certificateValidDays; - this.exceptionFunc = exceptionFunc; } /// diff --git a/Keboo.Web.Proxy/Certificates/Makers/BCCertificateMakerFast.cs b/Keboo.Web.Proxy/Certificates/Makers/BCCertificateMakerFast.cs index 7140132..123c0bb 100644 --- a/Keboo.Web.Proxy/Certificates/Makers/BCCertificateMakerFast.cs +++ b/Keboo.Web.Proxy/Certificates/Makers/BCCertificateMakerFast.cs @@ -1,7 +1,9 @@ -using System; -using System.IO; using System.Net; using System.Security.Cryptography.X509Certificates; + +using Keboo.Web.Proxy.Helpers; +using Keboo.Web.Proxy.Shared; + using Org.BouncyCastle.Asn1; using Org.BouncyCastle.Asn1.Pkcs; using Org.BouncyCastle.Asn1.X509; @@ -16,8 +18,7 @@ using Org.BouncyCastle.Security; using Org.BouncyCastle.Utilities; using Org.BouncyCastle.X509; -using Keboo.Web.Proxy.Helpers; -using Keboo.Web.Proxy.Shared; + using X509Certificate = Org.BouncyCastle.X509.X509Certificate; namespace Keboo.Web.Proxy.Network.Certificate; @@ -33,13 +34,11 @@ internal class BcCertificateMakerFast : ICertificateMaker // Set this flag to true when exception detected to avoid further exceptions private static bool _doNotSetFriendlyName; - private readonly ExceptionHandler? exceptionFunc; - private readonly int certificateValidDays; + private readonly int _certificateValidDays; - internal BcCertificateMakerFast(ExceptionHandler? exceptionFunc, int certificateValidDays) + internal BcCertificateMakerFast(int certificateValidDays) { - this.certificateValidDays = certificateValidDays; - this.exceptionFunc = exceptionFunc; + _certificateValidDays = certificateValidDays; KeyPair = GenerateKeyPair(); } @@ -222,7 +221,7 @@ private X509Certificate2 MakeCertificateInternal(string subject, bool switchToMtaIfNeeded, X509Certificate2? signingCert = null) { return MakeCertificateInternal(subject, $"CN={subject}", - DateTime.UtcNow.AddDays(-CertificateGraceDays), DateTime.UtcNow.AddDays(certificateValidDays), + DateTime.UtcNow.AddDays(-CertificateGraceDays), DateTime.UtcNow.AddDays(_certificateValidDays), signingCert); } } \ No newline at end of file diff --git a/Keboo.Web.Proxy/Certificates/Makers/WinCertificateMaker.cs b/Keboo.Web.Proxy/Certificates/Makers/WinCertificateMaker.cs index 87f586e..c685780 100644 --- a/Keboo.Web.Proxy/Certificates/Makers/WinCertificateMaker.cs +++ b/Keboo.Web.Proxy/Certificates/Makers/WinCertificateMaker.cs @@ -1,9 +1,7 @@ -using System; using System.Net; using System.Reflection; +using System.Runtime.Versioning; using System.Security.Cryptography.X509Certificates; -using System.Threading; -using System.Threading.Tasks; namespace Keboo.Web.Proxy.Network.Certificate; @@ -12,68 +10,70 @@ namespace Keboo.Web.Proxy.Network.Certificate; /// Certificate Maker - uses MakeCert /// Calls COM objects using reflection /// +[SupportedOSPlatform("windows")] internal class WinCertificateMaker : ICertificateMaker { - private readonly ExceptionHandler? exceptionFunc; - private readonly string sProviderName = "Microsoft Enhanced Cryptographic Provider v1.0"; - private readonly Type? typeAltNamesCollection; + private readonly Type _typeAltNamesCollection; - private readonly Type? typeBasicConstraints; + private readonly Type _typeBasicConstraints; - private readonly Type? typeCAlternativeName; + private readonly Type _typeCAlternativeName; - private readonly Type? typeEkuExt; + private readonly Type _typeEkuExt; - private readonly Type? typeExtNames; + private readonly Type _typeExtNames; - private readonly Type? typeKuExt; + private readonly Type _typeKuExt; - private readonly Type? typeOid; + private readonly Type _typeOid; - private readonly Type? typeOids; + private readonly Type _typeOids; - private readonly Type? typeRequestCert; + private readonly Type _typeRequestCert; - private readonly Type? typeSignerCertificate; - private readonly Type? typeX500Dn; + private readonly Type _typeSignerCertificate; + private readonly Type _typeX500Dn; - private readonly Type? typeX509Enrollment; + private readonly Type _typeX509Enrollment; - private readonly Type? typeX509Extensions; + private readonly Type _typeX509Extensions; - private readonly Type? typeX509PrivateKey; + private readonly Type _typeX509PrivateKey; // Validity Days for Root Certificates Generated. - private readonly int certificateValidDays; + private readonly int _certificateValidDays; - private object? sharedPrivateKey; + private object? _sharedPrivateKey; /// /// Constructor. /// - internal WinCertificateMaker(ExceptionHandler? exceptionFunc, int certificateValidDays) + internal WinCertificateMaker(int certificateValidDays) { - this.certificateValidDays = certificateValidDays; - this.exceptionFunc = exceptionFunc; - - typeX500Dn = Type.GetTypeFromProgID("X509Enrollment.CX500DistinguishedName", true); - typeX509PrivateKey = Type.GetTypeFromProgID("X509Enrollment.CX509PrivateKey", true); - typeOid = Type.GetTypeFromProgID("X509Enrollment.CObjectId", true); - typeOids = Type.GetTypeFromProgID("X509Enrollment.CObjectIds.1", true); - typeEkuExt = Type.GetTypeFromProgID("X509Enrollment.CX509ExtensionEnhancedKeyUsage"); - typeKuExt = Type.GetTypeFromProgID("X509Enrollment.CX509ExtensionKeyUsage"); - typeRequestCert = Type.GetTypeFromProgID("X509Enrollment.CX509CertificateRequestCertificate"); - typeX509Extensions = Type.GetTypeFromProgID("X509Enrollment.CX509Extensions"); - typeBasicConstraints = Type.GetTypeFromProgID("X509Enrollment.CX509ExtensionBasicConstraints"); - typeSignerCertificate = Type.GetTypeFromProgID("X509Enrollment.CSignerCertificate"); - typeX509Enrollment = Type.GetTypeFromProgID("X509Enrollment.CX509Enrollment"); + _certificateValidDays = certificateValidDays; + + _typeX500Dn = GetType("X509Enrollment.CX500DistinguishedName", true); + _typeX509PrivateKey = GetType("X509Enrollment.CX509PrivateKey", true); + _typeOid = GetType("X509Enrollment.CObjectId", true); + _typeOids = GetType("X509Enrollment.CObjectIds.1", true); + _typeEkuExt = GetType("X509Enrollment.CX509ExtensionEnhancedKeyUsage"); + _typeKuExt = GetType("X509Enrollment.CX509ExtensionKeyUsage"); + _typeRequestCert = GetType("X509Enrollment.CX509CertificateRequestCertificate"); + _typeX509Extensions = GetType("X509Enrollment.CX509Extensions"); + _typeBasicConstraints = GetType("X509Enrollment.CX509ExtensionBasicConstraints"); + _typeSignerCertificate = GetType("X509Enrollment.CSignerCertificate"); + _typeX509Enrollment = GetType("X509Enrollment.CX509Enrollment"); // for alternative names - typeAltNamesCollection = Type.GetTypeFromProgID("X509Enrollment.CAlternativeNames"); - typeExtNames = Type.GetTypeFromProgID("X509Enrollment.CX509ExtensionAlternativeNames"); - typeCAlternativeName = Type.GetTypeFromProgID("X509Enrollment.CAlternativeName"); + _typeAltNamesCollection = GetType("X509Enrollment.CAlternativeNames"); + _typeExtNames = GetType("X509Enrollment.CX509ExtensionAlternativeNames"); + _typeCAlternativeName = GetType("X509Enrollment.CAlternativeName"); + + static Type GetType(string programId, bool throwOnError = false) + => Type.GetTypeFromProgID(programId, throwOnError) ?? + throw new InvalidOperationException($"Could not retrieve {programId}"); } @@ -108,7 +108,7 @@ private X509Certificate2 MakeCertificate(string sSubjectCn, var now = DateTime.UtcNow; var graceTime = now.AddDays(graceDays); var certificate = MakeCertificate(sSubjectCn, fullSubject, keyLength, hashAlgo, graceTime, - now.AddDays(certificateValidDays), signingCertificate); + now.AddDays(_certificateValidDays), signingCertificate); return certificate; } @@ -116,186 +116,184 @@ private X509Certificate2 MakeCertificate(string subject, string fullSubject, int privateKeyLength, string hashAlg, DateTime validFrom, DateTime validTo, X509Certificate2? signingCertificate) { - var x500CertDn = Activator.CreateInstance(typeX500Dn); - var typeValue = new object[] { fullSubject, 0 }; - typeX500Dn.InvokeMember("Encode", BindingFlags.InvokeMethod, null, x500CertDn, typeValue); + var x500CertDn = Activator.CreateInstance(_typeX500Dn); + object?[] typeValue = [fullSubject, 0]; + _typeX500Dn.InvokeMember("Encode", BindingFlags.InvokeMethod, null, x500CertDn, typeValue); - var x500RootCertDn = Activator.CreateInstance(typeX500Dn); + var x500RootCertDn = Activator.CreateInstance(_typeX500Dn); if (signingCertificate != null) typeValue[0] = signingCertificate.Subject; - typeX500Dn.InvokeMember("Encode", BindingFlags.InvokeMethod, null, x500RootCertDn, typeValue); + _typeX500Dn.InvokeMember("Encode", BindingFlags.InvokeMethod, null, x500RootCertDn, typeValue); object? sharedPrivateKey = null; - if (signingCertificate != null) sharedPrivateKey = this.sharedPrivateKey; + if (signingCertificate != null) sharedPrivateKey = this._sharedPrivateKey; if (sharedPrivateKey == null) { - sharedPrivateKey = Activator.CreateInstance(typeX509PrivateKey); - typeValue = new object[] { sProviderName }; - typeX509PrivateKey.InvokeMember("ProviderName", BindingFlags.PutDispProperty, null, sharedPrivateKey, + sharedPrivateKey = Activator.CreateInstance(_typeX509PrivateKey); + typeValue = [sProviderName]; + _typeX509PrivateKey.InvokeMember("ProviderName", BindingFlags.PutDispProperty, null, sharedPrivateKey, typeValue); typeValue[0] = 2; - typeX509PrivateKey.InvokeMember("ExportPolicy", BindingFlags.PutDispProperty, null, sharedPrivateKey, + _typeX509PrivateKey.InvokeMember("ExportPolicy", BindingFlags.PutDispProperty, null, sharedPrivateKey, typeValue); - typeValue = new object[] { signingCertificate == null ? 2 : 1 }; - typeX509PrivateKey.InvokeMember("KeySpec", BindingFlags.PutDispProperty, null, sharedPrivateKey, + typeValue = [signingCertificate == null ? 2 : 1]; + _typeX509PrivateKey.InvokeMember("KeySpec", BindingFlags.PutDispProperty, null, sharedPrivateKey, typeValue); if (signingCertificate != null) { - typeValue = new object[] { 176 }; - typeX509PrivateKey.InvokeMember("KeyUsage", BindingFlags.PutDispProperty, null, sharedPrivateKey, + typeValue = [176]; + _typeX509PrivateKey.InvokeMember("KeyUsage", BindingFlags.PutDispProperty, null, sharedPrivateKey, typeValue); } typeValue[0] = privateKeyLength; - typeX509PrivateKey.InvokeMember("Length", BindingFlags.PutDispProperty, null, sharedPrivateKey, + _typeX509PrivateKey.InvokeMember("Length", BindingFlags.PutDispProperty, null, sharedPrivateKey, typeValue); - typeX509PrivateKey.InvokeMember("Create", BindingFlags.InvokeMethod, null, sharedPrivateKey, null); + _typeX509PrivateKey.InvokeMember("Create", BindingFlags.InvokeMethod, null, sharedPrivateKey, null); - if (signingCertificate != null) this.sharedPrivateKey = sharedPrivateKey; + if (signingCertificate != null) this._sharedPrivateKey = sharedPrivateKey; } typeValue = new object[1]; - var oid = Activator.CreateInstance(typeOid); + var oid = Activator.CreateInstance(_typeOid); typeValue[0] = "1.3.6.1.5.5.7.3.1"; - typeOid.InvokeMember("InitializeFromValue", BindingFlags.InvokeMethod, null, oid, typeValue); + _typeOid.InvokeMember("InitializeFromValue", BindingFlags.InvokeMethod, null, oid, typeValue); - var oids = Activator.CreateInstance(typeOids); + var oids = Activator.CreateInstance(_typeOids); typeValue[0] = oid; - typeOids.InvokeMember("Add", BindingFlags.InvokeMethod, null, oids, typeValue); + _typeOids.InvokeMember("Add", BindingFlags.InvokeMethod, null, oids, typeValue); - var ekuExt = Activator.CreateInstance(typeEkuExt); + var ekuExt = Activator.CreateInstance(_typeEkuExt); typeValue[0] = oids; - typeEkuExt.InvokeMember("InitializeEncode", BindingFlags.InvokeMethod, null, ekuExt, typeValue); + _typeEkuExt.InvokeMember("InitializeEncode", BindingFlags.InvokeMethod, null, ekuExt, typeValue); - var requestCert = Activator.CreateInstance(typeRequestCert); + var requestCert = Activator.CreateInstance(_typeRequestCert); typeValue = new[] { 1, sharedPrivateKey, string.Empty }; - typeRequestCert.InvokeMember("InitializeFromPrivateKey", BindingFlags.InvokeMethod, null, requestCert, + _typeRequestCert.InvokeMember("InitializeFromPrivateKey", BindingFlags.InvokeMethod, null, requestCert, typeValue); typeValue = new[] { x500CertDn }; - typeRequestCert.InvokeMember("Subject", BindingFlags.PutDispProperty, null, requestCert, typeValue); + _typeRequestCert.InvokeMember("Subject", BindingFlags.PutDispProperty, null, requestCert, typeValue); typeValue[0] = x500RootCertDn; - typeRequestCert.InvokeMember("Issuer", BindingFlags.PutDispProperty, null, requestCert, typeValue); + _typeRequestCert.InvokeMember("Issuer", BindingFlags.PutDispProperty, null, requestCert, typeValue); typeValue[0] = validFrom; - typeRequestCert.InvokeMember("NotBefore", BindingFlags.PutDispProperty, null, requestCert, typeValue); + _typeRequestCert.InvokeMember("NotBefore", BindingFlags.PutDispProperty, null, requestCert, typeValue); typeValue[0] = validTo; - typeRequestCert.InvokeMember("NotAfter", BindingFlags.PutDispProperty, null, requestCert, typeValue); + _typeRequestCert.InvokeMember("NotAfter", BindingFlags.PutDispProperty, null, requestCert, typeValue); - var kuExt = Activator.CreateInstance(typeKuExt); + var kuExt = Activator.CreateInstance(_typeKuExt); typeValue[0] = 176; - typeKuExt.InvokeMember("InitializeEncode", BindingFlags.InvokeMethod, null, kuExt, typeValue); + _typeKuExt.InvokeMember("InitializeEncode", BindingFlags.InvokeMethod, null, kuExt, typeValue); var certificate = - typeRequestCert.InvokeMember("X509Extensions", BindingFlags.GetProperty, null, requestCert, null); + _typeRequestCert.InvokeMember("X509Extensions", BindingFlags.GetProperty, null, requestCert, null); typeValue = new object[1]; if (signingCertificate != null) { typeValue[0] = kuExt; - typeX509Extensions.InvokeMember("Add", BindingFlags.InvokeMethod, null, certificate, typeValue); + _typeX509Extensions.InvokeMember("Add", BindingFlags.InvokeMethod, null, certificate, typeValue); } typeValue[0] = ekuExt; - typeX509Extensions.InvokeMember("Add", BindingFlags.InvokeMethod, null, certificate, typeValue); + _typeX509Extensions.InvokeMember("Add", BindingFlags.InvokeMethod, null, certificate, typeValue); if (signingCertificate != null) { // add alternative names // https://forums.iis.net/t/1180823.aspx - var altNameCollection = Activator.CreateInstance(typeAltNamesCollection); - var extNames = Activator.CreateInstance(typeExtNames); - var altDnsNames = Activator.CreateInstance(typeCAlternativeName); + var altNameCollection = Activator.CreateInstance(_typeAltNamesCollection); + var extNames = Activator.CreateInstance(_typeExtNames); + var altDnsNames = Activator.CreateInstance(_typeCAlternativeName); - IPAddress ip; - if (IPAddress.TryParse(subject, out ip)) + if (IPAddress.TryParse(subject, out IPAddress? ip)) { var ipBase64 = Convert.ToBase64String(ip.GetAddressBytes()); - typeValue = new object[] - { AlternativeNameType.XcnCertAltNameIpAddress, EncodingType.XcnCryptStringBase64, ipBase64 }; - typeCAlternativeName.InvokeMember("InitializeFromRawData", BindingFlags.InvokeMethod, null, altDnsNames, + typeValue = [AlternativeNameType.XcnCertAltNameIpAddress, EncodingType.XcnCryptStringBase64, ipBase64]; + _typeCAlternativeName.InvokeMember("InitializeFromRawData", BindingFlags.InvokeMethod, null, altDnsNames, typeValue); } else { - typeValue = new object[] { 3, subject }; //3==DNS, 8==IP ADDR - typeCAlternativeName.InvokeMember("InitializeFromString", BindingFlags.InvokeMethod, null, altDnsNames, + typeValue = [3, subject]; //3==DNS, 8==IP ADDR + _typeCAlternativeName.InvokeMember("InitializeFromString", BindingFlags.InvokeMethod, null, altDnsNames, typeValue); } - typeValue = new[] { altDnsNames }; - typeAltNamesCollection.InvokeMember("Add", BindingFlags.InvokeMethod, null, altNameCollection, + typeValue = [altDnsNames]; + _typeAltNamesCollection.InvokeMember("Add", BindingFlags.InvokeMethod, null, altNameCollection, typeValue); - typeValue = new[] { altNameCollection }; - typeExtNames.InvokeMember("InitializeEncode", BindingFlags.InvokeMethod, null, extNames, typeValue); + typeValue = [altNameCollection]; + _typeExtNames.InvokeMember("InitializeEncode", BindingFlags.InvokeMethod, null, extNames, typeValue); typeValue[0] = extNames; - typeX509Extensions.InvokeMember("Add", BindingFlags.InvokeMethod, null, certificate, typeValue); + _typeX509Extensions.InvokeMember("Add", BindingFlags.InvokeMethod, null, certificate, typeValue); } if (signingCertificate != null) { - var signerCertificate = Activator.CreateInstance(typeSignerCertificate); + var signerCertificate = Activator.CreateInstance(_typeSignerCertificate); - typeValue = new object[] { 0, 0, 12, signingCertificate.Thumbprint }; - typeSignerCertificate.InvokeMember("Initialize", BindingFlags.InvokeMethod, null, signerCertificate, + typeValue = [0, 0, 12, signingCertificate.Thumbprint]; + _typeSignerCertificate.InvokeMember("Initialize", BindingFlags.InvokeMethod, null, signerCertificate, typeValue); - typeValue = new[] { signerCertificate }; - typeRequestCert.InvokeMember("SignerCertificate", BindingFlags.PutDispProperty, null, requestCert, + typeValue = [signerCertificate]; + _typeRequestCert.InvokeMember("SignerCertificate", BindingFlags.PutDispProperty, null, requestCert, typeValue); } else { - var basicConstraints = Activator.CreateInstance(typeBasicConstraints); + var basicConstraints = Activator.CreateInstance(_typeBasicConstraints); - typeValue = new object[] { "true", "0" }; - typeBasicConstraints.InvokeMember("InitializeEncode", BindingFlags.InvokeMethod, null, basicConstraints, + typeValue = ["true", "0"]; + _typeBasicConstraints.InvokeMember("InitializeEncode", BindingFlags.InvokeMethod, null, basicConstraints, typeValue); - typeValue = new[] { basicConstraints }; - typeX509Extensions.InvokeMember("Add", BindingFlags.InvokeMethod, null, certificate, typeValue); + typeValue = [basicConstraints]; + _typeX509Extensions.InvokeMember("Add", BindingFlags.InvokeMethod, null, certificate, typeValue); } - oid = Activator.CreateInstance(typeOid); + oid = Activator.CreateInstance(_typeOid); - typeValue = new object[] { 1, 0, 0, hashAlg }; - typeOid.InvokeMember("InitializeFromAlgorithmName", BindingFlags.InvokeMethod, null, oid, typeValue); + typeValue = [1, 0, 0, hashAlg]; + _typeOid.InvokeMember("InitializeFromAlgorithmName", BindingFlags.InvokeMethod, null, oid, typeValue); - typeValue = new[] { oid }; - typeRequestCert.InvokeMember("HashAlgorithm", BindingFlags.PutDispProperty, null, requestCert, typeValue); - typeRequestCert.InvokeMember("Encode", BindingFlags.InvokeMethod, null, requestCert, null); + typeValue = [oid]; + _typeRequestCert.InvokeMember("HashAlgorithm", BindingFlags.PutDispProperty, null, requestCert, typeValue); + _typeRequestCert.InvokeMember("Encode", BindingFlags.InvokeMethod, null, requestCert, null); - var x509Enrollment = Activator.CreateInstance(typeX509Enrollment); + var x509Enrollment = Activator.CreateInstance(_typeX509Enrollment); typeValue[0] = requestCert; - typeX509Enrollment.InvokeMember("InitializeFromRequest", BindingFlags.InvokeMethod, null, x509Enrollment, + _typeX509Enrollment.InvokeMember("InitializeFromRequest", BindingFlags.InvokeMethod, null, x509Enrollment, typeValue); - if (signingCertificate == null) + if (signingCertificate is null) { typeValue[0] = fullSubject; - typeX509Enrollment.InvokeMember("CertificateFriendlyName", BindingFlags.PutDispProperty, null, + _typeX509Enrollment.InvokeMember("CertificateFriendlyName", BindingFlags.PutDispProperty, null, x509Enrollment, typeValue); } typeValue[0] = 0; - var createCertRequest = typeX509Enrollment.InvokeMember("CreateRequest", BindingFlags.InvokeMethod, null, + var createCertRequest = _typeX509Enrollment.InvokeMember("CreateRequest", BindingFlags.InvokeMethod, null, x509Enrollment, typeValue); - typeValue = new[] { 2, createCertRequest, 0, string.Empty }; + typeValue = [2, createCertRequest, 0, string.Empty]; - typeX509Enrollment.InvokeMember("InstallResponse", BindingFlags.InvokeMethod, null, x509Enrollment, + _typeX509Enrollment.InvokeMember("InstallResponse", BindingFlags.InvokeMethod, null, x509Enrollment, typeValue); - typeValue = new object[] { null!, 0, 1 }; + typeValue = [null, 0, 1]; - var empty = (string)typeX509Enrollment.InvokeMember("CreatePFX", BindingFlags.InvokeMethod, null, - x509Enrollment, typeValue); + var empty = _typeX509Enrollment.InvokeMember("CreatePFX", BindingFlags.InvokeMethod, null, + x509Enrollment, typeValue) as string ?? throw new InvalidOperationException("Could not create PFX"); return X509CertificateLoader.LoadPkcs12(Convert.FromBase64String(empty), string.Empty, X509KeyStorageFlags.Exportable); } From 3ae842b0b598471fa3aef6239b9d52b1743f8e1b Mon Sep 17 00:00:00 2001 From: Kevin Bost Date: Tue, 3 Feb 2026 08:45:50 -0800 Subject: [PATCH 5/5] Enhances code quality with nullable reference types and updates dependencies Introduces nullable reference type annotations for improved compile-time safety, especially within integration test helpers. Updates the TUnit test framework to a newer version, bringing potential improvements to the testing infrastructure. Modernizes C# syntax in unit tests with file-scoped namespaces. Adds platform-specific attributes for Windows-dependent system proxy helper methods and refactors related code for clarity and robustness. Updates the Google.Protobuf dependency. --- Directory.Packages.props | 8 +- .../Helpers/HttpContinueClient.cs | 4 +- .../Helpers/HttpContinueServer.cs | 19 +- .../Helpers/HttpMessageParsing.cs | 6 +- .../Helpers/TestHelper.cs | 4 +- .../NestedProxyTests.cs | 2 +- .../ReverseProxyTests.cs | 2 +- .../Setup/TestProxyServer.cs | 4 +- .../Setup/TestServer.cs | 26 +- .../Setup/TestSuite.cs | 7 +- Keboo.Web.Proxy.UnitTests/ProxyServerTests.cs | 125 +++++----- Keboo.Web.Proxy.UnitTests/SystemProxyTest.cs | 225 +++++++++--------- Keboo.Web.Proxy.UnitTests/WinAuthTests.cs | 15 +- .../Helpers/NativeMethods.SystemProxy.cs | 2 + Keboo.Web.Proxy/Helpers/SystemProxy.cs | 206 ++++++++-------- 15 files changed, 325 insertions(+), 330 deletions(-) diff --git a/Directory.Packages.props b/Directory.Packages.props index cf3517e..4b68250 100644 --- a/Directory.Packages.props +++ b/Directory.Packages.props @@ -21,7 +21,7 @@ - + @@ -45,9 +45,9 @@ - - - + + + runtime; build; native; contentfiles; analyzers; buildtransitive diff --git a/Keboo.Web.Proxy.IntegrationTests/Helpers/HttpContinueClient.cs b/Keboo.Web.Proxy.IntegrationTests/Helpers/HttpContinueClient.cs index 718604f..01f4212 100644 --- a/Keboo.Web.Proxy.IntegrationTests/Helpers/HttpContinueClient.cs +++ b/Keboo.Web.Proxy.IntegrationTests/Helpers/HttpContinueClient.cs @@ -13,7 +13,7 @@ internal class HttpContinueClient private static readonly Encoding _msgEncoding = HttpHelper.GetEncodingFromContentType(null); - public static async Task Post(string server, int port, string content) + public static async Task Post(string server, int port, string content) { var message = _msgEncoding.GetBytes(content); var client = new TcpClient(server, port); @@ -29,7 +29,7 @@ public static async Task Post(string server, int port, string content) var buffer = new byte[1024]; var responseMsg = string.Empty; - Response response; + Response? response; while ((response = HttpMessageParsing.ParseResponse(responseMsg)) == null) { diff --git a/Keboo.Web.Proxy.IntegrationTests/Helpers/HttpContinueServer.cs b/Keboo.Web.Proxy.IntegrationTests/Helpers/HttpContinueServer.cs index 1772c0e..bc39ceb 100644 --- a/Keboo.Web.Proxy.IntegrationTests/Helpers/HttpContinueServer.cs +++ b/Keboo.Web.Proxy.IntegrationTests/Helpers/HttpContinueServer.cs @@ -12,12 +12,12 @@ namespace Keboo.Web.Proxy.IntegrationTests.Helpers; internal class HttpContinueServer { private static readonly Encoding _msgEncoding = HttpHelper.GetEncodingFromContentType(null); - public HttpStatusCode ExpectationResponse; - public string ResponseBody; + public HttpStatusCode ExpectationResponse { get; set; } + public string? ResponseBody { get; set; } public async Task HandleRequest(ConnectionContext context) { - var request = await ReadHeaders(context.Transport.Input); + var request = await ReadHeaders(context.Transport.Input) ?? throw new Exception("Failed to read headers"); if (request.ExpectContinue) { @@ -37,7 +37,7 @@ public async Task HandleRequest(ConnectionContext context) request = await ReadBody(request, context.Transport.Input); - var responseMsg = _msgEncoding.GetBytes(ResponseBody); + var responseMsg = _msgEncoding.GetBytes(ResponseBody ?? ""); var respondOk = new Response(responseMsg) { HttpVersion = new Version(1, 1), @@ -49,9 +49,9 @@ public async Task HandleRequest(ConnectionContext context) context.Transport.Output.Complete(); } - private static async Task ReadHeaders(PipeReader input) + private static async Task ReadHeaders(PipeReader input) { - Request request = null; + Request? request = null; try { var requestMsg = string.Empty; @@ -74,12 +74,13 @@ private static async Task ReadHeaders(PipeReader input) return request; } - private static async Task ReadBody(Request request, PipeReader input) + private static async Task ReadBody(Request request, PipeReader input) { var msg = request.HeaderText; + Request? parsedRequest = request; try { - while ((request = HttpMessageParsing.ParseRequest(msg, true)) == null) + while ((parsedRequest = HttpMessageParsing.ParseRequest(msg, true)) is null) { var result = await input.ReadAsync(); foreach (var seg in result.Buffer) @@ -95,6 +96,6 @@ private static async Task ReadBody(Request request, PipeReader input) Console.WriteLine($"{ex.GetType()}: {ex.Message}"); } - return request; + return parsedRequest; } } diff --git a/Keboo.Web.Proxy.IntegrationTests/Helpers/HttpMessageParsing.cs b/Keboo.Web.Proxy.IntegrationTests/Helpers/HttpMessageParsing.cs index 3f3ec98..f5ff24c 100644 --- a/Keboo.Web.Proxy.IntegrationTests/Helpers/HttpMessageParsing.cs +++ b/Keboo.Web.Proxy.IntegrationTests/Helpers/HttpMessageParsing.cs @@ -1,4 +1,4 @@ -using System.IO; +using System.IO; using System.Text; using Keboo.Web.Proxy.Http; @@ -15,7 +15,7 @@ internal static class HttpMessageParsing /// The request message /// /// Request object if message complete, null otherwise - internal static Request ParseRequest(string messageText, bool requireBody) + internal static Request? ParseRequest(string messageText, bool requireBody) { var reader = new StringReader(messageText); var line = reader.ReadLine(); @@ -65,7 +65,7 @@ internal static Request ParseRequest(string messageText, bool requireBody) /// /// The response message /// Response object if message complete, null otherwise - internal static Response ParseResponse(string messageText) + internal static Response? ParseResponse(string messageText) { var reader = new StringReader(messageText); var line = reader.ReadLine(); diff --git a/Keboo.Web.Proxy.IntegrationTests/Helpers/TestHelper.cs b/Keboo.Web.Proxy.IntegrationTests/Helpers/TestHelper.cs index 30d8c8c..b8a2caf 100644 --- a/Keboo.Web.Proxy.IntegrationTests/Helpers/TestHelper.cs +++ b/Keboo.Web.Proxy.IntegrationTests/Helpers/TestHelper.cs @@ -1,4 +1,4 @@ -using System; +using System; using System.Net; using System.Net.Http; @@ -51,7 +51,7 @@ private TestProxy(Uri proxyUri) } public Uri ProxyUri { get; set; } - public ICredentials Credentials { get; set; } + public ICredentials? Credentials { get; set; } public Uri GetProxy(Uri destination) { diff --git a/Keboo.Web.Proxy.IntegrationTests/NestedProxyTests.cs b/Keboo.Web.Proxy.IntegrationTests/NestedProxyTests.cs index f418776..53b837b 100644 --- a/Keboo.Web.Proxy.IntegrationTests/NestedProxyTests.cs +++ b/Keboo.Web.Proxy.IntegrationTests/NestedProxyTests.cs @@ -53,7 +53,7 @@ public async Task Smoke_Test_Nested_Proxy_UserData() var proxy1 = TestSuite.GetProxy(); proxy1.ProxyBasicAuthenticateFunc = async (session, username, password) => { - session.UserData = "Test"; + session!.UserData = "Test"; return await Task.FromResult(true); }; diff --git a/Keboo.Web.Proxy.IntegrationTests/ReverseProxyTests.cs b/Keboo.Web.Proxy.IntegrationTests/ReverseProxyTests.cs index a7e93b4..9f40661 100644 --- a/Keboo.Web.Proxy.IntegrationTests/ReverseProxyTests.cs +++ b/Keboo.Web.Proxy.IntegrationTests/ReverseProxyTests.cs @@ -143,7 +143,7 @@ public async Task Smoke_Test_Https_To_Https_Reverse_Proxy_Tunnel_Without_Decrypt var proxy = TestSuite.GetReverseProxy(); var endpoint = - proxy.ProxyEndPoints.Where(x => x is TransparentProxyEndPoint).First() as TransparentProxyEndPoint; + proxy.ProxyEndPoints.OfType().First(); endpoint.BeforeSslAuthenticate += async (sender, e) => { diff --git a/Keboo.Web.Proxy.IntegrationTests/Setup/TestProxyServer.cs b/Keboo.Web.Proxy.IntegrationTests/Setup/TestProxyServer.cs index 33ff430..d640b98 100644 --- a/Keboo.Web.Proxy.IntegrationTests/Setup/TestProxyServer.cs +++ b/Keboo.Web.Proxy.IntegrationTests/Setup/TestProxyServer.cs @@ -1,4 +1,4 @@ -using System; +using System; using System.Net; using Keboo.Web.Proxy.Models; using Keboo.Web.Proxy.Network; @@ -7,7 +7,7 @@ namespace Keboo.Web.Proxy.IntegrationTests.Setup; public class TestProxyServer : IDisposable { - public TestProxyServer(bool isReverseProxy, ProxyServer upStreamProxy = null) + public TestProxyServer(bool isReverseProxy, ProxyServer? upStreamProxy = null) { ProxyServer = new ProxyServer(); diff --git a/Keboo.Web.Proxy.IntegrationTests/Setup/TestServer.cs b/Keboo.Web.Proxy.IntegrationTests/Setup/TestServer.cs index cfa312a..0f0a847 100644 --- a/Keboo.Web.Proxy.IntegrationTests/Setup/TestServer.cs +++ b/Keboo.Web.Proxy.IntegrationTests/Setup/TestServer.cs @@ -21,8 +21,8 @@ public class TestServer : IDisposable { private readonly IHost host; - private Func requestHandler; - private Func tcpRequestHandler; + private Func? _requestHandler; + private Func? _tcpRequestHandler; public TestServer(X509Certificate2 serverCertificate, bool requireMutualTls) { @@ -35,7 +35,7 @@ public TestServer(X509Certificate2 serverCertificate, bool requireMutualTls) }) .ConfigureWebHostDefaults(webBuilder => { - webBuilder.UseStartup(x => new Startup(() => requestHandler)); + webBuilder.UseStartup(x => new Startup(() => _requestHandler)); webBuilder.ConfigureKestrel(options => { options.Listen(IPAddress.Loopback, 0); @@ -59,12 +59,12 @@ public TestServer(X509Certificate2 serverCertificate, bool requireMutualTls) { listenOptions.Run(context => { - if (tcpRequestHandler == null) + if (_tcpRequestHandler == null) { throw new Exception("Test server not configured to handle tcp request."); } - return tcpRequestHandler(context); + return _tcpRequestHandler(context); }); }); }); @@ -75,7 +75,7 @@ public TestServer(X509Certificate2 serverCertificate, bool requireMutualTls) var addresses = host.Services.GetRequiredService() .Features.Get() - .Addresses.ToArray(); + ?.Addresses.ToArray() ?? []; HttpListeningPort = new Uri(addresses[0]).Port; HttpsListeningPort = new Uri(addresses[1]).Port; @@ -98,33 +98,33 @@ public void Dispose() public void HandleRequest(Func requestHandler) { - this.requestHandler = requestHandler; + this._requestHandler = requestHandler; } public void HandleTcpRequest(Func tcpRequestHandler) { - this.tcpRequestHandler = tcpRequestHandler; + this._tcpRequestHandler = tcpRequestHandler; } private class Startup { - private readonly Func> requestHandler; + private readonly Func?>? _requestHandler; - public Startup(Func> requestHandler) + public Startup(Func?>? requestHandler) { - this.requestHandler = requestHandler; + _requestHandler = requestHandler; } public void Configure(IApplicationBuilder app) { app.Run(context => { - if (requestHandler == null) + if (_requestHandler is null) { throw new Exception("Test server not configured to handle request."); } - return requestHandler()(context); + return _requestHandler()?.Invoke(context) ?? Task.CompletedTask; }); } diff --git a/Keboo.Web.Proxy.IntegrationTests/Setup/TestSuite.cs b/Keboo.Web.Proxy.IntegrationTests/Setup/TestSuite.cs index cfcc114..38dbbd2 100644 --- a/Keboo.Web.Proxy.IntegrationTests/Setup/TestSuite.cs +++ b/Keboo.Web.Proxy.IntegrationTests/Setup/TestSuite.cs @@ -11,7 +11,8 @@ public class TestSuite public TestSuite(bool requireMutualTls = false) { var dummyProxy = new ProxyServer(); - var serverCertificate = dummyProxy.CertificateManager.CreateServerCertificate("localhost").Result; + var serverCertificate = dummyProxy.CertificateManager.CreateServerCertificate("localhost").Result + ?? throw new InvalidOperationException("Failed to create certificate"); server = new TestServer(serverCertificate, requireMutualTls); } @@ -20,7 +21,7 @@ public TestServer GetServer() return server; } - public static ProxyServer GetProxy(ProxyServer upStreamProxy = null) + public static ProxyServer GetProxy(ProxyServer? upStreamProxy = null) { if (upStreamProxy != null) { @@ -30,7 +31,7 @@ public static ProxyServer GetProxy(ProxyServer upStreamProxy = null) return new TestProxyServer(false).ProxyServer; } - public static ProxyServer GetReverseProxy(ProxyServer upStreamProxy = null) + public static ProxyServer GetReverseProxy(ProxyServer? upStreamProxy = null) { if (upStreamProxy != null) { diff --git a/Keboo.Web.Proxy.UnitTests/ProxyServerTests.cs b/Keboo.Web.Proxy.UnitTests/ProxyServerTests.cs index 5fa96bf..0275cb7 100644 --- a/Keboo.Web.Proxy.UnitTests/ProxyServerTests.cs +++ b/Keboo.Web.Proxy.UnitTests/ProxyServerTests.cs @@ -5,83 +5,82 @@ using TUnit.Assertions.Extensions; using Keboo.Web.Proxy.Models; -namespace Keboo.Web.Proxy.UnitTests +namespace Keboo.Web.Proxy.UnitTests; + +public class ProxyServerTests { - public class ProxyServerTests + [Test] + public async Task + GivenOneEndpointIsAlreadyAddedToAddress_WhenAddingNewEndpointToExistingAddress_ThenExceptionIsThrown() { - [Test] - public async Task - GivenOneEndpointIsAlreadyAddedToAddress_WhenAddingNewEndpointToExistingAddress_ThenExceptionIsThrown() + // Arrange + var proxy = new ProxyServer(); + const int port = 9999; + var firstIpAddress = IPAddress.Parse("127.0.0.1"); + var secondIpAddress = IPAddress.Parse("127.0.0.1"); + proxy.AddEndPoint(new ExplicitProxyEndPoint(firstIpAddress, port, false)); + + // Act + Exception? exception = await Assert.ThrowsAsync(async () => { - // Arrange - var proxy = new ProxyServer(); - const int port = 9999; - var firstIpAddress = IPAddress.Parse("127.0.0.1"); - var secondIpAddress = IPAddress.Parse("127.0.0.1"); - proxy.AddEndPoint(new ExplicitProxyEndPoint(firstIpAddress, port, false)); + proxy.AddEndPoint(new ExplicitProxyEndPoint(secondIpAddress, port, false)); + await Task.CompletedTask; + }); - // Act - var exception = await Assert.ThrowsAsync(async () => - { - proxy.AddEndPoint(new ExplicitProxyEndPoint(secondIpAddress, port, false)); - await Task.CompletedTask; - }); + // Assert + await Assert.That(exception!.Message).Contains("Cannot add another endpoint to same port"); + } - // Assert - await Assert.That(exception.Message).Contains("Cannot add another endpoint to same port"); - } + [Test] + public async Task + GivenOneEndpointIsAlreadyAddedToAddress_WhenAddingNewEndpointToExistingAddress_ThenTwoEndpointsExists() + { + // Arrange + var proxy = new ProxyServer(); + const int port = 9999; + var firstIpAddress = IPAddress.Parse("127.0.0.1"); + var secondIpAddress = IPAddress.Parse("192.168.1.1"); + proxy.AddEndPoint(new ExplicitProxyEndPoint(firstIpAddress, port, false)); - [Test] - public async Task - GivenOneEndpointIsAlreadyAddedToAddress_WhenAddingNewEndpointToExistingAddress_ThenTwoEndpointsExists() - { - // Arrange - var proxy = new ProxyServer(); - const int port = 9999; - var firstIpAddress = IPAddress.Parse("127.0.0.1"); - var secondIpAddress = IPAddress.Parse("192.168.1.1"); - proxy.AddEndPoint(new ExplicitProxyEndPoint(firstIpAddress, port, false)); + // Act + proxy.AddEndPoint(new ExplicitProxyEndPoint(secondIpAddress, port, false)); - // Act - proxy.AddEndPoint(new ExplicitProxyEndPoint(secondIpAddress, port, false)); + // Assert + await Assert.That(proxy.ProxyEndPoints.Count).IsEqualTo(2); + } - // Assert - await Assert.That(proxy.ProxyEndPoints.Count).IsEqualTo(2); - } + [Test] + public async Task GivenOneEndpointIsAlreadyAddedToPort_WhenAddingNewEndpointToExistingPort_ThenExceptionIsThrown() + { + // Arrange + var proxy = new ProxyServer(); + const int port = 9999; + proxy.AddEndPoint(new ExplicitProxyEndPoint(IPAddress.Loopback, port, false)); - [Test] - public async Task GivenOneEndpointIsAlreadyAddedToPort_WhenAddingNewEndpointToExistingPort_ThenExceptionIsThrown() + // Act + Exception? exception = await Assert.ThrowsAsync(async () => { - // Arrange - var proxy = new ProxyServer(); - const int port = 9999; proxy.AddEndPoint(new ExplicitProxyEndPoint(IPAddress.Loopback, port, false)); + await Task.CompletedTask; + }); - // Act - var exception = await Assert.ThrowsAsync(async () => - { - proxy.AddEndPoint(new ExplicitProxyEndPoint(IPAddress.Loopback, port, false)); - await Task.CompletedTask; - }); - - // Assert - await Assert.That(exception.Message).Contains("Cannot add another endpoint to same port"); - } + // Assert + await Assert.That(exception!.Message).Contains("Cannot add another endpoint to same port"); + } - [Test] - public async Task - GivenOneEndpointIsAlreadyAddedToZeroPort_WhenAddingNewEndpointToExistingPort_ThenTwoEndpointsExists() - { - // Arrange - var proxy = new ProxyServer(); - const int port = 0; - proxy.AddEndPoint(new ExplicitProxyEndPoint(IPAddress.Loopback, port, false)); + [Test] + public async Task + GivenOneEndpointIsAlreadyAddedToZeroPort_WhenAddingNewEndpointToExistingPort_ThenTwoEndpointsExists() + { + // Arrange + var proxy = new ProxyServer(); + const int port = 0; + proxy.AddEndPoint(new ExplicitProxyEndPoint(IPAddress.Loopback, port, false)); - // Act - proxy.AddEndPoint(new ExplicitProxyEndPoint(IPAddress.Loopback, port, false)); + // Act + proxy.AddEndPoint(new ExplicitProxyEndPoint(IPAddress.Loopback, port, false)); - // Assert - await Assert.That(proxy.ProxyEndPoints.Count).IsEqualTo(2); - } + // Assert + await Assert.That(proxy.ProxyEndPoints.Count).IsEqualTo(2); } } \ No newline at end of file diff --git a/Keboo.Web.Proxy.UnitTests/SystemProxyTest.cs b/Keboo.Web.Proxy.UnitTests/SystemProxyTest.cs index b340519..eaa7a77 100644 --- a/Keboo.Web.Proxy.UnitTests/SystemProxyTest.cs +++ b/Keboo.Web.Proxy.UnitTests/SystemProxyTest.cs @@ -7,139 +7,144 @@ using Keboo.Web.Proxy.Helpers.WinHttp; using Keboo.Web.Proxy.Models; -namespace Keboo.Web.Proxy.UnitTests +namespace Keboo.Web.Proxy.UnitTests; + +public class SystemProxyTest { - public class SystemProxyTest + [Test] + [NotInParallel("SystemProxy")] + public async Task CompareProxyAddressReturnedByWebProxyAndWinHttpProxyResolver() { - [Test] - public async Task CompareProxyAddressReturnedByWebProxyAndWinHttpProxyResolver() + if (!OperatingSystem.IsWindows()) { - var proxyManager = new SystemProxyManager(); + await Assert.Skip("This test requires Windows"); + } - try - { - await CompareUrls(); + var proxyManager = new SystemProxyManager(); + + try + { + await CompareUrls(); - proxyManager.SetProxy("127.0.0.1", 8000, ProxyProtocolType.Http); - await CompareUrls(); + proxyManager.SetProxy("127.0.0.1", 8000, ProxyProtocolType.Http); + await CompareUrls(); - proxyManager.SetProxy("127.0.0.1", 8000, ProxyProtocolType.Https); - await CompareUrls(); + proxyManager.SetProxy("127.0.0.1", 8000, ProxyProtocolType.Https); + await CompareUrls(); - proxyManager.SetProxy("127.0.0.1", 8000, ProxyProtocolType.AllHttp); - await CompareUrls(); + proxyManager.SetProxy("127.0.0.1", 8000, ProxyProtocolType.AllHttp); + await CompareUrls(); - // for this test you need to add a proxy.pac file to a local webserver - //function FindProxyForURL(url, host) - //{ - // if (shExpMatch(host, "google.com")) - // { - // return "PROXY 127.0.0.1:8888"; - // } + // for this test you need to add a proxy.pac file to a local webserver + //function FindProxyForURL(url, host) + //{ + // if (shExpMatch(host, "google.com")) + // { + // return "PROXY 127.0.0.1:8888"; + // } - // return "DIRECT"; - //} + // return "DIRECT"; + //} - //proxyManager.SetAutoProxyUrl("http://localhost/proxy.pac"); - //CompareUrls(); + //proxyManager.SetAutoProxyUrl("http://localhost/proxy.pac"); + //CompareUrls(); - proxyManager.SetProxyOverride("<-loopback>"); - await CompareUrls(); + proxyManager.SetProxyOverride("<-loopback>"); + await CompareUrls(); - proxyManager.SetProxyOverride(""); - await CompareUrls(); + proxyManager.SetProxyOverride(""); + await CompareUrls(); - proxyManager.SetProxyOverride("yahoo.com"); - await CompareUrls(); + proxyManager.SetProxyOverride("yahoo.com"); + await CompareUrls(); - proxyManager.SetProxyOverride("*.local"); - await CompareUrls(); + proxyManager.SetProxyOverride("*.local"); + await CompareUrls(); - proxyManager.SetProxyOverride("http://*.local"); - await CompareUrls(); + proxyManager.SetProxyOverride("http://*.local"); + await CompareUrls(); - proxyManager.SetProxyOverride("<-loopback>;*.local"); - await CompareUrls(); + proxyManager.SetProxyOverride("<-loopback>;*.local"); + await CompareUrls(); - proxyManager.SetProxyOverride("<-loopback>;*.local;"); - await CompareUrls(); - } - finally - { - proxyManager.RestoreOriginalSettings(); - } + proxyManager.SetProxyOverride("<-loopback>;*.local;"); + await CompareUrls(); } + finally + { + proxyManager.RestoreOriginalSettings(); + } + } - private async Task CompareUrls() + private async Task CompareUrls() + { + var webProxy = WebRequest.GetSystemWebProxy(); + + var resolver = new WinHttpWebProxyFinder(); + resolver.LoadFromIe(); + + await CompareProxy(webProxy, resolver, "http://127.0.0.1"); + await CompareProxy(webProxy, resolver, "https://127.0.0.1"); + await CompareProxy(webProxy, resolver, "http://localhost"); + await CompareProxy(webProxy, resolver, "https://localhost"); + + string? hostName = null; + try + { + hostName = Dns.GetHostName(); + } + catch { - var webProxy = WebRequest.GetSystemWebProxy(); - - var resolver = new WinHttpWebProxyFinder(); - resolver.LoadFromIe(); - - await CompareProxy(webProxy, resolver, "http://127.0.0.1"); - await CompareProxy(webProxy, resolver, "https://127.0.0.1"); - await CompareProxy(webProxy, resolver, "http://localhost"); - await CompareProxy(webProxy, resolver, "https://localhost"); - - string? hostName = null; - try - { - hostName = Dns.GetHostName(); - } - catch - { - } - - if (hostName != null) - { - await CompareProxy(webProxy, resolver, "http://" + hostName); - await CompareProxy(webProxy, resolver, "https://" + hostName); - } - - await CompareProxy(webProxy, resolver, "http://google.com"); - await CompareProxy(webProxy, resolver, "https://google.com"); - await CompareProxy(webProxy, resolver, "http://bing.com"); - await CompareProxy(webProxy, resolver, "https://bing.com"); - await CompareProxy(webProxy, resolver, "http://yahoo.com"); - await CompareProxy(webProxy, resolver, "https://yahoo.com"); - await CompareProxy(webProxy, resolver, "http://test.local"); - await CompareProxy(webProxy, resolver, "https://test.local"); } - private static async Task CompareProxy(IWebProxy webProxy, WinHttpWebProxyFinder resolver, string url) + if (hostName != null) { - var uri = new Uri(url); - - var expectedProxyUri = webProxy.GetProxy(uri); - - var proxy = resolver.GetProxy(uri); - - // Handle cases where both agree there's no proxy - if ((expectedProxyUri == null || expectedProxyUri == uri) && proxy == null) - { - // Both agree: no proxy - return; - } - - // Handle cases where one finds a proxy and the other doesn't - if (expectedProxyUri == null || expectedProxyUri == uri) - { - // WebProxy says no proxy, but WinHttpWebProxyFinder found one - // This can happen due to different proxy detection methods - return; - } - - if (proxy == null) - { - // WinHttpWebProxyFinder couldn't determine proxy, but WebProxy found one - // This can happen when the proxy is not configured via IE settings - // Skip this comparison as it's an expected difference - return; - } - - // Both found a proxy, verify they match - await Assert.That(expectedProxyUri.ToString()).IsEqualTo($"http://{proxy.HostName}:{proxy.Port}/"); + await CompareProxy(webProxy, resolver, "http://" + hostName); + await CompareProxy(webProxy, resolver, "https://" + hostName); } + + await CompareProxy(webProxy, resolver, "http://google.com"); + await CompareProxy(webProxy, resolver, "https://google.com"); + await CompareProxy(webProxy, resolver, "http://bing.com"); + await CompareProxy(webProxy, resolver, "https://bing.com"); + await CompareProxy(webProxy, resolver, "http://yahoo.com"); + await CompareProxy(webProxy, resolver, "https://yahoo.com"); + await CompareProxy(webProxy, resolver, "http://test.local"); + await CompareProxy(webProxy, resolver, "https://test.local"); + } + + private static async Task CompareProxy(IWebProxy webProxy, WinHttpWebProxyFinder resolver, string url) + { + var uri = new Uri(url); + + var expectedProxyUri = webProxy.GetProxy(uri); + + var proxy = resolver.GetProxy(uri); + + // Handle cases where both agree there's no proxy + if ((expectedProxyUri == null || expectedProxyUri == uri) && proxy == null) + { + // Both agree: no proxy + return; + } + + // Handle cases where one finds a proxy and the other doesn't + if (expectedProxyUri == null || expectedProxyUri == uri) + { + // WebProxy says no proxy, but WinHttpWebProxyFinder found one + // This can happen due to different proxy detection methods + return; + } + + if (proxy == null) + { + // WinHttpWebProxyFinder couldn't determine proxy, but WebProxy found one + // This can happen when the proxy is not configured via IE settings + // Skip this comparison as it's an expected difference + return; + } + + // Both found a proxy, verify they match + await Assert.That(expectedProxyUri.ToString()).IsEqualTo($"http://{proxy.HostName}:{proxy.Port}/"); } } \ No newline at end of file diff --git a/Keboo.Web.Proxy.UnitTests/WinAuthTests.cs b/Keboo.Web.Proxy.UnitTests/WinAuthTests.cs index 98846cc..b8ced87 100644 --- a/Keboo.Web.Proxy.UnitTests/WinAuthTests.cs +++ b/Keboo.Web.Proxy.UnitTests/WinAuthTests.cs @@ -4,15 +4,14 @@ using Keboo.Web.Proxy.Http; using Keboo.Web.Proxy.Network.WinAuth; -namespace Keboo.Web.Proxy.UnitTests +namespace Keboo.Web.Proxy.UnitTests; + +public class WinAuthTests { - public class WinAuthTests + [Test] + public async Task Test_Acquire_Client_Token() { - [Test] - public async Task Test_Acquire_Client_Token() - { - var token = WinAuthHandler.GetInitialAuthToken("mylocalserver.com", "NTLM", new InternalDataStore()); - await Assert.That(token.Length).IsGreaterThan(1); - } + var token = WinAuthHandler.GetInitialAuthToken("mylocalserver.com", "NTLM", new InternalDataStore()); + await Assert.That(token.Length).IsGreaterThan(1); } } \ No newline at end of file diff --git a/Keboo.Web.Proxy/Helpers/NativeMethods.SystemProxy.cs b/Keboo.Web.Proxy/Helpers/NativeMethods.SystemProxy.cs index 6eb410d..bd79cc5 100644 --- a/Keboo.Web.Proxy/Helpers/NativeMethods.SystemProxy.cs +++ b/Keboo.Web.Proxy/Helpers/NativeMethods.SystemProxy.cs @@ -1,8 +1,10 @@ using System; using System.Runtime.InteropServices; +using System.Runtime.Versioning; namespace Keboo.Web.Proxy.Helpers; +[SupportedOSPlatform("windows")] internal partial class NativeMethods { // Keeps it from getting garbage collected diff --git a/Keboo.Web.Proxy/Helpers/SystemProxy.cs b/Keboo.Web.Proxy/Helpers/SystemProxy.cs index 353ef95..936bfff 100644 --- a/Keboo.Web.Proxy/Helpers/SystemProxy.cs +++ b/Keboo.Web.Proxy/Helpers/SystemProxy.cs @@ -3,6 +3,7 @@ using System.Linq; using Microsoft.Win32; using Keboo.Web.Proxy.Models; +using System.Runtime.Versioning; // Helper classes for setting system proxy settings namespace Keboo.Web.Proxy.Helpers; @@ -57,7 +58,7 @@ internal class SystemProxyManager internal const int InternetOptionSettingsChanged = 39; internal const int InternetOptionRefresh = 37; - private ProxyInfo? originalValues; + private ProxyInfo? _originalValues; public SystemProxyManager() { @@ -86,29 +87,27 @@ public SystemProxyManager() /// internal void SetProxy(string hostname, int port, ProxyProtocolType protocolType) { - using (var reg = OpenInternetSettingsKey()) - { - if (reg == null) return; + using RegistryKey? reg = OpenInternetSettingsKey(); + if (reg is null) return; - SaveOriginalProxyConfiguration(reg); - PrepareRegistry(reg); + SaveOriginalProxyConfiguration(reg); + PrepareRegistry(reg); - var existingContent = reg.GetValue(RegProxyServer) as string; - var existingSystemProxyValues = ProxyInfo.GetSystemProxyValues(existingContent); - existingSystemProxyValues.RemoveAll(x => (protocolType & x.ProtocolType) != 0); - if ((protocolType & ProxyProtocolType.Http) != 0) - existingSystemProxyValues.Add(new HttpSystemProxyValue(hostname, port, ProxyProtocolType.Http)); + var existingContent = reg.GetValue(RegProxyServer) as string; + var existingSystemProxyValues = ProxyInfo.GetSystemProxyValues(existingContent); + existingSystemProxyValues.RemoveAll(x => (protocolType & x.ProtocolType) != 0); + if ((protocolType & ProxyProtocolType.Http) != 0) + existingSystemProxyValues.Add(new HttpSystemProxyValue(hostname, port, ProxyProtocolType.Http)); - if ((protocolType & ProxyProtocolType.Https) != 0) - existingSystemProxyValues.Add(new HttpSystemProxyValue(hostname, port, ProxyProtocolType.Https)); + if ((protocolType & ProxyProtocolType.Https) != 0) + existingSystemProxyValues.Add(new HttpSystemProxyValue(hostname, port, ProxyProtocolType.Https)); - reg.DeleteValue(RegAutoConfigUrl, false); - reg.SetValue(RegProxyEnable, 1); - reg.SetValue(RegProxyServer, - string.Join(";", existingSystemProxyValues.Select(x => x.ToString()).ToArray())); + reg.DeleteValue(RegAutoConfigUrl, false); + reg.SetValue(RegProxyEnable, 1); + reg.SetValue(RegProxyServer, + string.Join(";", existingSystemProxyValues.Select(x => x.ToString()).ToArray())); - Refresh(); - } + Refresh(); } /// @@ -116,34 +115,32 @@ internal void SetProxy(string hostname, int port, ProxyProtocolType protocolType /// internal void RemoveProxy(ProxyProtocolType protocolType, bool saveOriginalConfig = true) { - using (var reg = OpenInternetSettingsKey()) + using RegistryKey? reg = OpenInternetSettingsKey(); + if (reg is null) return; + + if (saveOriginalConfig) SaveOriginalProxyConfiguration(reg); + + if (reg.GetValue(RegProxyServer) != null) { - if (reg == null) return; + var existingContent = reg.GetValue(RegProxyServer) as string; - if (saveOriginalConfig) SaveOriginalProxyConfiguration(reg); + var existingSystemProxyValues = ProxyInfo.GetSystemProxyValues(existingContent); + existingSystemProxyValues.RemoveAll(x => (protocolType & x.ProtocolType) != 0); - if (reg.GetValue(RegProxyServer) != null) + if (existingSystemProxyValues.Count != 0) { - var existingContent = reg.GetValue(RegProxyServer) as string; - - var existingSystemProxyValues = ProxyInfo.GetSystemProxyValues(existingContent); - existingSystemProxyValues.RemoveAll(x => (protocolType & x.ProtocolType) != 0); - - if (existingSystemProxyValues.Count != 0) - { - reg.SetValue(RegProxyEnable, 1); - reg.SetValue(RegProxyServer, - string.Join(";", existingSystemProxyValues.Select(x => x.ToString()).ToArray())); - } - else - { - reg.SetValue(RegProxyEnable, 0); - reg.SetValue(RegProxyServer, string.Empty); - } + reg.SetValue(RegProxyEnable, 1); + reg.SetValue(RegProxyServer, + string.Join(";", existingSystemProxyValues.Select(x => x.ToString()).ToArray())); + } + else + { + reg.SetValue(RegProxyEnable, 0); + reg.SetValue(RegProxyServer, string.Empty); } - - Refresh(); } + + Refresh(); } /// @@ -151,97 +148,88 @@ internal void RemoveProxy(ProxyProtocolType protocolType, bool saveOriginalConfi /// internal void DisableAllProxy() { - using (var reg = OpenInternetSettingsKey()) - { - if (reg == null) return; + using RegistryKey? reg = OpenInternetSettingsKey(); + if (reg is null) return; - SaveOriginalProxyConfiguration(reg); + SaveOriginalProxyConfiguration(reg); - reg.SetValue(RegProxyEnable, 0); - reg.SetValue(RegProxyServer, string.Empty); + reg.SetValue(RegProxyEnable, 0); + reg.SetValue(RegProxyServer, string.Empty); - Refresh(); - } + Refresh(); } internal void SetAutoProxyUrl(string url) { - using (var reg = OpenInternetSettingsKey()) - { - if (reg == null) return; + using RegistryKey? reg = OpenInternetSettingsKey(); + if (reg is null) return; - SaveOriginalProxyConfiguration(reg); - reg.SetValue(RegAutoConfigUrl, url); - Refresh(); - } + SaveOriginalProxyConfiguration(reg); + reg.SetValue(RegAutoConfigUrl, url); + Refresh(); } internal void SetProxyOverride(string proxyOverride) { - using (var reg = OpenInternetSettingsKey()) - { - if (reg == null) return; + using RegistryKey? reg = OpenInternetSettingsKey(); + if (reg is null) return; - SaveOriginalProxyConfiguration(reg); - reg.SetValue(RegProxyOverride, proxyOverride); - Refresh(); - } + SaveOriginalProxyConfiguration(reg); + reg.SetValue(RegProxyOverride, proxyOverride); + Refresh(); } + [SupportedOSPlatform("windows")] internal void RestoreOriginalSettings() { - if (originalValues == null) return; - - using (var reg = Registry.CurrentUser.OpenSubKey(RegKeyInternetSettings, true)) - { - if (reg == null) return; + if (_originalValues is null) return; - var ov = originalValues; - if (ov.AutoConfigUrl != null) - reg.SetValue(RegAutoConfigUrl, ov.AutoConfigUrl); - else - reg.DeleteValue(RegAutoConfigUrl, false); - - if (ov.ProxyEnable.HasValue) - reg.SetValue(RegProxyEnable, ov.ProxyEnable.Value); - else - reg.DeleteValue(RegProxyEnable, false); + using var reg = Registry.CurrentUser.OpenSubKey(RegKeyInternetSettings, true); + if (reg is null) return; - if (ov.ProxyServer != null) - reg.SetValue(RegProxyServer, ov.ProxyServer); - else - reg.DeleteValue(RegProxyServer, false); + var ov = _originalValues; + if (ov.AutoConfigUrl != null) + reg.SetValue(RegAutoConfigUrl, ov.AutoConfigUrl); + else + reg.DeleteValue(RegAutoConfigUrl, false); - if (ov.ProxyOverride != null) - reg.SetValue(RegProxyOverride, ov.ProxyOverride); - else - reg.DeleteValue(RegProxyOverride, false); - - // This should not be needed, but sometimes the values are not stored into the registry - // at system shutdown without flushing. - reg.Flush(); - - originalValues = null; - - const int smShuttingdown = 0x2000; - var windows7Version = new Version(6, 1); - if (Environment.OSVersion.Version > windows7Version || - NativeMethods.GetSystemMetrics(smShuttingdown) == 0) - // Do not call refresh() in Windows 7 or earlier at system shutdown. - // SetInternetOption in the refresh method re-enables ProxyEnable registry value - // in Windows 7 or earlier at system shutdown. - Refresh(); - } + if (ov.ProxyEnable.HasValue) + reg.SetValue(RegProxyEnable, ov.ProxyEnable.Value); + else + reg.DeleteValue(RegProxyEnable, false); + + if (ov.ProxyServer != null) + reg.SetValue(RegProxyServer, ov.ProxyServer); + else + reg.DeleteValue(RegProxyServer, false); + + if (ov.ProxyOverride != null) + reg.SetValue(RegProxyOverride, ov.ProxyOverride); + else + reg.DeleteValue(RegProxyOverride, false); + + // This should not be needed, but sometimes the values are not stored into the registry + // at system shutdown without flushing. + reg.Flush(); + + _originalValues = null; + + const int smShuttingdown = 0x2000; + var windows7Version = new Version(6, 1); + if (Environment.OSVersion.Version > windows7Version || + NativeMethods.GetSystemMetrics(smShuttingdown) == 0) + // Do not call refresh() in Windows 7 or earlier at system shutdown. + // SetInternetOption in the refresh method re-enables ProxyEnable registry value + // in Windows 7 or earlier at system shutdown. + Refresh(); } internal ProxyInfo? GetProxyInfoFromRegistry() { - using (var reg = OpenInternetSettingsKey()) - { - if (reg == null) return null; + using var reg = OpenInternetSettingsKey(); + if (reg == null) return null; - return GetProxyInfoFromRegistry(reg); - } + return GetProxyInfoFromRegistry(reg); } private static ProxyInfo GetProxyInfoFromRegistry(RegistryKey reg) @@ -257,9 +245,9 @@ private static ProxyInfo GetProxyInfoFromRegistry(RegistryKey reg) private void SaveOriginalProxyConfiguration(RegistryKey reg) { - if (originalValues != null) return; + if (_originalValues != null) return; - originalValues = GetProxyInfoFromRegistry(reg); + _originalValues = GetProxyInfoFromRegistry(reg); } ///