diff --git a/src/fsharp/FSharp.Core/async.fs b/src/fsharp/FSharp.Core/async.fs index 1f20f7c8110..6a225de0872 100644 --- a/src/fsharp/FSharp.Core/async.fs +++ b/src/fsharp/FSharp.Core/async.fs @@ -984,7 +984,15 @@ namespace Microsoft.FSharp.Control else ctxt.cont completedTask.Result) |> unfake - task.ContinueWith(Action>(continuation)) |> ignore |> fake + let cancelContinuation (_: Task) : unit = + ctxt.trampolineHolder.ExecuteWithTrampoline (fun () -> + ctxt.OnCancellation () + ) |> unfake + + task + .ContinueWith(Action>(continuation), ctxt.token) + .ContinueWith(Action(cancelContinuation), TaskContinuationOptions.OnlyOnCanceled) + |> ignore |> fake [] let taskContinueWithUnit (task: Task) (ctxt: AsyncActivation) useCcontForTaskCancellation = @@ -1003,7 +1011,15 @@ namespace Microsoft.FSharp.Control else ctxt.cont ()) |> unfake - task.ContinueWith(Action(continuation)) |> ignore |> fake + let cancelContinuation (_: Task) : unit = + ctxt.trampolineHolder.ExecuteWithTrampoline (fun () -> + ctxt.OnCancellation () + ) |> unfake + + task + .ContinueWith(Action(continuation), ctxt.token) + .ContinueWith(Action(cancelContinuation), TaskContinuationOptions.OnlyOnCanceled) + |> ignore |> fake [] type AsyncIAsyncResult<'T>(callback: System.AsyncCallback, state:obj) = diff --git a/tests/FSharp.Core.UnitTests/FSharp.Core/Microsoft.FSharp.Control/AsyncType.fs b/tests/FSharp.Core.UnitTests/FSharp.Core/Microsoft.FSharp.Control/AsyncType.fs index c0797038f56..08a6173f874 100644 --- a/tests/FSharp.Core.UnitTests/FSharp.Core/Microsoft.FSharp.Control/AsyncType.fs +++ b/tests/FSharp.Core.UnitTests/FSharp.Core/Microsoft.FSharp.Control/AsyncType.fs @@ -143,7 +143,7 @@ type AsyncType() = Assert.AreEqual(s, t.Result) [] - member this.StartAsTaskCancellation () = + member this.StartAsTaskCancelAsync () = let cts = new CancellationTokenSource() let tcs = TaskCompletionSource() let a = async { @@ -155,12 +155,32 @@ type AsyncType() = use t : Task = #endif Async.StartAsTask(a, cancellationToken = cts.Token) + + try + this.WaitASec t + with :? AggregateException as a -> + match a.InnerException with + | :? TaskCanceledException as t -> () + | _ -> reraise() + Assert.IsTrue (t.IsCompleted, "Task is not completed") + + [] + member this.StartAsTaskCancelTask () = + let tcs = TaskCompletionSource() + let a = async { + do! tcs.Task |> Async.AwaitTask } +#if !NET46 + let t : Task = +#else + use t : Task = +#endif + Async.StartAsTask(a) // Should not finish try let result = t.Wait(300) Assert.IsFalse (result) - with :? AggregateException -> Assert.Fail "Task should not finish, jet" + with :? AggregateException -> Assert.Fail "Task should not finish, yet" tcs.SetCanceled() @@ -382,7 +402,7 @@ type AsyncType() = Async.RunSynchronously(a, 1000) |> Assert.IsTrue [] - member this.AwaitTaskCancellation () = + member this.AwaitTaskTaskCancellation () = let test() = async { let tcs = new System.Threading.Tasks.TaskCompletionSource() tcs.SetCanceled() @@ -392,8 +412,22 @@ type AsyncType() = with :? System.OperationCanceledException -> return true } - Async.RunSynchronously(test()) |> Assert.IsTrue - + Async.RunSynchronously(test()) |> Assert.IsTrue + + [] + member this.AwaitTaskAsyncCancellation () = + let tcs = new System.Threading.Tasks.TaskCompletionSource() + let test = Async.AwaitTask tcs.Task + + use cts = new CancellationTokenSource() + cts.CancelAfter(250) + try + Async.RunSynchronously(test, cancellationToken=cts.Token) |> ignore + Assert.Fail("Expected async to throw") + with + | :? TaskCanceledException -> Assert.Fail("Did not expect TaskCanceledException") + | :? System.OperationCanceledException -> () + [] member this.AwaitTaskCancellationUntyped () = let test() = async {