diff --git a/src/libraries/Microsoft.Extensions.DependencyInjection/src/ScopeState.cs b/src/libraries/Microsoft.Extensions.DependencyInjection/src/ScopeState.cs index edc06906a8e2d0..14570e149ee107 100644 --- a/src/libraries/Microsoft.Extensions.DependencyInjection/src/ScopeState.cs +++ b/src/libraries/Microsoft.Extensions.DependencyInjection/src/ScopeState.cs @@ -2,7 +2,6 @@ // The .NET Foundation licenses this file to you under the MIT license. using System; -using System.Collections.Concurrent; using System.Collections.Generic; using Microsoft.Extensions.DependencyInjection.ServiceLookup; @@ -10,16 +9,15 @@ namespace Microsoft.Extensions.DependencyInjection { internal class ScopeState { - public IDictionary ResolvedServices { get; } + public Dictionary ResolvedServices { get; } public List Disposables { get; set; } public int DisposableServicesCount => Disposables?.Count ?? 0; public int ResolvedServicesCount => ResolvedServices.Count; - public ScopeState(bool isRoot) + public ScopeState() { - // When isRoot is true to reduce lock contention for singletons upon resolve we use a concurrent dictionary. - ResolvedServices = isRoot ? new ConcurrentDictionary() : new Dictionary(); + ResolvedServices = new Dictionary(); } public void Track(ServiceProviderEngine engine) diff --git a/src/libraries/Microsoft.Extensions.DependencyInjection/src/ServiceLookup/CallSiteFactory.cs b/src/libraries/Microsoft.Extensions.DependencyInjection/src/ServiceLookup/CallSiteFactory.cs index 19ed4b81091a2d..82ee191aade689 100644 --- a/src/libraries/Microsoft.Extensions.DependencyInjection/src/ServiceLookup/CallSiteFactory.cs +++ b/src/libraries/Microsoft.Extensions.DependencyInjection/src/ServiceLookup/CallSiteFactory.cs @@ -16,7 +16,7 @@ internal sealed class CallSiteFactory { private const int DefaultSlot = 0; private readonly ServiceDescriptor[] _descriptors; - private readonly ConcurrentDictionary _callSiteCache = new ConcurrentDictionary(); + private readonly ConcurrentDictionary _callSiteCache = new ConcurrentDictionary(); private readonly Dictionary _descriptorLookup = new Dictionary(); private readonly StackGuard _stackGuard; @@ -77,7 +77,7 @@ private void Populate() } internal ServiceCallSite GetCallSite(Type serviceType, CallSiteChain callSiteChain) => - _callSiteCache.TryGetValue(serviceType, out ServiceCallSite site) ? site : + _callSiteCache.TryGetValue(new ServiceCacheKey(serviceType, DefaultSlot), out ServiceCallSite site) ? site : CreateCallSite(serviceType, callSiteChain); internal ServiceCallSite GetCallSite(ServiceDescriptor serviceDescriptor, CallSiteChain callSiteChain) @@ -104,8 +104,6 @@ private ServiceCallSite CreateCallSite(Type serviceType, CallSiteChain callSiteC TryCreateOpenGeneric(serviceType, callSiteChain) ?? TryCreateEnumerable(serviceType, callSiteChain); - _callSiteCache[serviceType] = callSite; - return callSite; } @@ -132,6 +130,12 @@ private ServiceCallSite TryCreateOpenGeneric(Type serviceType, CallSiteChain cal private ServiceCallSite TryCreateEnumerable(Type serviceType, CallSiteChain callSiteChain) { + ServiceCacheKey callSiteKey = new ServiceCacheKey(serviceType, DefaultSlot); + if (_callSiteCache.TryGetValue(callSiteKey, out ServiceCallSite serviceCallSite)) + { + return serviceCallSite; + } + try { callSiteChain.Add(serviceType); @@ -188,10 +192,10 @@ private ServiceCallSite TryCreateEnumerable(Type serviceType, CallSiteChain call ResultCache resultCache = ResultCache.None; if (cacheLocation == CallSiteResultCacheLocation.Scope || cacheLocation == CallSiteResultCacheLocation.Root) { - resultCache = new ResultCache(cacheLocation, new ServiceCacheKey(serviceType, DefaultSlot)); + resultCache = new ResultCache(cacheLocation, callSiteKey); } - return new IEnumerableCallSite(resultCache, itemType, callSites.ToArray()); + return _callSiteCache[callSiteKey] = new IEnumerableCallSite(resultCache, itemType, callSites.ToArray()); } return null; @@ -211,6 +215,12 @@ private ServiceCallSite TryCreateExact(ServiceDescriptor descriptor, Type servic { if (serviceType == descriptor.ServiceType) { + ServiceCacheKey callSiteKey = new ServiceCacheKey(serviceType, slot); + if (_callSiteCache.TryGetValue(callSiteKey, out ServiceCallSite serviceCallSite)) + { + return serviceCallSite; + } + ServiceCallSite callSite; var lifetime = new ResultCache(descriptor.Lifetime, serviceType, slot); if (descriptor.ImplementationInstance != null) @@ -230,7 +240,7 @@ private ServiceCallSite TryCreateExact(ServiceDescriptor descriptor, Type servic throw new InvalidOperationException(SR.InvalidServiceDescriptor); } - return callSite; + return _callSiteCache[callSiteKey] = callSite; } return null; @@ -241,6 +251,12 @@ private ServiceCallSite TryCreateOpenGeneric(ServiceDescriptor descriptor, Type if (serviceType.IsConstructedGenericType && serviceType.GetGenericTypeDefinition() == descriptor.ServiceType) { + ServiceCacheKey callSiteKey = new ServiceCacheKey(serviceType, slot); + if (_callSiteCache.TryGetValue(callSiteKey, out ServiceCallSite serviceCallSite)) + { + return serviceCallSite; + } + Debug.Assert(descriptor.ImplementationType != null, "descriptor.ImplementationType != null"); var lifetime = new ResultCache(descriptor.Lifetime, serviceType, slot); Type closedType; @@ -258,7 +274,7 @@ private ServiceCallSite TryCreateOpenGeneric(ServiceDescriptor descriptor, Type return null; } - return CreateConstructorCallSite(lifetime, serviceType, closedType, callSiteChain); + return _callSiteCache[callSiteKey] = CreateConstructorCallSite(lifetime, serviceType, closedType, callSiteChain); } return null; @@ -406,7 +422,7 @@ private ServiceCallSite[] CreateArgumentCallSites( public void Add(Type type, ServiceCallSite serviceCallSite) { - _callSiteCache[type] = serviceCallSite; + _callSiteCache[new ServiceCacheKey(type, DefaultSlot)] = serviceCallSite; } private struct ServiceDescriptorCacheItem diff --git a/src/libraries/Microsoft.Extensions.DependencyInjection/src/ServiceLookup/CallSiteRuntimeResolver.cs b/src/libraries/Microsoft.Extensions.DependencyInjection/src/ServiceLookup/CallSiteRuntimeResolver.cs index 36de790a24efde..7479759f2b527a 100644 --- a/src/libraries/Microsoft.Extensions.DependencyInjection/src/ServiceLookup/CallSiteRuntimeResolver.cs +++ b/src/libraries/Microsoft.Extensions.DependencyInjection/src/ServiceLookup/CallSiteRuntimeResolver.cs @@ -3,7 +3,6 @@ using System; using System.Collections.Generic; -using System.Collections.Concurrent; using System.Reflection; using System.Runtime.ExceptionServices; using System.Threading; @@ -60,21 +59,31 @@ protected override object VisitConstructor(ConstructorCallSite constructorCallSi protected override object VisitRootCache(ServiceCallSite callSite, RuntimeResolverContext context) { - var lockType = RuntimeResolverLock.Root; - bool lockTaken = false; - - // using more granular locking (per singleton) for the root - Monitor.Enter(callSite, ref lockTaken); - try + if (callSite.Value is object value) { - return ResolveService(callSite, context, lockType, serviceProviderEngine: context.Scope.Engine.Root); + // Value already calculated, return it directly + return value; } - finally + + var lockType = RuntimeResolverLock.Root; + ServiceProviderEngineScope serviceProviderEngine = context.Scope.Engine.Root; + + lock (callSite) { - if (lockTaken) + // Lock the callsite and check if another thread already cached the value + if (callSite.Value is object resolved) { - Monitor.Exit(callSite); + return resolved; } + + resolved = VisitCallSiteMain(callSite, new RuntimeResolverContext + { + Scope = serviceProviderEngine, + AcquiredLocks = context.AcquiredLocks | lockType + }); + serviceProviderEngine.CaptureDisposable(resolved); + callSite.Value = resolved; + return resolved; } } @@ -91,7 +100,7 @@ private object VisitCache(ServiceCallSite callSite, RuntimeResolverContext conte { bool lockTaken = false; object sync = serviceProviderEngine.Sync; - + Dictionary resolvedServices = serviceProviderEngine.ResolvedServices; // Taking locks only once allows us to fork resolution process // on another thread without causing the deadlock because we // always know that we are going to wait the other thread to finish before @@ -103,7 +112,21 @@ private object VisitCache(ServiceCallSite callSite, RuntimeResolverContext conte try { - return ResolveService(callSite, context, lockType, serviceProviderEngine); + // Note: This method has already taken lock by the caller for resolution and access synchronization. + // For scoped: takes a dictionary as both a resolution lock and a dictionary access lock. + if (resolvedServices.TryGetValue(callSite.Cache.Key, out object resolved)) + { + return resolved; + } + + resolved = VisitCallSiteMain(callSite, new RuntimeResolverContext + { + Scope = serviceProviderEngine, + AcquiredLocks = context.AcquiredLocks | lockType + }); + serviceProviderEngine.CaptureDisposable(resolved); + resolvedServices.Add(callSite.Cache.Key, resolved); + return resolved; } finally { @@ -114,32 +137,6 @@ private object VisitCache(ServiceCallSite callSite, RuntimeResolverContext conte } } - private object ResolveService(ServiceCallSite callSite, RuntimeResolverContext context, RuntimeResolverLock lockType, ServiceProviderEngineScope serviceProviderEngine) - { - IDictionary resolvedServices = serviceProviderEngine.ResolvedServices; - - // Note: This method has already taken lock by the caller for resolution and access synchronization. - // For root: uses a concurrent dictionary and takes a per singleton lock for resolution. - // For scoped: takes a dictionary as both a resolution lock and a dictionary access lock. - Debug.Assert( - (lockType == RuntimeResolverLock.Root && resolvedServices is ConcurrentDictionary) || - (lockType == RuntimeResolverLock.Scope && Monitor.IsEntered(serviceProviderEngine.Sync))); - - if (resolvedServices.TryGetValue(callSite.Cache.Key, out object resolved)) - { - return resolved; - } - - resolved = VisitCallSiteMain(callSite, new RuntimeResolverContext - { - Scope = serviceProviderEngine, - AcquiredLocks = context.AcquiredLocks | lockType - }); - serviceProviderEngine.CaptureDisposable(resolved); - resolvedServices.Add(callSite.Cache.Key, resolved); - return resolved; - } - protected override object VisitConstant(ConstantCallSite constantCallSite, RuntimeResolverContext context) { return constantCallSite.DefaultValue; diff --git a/src/libraries/Microsoft.Extensions.DependencyInjection/src/ServiceLookup/ConstantCallSite.cs b/src/libraries/Microsoft.Extensions.DependencyInjection/src/ServiceLookup/ConstantCallSite.cs index 206d2691a878fc..a27cb110389d34 100644 --- a/src/libraries/Microsoft.Extensions.DependencyInjection/src/ServiceLookup/ConstantCallSite.cs +++ b/src/libraries/Microsoft.Extensions.DependencyInjection/src/ServiceLookup/ConstantCallSite.cs @@ -8,7 +8,7 @@ namespace Microsoft.Extensions.DependencyInjection.ServiceLookup internal sealed class ConstantCallSite : ServiceCallSite { private readonly Type _serviceType; - internal object DefaultValue { get; } + internal object DefaultValue => Value; public ConstantCallSite(Type serviceType, object defaultValue): base(ResultCache.None) { @@ -18,7 +18,7 @@ public ConstantCallSite(Type serviceType, object defaultValue): base(ResultCache throw new ArgumentException(SR.Format(SR.ConstantCantBeConvertedToServiceType, defaultValue.GetType(), serviceType)); } - DefaultValue = defaultValue; + Value = defaultValue; } public override Type ServiceType => DefaultValue?.GetType() ?? _serviceType; diff --git a/src/libraries/Microsoft.Extensions.DependencyInjection/src/ServiceLookup/ServiceCacheKey.cs b/src/libraries/Microsoft.Extensions.DependencyInjection/src/ServiceLookup/ServiceCacheKey.cs index b8b5dde5dba068..f2c3baa2a6d471 100644 --- a/src/libraries/Microsoft.Extensions.DependencyInjection/src/ServiceLookup/ServiceCacheKey.cs +++ b/src/libraries/Microsoft.Extensions.DependencyInjection/src/ServiceLookup/ServiceCacheKey.cs @@ -5,7 +5,7 @@ namespace Microsoft.Extensions.DependencyInjection.ServiceLookup { - internal struct ServiceCacheKey: IEquatable + internal readonly struct ServiceCacheKey : IEquatable { public static ServiceCacheKey Empty { get; } = new ServiceCacheKey(null, 0); diff --git a/src/libraries/Microsoft.Extensions.DependencyInjection/src/ServiceLookup/ServiceCallSite.cs b/src/libraries/Microsoft.Extensions.DependencyInjection/src/ServiceLookup/ServiceCallSite.cs index f63f4c43835679..626ecce4f25486 100644 --- a/src/libraries/Microsoft.Extensions.DependencyInjection/src/ServiceLookup/ServiceCallSite.cs +++ b/src/libraries/Microsoft.Extensions.DependencyInjection/src/ServiceLookup/ServiceCallSite.cs @@ -19,6 +19,7 @@ protected ServiceCallSite(ResultCache cache) public abstract Type ImplementationType { get; } public abstract CallSiteKind Kind { get; } public ResultCache Cache { get; } + public object Value { get; set; } public bool CaptureDisposable => ImplementationType == null || diff --git a/src/libraries/Microsoft.Extensions.DependencyInjection/src/ServiceLookup/ServiceProviderEngine.cs b/src/libraries/Microsoft.Extensions.DependencyInjection/src/ServiceLookup/ServiceProviderEngine.cs index d6d804c6484482..ca334392841727 100644 --- a/src/libraries/Microsoft.Extensions.DependencyInjection/src/ServiceLookup/ServiceProviderEngine.cs +++ b/src/libraries/Microsoft.Extensions.DependencyInjection/src/ServiceLookup/ServiceProviderEngine.cs @@ -19,7 +19,7 @@ internal abstract class ServiceProviderEngine : IServiceProviderEngine, IService protected ServiceProviderEngine(IEnumerable serviceDescriptors) { _createServiceAccessor = CreateServiceAccessor; - Root = new ServiceProviderEngineScope(this, isRoot: true); + Root = new ServiceProviderEngineScope(this); RuntimeResolver = new CallSiteRuntimeResolver(); CallSiteFactory = new CallSiteFactory(serviceDescriptors); CallSiteFactory.Add(typeof(IServiceProvider), new ServiceProviderCallSite()); diff --git a/src/libraries/Microsoft.Extensions.DependencyInjection/src/ServiceLookup/ServiceProviderEngineScope.cs b/src/libraries/Microsoft.Extensions.DependencyInjection/src/ServiceLookup/ServiceProviderEngineScope.cs index ab31fabff48e4a..fefb711be07bfa 100644 --- a/src/libraries/Microsoft.Extensions.DependencyInjection/src/ServiceLookup/ServiceProviderEngineScope.cs +++ b/src/libraries/Microsoft.Extensions.DependencyInjection/src/ServiceLookup/ServiceProviderEngineScope.cs @@ -16,13 +16,13 @@ internal sealed class ServiceProviderEngineScope : IServiceScope, IServiceProvid private bool _disposed; private readonly ScopeState _state; - public ServiceProviderEngineScope(ServiceProviderEngine engine, bool isRoot = false) + public ServiceProviderEngineScope(ServiceProviderEngine engine) { Engine = engine; - _state = new ScopeState(isRoot); + _state = new ScopeState(); } - internal IDictionary ResolvedServices => _state.ResolvedServices; + internal Dictionary ResolvedServices => _state.ResolvedServices; // This lock protects state on the scope, in particular, for the root scope, it protects // the list of disposable entries only, since ResolvedServices is a concurrent dictionary. diff --git a/src/libraries/Microsoft.Extensions.DependencyInjection/tests/DI.Tests/CallSiteTests.cs b/src/libraries/Microsoft.Extensions.DependencyInjection/tests/DI.Tests/CallSiteTests.cs index c83e27790b637d..da75b6af9a4a62 100644 --- a/src/libraries/Microsoft.Extensions.DependencyInjection/tests/DI.Tests/CallSiteTests.cs +++ b/src/libraries/Microsoft.Extensions.DependencyInjection/tests/DI.Tests/CallSiteTests.cs @@ -89,9 +89,11 @@ public void BuiltExpressionWillReturnResolvedServiceWhenAppropriate( var compiledCallSite = CompileCallSite(callSite, provider); var compiledCollectionCallSite = CompileCallSite(collectionCallSite, provider); - var service1 = Invoke(callSite, provider); - var service2 = compiledCallSite(provider.Root); - var serviceEnumerator = ((IEnumerable)compiledCollectionCallSite(provider.Root)).GetEnumerator(); + using var scope = (ServiceProviderEngineScope)provider.CreateScope(); + + var service1 = Invoke(callSite, scope); + var service2 = compiledCallSite(scope); + var serviceEnumerator = ((IEnumerable)compiledCollectionCallSite(scope)).GetEnumerator(); Assert.NotNull(service1); Assert.True(compare(service1, service2)); @@ -114,10 +116,12 @@ public void BuiltExpressionCanResolveNestedScopedService() var callSite = provider.CallSiteFactory.GetCallSite(typeof(ServiceC), new CallSiteChain()); var compiledCallSite = CompileCallSite(callSite, provider); - var serviceC = (ServiceC)compiledCallSite(provider.Root); + using var scope = (ServiceProviderEngineScope)provider.CreateScope(); + + var serviceC = (ServiceC)compiledCallSite(scope); Assert.NotNull(serviceC.ServiceB.ServiceA); - Assert.Equal(serviceC, Invoke(callSite, provider)); + Assert.Equal(serviceC, Invoke(callSite, scope)); } [Theory] @@ -371,9 +375,9 @@ public void Dispose() } } - private static object Invoke(ServiceCallSite callSite, ServiceProviderEngine provider) + private static object Invoke(ServiceCallSite callSite, ServiceProviderEngineScope scope) { - return CallSiteRuntimeResolver.Resolve(callSite, provider.Root); + return CallSiteRuntimeResolver.Resolve(callSite, scope); } private static Func CompileCallSite(ServiceCallSite callSite, ServiceProviderEngine engine)