diff --git a/src/coreclr/System.Private.CoreLib/src/System/Runtime/CompilerServices/AsyncHelpers.CoreCLR.cs b/src/coreclr/System.Private.CoreLib/src/System/Runtime/CompilerServices/AsyncHelpers.CoreCLR.cs index 0eeea7e127dfd0..81d9f982c4e7ea 100644 --- a/src/coreclr/System.Private.CoreLib/src/System/Runtime/CompilerServices/AsyncHelpers.CoreCLR.cs +++ b/src/coreclr/System.Private.CoreLib/src/System/Runtime/CompilerServices/AsyncHelpers.CoreCLR.cs @@ -15,42 +15,6 @@ namespace System.Runtime.CompilerServices { - internal struct ExecutionAndSyncBlockStore - { - // Store current ExecutionContext and SynchronizationContext as "previousXxx". - // This allows us to restore them and undo any Context changes made in stateMachine.MoveNext - // so that they won't "leak" out of the first await. - public ExecutionContext? _previousExecutionCtx; - public SynchronizationContext? _previousSyncCtx; - public Thread _thread; - - public void Push() - { - _thread = Thread.CurrentThread; - // Here we get the execution context for synchronous restoring, - // not for flowing across suspension to potentially another thread. - // Therefore we do not need to worry about IsFlowSuppressed - _previousExecutionCtx = _thread._executionContext; - _previousSyncCtx = _thread._synchronizationContext; - } - - public void Pop() - { - // The common case is that these have not changed, so avoid the cost of a write barrier if not needed. - if (_previousSyncCtx != _thread._synchronizationContext) - { - // Restore changed SynchronizationContext back to previous - _thread._synchronizationContext = _previousSyncCtx; - } - - ExecutionContext? currentExecutionCtx = _thread._executionContext; - if (_previousExecutionCtx != currentExecutionCtx) - { - ExecutionContext.RestoreChangedContextToThread(_thread, _previousExecutionCtx, currentExecutionCtx); - } - } - } - [Flags] // Keep in sync with CORINFO_CONTINUATION_FLAGS internal enum ContinuationFlags @@ -206,12 +170,13 @@ public static partial class AsyncHelpers [Intrinsic] private static void TailAwait() => throw new UnreachableException(); - // Used during suspensions to hold the continuation chain and on what we are waiting. - // Methods like FinalizeTaskReturningThunk will unlink the state and wrap into a Task. - private struct RuntimeAsyncAwaitState + // This is state used by suspension/resumption machinery and stored in + // the two places that initiate runtime async chains: either a + // task-returning thunk, or DispatchContinuations. A pointer to this + // state is kept in the runtime async TLS. This storage method avoids + // costly write barriers on the hot path of suspension/resumption. + private ref struct RuntimeAsyncStackState { - public Continuation? SentinelContinuation; - // The following are the possible introducers of asynchrony into a chain of awaits. // In other words - when we build a chain of continuations it would be logicaly attached // to one of these notifiers. @@ -220,17 +185,82 @@ private struct RuntimeAsyncAwaitState public ValueTaskSourceNotifier? ValueTaskSourceNotifier; public Task? TaskNotifier; - public ExecutionContext? ExecutionContext; - public SynchronizationContext? SynchronizationContext; + // When we suspend in the leaf, the contexts are captured into these fields. + public ExecutionContext? LeafExecutionContext; + public SynchronizationContext? LeafSynchronizationContext; + + // When we enter the root of the async chain (either an async thunk + // or DispatchContinuations), the contexts are captured into these + // fields. + public ExecutionContext? RootExecutionContext; + public SynchronizationContext? RootSynchronizationContext; + + public unsafe RuntimeAsyncStackState* Next; + + public void Push(Thread thread) + { + RootExecutionContext = thread._executionContext; + RootSynchronizationContext = thread._synchronizationContext; + } + + public void Pop(Thread thread) + { + // The common case is that these have not changed, so avoid the cost of a write barrier if not needed. + if (RootSynchronizationContext != thread._synchronizationContext) + { + // Restore changed SynchronizationContext back to previous + thread._synchronizationContext = RootSynchronizationContext; + } + + ExecutionContext? currentExecutionCtx = thread._executionContext; + if (RootExecutionContext != currentExecutionCtx) + { + ExecutionContext.RestoreChangedContextToThread(thread, RootExecutionContext, currentExecutionCtx); + } + } + } + + // Used during suspensions to hold the continuation chain and on what we are waiting. + // Methods like FinalizeTaskReturningThunk will unlink the state and wrap into a Task. + private unsafe struct RuntimeAsyncAwaitState + { + public Continuation? SentinelContinuation; + + // We cache the thread here to avoid unnecessary repeated TLS lookups. + public Thread CurrentThread; + + public RuntimeAsyncStackState* StackState; public void CaptureContexts() { - Thread curThread = Thread.CurrentThreadAssumedInitialized; + // CaptureContext is called from leaf await helpers. We either just started a runtime async chain + // (from a thunk), or we came from DispatchContinuations (on resumption). + // Both cases have already initialized CurrentThread. + Thread curThread = CurrentThread; + Debug.Assert(curThread != null); + Debug.Assert(StackState != null); // Here we get the execution context for presenting to the notifier, // not for flowing across suspension to potentially another thread. // Therefore we do not need to worry about IsFlowSuppressed - ExecutionContext = curThread._executionContext; - SynchronizationContext = curThread._synchronizationContext; + StackState->LeafExecutionContext = curThread._executionContext; + StackState->LeafSynchronizationContext = curThread._synchronizationContext; + } + + // At the start of an async chain (task-returning thunk or DispatchContinuations) this function + // is called + public void Push(RuntimeAsyncStackState* stackState) + { + stackState->Next = StackState; + StackState = stackState; + stackState->Push(CurrentThread ??= Thread.CurrentThread); + } + + // This function is called at the end of an async chain + public void Pop() + { + Debug.Assert(CurrentThread != null); + StackState->Pop(CurrentThread); + StackState = StackState->Next; } } @@ -297,20 +327,18 @@ private static unsafe Continuation AllocContinuationClass(Continuation prevConti /// Task or a ValueTaskNotifier whose completion we are awaiting. [BypassReadyToRun] [MethodImpl(MethodImplOptions.NoInlining | MethodImplOptions.Async)] - private static void TransparentAwait(object o) + private static unsafe void TransparentAwait(object o) { ref RuntimeAsyncAwaitState state = ref t_runtimeAsyncAwaitState; - Continuation? sentinelContinuation = state.SentinelContinuation; - if (sentinelContinuation == null) - state.SentinelContinuation = sentinelContinuation = new Continuation(); + Continuation? sentinelContinuation = state.SentinelContinuation ??= new Continuation(); if (o is Task t) { - state.TaskNotifier = t; + state.StackState->TaskNotifier = t; } else { - state.ValueTaskSourceNotifier = (ValueTaskSourceNotifier)o; + state.StackState->ValueTaskSourceNotifier = (ValueTaskSourceNotifier)o; } state.CaptureContexts(); @@ -343,13 +371,19 @@ void ITaskCompletionAction.Invoke(Task completingTask) bool ITaskCompletionAction.InvokeMayRunArbitraryCode => true; - private Action GetContinuationAction() => (Action)m_action!; + private Action GetContinuationAction() + { + object? action = m_action; + Debug.Assert(action is Action); + return Unsafe.As(action); + } private Continuation MoveContinuationState() { - Continuation continuation = (Continuation)m_stateObject!; + object? stateObject = m_stateObject; + Debug.Assert(stateObject is Continuation); m_stateObject = null; - return continuation; + return Unsafe.As(stateObject); } private void SetContinuationState(Continuation value) @@ -358,23 +392,23 @@ private void SetContinuationState(Continuation value) m_stateObject = value; } - internal bool HandleSuspended() + internal unsafe bool HandleSuspended(ref RuntimeAsyncAwaitState state) { - ref RuntimeAsyncAwaitState state = ref t_runtimeAsyncAwaitState; + Thread currentThread = state.CurrentThread; + Debug.Assert(currentThread != null); - RestoreContextsOnSuspension(false, state.ExecutionContext, state.SynchronizationContext); - - ICriticalNotifyCompletion? critNotifier = state.CriticalNotifier; - INotifyCompletion? notifier = state.Notifier; - ValueTaskSourceNotifier? vtsNotifier = state.ValueTaskSourceNotifier; - Task? taskNotifier = state.TaskNotifier; + RuntimeAsyncStackState* stackState = state.StackState; + ExecutionContext? suspendingExecutionContext = stackState->LeafExecutionContext; + SynchronizationContext? suspendingSyncContext = stackState->LeafSynchronizationContext; + if (suspendingExecutionContext != currentThread._executionContext) + { + currentThread._executionContext = suspendingExecutionContext; + } - state.CriticalNotifier = null; - state.Notifier = null; - state.ValueTaskSourceNotifier = null; - state.TaskNotifier = null; - state.ExecutionContext = null; - state.SynchronizationContext = null; + if (suspendingSyncContext != currentThread._synchronizationContext) + { + currentThread._synchronizationContext = suspendingSyncContext; + } Continuation sentinelContinuation = state.SentinelContinuation!; Continuation headContinuation = sentinelContinuation.Next!; @@ -393,11 +427,11 @@ internal bool HandleSuspended() try { - if (critNotifier != null) + if (stackState->CriticalNotifier is { } critNotifier) { critNotifier.UnsafeOnCompleted(GetContinuationAction()); } - else if (taskNotifier != null) + else if (stackState->TaskNotifier is { } taskNotifier) { // Runtime async callable wrapper for task returning // method. This implements the context transparent @@ -407,7 +441,7 @@ internal bool HandleSuspended() ThreadPool.UnsafeQueueUserWorkItemInternal(this, preferLocal: true); } } - else if (vtsNotifier != null) + else if (stackState->ValueTaskSourceNotifier is { } valueTaskSourceNotifier) { // The awaiter must inform the ValueTaskSource on whether the continuation // wants to run on a context, although the source may decide to ignore the suggestion. @@ -442,12 +476,12 @@ internal bool HandleSuspended() // Clear continuation flags, so that continuation runs transparently nextUserContinuation.Flags &= ~continueFlags; - vtsNotifier.OnCompleted(s_runContinuationAction, this, configFlags); + valueTaskSourceNotifier.OnCompleted(s_runContinuationAction, this, configFlags); } else { - Debug.Assert(notifier != null); - notifier.OnCompleted(GetContinuationAction()); + Debug.Assert(stackState->Notifier != null); + stackState->Notifier!.OnCompleted(GetContinuationAction()); } return true; @@ -460,7 +494,7 @@ internal bool HandleSuspended() return false; } - internal void InstrumentedHandleSuspended(AsyncInstrumentation.Flags flags, Continuation? newContinuation = null) + internal void InstrumentedHandleSuspended(AsyncInstrumentation.Flags flags, ref RuntimeAsyncAwaitState state, Continuation? newContinuation = null) { if (AsyncInstrumentation.IsEnabled.AsyncDebugger(flags)) { @@ -468,7 +502,7 @@ internal void InstrumentedHandleSuspended(AsyncInstrumentation.Flags flags, Cont AsyncDebugger.HandleSuspended(nextContinuation, newContinuation); - if (!HandleSuspended()) + if (!HandleSuspended(ref state)) { AsyncDebugger.HandleSuspendedFailed(this, nextContinuation); } @@ -476,7 +510,7 @@ internal void InstrumentedHandleSuspended(AsyncInstrumentation.Flags flags, Cont return; } - HandleSuspended(); + HandleSuspended(ref state); } #pragma warning disable CA1822 // Mark members as static @@ -500,13 +534,17 @@ private unsafe void DispatchContinuations() } } - ExecutionAndSyncBlockStore contexts = default; - contexts.Push(); + RuntimeAsyncStackState stackState = default; + + ref RuntimeAsyncAwaitState awaitState = ref t_runtimeAsyncAwaitState; + awaitState.Push(&stackState); + + ref AsyncDispatcherInfo* refDispatcherInfo = ref AsyncDispatcherInfo.t_current; AsyncDispatcherInfo asyncDispatcherInfo; - asyncDispatcherInfo.Next = AsyncDispatcherInfo.t_current; + asyncDispatcherInfo.Next = refDispatcherInfo; asyncDispatcherInfo.NextContinuation = MoveContinuationState(); - AsyncDispatcherInfo.t_current = &asyncDispatcherInfo; + refDispatcherInfo = &asyncDispatcherInfo; while (true) { @@ -524,10 +562,10 @@ private unsafe void DispatchContinuations() if (newContinuation != null) { newContinuation.Next = nextContinuation; - HandleSuspended(); + HandleSuspended(ref awaitState); - contexts.Pop(); - AsyncDispatcherInfo.t_current = asyncDispatcherInfo.Next; + awaitState.Pop(); + refDispatcherInfo = asyncDispatcherInfo.Next; return; } } @@ -542,9 +580,8 @@ private unsafe void DispatchContinuations() TrySetCanceled(oce.CancellationToken, oce) : TrySetException(ex); - contexts.Pop(); - - AsyncDispatcherInfo.t_current = asyncDispatcherInfo.Next; + awaitState.Pop(); + refDispatcherInfo = asyncDispatcherInfo.Next; if (!successfullySet) { @@ -562,9 +599,8 @@ private unsafe void DispatchContinuations() { bool successfullySet = TrySetResult(m_result); - contexts.Pop(); - - AsyncDispatcherInfo.t_current = asyncDispatcherInfo.Next; + awaitState.Pop(); + refDispatcherInfo = asyncDispatcherInfo.Next; if (!successfullySet) { @@ -576,8 +612,8 @@ private unsafe void DispatchContinuations() if (QueueContinuationFollowUpActionIfNecessary(asyncDispatcherInfo.NextContinuation)) { - contexts.Pop(); - AsyncDispatcherInfo.t_current = asyncDispatcherInfo.Next; + awaitState.Pop(); + refDispatcherInfo = asyncDispatcherInfo.Next; return; } @@ -585,8 +621,8 @@ private unsafe void DispatchContinuations() { SetContinuationState(asyncDispatcherInfo.NextContinuation); - contexts.Pop(); - AsyncDispatcherInfo.t_current = asyncDispatcherInfo.Next; + awaitState.Pop(); + refDispatcherInfo = asyncDispatcherInfo.Next; InstrumentedDispatchContinuations(AsyncInstrumentation.ActiveFlags); return; @@ -597,13 +633,17 @@ private unsafe void DispatchContinuations() [StackTraceHidden] private unsafe void InstrumentedDispatchContinuations(AsyncInstrumentation.Flags flags) { - ExecutionAndSyncBlockStore contexts = default; - contexts.Push(); + RuntimeAsyncStackState stackState = default; + + ref RuntimeAsyncAwaitState awaitState = ref t_runtimeAsyncAwaitState; + awaitState.Push(&stackState); + + ref AsyncDispatcherInfo* refDispatcherInfo = ref AsyncDispatcherInfo.t_current; AsyncDispatcherInfo asyncDispatcherInfo; - asyncDispatcherInfo.Next = AsyncDispatcherInfo.t_current; + asyncDispatcherInfo.Next = refDispatcherInfo; asyncDispatcherInfo.NextContinuation = MoveContinuationState(); - AsyncDispatcherInfo.t_current = &asyncDispatcherInfo; + refDispatcherInfo = &asyncDispatcherInfo; RuntimeAsyncInstrumentationHelpers.ResumeRuntimeAsyncContext(this, ref asyncDispatcherInfo, flags); @@ -625,10 +665,10 @@ private unsafe void InstrumentedDispatchContinuations(AsyncInstrumentation.Flags { newContinuation.Next = nextContinuation; RuntimeAsyncInstrumentationHelpers.SuspendRuntimeAsyncContext(flags, curContinuation, newContinuation); - InstrumentedHandleSuspended(flags, newContinuation); + InstrumentedHandleSuspended(flags, ref awaitState, newContinuation); - contexts.Pop(); - AsyncDispatcherInfo.t_current = asyncDispatcherInfo.Next; + awaitState.Pop(); + refDispatcherInfo = asyncDispatcherInfo.Next; return; } @@ -647,9 +687,8 @@ private unsafe void InstrumentedDispatchContinuations(AsyncInstrumentation.Flags TrySetCanceled(oce.CancellationToken, oce) : TrySetException(ex); - contexts.Pop(); - - AsyncDispatcherInfo.t_current = asyncDispatcherInfo.Next; + awaitState.Pop(); + refDispatcherInfo = asyncDispatcherInfo.Next; if (!successfullySet) { @@ -671,9 +710,8 @@ private unsafe void InstrumentedDispatchContinuations(AsyncInstrumentation.Flags bool successfullySet = TrySetResult(m_result); - contexts.Pop(); - - AsyncDispatcherInfo.t_current = asyncDispatcherInfo.Next; + awaitState.Pop(); + refDispatcherInfo = asyncDispatcherInfo.Next; if (!successfullySet) { @@ -687,8 +725,8 @@ private unsafe void InstrumentedDispatchContinuations(AsyncInstrumentation.Flags { RuntimeAsyncInstrumentationHelpers.SuspendRuntimeAsyncContext(ref asyncDispatcherInfo, flags, curContinuation); - contexts.Pop(); - AsyncDispatcherInfo.t_current = asyncDispatcherInfo.Next; + awaitState.Pop(); + refDispatcherInfo = asyncDispatcherInfo.Next; return; } @@ -795,7 +833,7 @@ private bool QueueContinuationFollowUpActionIfNecessary(Continuation continuatio }; } - private static void InstrumentedFinalizeRuntimeAsyncTask(RuntimeAsyncTask task, AsyncInstrumentation.Flags flags) + private static void InstrumentedFinalizeRuntimeAsyncTask(RuntimeAsyncTask task, ref RuntimeAsyncAwaitState state, AsyncInstrumentation.Flags flags) { if (AsyncInstrumentation.IsEnabled.CreateAsyncContext(flags)) { @@ -806,54 +844,54 @@ private static void InstrumentedFinalizeRuntimeAsyncTask(RuntimeAsyncTask } } - task.InstrumentedHandleSuspended(flags); + task.InstrumentedHandleSuspended(flags, ref state); return; } [MethodImpl(MethodImplOptions.AggressiveInlining)] - private static void FinalizeRuntimeAsyncTask(RuntimeAsyncTask task) + private static void FinalizeRuntimeAsyncTask(ref RuntimeAsyncAwaitState state, RuntimeAsyncTask task) { if (RuntimeAsyncInstrumentationHelpers.InstrumentCheckPoint) { AsyncInstrumentation.Flags flags = AsyncInstrumentation.SyncActiveFlags(); if (flags != AsyncInstrumentation.Flags.Disabled) { - InstrumentedFinalizeRuntimeAsyncTask(task, flags); + InstrumentedFinalizeRuntimeAsyncTask(task, ref state, flags); return; } } - task.HandleSuspended(); + task.HandleSuspended(ref state); } // Change return type to RuntimeAsyncTask -- no benefit since this is used for Task returning thunks only #pragma warning disable CA1859 // When a Task-returning thunk gets a continuation result // it calls here to make a Task that awaits on the current async state. - private static Task FinalizeTaskReturningThunk() + private static Task FinalizeTaskReturningThunk(ref RuntimeAsyncAwaitState state) { RuntimeAsyncTask result = new(); - FinalizeRuntimeAsyncTask(result!); + FinalizeRuntimeAsyncTask(ref state, result!); return result; } - private static Task FinalizeTaskReturningThunk() + private static Task FinalizeTaskReturningThunk(ref RuntimeAsyncAwaitState state) { RuntimeAsyncTask result = new(); - FinalizeRuntimeAsyncTask(result!); + FinalizeRuntimeAsyncTask(ref state, result!); return result; } - private static ValueTask FinalizeValueTaskReturningThunk() + private static ValueTask FinalizeValueTaskReturningThunk(ref RuntimeAsyncAwaitState state) { // We only come to these methods in the expensive case (already // suspended), so ValueTask optimization here is not relevant. - return new ValueTask(FinalizeTaskReturningThunk()); + return new ValueTask(FinalizeTaskReturningThunk(ref state)); } - private static ValueTask FinalizeValueTaskReturningThunk() + private static ValueTask FinalizeValueTaskReturningThunk(ref RuntimeAsyncAwaitState state) { - return new ValueTask(FinalizeTaskReturningThunk()); + return new ValueTask(FinalizeTaskReturningThunk(ref state)); } private static Task TaskFromException(Exception ex) diff --git a/src/coreclr/tools/Common/TypeSystem/IL/Stubs/AsyncThunks.cs b/src/coreclr/tools/Common/TypeSystem/IL/Stubs/AsyncThunks.cs index 2d9891ef57b874..d5f367c6498b0b 100644 --- a/src/coreclr/tools/Common/TypeSystem/IL/Stubs/AsyncThunks.cs +++ b/src/coreclr/tools/Common/TypeSystem/IL/Stubs/AsyncThunks.cs @@ -38,15 +38,22 @@ public static MethodIL EmitTaskReturningThunk(MethodDesc taskReturningMethod, Me ILLocalVariable returnTaskLocal = emitter.NewLocal(returnType); - TypeDesc executionAndSyncBlockStoreType = context.SystemModule.GetKnownType("System.Runtime.CompilerServices"u8, "ExecutionAndSyncBlockStore"u8); - ILLocalVariable executionAndSyncBlockStoreLocal = emitter.NewLocal(executionAndSyncBlockStoreType); + MetadataType asyncHelpersType = context.SystemModule.GetKnownType("System.Runtime.CompilerServices"u8, "AsyncHelpers"u8); + TypeDesc stackStateType = asyncHelpersType.GetKnownNestedType("RuntimeAsyncStackState"u8); + ILLocalVariable stackStateLocal = emitter.NewLocal(stackStateType); + TypeDesc awaitStateType = asyncHelpersType.GetKnownNestedType("RuntimeAsyncAwaitState"u8); + ILLocalVariable refAwaitStateLocal = emitter.NewLocal(awaitStateType.MakeByRefType()); ILCodeLabel returnTaskLabel = emitter.NewCodeLabel(); ILCodeLabel suspendedLabel = emitter.NewCodeLabel(); ILCodeLabel finishedLabel = emitter.NewCodeLabel(); - codestream.EmitLdLoca(executionAndSyncBlockStoreLocal); - codestream.Emit(ILOpcode.call, emitter.NewToken(executionAndSyncBlockStoreType.GetKnownMethod("Push"u8, null))); + codestream.Emit(ILOpcode.ldsflda, emitter.NewToken(asyncHelpersType.GetKnownField("t_runtimeAsyncAwaitState"u8))); + codestream.EmitStLoc(refAwaitStateLocal); + + codestream.EmitLdLoc(refAwaitStateLocal); + codestream.EmitLdLoca(stackStateLocal); + codestream.Emit(ILOpcode.call, emitter.NewToken(awaitStateType.GetKnownMethod("Push"u8, null))); ILExceptionRegionBuilder tryFinallyRegion = emitter.NewFinallyRegion(); { @@ -90,9 +97,7 @@ public static MethodIL EmitTaskReturningThunk(MethodDesc taskReturningMethod, Me codestream.EmitStLoc(logicalResultLocal); } - MethodDesc asyncCallContinuationMd = context.SystemModule - .GetKnownType("System.Runtime.CompilerServices"u8, "AsyncHelpers"u8) - .GetKnownMethod("AsyncCallContinuation"u8, null); + MethodDesc asyncCallContinuationMd = asyncHelpersType.GetKnownMethod("AsyncCallContinuation"u8, null); codestream.Emit(ILOpcode.call, emitter.NewToken(asyncCallContinuationMd)); @@ -161,8 +166,7 @@ public static MethodIL EmitTaskReturningThunk(MethodDesc taskReturningMethod, Me parameters: new[] { exceptionType } ); - fromExceptionMd = context.SystemModule - .GetKnownType("System.Runtime.CompilerServices"u8, "AsyncHelpers"u8) + fromExceptionMd = asyncHelpersType .GetKnownMethod(isValueTask ? "ValueTaskFromException"u8 : "TaskFromException"u8, fromExceptionSignature) .MakeInstantiatedMethod(new Instantiation(logicalReturnType)); } @@ -175,8 +179,7 @@ public static MethodIL EmitTaskReturningThunk(MethodDesc taskReturningMethod, Me parameters: new[] { exceptionType } ); - fromExceptionMd = context.SystemModule - .GetKnownType("System.Runtime.CompilerServices"u8, "AsyncHelpers"u8) + fromExceptionMd = asyncHelpersType .GetKnownMethod(isValueTask ? "ValueTaskFromException"u8 : "TaskFromException"u8, fromExceptionSignature); } @@ -195,11 +198,10 @@ public static MethodIL EmitTaskReturningThunk(MethodDesc taskReturningMethod, Me MethodSignatureFlags.Static, genericParameterCount: 1, returnType: ((MetadataType)returnType.GetTypeDefinition()).MakeInstantiatedType(context.GetSignatureVariable(0, true)), - parameters: Array.Empty() + parameters: [awaitStateType.MakeByRefType()] ); - finalizeTaskReturningThunkMd = context.SystemModule - .GetKnownType("System.Runtime.CompilerServices"u8, "AsyncHelpers"u8) + finalizeTaskReturningThunkMd = asyncHelpersType .GetKnownMethod(isValueTask ? "FinalizeValueTaskReturningThunk"u8 : "FinalizeTaskReturningThunk"u8, finalizeReturningThunkSignature) .MakeInstantiatedMethod(new Instantiation(logicalReturnType)); } @@ -209,14 +211,14 @@ public static MethodIL EmitTaskReturningThunk(MethodDesc taskReturningMethod, Me MethodSignatureFlags.Static, genericParameterCount: 0, returnType: returnType, - parameters: Array.Empty() + parameters: [awaitStateType.MakeByRefType()] ); - finalizeTaskReturningThunkMd = context.SystemModule - .GetKnownType("System.Runtime.CompilerServices"u8, "AsyncHelpers"u8) + finalizeTaskReturningThunkMd = asyncHelpersType .GetKnownMethod(isValueTask ? "FinalizeValueTaskReturningThunk"u8 : "FinalizeTaskReturningThunk"u8, finalizeReturningThunkSignature); } + codestream.EmitLdLoc(refAwaitStateLocal); codestream.Emit(ILOpcode.call, emitter.NewToken(finalizeTaskReturningThunkMd)); codestream.EmitStLoc(returnTaskLocal); codestream.Emit(ILOpcode.leave, returnTaskLabel); @@ -227,8 +229,8 @@ public static MethodIL EmitTaskReturningThunk(MethodDesc taskReturningMethod, Me { codestream.BeginHandler(tryFinallyRegion); - codestream.EmitLdLoca(executionAndSyncBlockStoreLocal); - codestream.Emit(ILOpcode.call, emitter.NewToken(executionAndSyncBlockStoreType.GetKnownMethod("Pop"u8, null))); + codestream.EmitLdLoc(refAwaitStateLocal); + codestream.Emit(ILOpcode.call, emitter.NewToken(awaitStateType.GetKnownMethod("Pop"u8, null))); codestream.Emit(ILOpcode.endfinally); codestream.EndHandler(tryFinallyRegion); } diff --git a/src/coreclr/vm/asyncthunks.cpp b/src/coreclr/vm/asyncthunks.cpp index 3fa631527cf29c..ef45a4a41c715d 100644 --- a/src/coreclr/vm/asyncthunks.cpp +++ b/src/coreclr/vm/asyncthunks.cpp @@ -93,8 +93,10 @@ void MethodDesc::EmitTaskReturningThunk(MethodDesc* pAsyncCallVariant, MetaSig& // Emits roughly the following code: // - // ExecutionAndSyncBlockStore store = default; - // store.Push(); + // RuntimeAsyncStackState stackState; + // ref RuntimeAsyncAwaitState awaitState = ref AsyncHelpers.t_runtimeAsyncAwaitState; + // awaitState.Push(&stackState); + // // try // { // try @@ -104,7 +106,7 @@ void MethodDesc::EmitTaskReturningThunk(MethodDesc* pAsyncCallVariant, MetaSig& // if (AsyncHelpers.AsyncCallContinuation() == null) // return Task.FromResult(result); // - // return FinalizeTaskReturningThunk(); + // return FinalizeTaskReturningThunk(ref awaitState); // } // catch (Exception ex) // { @@ -113,7 +115,7 @@ void MethodDesc::EmitTaskReturningThunk(MethodDesc* pAsyncCallVariant, MetaSig& // } // finally // { - // store.Pop(); + // awaitState.Pop(); // } ILCodeStream* pCode = pSL->NewCodeStream(ILStubLinker::kDispatch); @@ -132,15 +134,24 @@ void MethodDesc::EmitTaskReturningThunk(MethodDesc* pAsyncCallVariant, MetaSig& LocalDesc returnLocalDesc(thTaskRet); DWORD returnTaskLocal = pCode->NewLocal(returnLocalDesc); - LocalDesc executionAndSyncBlockStoreLocalDesc(CoreLibBinder::GetClass(CLASS__EXECUTIONANDSYNCBLOCKSTORE)); - DWORD executionAndSyncBlockStoreLocal = pCode->NewLocal(executionAndSyncBlockStoreLocalDesc); + + LocalDesc stackStateLocalDesc(TypeHandle(CoreLibBinder::GetClass(CLASS__RUNTIME_ASYNC_STACK_STATE))); + DWORD stackStateLocal = pCode->NewLocal(stackStateLocalDesc); + + LocalDesc refAwaitStateLocalDesc(TypeHandle(CoreLibBinder::GetClass(CLASS__RUNTIME_ASYNC_AWAIT_STATE))); + refAwaitStateLocalDesc.MakeByRef(); + DWORD refAwaitStateLocal = pCode->NewLocal(refAwaitStateLocalDesc); ILCodeLabel* returnTaskLabel = pCode->NewCodeLabel(); ILCodeLabel* suspendedLabel = pCode->NewCodeLabel(); ILCodeLabel* finishedLabel = pCode->NewCodeLabel(); - pCode->EmitLDLOCA(executionAndSyncBlockStoreLocal); - pCode->EmitCALL(pCode->GetToken(CoreLibBinder::GetMethod(METHOD__EXECUTIONANDSYNCBLOCKSTORE__PUSH)), 1, 0); + pCode->EmitLDSFLDA(pCode->GetToken(CoreLibBinder::GetField(FIELD__ASYNC_HELPERS__TLS_RUNTIME_ASYNC_AWAIT_STATE))); + pCode->EmitSTLOC(refAwaitStateLocal); + + pCode->EmitLDLOC(refAwaitStateLocal); + pCode->EmitLDLOCA(stackStateLocal); + pCode->EmitCALL(METHOD__RUNTIME_ASYNC_AWAIT_STATE__PUSH, 2, 0); { pCode->BeginTryBlock(); @@ -252,7 +263,8 @@ void MethodDesc::EmitTaskReturningThunk(MethodDesc* pAsyncCallVariant, MetaSig& md = CoreLibBinder::GetMethod(METHOD__ASYNC_HELPERS__FINALIZE_TASK_RETURNING_THUNK); finalizeTaskReturningThunkToken = pCode->GetToken(md); } - pCode->EmitCALL(finalizeTaskReturningThunkToken, 0, 1); + pCode->EmitLDLOC(refAwaitStateLocal); + pCode->EmitCALL(finalizeTaskReturningThunkToken, 1, 1); pCode->EmitSTLOC(returnTaskLocal); pCode->EmitLEAVE(returnTaskLabel); @@ -261,8 +273,8 @@ void MethodDesc::EmitTaskReturningThunk(MethodDesc* pAsyncCallVariant, MetaSig& // { pCode->BeginFinallyBlock(); - pCode->EmitLDLOCA(executionAndSyncBlockStoreLocal); - pCode->EmitCALL(pCode->GetToken(CoreLibBinder::GetMethod(METHOD__EXECUTIONANDSYNCBLOCKSTORE__POP)), 1, 0); + pCode->EmitLDLOC(refAwaitStateLocal); + pCode->EmitCALL(METHOD__RUNTIME_ASYNC_AWAIT_STATE__POP, 1, 0); pCode->EmitENDFINALLY(); pCode->EndFinallyBlock(); } diff --git a/src/coreclr/vm/corelib.h b/src/coreclr/vm/corelib.h index f75fdcce0c832d..674942ce365a61 100644 --- a/src/coreclr/vm/corelib.h +++ b/src/coreclr/vm/corelib.h @@ -683,10 +683,6 @@ DEFINE_CLASS(RESOURCE_MANAGER, Resources, ResourceManager) DEFINE_CLASS(RTFIELD, Reflection, RtFieldInfo) DEFINE_METHOD(RTFIELD, GET_FIELDESC, GetFieldDesc, IM_RetIntPtr) -DEFINE_CLASS(EXECUTIONANDSYNCBLOCKSTORE, CompilerServices, ExecutionAndSyncBlockStore) -DEFINE_METHOD(EXECUTIONANDSYNCBLOCKSTORE, PUSH, Push, NoSig) -DEFINE_METHOD(EXECUTIONANDSYNCBLOCKSTORE, POP, Pop, NoSig) - DEFINE_CLASS(RUNTIME_HELPERS, CompilerServices, RuntimeHelpers) DEFINE_METHOD(RUNTIME_HELPERS, IS_BITWISE_EQUATABLE, IsBitwiseEquatable, NoSig) DEFINE_METHOD(RUNTIME_HELPERS, GET_RAW_DATA, GetRawData, NoSig) @@ -711,10 +707,10 @@ DEFINE_METHOD(ASYNC_HELPERS, ALLOC_CONTINUATION, AllocContinuation, DEFINE_METHOD(ASYNC_HELPERS, ALLOC_CONTINUATION_METHOD, AllocContinuationMethod, NoSig) DEFINE_METHOD(ASYNC_HELPERS, ALLOC_CONTINUATION_CLASS, AllocContinuationClass, NoSig) -DEFINE_METHOD(ASYNC_HELPERS, FINALIZE_TASK_RETURNING_THUNK, FinalizeTaskReturningThunk, SM_RetTask) -DEFINE_METHOD(ASYNC_HELPERS, FINALIZE_TASK_RETURNING_THUNK_1, FinalizeTaskReturningThunk, GM_RetTaskOfT) -DEFINE_METHOD(ASYNC_HELPERS, FINALIZE_VALUETASK_RETURNING_THUNK, FinalizeValueTaskReturningThunk, SM_RetValueTask) -DEFINE_METHOD(ASYNC_HELPERS, FINALIZE_VALUETASK_RETURNING_THUNK_1, FinalizeValueTaskReturningThunk, GM_RetValueTaskOfT) +DEFINE_METHOD(ASYNC_HELPERS, FINALIZE_TASK_RETURNING_THUNK, FinalizeTaskReturningThunk, SM_RefRuntimeAsyncAwaitState_RetTask) +DEFINE_METHOD(ASYNC_HELPERS, FINALIZE_TASK_RETURNING_THUNK_1, FinalizeTaskReturningThunk, GM_RefRuntimeAsyncAwaitState_RetTaskOfT) +DEFINE_METHOD(ASYNC_HELPERS, FINALIZE_VALUETASK_RETURNING_THUNK, FinalizeValueTaskReturningThunk, SM_RefRuntimeAsyncAwaitState_RetValueTask) +DEFINE_METHOD(ASYNC_HELPERS, FINALIZE_VALUETASK_RETURNING_THUNK_1, FinalizeValueTaskReturningThunk, GM_RefRuntimeAsyncAwaitState_RetValueTaskOfT) DEFINE_METHOD(ASYNC_HELPERS, TASK_FROM_EXCEPTION, TaskFromException, SM_Exception_RetTask) DEFINE_METHOD(ASYNC_HELPERS, TASK_FROM_EXCEPTION_1, TaskFromException, GM_Exception_RetTaskOfT) @@ -732,11 +728,18 @@ DEFINE_METHOD(ASYNC_HELPERS, RESTORE_CONTEXTS, RestoreContexts, No DEFINE_METHOD(ASYNC_HELPERS, RESTORE_CONTEXTS_ON_SUSPENSION, RestoreContextsOnSuspension, NoSig) DEFINE_METHOD(ASYNC_HELPERS, ASYNC_CALL_CONTINUATION, AsyncCallContinuation, NoSig) DEFINE_METHOD(ASYNC_HELPERS, TAIL_AWAIT, TailAwait, NoSig) +DEFINE_FIELD(ASYNC_HELPERS, TLS_RUNTIME_ASYNC_AWAIT_STATE, t_runtimeAsyncAwaitState) #ifdef FEATURE_INTERPRETER DEFINE_METHOD(ASYNC_HELPERS, RESUME_INTERPRETER_CONTINUATION, ResumeInterpreterContinuation, NoSig) #endif +DEFINE_CLASS(RUNTIME_ASYNC_AWAIT_STATE, CompilerServices, AsyncHelpers+RuntimeAsyncAwaitState) +DEFINE_METHOD(RUNTIME_ASYNC_AWAIT_STATE, PUSH, Push, NoSig) +DEFINE_METHOD(RUNTIME_ASYNC_AWAIT_STATE, POP, Pop, NoSig) + +DEFINE_CLASS(RUNTIME_ASYNC_STACK_STATE, CompilerServices, AsyncHelpers+RuntimeAsyncStackState) + DEFINE_CLASS_U(CompilerServices, Continuation, ContinuationObject) DEFINE_FIELD_U(Next, ContinuationObject, Next) DEFINE_FIELD_U(ResumeInfo, ContinuationObject, ResumeInfo) @@ -978,11 +981,11 @@ DEFINE_FIELD(EXECUTIONCONTEXT, DEFAULT_FLOW_SUPPRESSED, DefaultFlowSu DEFINE_CLASS(DIRECTONTHREADLOCALDATA, Threading, Thread+DirectOnThreadLocalData) DEFINE_CLASS(THREAD, Threading, Thread) -DEFINE_METHOD(THREAD, START_CALLBACK, StartCallback, SM_PtrThread_RetVoid) +DEFINE_METHOD(THREAD, START_CALLBACK, StartCallback, SM_PtrThread_RetVoid) DEFINE_METHOD(THREAD, POLLGC, PollGC, NoSig) DEFINE_METHOD(THREAD, ON_THREAD_EXITING, OnThreadExited, SM_PtrThread_PtrException_RetVoid) #ifdef FOR_ILLINK -DEFINE_METHOD(THREAD, CTOR, .ctor, IM_RetVoid) +DEFINE_METHOD(THREAD, CTOR, .ctor, IM_RetVoid) #endif // FOR_ILLINK #ifdef FEATURE_OBJCMARSHAL diff --git a/src/coreclr/vm/metasig.h b/src/coreclr/vm/metasig.h index 0d631bf34b4db0..d304d63ff9d908 100644 --- a/src/coreclr/vm/metasig.h +++ b/src/coreclr/vm/metasig.h @@ -547,8 +547,10 @@ DEFINE_METASIG(SM(PtrByte_RetStr, P(b), s)) DEFINE_METASIG(SM(Str_RetPtrByte, s, P(b))) DEFINE_METASIG(SM(PtrByte_RetVoid, P(b), v)) -DEFINE_METASIG_T(GM(RetTaskOfT, IMAGE_CEE_CS_CALLCONV_DEFAULT, 1, , GI(C(TASK_1), 1, M(0)))) -DEFINE_METASIG_T(GM(RetValueTaskOfT, IMAGE_CEE_CS_CALLCONV_DEFAULT, 1, , GI(g(VALUETASK_1), 1, M(0)))) +DEFINE_METASIG_T(SM(RefRuntimeAsyncAwaitState_RetTask, r(g(RUNTIME_ASYNC_AWAIT_STATE)), C(TASK))) +DEFINE_METASIG_T(SM(RefRuntimeAsyncAwaitState_RetValueTask, r(g(RUNTIME_ASYNC_AWAIT_STATE)), g(VALUETASK))) +DEFINE_METASIG_T(GM(RefRuntimeAsyncAwaitState_RetTaskOfT, IMAGE_CEE_CS_CALLCONV_DEFAULT, 1, r(g(RUNTIME_ASYNC_AWAIT_STATE)), GI(C(TASK_1), 1, M(0)))) +DEFINE_METASIG_T(GM(RefRuntimeAsyncAwaitState_RetValueTaskOfT, IMAGE_CEE_CS_CALLCONV_DEFAULT, 1, r(g(RUNTIME_ASYNC_AWAIT_STATE)), GI(g(VALUETASK_1), 1, M(0)))) // Undefine macros in case we include the file again in the compilation unit diff --git a/src/libraries/System.Private.CoreLib/src/System/Runtime/CompilerServices/AsyncHelpers.cs b/src/libraries/System.Private.CoreLib/src/System/Runtime/CompilerServices/AsyncHelpers.cs index 0d5ff2b7ca8143..510b863d4dda8e 100644 --- a/src/libraries/System.Private.CoreLib/src/System/Runtime/CompilerServices/AsyncHelpers.cs +++ b/src/libraries/System.Private.CoreLib/src/System/Runtime/CompilerServices/AsyncHelpers.cs @@ -25,14 +25,11 @@ public static partial class AsyncHelpers [BypassReadyToRun] [MethodImpl(MethodImplOptions.NoInlining | MethodImplOptions.Async)] [StackTraceHidden] - public static void AwaitAwaiter(TAwaiter awaiter) where TAwaiter : INotifyCompletion + public static unsafe void AwaitAwaiter(TAwaiter awaiter) where TAwaiter : INotifyCompletion { ref RuntimeAsyncAwaitState state = ref t_runtimeAsyncAwaitState; - Continuation? sentinelContinuation = state.SentinelContinuation; - if (sentinelContinuation == null) - state.SentinelContinuation = sentinelContinuation = new Continuation(); - - state.Notifier = awaiter; + Continuation? sentinelContinuation = state.SentinelContinuation ??= new Continuation(); + state.StackState->Notifier = awaiter; state.CaptureContexts(); AsyncSuspend(sentinelContinuation); } @@ -48,14 +45,11 @@ public static void AwaitAwaiter(TAwaiter awaiter) where TAwaiter : INo [BypassReadyToRun] [MethodImpl(MethodImplOptions.NoInlining | MethodImplOptions.Async)] [StackTraceHidden] - public static void UnsafeAwaitAwaiter(TAwaiter awaiter) where TAwaiter : ICriticalNotifyCompletion + public static unsafe void UnsafeAwaitAwaiter(TAwaiter awaiter) where TAwaiter : ICriticalNotifyCompletion { ref RuntimeAsyncAwaitState state = ref t_runtimeAsyncAwaitState; - Continuation? sentinelContinuation = state.SentinelContinuation; - if (sentinelContinuation == null) - state.SentinelContinuation = sentinelContinuation = new Continuation(); - - state.CriticalNotifier = awaiter; + Continuation? sentinelContinuation = state.SentinelContinuation ??= new Continuation(); + state.StackState->CriticalNotifier = awaiter; state.CaptureContexts(); AsyncSuspend(sentinelContinuation); }