diff --git a/src/coreclr/jit/async.cpp b/src/coreclr/jit/async.cpp index 31384f45a64e12..dc5e45ffa8d43e 100644 --- a/src/coreclr/jit/async.cpp +++ b/src/coreclr/jit/async.cpp @@ -544,39 +544,35 @@ PhaseStatus AsyncTransformation::Run() PhaseStatus result = PhaseStatus::MODIFIED_NOTHING; ArrayStack blocksWithNormalAwaits(m_compiler->getAllocator(CMK_Async)); ArrayStack blocksWithTailAwaits(m_compiler->getAllocator(CMK_Async)); - int numNormalAwaits = 0; - int numTailAwaits = 0; - FindAwaits(blocksWithNormalAwaits, blocksWithTailAwaits, &numNormalAwaits, &numTailAwaits); + AggregatedAwaitInfo awaits = FindAwaits(blocksWithNormalAwaits, blocksWithTailAwaits); - if (numNormalAwaits + numTailAwaits > 1) + if (awaits.NumNormalAwaits + awaits.NumTailAwaits > 1) { CreateSharedReturnBB(); } // Transform all tail awaits first. They will not require running all of // our analyses. - if (numTailAwaits > 0) + if (awaits.NumTailAwaits > 0) { - JITDUMP("Found %d tail awaits in %d blocks\n", numTailAwaits, blocksWithTailAwaits.Height()); + JITDUMP("Found %u tail awaits in %d blocks\n", awaits.NumTailAwaits, blocksWithTailAwaits.Height()); TransformTailAwaits(blocksWithTailAwaits); m_compiler->fgInvalidateDfsTree(); - if (numNormalAwaits > 0) + if (awaits.NumNormalAwaits > 0) { // This may have changed blocks, so refind the normal awaits. blocksWithNormalAwaits.Reset(); blocksWithTailAwaits.Reset(); - numNormalAwaits = 0; - numTailAwaits = 0; - FindAwaits(blocksWithNormalAwaits, blocksWithTailAwaits, &numNormalAwaits, &numTailAwaits); + awaits = FindAwaits(blocksWithNormalAwaits, blocksWithTailAwaits); } result = PhaseStatus::MODIFIED_EVERYTHING; } - JITDUMP("Found %d awaits in %d blocks\n", numNormalAwaits, blocksWithNormalAwaits.Height()); + JITDUMP("Found %u awaits in %d blocks\n", awaits.NumNormalAwaits, blocksWithNormalAwaits.Height()); - if (numNormalAwaits <= 0) + if (awaits.NumNormalAwaits <= 0) { return result; } @@ -733,34 +729,40 @@ PhaseStatus AsyncTransformation::Run() // Parameters: // blocksWithNormalAwaits - [out] Blocks with normal awaits are pushed onto this stack // blocksWithTailAwaits - [out] Blocks with tail awaits are pushed onto this stack -// numNormalAwaits - [out] Number of normal awaits found -// numTailAwaits - [out] Number of tail awaits found // -void AsyncTransformation::FindAwaits(ArrayStack& blocksWithNormalAwaits, - ArrayStack& blocksWithTailAwaits, - int* numNormalAwaits, - int* numTailAwaits) +// Returns: +// Information about awaits in the function. +// +AggregatedAwaitInfo AsyncTransformation::FindAwaits(ArrayStack& blocksWithNormalAwaits, + ArrayStack& blocksWithTailAwaits) { + AggregatedAwaitInfo awaits; for (BasicBlock* block : m_compiler->Blocks()) { bool hasNormalAwait = false; bool hasTailAwait = false; for (GenTree* tree : LIR::AsRange(block)) { - if (!tree->IsCall() || !tree->AsCall()->IsAsync() || tree->AsCall()->IsTailCall()) + if (!tree->IsCall()) + { + continue; + } + + GenTreeCall* call = tree->AsCall(); + if (!call->IsAsync() || call->IsTailCall()) { continue; } - if (tree->AsCall()->GetAsyncInfo().IsTailAwait) + if (call->GetAsyncInfo().IsTailAwait) { hasTailAwait = true; - (*numTailAwaits)++; + awaits.NumTailAwaits++; } else { hasNormalAwait = true; - (*numNormalAwaits)++; + awaits.NumNormalAwaits++; } } @@ -774,6 +776,8 @@ void AsyncTransformation::FindAwaits(ArrayStack& blocksWithNormalAw blocksWithTailAwaits.Push(block); } } + + return awaits; } //------------------------------------------------------------------------ @@ -1235,8 +1239,12 @@ void AsyncTransformation::BuildContinuation(BasicBlock* block, JITDUMP(" Continuation will have keep alive object\n"); } - layoutBuilder->SetNeedsExecutionContext(); - JITDUMP(" Call has async-only save and restore of ExecutionContext; continuation will have ExecutionContext\n"); + if (call->GetAsyncInfo().NeedsToSaveAndRestoreExecutionContext()) + { + layoutBuilder->SetNeedsExecutionContext(); + JITDUMP( + " Call has async-only save and restore of ExecutionContext; continuation will have ExecutionContext\n"); + } } #ifdef DEBUG @@ -1645,22 +1653,28 @@ CallDefinitionInfo AsyncTransformation::CanonicalizeCallDefinition(BasicBlock* // BasicBlock* AsyncTransformation::CreateSuspensionBlock(BasicBlock* block, unsigned stateNum) { + BasicBlock* suspendBB; if (m_lastSuspensionBB == nullptr) { - m_lastSuspensionBB = m_compiler->fgLastBBInMainFunction(); + if (m_sharedReturnBB != nullptr) + { + suspendBB = m_compiler->fgNewBBbefore(BBJ_RETURN, m_sharedReturnBB, false); + } + else + { + m_lastSuspensionBB = m_compiler->fgLastBBInMainFunction(); + suspendBB = m_compiler->fgNewBBafter(BBJ_RETURN, m_lastSuspensionBB, false); + } + } + else + { + suspendBB = m_compiler->fgNewBBafter(BBJ_RETURN, m_lastSuspensionBB, false); } - BasicBlock* suspendBB = m_compiler->fgNewBBafter(BBJ_RETURN, m_lastSuspensionBB, false); suspendBB->clearTryIndex(); suspendBB->clearHndIndex(); suspendBB->inheritWeightPercentage(block, 0); m_lastSuspensionBB = suspendBB; - - if (m_sharedReturnBB != nullptr) - { - suspendBB->SetKindAndTargetEdge(BBJ_ALWAYS, m_compiler->fgAddRefPred(m_sharedReturnBB, suspendBB)); - } - JITDUMP(" Creating suspension " FMT_BB " for state %u\n", suspendBB->bbNum, stateNum); return suspendBB; @@ -1824,14 +1838,7 @@ void AsyncTransformation::CreateSuspension(BasicBlock* call FillInDataOnSuspension(call, layout, subLayout, suspendBB, mutatedSinceResumption, tailSaveSet); - FinishContextHandlingOnSuspension(callBlock, call, suspendBB, layout, subLayout); - - if (suspendBB->KindIs(BBJ_RETURN)) - { - newContinuation = m_compiler->gtNewLclvNode(newContinuationVar, TYP_REF); - GenTree* ret = m_compiler->gtNewOperNode(GT_RETURN_SUSPEND, TYP_VOID, newContinuation); - LIR::AsRange(suspendBB).InsertAtEnd(newContinuation, ret); - } + FinishContextHandlingAndSuspension(callBlock, call, suspendBB, layout, subLayout); } //------------------------------------------------------------------------ @@ -2005,13 +2012,56 @@ SaveSet AsyncTransformation::GetLocalSaveSet(const LclVarDsc* dsc, VARSET_VALARG } //------------------------------------------------------------------------ -// AsyncTransformation::FinishContextHandlingOnSuspension: -// Generate code to finish handling of contexts on suspension: +// AsyncTransformation::GetSuspensionContextHelper: +// Figure out what context handling helper can be used during suspension. +// +// Parameters: +// call - The async call +// +// Returns: +// Kind of helper that can be used, or None if no helper can be used. +// +// Remarks: +// - No helper exists for the case where there are no contexts to restore +// (when CORINFO_ASYNC_SAVE_CONTEXTS was not given by the EE), or when no +// execution context needs to be saved/restored. The former happens only in +// thunks where there is only one async call anyway, while the latter never +// currently happens. +// - We have two different helpers, depending on whether a continuation context needs to be captured. +// + For task awaits with ConfigureAwait(false), or for custom awaits, no continuation context is needed +// + For normal task awaits a continuation context is needed +// +SuspensionContextHelper AsyncTransformation::GetSuspensionContextHelper(GenTreeCall* call) +{ + CallArg* execContextArg = call->gtArgs.FindWellKnownArg(WellKnownArg::AsyncExecutionContext); + CallArg* syncContextArg = call->gtArgs.FindWellKnownArg(WellKnownArg::AsyncSynchronizationContext); + assert((execContextArg != nullptr) == (syncContextArg != nullptr)); + + // In most cases we can use a helper. It is not the case when the call has + // no contexts to restore, which is the case for task-returning thunks or + // more specifically when the EE told us !CORINFO_ASYNC_SAVE_CONTEXTS. + if ((execContextArg == nullptr) || !call->GetAsyncInfo().NeedsToSaveAndRestoreExecutionContext()) + { + return SuspensionContextHelper::None; + } + + if (call->GetAsyncInfo().ContinuationContextHandling == ContinuationContextHandling::ContinueOnCapturedContext) + { + return SuspensionContextHelper::WithContinuationContext; + } + + return SuspensionContextHelper::WithoutContinuationContext; +} + +//------------------------------------------------------------------------ +// AsyncTransformation::FinishContextHandlingAndSuspension: +// Generate code to finish handling of contexts on suspension, and finish the suspension: // - Capture SynchronizationContext or TaskScheduler into the continuation // if needed when later resuming // - Capture ExecutionContext into the continuation // - Restore current Thread._synchronizationContext and // Thread._executionContext from the state before the async call +// - Return continuation back to caller. // // Parameters: // callBlock - The block containing the async call @@ -2020,24 +2070,19 @@ SaveSet AsyncTransformation::GetLocalSaveSet(const LclVarDsc* dsc, VARSET_VALARG // layout - Information about the continuation layout. // subLayout - Per-call layout builder indicating which fields are needed. // -void AsyncTransformation::FinishContextHandlingOnSuspension(BasicBlock* callBlock, - GenTreeCall* call, - BasicBlock* suspendBB, - const ContinuationLayout& layout, - const ContinuationLayoutBuilder& subLayout) +void AsyncTransformation::FinishContextHandlingAndSuspension(BasicBlock* callBlock, + GenTreeCall* call, + BasicBlock* suspendBB, + const ContinuationLayout& layout, + const ContinuationLayoutBuilder& subLayout) { - CallArg* execContextArg = call->gtArgs.FindWellKnownArg(WellKnownArg::AsyncExecutionContext); - CallArg* syncContextArg = call->gtArgs.FindWellKnownArg(WellKnownArg::AsyncSynchronizationContext); - assert((execContextArg != nullptr) == (syncContextArg != nullptr)); + SuspensionContextHelper helper = GetSuspensionContextHelper(call); - // In most cases we can use a helper. It is not the case when the call has - // no contexts to restore, which is the case for task-returning thunks or - // more specifically when the EE told us !CORINFO_ASYNC_SAVE_CONTEXTS. - if (execContextArg != nullptr && subLayout.NeedsExecutionContext()) + if (helper != SuspensionContextHelper::None) { JITDUMP(" Call [%06u] has async context and captured execution context; using finish-suspension helper\n", Compiler::dspTreeID(call)); - FinishContextHandlingOnSuspensionWithHelper(callBlock, call, suspendBB, layout, subLayout); + FinishContextHandlingAndSuspensionWithHelper(callBlock, call, suspendBB, layout, subLayout, helper); return; } @@ -2108,10 +2153,23 @@ void AsyncTransformation::FinishContextHandlingOnSuspension(BasicBlock* } RestoreContexts(callBlock, call, suspendBB); + + assert(suspendBB->KindIs(BBJ_RETURN)); + + if (m_sharedReturnBB != nullptr) + { + suspendBB->SetKindAndTargetEdge(BBJ_ALWAYS, m_compiler->fgAddRefPred(m_sharedReturnBB, suspendBB)); + } + else + { + GenTree* newContinuation = m_compiler->gtNewLclvNode(GetNewContinuationVar(), TYP_REF); + GenTree* ret = m_compiler->gtNewOperNode(GT_RETURN_SUSPEND, TYP_VOID, newContinuation); + LIR::AsRange(suspendBB).InsertAtEnd(newContinuation, ret); + } } //------------------------------------------------------------------------ -// AsyncTransformation::FinishContextHandlingOnSuspensionWithHelper: +// AsyncTransformation::FinishContextHandlingAndSuspensionWithHelper: // Generate code to finish handling of contexts on suspension by calling into a helper. // // Parameters: @@ -2126,119 +2184,29 @@ void AsyncTransformation::FinishContextHandlingOnSuspension(BasicBlock* // context restores. We do that with a single helper call that does // everything, for both size and to avoid multiple loads of the Thread TLS. // -void AsyncTransformation::FinishContextHandlingOnSuspensionWithHelper(BasicBlock* callBlock, - GenTreeCall* call, - BasicBlock* suspendBB, - const ContinuationLayout& layout, - const ContinuationLayoutBuilder& subLayout) +void AsyncTransformation::FinishContextHandlingAndSuspensionWithHelper(BasicBlock* callBlock, + GenTreeCall* call, + BasicBlock* suspendBB, + const ContinuationLayout& layout, + const ContinuationLayoutBuilder& subLayout, + SuspensionContextHelper helper) { - CORINFO_METHOD_HANDLE helper = subLayout.NeedsContinuationContext() - ? m_asyncInfo->finishSuspensionWithContinuationContextMethHnd - : m_asyncInfo->finishSuspensionNoContinuationContextMethHnd; + assert(helper != SuspensionContextHelper::None); + assert((helper != SuspensionContextHelper::WithContinuationContext) || subLayout.NeedsContinuationContext()); - // Insert call - // finishSuspension[With|No]ContinuationContext( - // ref newContinuation.ContinuationContext, // optional - // ref newContinuation.Flags, // optional - // ref newContinuation.ExecutionContext, - // resumed, - // execContext, - // syncContext) - // + BasicBlock* sharedFinish = (helper == SuspensionContextHelper::WithContinuationContext) + ? m_sharedFinishContextHandlingWithContinuationContextBB + : m_sharedFinishContextHandlingWithoutContinuationContextBB; CallArg* execContextArg = call->gtArgs.FindWellKnownArg(WellKnownArg::AsyncExecutionContext); CallArg* syncContextArg = call->gtArgs.FindWellKnownArg(WellKnownArg::AsyncSynchronizationContext); assert((execContextArg != nullptr) && (syncContextArg != nullptr)); - GenTree* contContextAddrPlaceholder = nullptr; - GenTree* flagsPlaceholder = nullptr; - GenTree* execContextAddrPlaceholder = m_compiler->gtNewZeroConNode(TYP_BYREF); - GenTree* resumedPlaceholder = m_compiler->gtNewIconNode(0); - GenTree* execContextPlaceholder = m_compiler->gtNewNull(); - GenTree* syncContextPlaceholder = m_compiler->gtNewNull(); - - GenTreeCall* finishCall = m_compiler->gtNewCallNode(CT_USER_FUNC, helper, TYP_VOID); - SetCallEntrypointForR2R(finishCall, m_compiler, helper); - - finishCall->gtArgs.PushFront(m_compiler, NewCallArg::Primitive(syncContextPlaceholder)); - finishCall->gtArgs.PushFront(m_compiler, NewCallArg::Primitive(execContextPlaceholder)); - finishCall->gtArgs.PushFront(m_compiler, NewCallArg::Primitive(resumedPlaceholder)); - finishCall->gtArgs.PushFront(m_compiler, NewCallArg::Primitive(execContextAddrPlaceholder)); - - if (subLayout.NeedsContinuationContext()) - { - contContextAddrPlaceholder = m_compiler->gtNewZeroConNode(TYP_BYREF); - flagsPlaceholder = m_compiler->gtNewZeroConNode(TYP_BYREF); - finishCall->gtArgs.PushFront(m_compiler, NewCallArg::Primitive(flagsPlaceholder)); - finishCall->gtArgs.PushFront(m_compiler, NewCallArg::Primitive(contContextAddrPlaceholder)); - } - - m_compiler->compCurBB = suspendBB; - m_compiler->fgMorphTree(finishCall); - - LIR::AsRange(suspendBB).InsertAtEnd(LIR::SeqTree(m_compiler, finishCall)); - - if (subLayout.NeedsContinuationContext()) - { - // Replace contContextAddrPlaceholder with actual address of the continuation context - LIR::Use use; - bool gotUse = LIR::AsRange(suspendBB).TryGetUse(contContextAddrPlaceholder, &use); - assert(gotUse); - - GenTree* newContinuation = m_compiler->gtNewLclvNode(GetNewContinuationVar(), TYP_REF); - unsigned contContextOffset = OFFSETOF__CORINFO_Continuation__data + layout.ContinuationContextOffset; - GenTree* contContextAddrOffset = - m_compiler->gtNewOperNode(GT_ADD, TYP_BYREF, newContinuation, - m_compiler->gtNewIconNode((ssize_t)contContextOffset, TYP_I_IMPL)); - - LIR::AsRange(suspendBB).InsertBefore(contContextAddrPlaceholder, - LIR::SeqTree(m_compiler, contContextAddrOffset)); - use.ReplaceWith(contContextAddrOffset); - LIR::AsRange(suspendBB).Remove(contContextAddrPlaceholder); - - // Replace flagsPlaceholder with actual address of the flags - gotUse = LIR::AsRange(suspendBB).TryGetUse(flagsPlaceholder, &use); - assert(gotUse); - - newContinuation = m_compiler->gtNewLclvNode(GetNewContinuationVar(), TYP_REF); - unsigned flagsOffset = m_compiler->info.compCompHnd->getFieldOffset(m_asyncInfo->continuationFlagsFldHnd); - GenTree* flagsOffsetNode = - m_compiler->gtNewOperNode(GT_ADD, TYP_BYREF, newContinuation, - m_compiler->gtNewIconNode((ssize_t)flagsOffset, TYP_I_IMPL)); - - LIR::AsRange(suspendBB).InsertBefore(flagsPlaceholder, LIR::SeqTree(m_compiler, flagsOffsetNode)); - use.ReplaceWith(flagsOffsetNode); - LIR::AsRange(suspendBB).Remove(flagsPlaceholder); - } - - // Replace execContextAddrPlaceholder with actual address of the execution context - LIR::Use use; - bool gotUse = LIR::AsRange(suspendBB).TryGetUse(execContextAddrPlaceholder, &use); - assert(gotUse); - - GenTree* newContinuation = m_compiler->gtNewLclvNode(GetNewContinuationVar(), TYP_REF); - unsigned execContextOffset = OFFSETOF__CORINFO_Continuation__data + layout.ExecutionContextOffset; - GenTree* execContextAddrOffset = - m_compiler->gtNewOperNode(GT_ADD, TYP_BYREF, newContinuation, - m_compiler->gtNewIconNode((ssize_t)execContextOffset, TYP_I_IMPL)); - - LIR::AsRange(suspendBB).InsertBefore(execContextAddrPlaceholder, LIR::SeqTree(m_compiler, execContextAddrOffset)); - use.ReplaceWith(execContextAddrOffset); - LIR::AsRange(suspendBB).Remove(execContextAddrPlaceholder); - - // Replace resumedPlaceholder with actual "continuationParameter != null" arg - gotUse = LIR::AsRange(suspendBB).TryGetUse(resumedPlaceholder, &use); - assert(gotUse); - - GenTree* continuation = m_compiler->gtNewLclvNode(m_compiler->lvaAsyncContinuationArg, TYP_REF); - GenTree* null = m_compiler->gtNewNull(); - GenTree* resumed = m_compiler->gtNewOperNode(GT_NE, TYP_INT, continuation, null); - - LIR::AsRange(suspendBB).InsertBefore(resumedPlaceholder, LIR::SeqTree(m_compiler, resumed)); - use.ReplaceWith(resumed); - LIR::AsRange(suspendBB).Remove(resumedPlaceholder); - - // Replace execContextPlaceholder with actual value + // Get the contexts from the call node: + // 1. For shared finish, store it directly to the shared locals in the same block + // 2. For non-shared finish, just make sure it is a GT_LCL_VAR since we need to create + // a use in a different block. + // Also remove the nodes from the original block and the call args. GenTree* execContext = execContextArg->GetNode(); if (!execContext->OperIs(GT_LCL_VAR)) { @@ -2247,18 +2215,9 @@ void AsyncTransformation::FinishContextHandlingOnSuspensionWithHelper(BasicBlock use.ReplaceWithLclVar(m_compiler); execContext = use.Def(); } - - gotUse = LIR::AsRange(suspendBB).TryGetUse(execContextPlaceholder, &use); - assert(gotUse); - LIR::AsRange(callBlock).Remove(execContext); - LIR::AsRange(suspendBB).InsertBefore(execContextPlaceholder, execContext); - use.ReplaceWith(execContext); - LIR::AsRange(suspendBB).Remove(execContextPlaceholder); - call->gtArgs.RemoveUnsafe(execContextArg); - // Replace syncContextPlaceholder with actual value GenTree* syncContext = syncContextArg->GetNode(); if (!syncContext->OperIs(GT_LCL_VAR)) { @@ -2267,19 +2226,46 @@ void AsyncTransformation::FinishContextHandlingOnSuspensionWithHelper(BasicBlock use.ReplaceWithLclVar(m_compiler); syncContext = use.Def(); } + LIR::AsRange(callBlock).Remove(syncContext); + call->gtArgs.RemoveUnsafe(syncContextArg); - gotUse = LIR::AsRange(suspendBB).TryGetUse(syncContextPlaceholder, &use); - assert(gotUse); + if (sharedFinish != nullptr) + { + // Store the contexts to the shared locals that the shared finish block will take them from. + if (m_sharedFinishContextHandlingExecContextVar != BAD_VAR_NUM) + { + GenTree* storeExecContext = + m_compiler->gtNewStoreLclVarNode(m_sharedFinishContextHandlingExecContextVar, execContext); + LIR::AsRange(suspendBB).InsertAtEnd(LIR::SeqTree(m_compiler, storeExecContext)); + } - LIR::AsRange(callBlock).Remove(syncContext); - LIR::AsRange(suspendBB).InsertBefore(syncContextPlaceholder, syncContext); - use.ReplaceWith(syncContext); - LIR::AsRange(suspendBB).Remove(syncContextPlaceholder); + if (m_sharedFinishContextHandlingSyncContextVar != BAD_VAR_NUM) + { + GenTree* storeSyncContext = + m_compiler->gtNewStoreLclVarNode(m_sharedFinishContextHandlingSyncContextVar, syncContext); + LIR::AsRange(suspendBB).InsertAtEnd(LIR::SeqTree(m_compiler, storeSyncContext)); + } - call->gtArgs.RemoveUnsafe(syncContextArg); + // Then just finish by jumping. + suspendBB->SetKindAndTargetEdge(BBJ_ALWAYS, m_compiler->fgAddRefPred(sharedFinish, suspendBB)); + } + else + { + // Otherwise insert a new call + InsertFinishContextHandlingCall(suspendBB, layout, helper, execContext, syncContext); - JITDUMP(" Created FinishSuspension call on suspension:\n"); - DISPTREERANGE(LIR::AsRange(suspendBB), finishCall); + // And return either via a new GT_RETURN_SUSPEND or via the shared return BB. + if (m_sharedReturnBB != nullptr) + { + suspendBB->SetKindAndTargetEdge(BBJ_ALWAYS, m_compiler->fgAddRefPred(m_sharedReturnBB, suspendBB)); + } + else + { + GenTree* newContinuation = m_compiler->gtNewLclvNode(GetNewContinuationVar(), TYP_REF); + GenTree* ret = m_compiler->gtNewOperNode(GT_RETURN_SUSPEND, TYP_VOID, newContinuation); + LIR::AsRange(suspendBB).InsertAtEnd(newContinuation, ret); + } + } } //------------------------------------------------------------------------ @@ -2998,6 +2984,219 @@ void AsyncTransformation::CreateSharedReturnBB() DISPRANGE(LIR::AsRange(m_sharedReturnBB)); } +//------------------------------------------------------------------------ +// AsyncTransformation::CreateSharedFinishContextHandlingBB: +// Create a shared BB that finishes all necessary context handling and +// suspends the method. +// +// Parameters: +// helper - The type of helper to call +// layout - The continuation layout +// execContextMayVary - If true, callers may use different execution +// contexts, and thus we need a local to allow it to vary. +// syncContextMayVary - If true, callers may use different synchronization +// contexts, and thus we need a local to allow it to vary. +// +// Returns: +// Basic block that handles the shared finish logic. +// +BasicBlock* AsyncTransformation::CreateSharedFinishContextHandlingBB(SuspensionContextHelper helper, + const ContinuationLayout& layout, + bool execContextMayVary, + bool syncContextMayVary) +{ + assert(m_sharedReturnBB != nullptr); + BasicBlock* block = m_compiler->fgNewBBbefore(BBJ_ALWAYS, m_sharedReturnBB, false); + block->SetKindAndTargetEdge(BBJ_ALWAYS, m_compiler->fgAddRefPred(m_sharedReturnBB, block)); + block->bbSetRunRarely(); + block->clearTryIndex(); + block->clearHndIndex(); + + if (m_compiler->fgIsUsingProfileWeights()) + { + // All suspension BBs are cold, so we do not need to propagate any + // weights, but we do need to propagate the flag. + block->SetFlags(BBF_PROF_WEIGHT); + } + + unsigned execContextLclNum; + if (execContextMayVary) + { + if (m_sharedFinishContextHandlingExecContextVar == BAD_VAR_NUM) + { + m_sharedFinishContextHandlingExecContextVar = + m_compiler->lvaGrabTemp(false DEBUGARG("exec context for shared finish context handling")); + m_compiler->lvaGetDesc(m_sharedFinishContextHandlingExecContextVar)->lvType = TYP_REF; + } + + execContextLclNum = m_sharedFinishContextHandlingExecContextVar; + } + else + { + execContextLclNum = m_compiler->lvaAsyncExecutionContextVar; + } + + unsigned syncContextLclNum; + if (syncContextMayVary) + { + if (m_sharedFinishContextHandlingSyncContextVar == BAD_VAR_NUM) + { + m_sharedFinishContextHandlingSyncContextVar = + m_compiler->lvaGrabTemp(false DEBUGARG("sync context for shared finish context handling")); + m_compiler->lvaGetDesc(m_sharedFinishContextHandlingSyncContextVar)->lvType = TYP_REF; + } + + syncContextLclNum = m_sharedFinishContextHandlingSyncContextVar; + } + else + { + syncContextLclNum = m_compiler->lvaAsyncSynchronizationContextVar; + } + + InsertFinishContextHandlingCall(block, layout, helper, m_compiler->gtNewLclvNode(execContextLclNum, TYP_REF), + m_compiler->gtNewLclvNode(syncContextLclNum, TYP_REF)); + + return block; +} + +//------------------------------------------------------------------------ +// AsyncTransformation::InsertFinishContextHandlingCall: +// Insert a call to the specified context handling helper. +// +// Parameters: +// block - Block that should contain the call (inserted at the end) +// layout - The continuation layout +// helper - The type of helper +// execContext - The execution context tree to pass to the helper +// syncContext - The synchronization context tree to pass to the helper +// +void AsyncTransformation::InsertFinishContextHandlingCall(BasicBlock* block, + const ContinuationLayout& layout, + SuspensionContextHelper helper, + GenTree* execContext, + GenTree* syncContext) +{ + CORINFO_METHOD_HANDLE helperMethod = (helper == SuspensionContextHelper::WithContinuationContext) + ? m_asyncInfo->finishSuspensionWithContinuationContextMethHnd + : m_asyncInfo->finishSuspensionNoContinuationContextMethHnd; + + // Insert call + // finishSuspension[With|No]ContinuationContext( + // ref newContinuation.ContinuationContext, // optional + // ref newContinuation.Flags, // optional + // ref newContinuation.ExecutionContext, + // resumed, + // execContext, + // syncContext) + // + + GenTree* contContextAddrPlaceholder = nullptr; + GenTree* flagsPlaceholder = nullptr; + GenTree* execContextAddrPlaceholder = m_compiler->gtNewZeroConNode(TYP_BYREF); + GenTree* resumedPlaceholder = m_compiler->gtNewIconNode(0); + GenTree* execContextPlaceholder = m_compiler->gtNewNull(); + GenTree* syncContextPlaceholder = m_compiler->gtNewNull(); + + GenTreeCall* finishCall = m_compiler->gtNewCallNode(CT_USER_FUNC, helperMethod, TYP_VOID); + SetCallEntrypointForR2R(finishCall, m_compiler, helperMethod); + + finishCall->gtArgs.PushFront(m_compiler, NewCallArg::Primitive(syncContextPlaceholder)); + finishCall->gtArgs.PushFront(m_compiler, NewCallArg::Primitive(execContextPlaceholder)); + finishCall->gtArgs.PushFront(m_compiler, NewCallArg::Primitive(resumedPlaceholder)); + finishCall->gtArgs.PushFront(m_compiler, NewCallArg::Primitive(execContextAddrPlaceholder)); + + if (helper == SuspensionContextHelper::WithContinuationContext) + { + contContextAddrPlaceholder = m_compiler->gtNewZeroConNode(TYP_BYREF); + flagsPlaceholder = m_compiler->gtNewZeroConNode(TYP_BYREF); + finishCall->gtArgs.PushFront(m_compiler, NewCallArg::Primitive(flagsPlaceholder)); + finishCall->gtArgs.PushFront(m_compiler, NewCallArg::Primitive(contContextAddrPlaceholder)); + } + + m_compiler->compCurBB = block; + m_compiler->fgMorphTree(finishCall); + + LIR::AsRange(block).InsertAtEnd(LIR::SeqTree(m_compiler, finishCall)); + + if (helper == SuspensionContextHelper::WithContinuationContext) + { + // Replace contContextAddrPlaceholder with actual address of the continuation context + LIR::Use use; + bool gotUse = LIR::AsRange(block).TryGetUse(contContextAddrPlaceholder, &use); + assert(gotUse); + + GenTree* newContinuation = m_compiler->gtNewLclvNode(GetNewContinuationVar(), TYP_REF); + unsigned contContextOffset = OFFSETOF__CORINFO_Continuation__data + layout.ContinuationContextOffset; + GenTree* contContextAddrOffset = + m_compiler->gtNewOperNode(GT_ADD, TYP_BYREF, newContinuation, + m_compiler->gtNewIconNode((ssize_t)contContextOffset, TYP_I_IMPL)); + + LIR::AsRange(block).InsertBefore(contContextAddrPlaceholder, LIR::SeqTree(m_compiler, contContextAddrOffset)); + use.ReplaceWith(contContextAddrOffset); + LIR::AsRange(block).Remove(contContextAddrPlaceholder); + + // Replace flagsPlaceholder with actual address of the flags + gotUse = LIR::AsRange(block).TryGetUse(flagsPlaceholder, &use); + assert(gotUse); + + newContinuation = m_compiler->gtNewLclvNode(GetNewContinuationVar(), TYP_REF); + unsigned flagsOffset = m_compiler->info.compCompHnd->getFieldOffset(m_asyncInfo->continuationFlagsFldHnd); + GenTree* flagsOffsetNode = + m_compiler->gtNewOperNode(GT_ADD, TYP_BYREF, newContinuation, + m_compiler->gtNewIconNode((ssize_t)flagsOffset, TYP_I_IMPL)); + + LIR::AsRange(block).InsertBefore(flagsPlaceholder, LIR::SeqTree(m_compiler, flagsOffsetNode)); + use.ReplaceWith(flagsOffsetNode); + LIR::AsRange(block).Remove(flagsPlaceholder); + } + + // Replace execContextAddrPlaceholder with actual address of the execution context + LIR::Use use; + bool gotUse = LIR::AsRange(block).TryGetUse(execContextAddrPlaceholder, &use); + assert(gotUse); + + GenTree* newContinuation = m_compiler->gtNewLclvNode(GetNewContinuationVar(), TYP_REF); + unsigned execContextOffset = OFFSETOF__CORINFO_Continuation__data + layout.ExecutionContextOffset; + GenTree* execContextAddrOffset = + m_compiler->gtNewOperNode(GT_ADD, TYP_BYREF, newContinuation, + m_compiler->gtNewIconNode((ssize_t)execContextOffset, TYP_I_IMPL)); + + LIR::AsRange(block).InsertBefore(execContextAddrPlaceholder, LIR::SeqTree(m_compiler, execContextAddrOffset)); + use.ReplaceWith(execContextAddrOffset); + LIR::AsRange(block).Remove(execContextAddrPlaceholder); + + // Replace resumedPlaceholder with actual "continuationParameter != null" arg + gotUse = LIR::AsRange(block).TryGetUse(resumedPlaceholder, &use); + assert(gotUse); + + GenTree* continuation = m_compiler->gtNewLclvNode(m_compiler->lvaAsyncContinuationArg, TYP_REF); + GenTree* null = m_compiler->gtNewNull(); + GenTree* resumed = m_compiler->gtNewOperNode(GT_NE, TYP_INT, continuation, null); + + LIR::AsRange(block).InsertBefore(resumedPlaceholder, LIR::SeqTree(m_compiler, resumed)); + use.ReplaceWith(resumed); + LIR::AsRange(block).Remove(resumedPlaceholder); + + // Replace execContextPlaceholder with actual value + gotUse = LIR::AsRange(block).TryGetUse(execContextPlaceholder, &use); + assert(gotUse); + + LIR::AsRange(block).InsertBefore(execContextPlaceholder, execContext); + use.ReplaceWith(execContext); + LIR::AsRange(block).Remove(execContextPlaceholder); + + // Replace syncContextPlaceholder with actual value + gotUse = LIR::AsRange(block).TryGetUse(syncContextPlaceholder, &use); + assert(gotUse); + + LIR::AsRange(block).InsertBefore(syncContextPlaceholder, syncContext); + use.ReplaceWith(syncContext); + LIR::AsRange(block).Remove(syncContextPlaceholder); + + JITDUMP(" Created FinishSuspension call:\n"); + DISPTREERANGE(LIR::AsRange(block), finishCall); +} + //------------------------------------------------------------------------ // AsyncTransformation::CreateResumptionsAndSuspensions: // Walk all recorded async states and create the suspension and resumption @@ -3014,6 +3213,65 @@ void AsyncTransformation::CreateResumptionsAndSuspensions() ContinuationLayoutBuilder* sharedLayoutBuilder = ContinuationLayoutBuilder::CreateSharedLayout(m_compiler, m_states); sharedLayout = sharedLayoutBuilder->Create(); + + unsigned numSharedSuspensionsWithContinuationContext = 0; + unsigned numSharedSuspensionsWithoutContinuationContext = 0; + + bool execContextMayVary = false; + bool syncContextMayVary = false; + + for (const AsyncState& state : m_states) + { + SuspensionContextHelper helper = GetSuspensionContextHelper(state.Call); + switch (helper) + { + case SuspensionContextHelper::WithContinuationContext: + numSharedSuspensionsWithContinuationContext++; + break; + case SuspensionContextHelper::WithoutContinuationContext: + numSharedSuspensionsWithoutContinuationContext++; + break; + default: + break; + } + + // If all calls still have the async context vars we created early + // then avoid round tripping through a local which will create + // unnecessary additional register moves. This is a common case. + if (helper != SuspensionContextHelper::None) + { + CallArg* execContextArg = state.Call->gtArgs.FindWellKnownArg(WellKnownArg::AsyncExecutionContext); + CallArg* syncContextArg = + state.Call->gtArgs.FindWellKnownArg(WellKnownArg::AsyncSynchronizationContext); + assert((execContextArg != nullptr) && (syncContextArg != nullptr)); + + execContextMayVary |= + !execContextArg->GetNode()->OperIsScalarLocal() || + (execContextArg->GetNode()->AsLclVar()->GetLclNum() != m_compiler->lvaAsyncExecutionContextVar); + syncContextMayVary |= !syncContextArg->GetNode()->OperIsScalarLocal() || + (syncContextArg->GetNode()->AsLclVar()->GetLclNum() != + m_compiler->lvaAsyncSynchronizationContextVar); + } + } + + if (numSharedSuspensionsWithContinuationContext > 1) + { + JITDUMP("Using shared path for final context handling with continuation context -- needed by %u awaits\n", + numSharedSuspensionsWithContinuationContext); + m_sharedFinishContextHandlingWithContinuationContextBB = + CreateSharedFinishContextHandlingBB(SuspensionContextHelper::WithContinuationContext, *sharedLayout, + execContextMayVary, syncContextMayVary); + } + + if (numSharedSuspensionsWithoutContinuationContext > 1) + { + JITDUMP( + "Using shared path for final context handling without continuation context -- needed by %u awaits\n", + numSharedSuspensionsWithoutContinuationContext); + m_sharedFinishContextHandlingWithoutContinuationContextBB = + CreateSharedFinishContextHandlingBB(SuspensionContextHelper::WithoutContinuationContext, *sharedLayout, + execContextMayVary, syncContextMayVary); + } } JITDUMP("Creating suspensions and resumptions for %zu states\n", m_states.size()); diff --git a/src/coreclr/jit/async.h b/src/coreclr/jit/async.h index c1df2f21e3fc6a..ce25cd991362cf 100644 --- a/src/coreclr/jit/async.h +++ b/src/coreclr/jit/async.h @@ -331,6 +331,19 @@ enum class SaveSet MutatedLocals, }; +enum class SuspensionContextHelper +{ + None, + WithContinuationContext, + WithoutContinuationContext, +}; + +struct AggregatedAwaitInfo +{ + unsigned NumNormalAwaits = 0; + unsigned NumTailAwaits = 0; +}; + class AsyncTransformation { friend class AsyncAnalysis; @@ -349,10 +362,16 @@ class AsyncTransformation BasicBlock* m_lastResumptionBB = nullptr; BasicBlock* m_sharedReturnBB = nullptr; - void FindAwaits(ArrayStack& blocksWithNormalAwaits, - ArrayStack& blocksWithTailAwaits, - int* numNormalAwaits, - int* numTailAwaits); + // Shared basic blocks used by suspensions that handle required context + // saves/restores and then suspend. + BasicBlock* m_sharedFinishContextHandlingWithContinuationContextBB = nullptr; + BasicBlock* m_sharedFinishContextHandlingWithoutContinuationContextBB = nullptr; + // Variables that shared suspension finishing BBs take the exec/sync contexts in + unsigned m_sharedFinishContextHandlingExecContextVar = BAD_VAR_NUM; + unsigned m_sharedFinishContextHandlingSyncContextVar = BAD_VAR_NUM; + + AggregatedAwaitInfo FindAwaits(ArrayStack& blocksWithNormalAwaits, + ArrayStack& blocksWithTailAwaits); void TransformTailAwaits(ArrayStack& blocksWithTailAwaits); void TransformTailAwait(BasicBlock* block, GenTreeCall* call, BasicBlock** remainder); @@ -399,36 +418,38 @@ class AsyncTransformation GenTree* prevContinuation, const ContinuationLayout& layout); - void FillInDataOnSuspension(GenTreeCall* call, - const ContinuationLayout& layout, - const ContinuationLayoutBuilder& subLayout, - BasicBlock* suspendBB, - VARSET_VALARG_TP mutatedSinceResumption, - SaveSet saveSet); - SaveSet GetLocalSaveSet(const LclVarDsc* dsc, VARSET_VALARG_TP mutatedSinceResumption); - void FinishContextHandlingOnSuspension(BasicBlock* callBlock, - GenTreeCall* call, - BasicBlock* suspendBB, - const ContinuationLayout& layout, - const ContinuationLayoutBuilder& subLayout); - void FinishContextHandlingOnSuspensionWithHelper(BasicBlock* callBlock, - GenTreeCall* call, - BasicBlock* suspendBB, - const ContinuationLayout& layout, - const ContinuationLayoutBuilder& subLayout); - void RestoreContexts(BasicBlock* block, GenTreeCall* call, BasicBlock* insertionBB); - void CreateCheckAndSuspendAfterCall(BasicBlock* block, - GenTreeCall* call, - const CallDefinitionInfo& callDefInfo, - BasicBlock* suspendBB, - BasicBlock** remainder); - BasicBlock* CreateResumptionBlock(BasicBlock* remainder, unsigned stateNum); - void CreateResumption(BasicBlock* callBlock, - GenTreeCall* call, - BasicBlock* resumeBB, - const CallDefinitionInfo& callDefInfo, - const ContinuationLayout& layout, - const ContinuationLayoutBuilder& subLayout); + void FillInDataOnSuspension(GenTreeCall* call, + const ContinuationLayout& layout, + const ContinuationLayoutBuilder& subLayout, + BasicBlock* suspendBB, + VARSET_VALARG_TP mutatedSinceResumption, + SaveSet saveSet); + SaveSet GetLocalSaveSet(const LclVarDsc* dsc, VARSET_VALARG_TP mutatedSinceResumption); + SuspensionContextHelper GetSuspensionContextHelper(GenTreeCall* call); + void FinishContextHandlingAndSuspension(BasicBlock* callBlock, + GenTreeCall* call, + BasicBlock* suspendBB, + const ContinuationLayout& layout, + const ContinuationLayoutBuilder& subLayout); + void FinishContextHandlingAndSuspensionWithHelper(BasicBlock* callBlock, + GenTreeCall* call, + BasicBlock* suspendBB, + const ContinuationLayout& layout, + const ContinuationLayoutBuilder& subLayout, + SuspensionContextHelper helper); + void RestoreContexts(BasicBlock* block, GenTreeCall* call, BasicBlock* insertionBB); + void CreateCheckAndSuspendAfterCall(BasicBlock* block, + GenTreeCall* call, + const CallDefinitionInfo& callDefInfo, + BasicBlock* suspendBB, + BasicBlock** remainder); + BasicBlock* CreateResumptionBlock(BasicBlock* remainder, unsigned stateNum); + void CreateResumption(BasicBlock* callBlock, + GenTreeCall* call, + BasicBlock* resumeBB, + const CallDefinitionInfo& callDefInfo, + const ContinuationLayout& layout, + const ContinuationLayoutBuilder& subLayout); void RestoreFromDataOnResumption(const ContinuationLayout& layout, const ContinuationLayoutBuilder& subLayout, @@ -449,16 +470,25 @@ class AsyncTransformation var_types storeType, GenTreeFlags indirFlags = GTF_IND_NONFAULTING); - void CreateDebugInfoForSuspensionPoint(const ContinuationLayout& layout, - const ContinuationLayoutBuilder& subLayout); - unsigned GetReturnedContinuationVar(); - unsigned GetNewContinuationVar(); - unsigned GetResultBaseVar(); - unsigned GetExceptionVar(); - void CreateSharedReturnBB(); - bool ReuseContinuations(); - void CreateResumptionsAndSuspensions(); - void CreateResumptionSwitch(); + void CreateDebugInfoForSuspensionPoint(const ContinuationLayout& layout, + const ContinuationLayoutBuilder& subLayout); + unsigned GetReturnedContinuationVar(); + unsigned GetNewContinuationVar(); + unsigned GetResultBaseVar(); + unsigned GetExceptionVar(); + void CreateSharedReturnBB(); + BasicBlock* CreateSharedFinishContextHandlingBB(SuspensionContextHelper helper, + const ContinuationLayout& layout, + bool execContextMayVary, + bool syncContextMayVary); + void InsertFinishContextHandlingCall(BasicBlock* block, + const ContinuationLayout& layout, + SuspensionContextHelper helper, + GenTree* execContext, + GenTree* syncContext); + bool ReuseContinuations(); + void CreateResumptionsAndSuspensions(); + void CreateResumptionSwitch(); public: AsyncTransformation(Compiler* comp) diff --git a/src/coreclr/jit/gentree.h b/src/coreclr/jit/gentree.h index 1e2d4e3e48c7e2..31af5e11feec05 100644 --- a/src/coreclr/jit/gentree.h +++ b/src/coreclr/jit/gentree.h @@ -4491,6 +4491,11 @@ struct AsyncCallInfo // Tail awaits do not generate suspension points and the JIT instead // directly returns the callee's continuation to the caller. bool IsTailAwait = false; + + bool NeedsToSaveAndRestoreExecutionContext() const + { + return true; + } }; // Return type descriptor of a GT_CALL node.