diff --git a/src/libraries/Microsoft.Extensions.DependencyInjection/src/ServiceLookup/ServiceProviderEngineScope.cs b/src/libraries/Microsoft.Extensions.DependencyInjection/src/ServiceLookup/ServiceProviderEngineScope.cs index 7582ed8e79f025..edeb6330091ece 100644 --- a/src/libraries/Microsoft.Extensions.DependencyInjection/src/ServiceLookup/ServiceProviderEngineScope.cs +++ b/src/libraries/Microsoft.Extensions.DependencyInjection/src/ServiceLookup/ServiceProviderEngineScope.cs @@ -5,6 +5,7 @@ using System.Collections.Generic; using System.Diagnostics; using System.Diagnostics.CodeAnalysis; +using System.Runtime.ExceptionServices; using System.Threading.Tasks; using Microsoft.Extensions.Internal; @@ -120,10 +121,15 @@ public object GetRequiredKeyedService(Type serviceType, object? serviceKey) public void Dispose() { List? toDispose = BeginDispose(); + if (toDispose is null) + { + return; + } - if (toDispose != null) + object? exceptionsCache = null; + for (var i = toDispose.Count - 1; i >= 0; i--) { - for (int i = toDispose.Count - 1; i >= 0; i--) + try { if (toDispose[i] is IDisposable disposable) { @@ -134,68 +140,126 @@ public void Dispose() throw new InvalidOperationException(SR.Format(SR.AsyncDisposableServiceDispose, TypeNameHelper.GetTypeDisplayName(toDispose[i]))); } } + catch (Exception exception) + { + AddExceptionToCache(ref exceptionsCache, exception); + } } + + CheckExceptionCache(exceptionsCache); } public ValueTask DisposeAsync() { List? toDispose = BeginDispose(); + if (toDispose is null) + { + return default; + } - if (toDispose != null) + object? exceptionsCache = null; + for (var i = toDispose.Count - 1; i >= 0; i--) { try { - for (int i = toDispose.Count - 1; i >= 0; i--) + object disposable = toDispose[i]; + if (disposable is IAsyncDisposable asyncDisposable) { - object disposable = toDispose[i]; - if (disposable is IAsyncDisposable asyncDisposable) - { - ValueTask vt = asyncDisposable.DisposeAsync(); - if (!vt.IsCompletedSuccessfully) - { - return Await(i, vt, toDispose); - } - - // If its a IValueTaskSource backed ValueTask, - // inform it its result has been read so it can reset - vt.GetAwaiter().GetResult(); - } - else + ValueTask vt = asyncDisposable.DisposeAsync(); + if (!vt.IsCompletedSuccessfully) { - ((IDisposable)disposable).Dispose(); + return Await(i, vt, toDispose, exceptionsCache); } + + // If its a IValueTaskSource backed ValueTask, + // inform it its result has been read so it can reset + vt.GetAwaiter().GetResult(); + } + else + { + ((IDisposable)disposable).Dispose(); } } - catch (Exception ex) + catch (Exception exception) { - return new ValueTask(Task.FromException(ex)); + AddExceptionToCache(ref exceptionsCache, exception); } } + CheckExceptionCache(exceptionsCache); + return default; - static async ValueTask Await(int i, ValueTask vt, List toDispose) + static async ValueTask Await(int i, ValueTask vt, List toDispose, object? exceptionsCache) { - await vt.ConfigureAwait(false); + try + { + await vt.ConfigureAwait(false); + } + catch (Exception exception) + { + AddExceptionToCache(ref exceptionsCache, exception); + } + // vt is acting on the disposable at index i, // decrement it and move to the next iteration i--; for (; i >= 0; i--) { - object disposable = toDispose[i]; - if (disposable is IAsyncDisposable asyncDisposable) + try { - await asyncDisposable.DisposeAsync().ConfigureAwait(false); + object disposable = toDispose[i]; + if (disposable is IAsyncDisposable asyncDisposable) + { + await asyncDisposable.DisposeAsync().ConfigureAwait(false); + } + else + { + ((IDisposable)disposable).Dispose(); + } } - else + catch (Exception exception) { - ((IDisposable)disposable).Dispose(); + AddExceptionToCache(ref exceptionsCache, exception); } } + + CheckExceptionCache(exceptionsCache); } } + private static void AddExceptionToCache(ref object? exceptionsCache, Exception exception) + { + if (exceptionsCache is null) + { + exceptionsCache = ExceptionDispatchInfo.Capture(exception); + } + else if (exceptionsCache is ExceptionDispatchInfo exceptionInfo) + { + exceptionsCache = new List { exceptionInfo.SourceException, exception }; + } + else + { + ((List)exceptionsCache).Add(exception); + } + } + + private static void CheckExceptionCache(object? exceptionsCache) + { + if (exceptionsCache is null) + { + return; + } + + if (exceptionsCache is ExceptionDispatchInfo exceptionInfo) + { + exceptionInfo.Throw(); + } + + throw new AggregateException((List)exceptionsCache); + } + private List? BeginDispose() { lock (Sync) diff --git a/src/libraries/Microsoft.Extensions.DependencyInjection/tests/DI.Tests/ServiceProviderEngineScopeTests.cs b/src/libraries/Microsoft.Extensions.DependencyInjection/tests/DI.Tests/ServiceProviderEngineScopeTests.cs index e25174cd011245..ccbedc43f1e16d 100644 --- a/src/libraries/Microsoft.Extensions.DependencyInjection/tests/DI.Tests/ServiceProviderEngineScopeTests.cs +++ b/src/libraries/Microsoft.Extensions.DependencyInjection/tests/DI.Tests/ServiceProviderEngineScopeTests.cs @@ -2,10 +2,9 @@ // The .NET Foundation licenses this file to you under the MIT license. using System; -using System.Collections.Generic; +using System.Threading.Tasks; using Microsoft.Extensions.DependencyInjection.Specification.Fakes; using Xunit; -using Xunit.Abstractions; namespace Microsoft.Extensions.DependencyInjection.ServiceLookup { @@ -41,5 +40,138 @@ public void ServiceProviderEngineScope_ImplementsAllServiceProviderInterfaces() Assert.Contains(serviceProviderInterface, engineScopeInterfaces); } } + + [Fact] + public void Dispose_ServiceThrows_DisposesAllAndThrows() + { + var services = new ServiceCollection(); + services.AddKeyedTransient("throws", (_, _) => new TestDisposable(true)); + services.AddKeyedTransient("doesnotthrow", (_, _) => new TestDisposable(false)); + + var scope = services.BuildServiceProvider().GetRequiredService().CreateScope().ServiceProvider; + + var disposables = new TestDisposable[] + { + scope.GetRequiredKeyedService("throws"), + scope.GetRequiredKeyedService("doesnotthrow") + }; + + var exception = Assert.Throws(() => ((IDisposable)scope).Dispose()); + Assert.Equal(TestDisposable.ErrorMessage, exception.Message); + Assert.All(disposables, disposable => Assert.True(disposable.IsDisposed)); + } + + [Fact] + public void Dispose_TwoServicesThrows_DisposesAllAndThrowsAggregateException() + { + var services = new ServiceCollection(); + services.AddKeyedTransient("throws", (_, _) => new TestDisposable(true)); + services.AddKeyedTransient("doesnotthrow", (_, _) => new TestDisposable(false)); + + var scope = services.BuildServiceProvider().GetRequiredService().CreateScope().ServiceProvider; + + var disposables = new TestDisposable[] + { + scope.GetRequiredKeyedService("throws"), + scope.GetRequiredKeyedService("doesnotthrow"), + scope.GetRequiredKeyedService("throws"), + scope.GetRequiredKeyedService("doesnotthrow"), + }; + + var exception = Assert.Throws(() => ((IDisposable)scope).Dispose()); + Assert.Equal(2, exception.InnerExceptions.Count); + Assert.All(exception.InnerExceptions, ex => Assert.IsType(ex)); + Assert.All(disposables, disposable => Assert.True(disposable.IsDisposed)); + } + + [Theory] + [InlineData(true)] + [InlineData(false)] + public async Task DisposeAsync_ServiceThrows_DisposesAllAndThrows(bool synchronous) + { + var services = new ServiceCollection(); + services.AddKeyedTransient("throws", (_, _) => new TestDisposable(true, synchronous)); + services.AddKeyedTransient("doesnotthrow", (_, _) => new TestDisposable(false, synchronous)); + + var scope = services.BuildServiceProvider().GetRequiredService().CreateScope().ServiceProvider; + + var disposables = new TestDisposable[] + { + scope.GetRequiredKeyedService("throws"), + scope.GetRequiredKeyedService("doesnotthrow") + }; + + var exception = await Assert.ThrowsAsync(async () => await ((IAsyncDisposable)scope).DisposeAsync()); + Assert.Equal(TestDisposable.ErrorMessage, exception.Message); + Assert.All(disposables, disposable => Assert.True(disposable.IsDisposed)); + } + + [Theory] + [InlineData(true)] + [InlineData(false)] + public async Task DisposeAsync_TwoServicesThrows_DisposesAllAndThrowsAggregateException(bool synchronous) + { + var services = new ServiceCollection(); + services.AddKeyedTransient("throws", (_, _) => new TestDisposable(true, synchronous)); + services.AddKeyedTransient("doesnotthrow", (_, _) => new TestDisposable(false, synchronous)); + + var scope = services.BuildServiceProvider().GetRequiredService().CreateScope().ServiceProvider; + + var disposables = new TestDisposable[] + { + scope.GetRequiredKeyedService("throws"), + scope.GetRequiredKeyedService("doesnotthrow"), + scope.GetRequiredKeyedService("throws"), + scope.GetRequiredKeyedService("doesnotthrow"), + }; + + var exception = await Assert.ThrowsAsync(async () => await ((IAsyncDisposable)scope).DisposeAsync()); + Assert.Equal(2, exception.InnerExceptions.Count); + Assert.All(exception.InnerExceptions, ex => Assert.IsType(ex)); + Assert.All(disposables, disposable => Assert.True(disposable.IsDisposed)); + } + + private class TestDisposable : IDisposable, IAsyncDisposable + { + public const string ErrorMessage = "Dispose failed."; + + private readonly bool _throwsOnDispose; + private readonly bool _synchronous; + + public bool IsDisposed { get; private set; } + + public TestDisposable(bool throwsOnDispose = false, bool synchronous = false) + { + _throwsOnDispose = throwsOnDispose; + _synchronous = synchronous; + } + + public void Dispose() + { + IsDisposed = true; + + if (_throwsOnDispose) + { + throw new InvalidOperationException(ErrorMessage); + } + } + + public ValueTask DisposeAsync() + { + if (_synchronous) + { + Dispose(); + return default; + } + + return new ValueTask(DisposeAsyncInternal()); + + async Task DisposeAsyncInternal() + { + await Task.Yield(); + Dispose(); + } + } + } } }