diff --git a/src/System.Private.CoreLib/shared/System/Threading/Tasks/Task.cs b/src/System.Private.CoreLib/shared/System/Threading/Tasks/Task.cs index 3f580c4bb09c..7e5277423559 100644 --- a/src/System.Private.CoreLib/shared/System/Threading/Tasks/Task.cs +++ b/src/System.Private.CoreLib/shared/System/Threading/Tasks/Task.cs @@ -5370,12 +5370,18 @@ public static Task Delay(int millisecondsDelay, CancellationToken cancellationTo } // Construct a promise-style Task to encapsulate our return value - var promise = new DelayPromise(cancellationToken); + DelayPromise promise; - // Register our cancellation token, if necessary. if (cancellationToken.CanBeCanceled) { - promise.Registration = cancellationToken.UnsafeRegister(state => ((DelayPromise)state).Complete(), promise); + var promiseWithCancellation = new DelayPromiseWithCancellation(cancellationToken); + // Register our cancellation token, if necessary. + promiseWithCancellation.Registration = cancellationToken.UnsafeRegister(state => ((DelayPromise)state).Complete(), promiseWithCancellation); + promise = promiseWithCancellation; + } + else + { + promise = new DelayPromise(); } // ... and create our timer and make sure that it stays rooted. @@ -5389,12 +5395,10 @@ public static Task Delay(int millisecondsDelay, CancellationToken cancellationTo } /// Task that also stores the completion closure and logic for Task.Delay implementation. - private sealed class DelayPromise : Task + private class DelayPromise : Task { - internal DelayPromise(CancellationToken token) - : base() + internal DelayPromise() : base() { - this.Token = token; if (AsyncCausalityTracer.LoggingOn) AsyncCausalityTracer.TraceOperationCreation(this, "Task.Delay"); @@ -5402,36 +5406,63 @@ internal DelayPromise(CancellationToken token) AddToActiveTasks(this); } - internal readonly CancellationToken Token; - internal CancellationTokenRegistration Registration; internal TimerQueueTimer Timer; - internal void Complete() + internal virtual bool Complete() { // Transition the task to completed. bool setSucceeded; + if (AsyncCausalityTracer.LoggingOn) + AsyncCausalityTracer.TraceOperationCompletion(this, AsyncCausalityStatus.Completed); + + if (s_asyncDebuggingEnabled) + RemoveFromActiveTasks(this); + + setSucceeded = TrySetResult(default); + + // If we set the value, also clean up. + if (setSucceeded) + { + Timer?.Close(); + } + + return setSucceeded; + } + } + + private sealed class DelayPromiseWithCancellation : DelayPromise + { + internal DelayPromiseWithCancellation(CancellationToken token) : base() + => Token = token; + + internal readonly CancellationToken Token; + internal CancellationTokenRegistration Registration; + + internal override bool Complete() + { + // Transition the task to completed. + bool setSucceeded; if (Token.IsCancellationRequested) { setSucceeded = TrySetCanceled(Token); + if (setSucceeded) + { + Timer?.Close(); + } } else { - if (AsyncCausalityTracer.LoggingOn) - AsyncCausalityTracer.TraceOperationCompletion(this, AsyncCausalityStatus.Completed); - - if (s_asyncDebuggingEnabled) - RemoveFromActiveTasks(this); - - setSucceeded = TrySetResult(default); + setSucceeded = base.Complete(); } // If we set the value, also clean up. if (setSucceeded) { - Timer?.Close(); Registration.Dispose(); } + + return setSucceeded; } } #endregion