diff --git a/src/libraries/System.Net.Quic/src/System/Net/Quic/Implementations/Mock/MockStream.cs b/src/libraries/System.Net.Quic/src/System/Net/Quic/Implementations/Mock/MockStream.cs index 14ecead9a7f88e..68964fc30688d6 100644 --- a/src/libraries/System.Net.Quic/src/System/Net/Quic/Implementations/Mock/MockStream.cs +++ b/src/libraries/System.Net.Quic/src/System/Net/Quic/Implementations/Mock/MockStream.cs @@ -16,6 +16,7 @@ internal sealed class MockStream : QuicStreamProvider private readonly bool _isInitiator; private readonly StreamState _streamState; + private bool _writesCanceled; internal MockStream(StreamState streamState, bool isInitiator) { @@ -84,6 +85,10 @@ internal override async ValueTask ReadAsync(Memory buffer, Cancellati internal override void Write(ReadOnlySpan buffer) { CheckDisposed(); + if (Volatile.Read(ref _writesCanceled)) + { + throw new OperationCanceledException(); + } StreamBuffer? streamBuffer = WriteStreamBuffer; if (streamBuffer is null) @@ -102,6 +107,11 @@ internal override ValueTask WriteAsync(ReadOnlyMemory buffer, Cancellation internal override async ValueTask WriteAsync(ReadOnlyMemory buffer, bool endStream, CancellationToken cancellationToken = default) { CheckDisposed(); + if (Volatile.Read(ref _writesCanceled)) + { + cancellationToken.ThrowIfCancellationRequested(); + throw new OperationCanceledException(); + } StreamBuffer? streamBuffer = WriteStreamBuffer; if (streamBuffer is null) @@ -109,6 +119,12 @@ internal override async ValueTask WriteAsync(ReadOnlyMemory buffer, bool e throw new NotSupportedException(); } + using var registration = cancellationToken.UnsafeRegister(static s => + { + var stream = (MockStream)s!; + Volatile.Write(ref stream._writesCanceled, true); + }, this); + await streamBuffer.WriteAsync(buffer, cancellationToken).ConfigureAwait(false); if (endStream) diff --git a/src/libraries/System.Net.Quic/src/System/Net/Quic/Implementations/MsQuic/MsQuicStream.cs b/src/libraries/System.Net.Quic/src/System/Net/Quic/Implementations/MsQuic/MsQuicStream.cs index 6102235f7e1fe4..bb1468cd525586 100644 --- a/src/libraries/System.Net.Quic/src/System/Net/Quic/Implementations/MsQuic/MsQuicStream.cs +++ b/src/libraries/System.Net.Quic/src/System/Net/Quic/Implementations/MsQuic/MsQuicStream.cs @@ -216,18 +216,14 @@ private async ValueTask HandleWriteStartState(Can throw new InvalidOperationException(SR.net_quic_writing_notallowed); } - lock (_state) + // Make sure start has completed + if (!_started) { - if (_state.SendState == SendState.Aborted) - { - throw new OperationCanceledException(SR.net_quic_sending_aborted); - } - else if (_state.SendState == SendState.ConnectionClosed) - { - throw GetConnectionAbortedException(_state); - } + await _state.SendResettableCompletionSource.GetTypelessValueTask().ConfigureAwait(false); + _started = true; } + // if token was already cancelled, this would execute syncronously CancellationTokenRegistration registration = cancellationToken.UnsafeRegister(static (s, token) => { var state = (State)s!; @@ -248,11 +244,17 @@ private async ValueTask HandleWriteStartState(Can } }, _state); - // Make sure start has completed - if (!_started) + lock (_state) { - await _state.SendResettableCompletionSource.GetTypelessValueTask().ConfigureAwait(false); - _started = true; + if (_state.SendState == SendState.Aborted) + { + cancellationToken.ThrowIfCancellationRequested(); + throw new OperationCanceledException(SR.net_quic_sending_aborted); + } + else if (_state.SendState == SendState.ConnectionClosed) + { + throw GetConnectionAbortedException(_state); + } } return registration; @@ -262,7 +264,7 @@ private void HandleWriteCompletedState() { lock (_state) { - if (_state.SendState == SendState.Finished || _state.SendState == SendState.Aborted) + if (_state.SendState == SendState.Finished) { _state.SendState = SendState.None; } @@ -827,6 +829,9 @@ private static uint HandleEventPeerSendShutdown(State state) private static uint HandleEventSendComplete(State state, ref StreamEvent evt) { + StreamEventDataSendComplete sendCompleteEvent = evt.Data.SendComplete; + bool canceled = sendCompleteEvent.Canceled != 0; + bool complete = false; lock (state) @@ -836,13 +841,26 @@ private static uint HandleEventSendComplete(State state, ref StreamEvent evt) state.SendState = SendState.Finished; complete = true; } + + if (canceled) + { + state.SendState = SendState.Aborted; + } } if (complete) { CleanupSendState(state); - // TODO throw if a write was canceled. - state.SendResettableCompletionSource.Complete(MsQuicStatusCodes.Success); + + if (!canceled) + { + state.SendResettableCompletionSource.Complete(MsQuicStatusCodes.Success); + } + else + { + state.SendResettableCompletionSource.CompleteException( + ExceptionDispatchInfo.SetCurrentStackTrace(new OperationCanceledException("Write was canceled"))); + } } return MsQuicStatusCodes.Success; diff --git a/src/libraries/System.Net.Quic/tests/FunctionalTests/QuicStreamTests.cs b/src/libraries/System.Net.Quic/tests/FunctionalTests/QuicStreamTests.cs index 72243c3bdb723d..4eee9b459d9fbb 100644 --- a/src/libraries/System.Net.Quic/tests/FunctionalTests/QuicStreamTests.cs +++ b/src/libraries/System.Net.Quic/tests/FunctionalTests/QuicStreamTests.cs @@ -6,6 +6,7 @@ using System.Collections.Generic; using System.Linq; using System.Text; +using System.Threading; using System.Threading.Tasks; using Xunit; @@ -434,6 +435,138 @@ await Task.Run(async () => Assert.Equal(ExpectedErrorCode, ex.ErrorCode); }).WaitAsync(TimeSpan.FromSeconds(15)); } + + [ActiveIssue("https://github.com/dotnet/runtime/issues/53530")] + [Fact] + public async Task StreamAbortedWithoutWriting_ReadThrows() + { + long expectedErrorCode = 1234; + + await RunClientServer( + clientFunction: async connection => + { + await using QuicStream stream = connection.OpenUnidirectionalStream(); + stream.AbortWrite(expectedErrorCode); + + await stream.ShutdownCompleted(); + }, + serverFunction: async connection => + { + await using QuicStream stream = await connection.AcceptStreamAsync(); + + byte[] buffer = new byte[1]; + + QuicStreamAbortedException ex = await Assert.ThrowsAsync(() => ReadAll(stream, buffer)); + Assert.Equal(expectedErrorCode, ex.ErrorCode); + + await stream.ShutdownCompleted(); + } + ); + } + + [Fact] + public async Task WritePreCanceled_Throws() + { + long expectedErrorCode = 1234; + + await RunClientServer( + clientFunction: async connection => + { + await using QuicStream stream = connection.OpenUnidirectionalStream(); + + CancellationTokenSource cts = new CancellationTokenSource(); + cts.Cancel(); + + await Assert.ThrowsAsync(() => stream.WriteAsync(new byte[1], cts.Token).AsTask()); + + // next write would also throw + await Assert.ThrowsAsync(() => stream.WriteAsync(new byte[1]).AsTask()); + + // manual write abort is still required + stream.AbortWrite(expectedErrorCode); + + await stream.ShutdownCompleted(); + }, + serverFunction: async connection => + { + await using QuicStream stream = await connection.AcceptStreamAsync(); + + byte[] buffer = new byte[1024 * 1024]; + + // TODO: it should always throw QuicStreamAbortedException, but sometimes it does not https://github.com/dotnet/runtime/issues/53530 + //QuicStreamAbortedException ex = await Assert.ThrowsAsync(() => ReadAll(stream, buffer)); + try + { + await ReadAll(stream, buffer); + } + catch (QuicStreamAbortedException) { } + + await stream.ShutdownCompleted(); + } + ); + } + + [Fact] + public async Task WriteCanceled_NextWriteThrows() + { + long expectedErrorCode = 1234; + + await RunClientServer( + clientFunction: async connection => + { + await using QuicStream stream = connection.OpenUnidirectionalStream(); + + CancellationTokenSource cts = new CancellationTokenSource(500); + + async Task WriteUntilCanceled() + { + var buffer = new byte[64 * 1024]; + while (true) + { + await stream.WriteAsync(buffer, cancellationToken: cts.Token); + } + } + + // a write would eventually be canceled + await Assert.ThrowsAsync(() => WriteUntilCanceled().WaitAsync(TimeSpan.FromSeconds(3))); + + // next write would also throw + await Assert.ThrowsAsync(() => stream.WriteAsync(new byte[1]).AsTask()); + + // manual write abort is still required + stream.AbortWrite(expectedErrorCode); + + await stream.ShutdownCompleted(); + }, + serverFunction: async connection => + { + await using QuicStream stream = await connection.AcceptStreamAsync(); + + async Task ReadUntilAborted() + { + var buffer = new byte[1024]; + while (true) + { + int res = await stream.ReadAsync(buffer); + if (res == 0) + { + break; + } + } + } + + // TODO: it should always throw QuicStreamAbortedException, but sometimes it does not https://github.com/dotnet/runtime/issues/53530 + //QuicStreamAbortedException ex = await Assert.ThrowsAsync(() => ReadUntilAborted()); + try + { + await ReadUntilAborted().WaitAsync(TimeSpan.FromSeconds(3)); + } + catch (QuicStreamAbortedException) { } + + await stream.ShutdownCompleted(); + } + ); + } } public sealed class QuicStreamTests_MockProvider : QuicStreamTests { }