From a512e8b215939d4ddd6a451104deffcb9cc530eb Mon Sep 17 00:00:00 2001 From: Stephen Toub Date: Tue, 27 Oct 2020 09:12:08 -0400 Subject: [PATCH] Reimplement Socket.Begin/EndSend/Receive on Send/ReceiveAsync --- .../Windows/WinSock/Interop.WSARecv.cs | 18 +- .../Windows/WinSock/Interop.WSASend.cs | 18 +- .../src/System/Threading/Tasks/TaskToApm.cs | 9 +- .../tests/HttpResponseStreamTests.cs | 2 + .../src/System/Net/Sockets/Socket.Tasks.cs | 20 +- .../src/System/Net/Sockets/Socket.cs | 573 +++--------------- .../Sockets/SocketAsyncEventArgs.Windows.cs | 4 +- .../src/System/Net/Sockets/SocketPal.Unix.cs | 46 -- .../System/Net/Sockets/SocketPal.Windows.cs | 75 --- 9 files changed, 116 insertions(+), 649 deletions(-) diff --git a/src/libraries/Common/src/Interop/Windows/WinSock/Interop.WSARecv.cs b/src/libraries/Common/src/Interop/Windows/WinSock/Interop.WSARecv.cs index 3a07fab5ce1d55..e775ebcd775b00 100644 --- a/src/libraries/Common/src/Interop/Windows/WinSock/Interop.WSARecv.cs +++ b/src/libraries/Common/src/Interop/Windows/WinSock/Interop.WSARecv.cs @@ -11,7 +11,7 @@ internal static partial class Interop { internal static partial class Winsock { - [DllImport(Interop.Libraries.Ws2_32, SetLastError = true)] + [DllImport(Libraries.Ws2_32, SetLastError = true)] internal static extern unsafe SocketError WSARecv( SafeHandle socketHandle, WSABuffer* buffer, @@ -21,22 +21,6 @@ internal static extern unsafe SocketError WSARecv( NativeOverlapped* overlapped, IntPtr completionRoutine); - internal static unsafe SocketError WSARecv( - SafeHandle socketHandle, - ref WSABuffer buffer, - int bufferCount, - out int bytesTransferred, - ref SocketFlags socketFlags, - NativeOverlapped* overlapped, - IntPtr completionRoutine) - { - // We intentionally do NOT copy this back after the function completes: - // We don't want to cause a race in async scenarios. - // The WSABuffer struct should be unchanged anyway. - WSABuffer localBuffer = buffer; - return WSARecv(socketHandle, &localBuffer, bufferCount, out bytesTransferred, ref socketFlags, overlapped, completionRoutine); - } - internal static unsafe SocketError WSARecv( SafeHandle socketHandle, Span buffers, diff --git a/src/libraries/Common/src/Interop/Windows/WinSock/Interop.WSASend.cs b/src/libraries/Common/src/Interop/Windows/WinSock/Interop.WSASend.cs index f5af2516948c31..30c2786b664e20 100644 --- a/src/libraries/Common/src/Interop/Windows/WinSock/Interop.WSASend.cs +++ b/src/libraries/Common/src/Interop/Windows/WinSock/Interop.WSASend.cs @@ -11,7 +11,7 @@ internal static partial class Interop { internal static partial class Winsock { - [DllImport(Interop.Libraries.Ws2_32, SetLastError = true)] + [DllImport(Libraries.Ws2_32, SetLastError = true)] internal static extern unsafe SocketError WSASend( SafeHandle socketHandle, WSABuffer* buffers, @@ -21,22 +21,6 @@ internal static extern unsafe SocketError WSASend( NativeOverlapped* overlapped, IntPtr completionRoutine); - internal static unsafe SocketError WSASend( - SafeHandle socketHandle, - ref WSABuffer buffer, - int bufferCount, - out int bytesTransferred, - SocketFlags socketFlags, - NativeOverlapped* overlapped, - IntPtr completionRoutine) - { - // We intentionally do NOT copy this back after the function completes: - // We don't want to cause a race in async scenarios. - // The WSABuffer struct should be unchanged anyway. - WSABuffer localBuffer = buffer; - return WSASend(socketHandle, &localBuffer, bufferCount, out bytesTransferred, socketFlags, overlapped, completionRoutine); - } - internal static unsafe SocketError WSASend( SafeHandle socketHandle, Span buffers, diff --git a/src/libraries/Common/src/System/Threading/Tasks/TaskToApm.cs b/src/libraries/Common/src/System/Threading/Tasks/TaskToApm.cs index 0e77f88861f5ae..7dc13d0179c282 100644 --- a/src/libraries/Common/src/System/Threading/Tasks/TaskToApm.cs +++ b/src/libraries/Common/src/System/Threading/Tasks/TaskToApm.cs @@ -37,9 +37,9 @@ public static IAsyncResult Begin(Task task, AsyncCallback? callback, object? sta /// The IAsyncResult to unwrap. public static void End(IAsyncResult asyncResult) { - if (asyncResult is TaskAsyncResult twar) + if (GetTask(asyncResult) is Task t) { - twar._task.GetAwaiter().GetResult(); + t.GetAwaiter().GetResult(); return; } @@ -50,7 +50,7 @@ public static void End(IAsyncResult asyncResult) /// The IAsyncResult to unwrap. public static TResult End(IAsyncResult asyncResult) { - if (asyncResult is TaskAsyncResult twar && twar._task is Task task) + if (GetTask(asyncResult) is Task task) { return task.GetAwaiter().GetResult(); } @@ -59,6 +59,9 @@ public static TResult End(IAsyncResult asyncResult) return default!; // unreachable } + /// Gets the task represented by the IAsyncResult. + public static Task? GetTask(IAsyncResult asyncResult) => (asyncResult as TaskAsyncResult)?._task; + /// Throws an argument exception for the invalid . [DoesNotReturn] private static void ThrowArgumentException(IAsyncResult asyncResult) => diff --git a/src/libraries/System.Net.HttpListener/tests/HttpResponseStreamTests.cs b/src/libraries/System.Net.HttpListener/tests/HttpResponseStreamTests.cs index 59bbc82af42fe9..d3478200198d6c 100644 --- a/src/libraries/System.Net.HttpListener/tests/HttpResponseStreamTests.cs +++ b/src/libraries/System.Net.HttpListener/tests/HttpResponseStreamTests.cs @@ -547,6 +547,7 @@ public async Task EndWrite_NullAsyncResult_ThrowsArgumentNullException(bool igno } } + [PlatformSpecific(TestPlatforms.Windows)] // Unix implementation uses Socket.Begin/EndSend, which doesn't fail in this case [Fact] public async Task EndWrite_InvalidAsyncResult_ThrowsArgumentException() { @@ -562,6 +563,7 @@ public async Task EndWrite_InvalidAsyncResult_ThrowsArgumentException() } } + [PlatformSpecific(TestPlatforms.Windows)] // Unix implementation uses Socket.Begin/EndSend, which doesn't fail in this case [Fact] public async Task EndWrite_CalledTwice_ThrowsInvalidOperationException() { diff --git a/src/libraries/System.Net.Sockets/src/System/Net/Sockets/Socket.Tasks.cs b/src/libraries/System.Net.Sockets/src/System/Net/Sockets/Socket.Tasks.cs index 6c6655f0ed94ce..a0267dd494dbb1 100644 --- a/src/libraries/System.Net.Sockets/src/System/Net/Sockets/Socket.Tasks.cs +++ b/src/libraries/System.Net.Sockets/src/System/Net/Sockets/Socket.Tasks.cs @@ -359,6 +359,22 @@ internal Task SendToAsync(ArraySegment buffer, SocketFlags socketFlag return tcs.Task; } + private static void ValidateBufferArguments(byte[] buffer, int offset, int size) + { + if (buffer == null) + { + throw new ArgumentNullException(nameof(buffer)); + } + if ((uint)offset > (uint)buffer.Length) + { + throw new ArgumentOutOfRangeException(nameof(offset)); + } + if ((uint)size > (uint)(buffer.Length - offset)) + { + throw new ArgumentOutOfRangeException(nameof(size)); + } + } + /// Validates the supplied array segment, throwing if its array or indices are null or out-of-bounds, respectively. private static void ValidateBuffer(ArraySegment buffer) { @@ -366,11 +382,11 @@ private static void ValidateBuffer(ArraySegment buffer) { throw new ArgumentNullException(nameof(buffer.Array)); } - if ((uint)buffer.Offset > buffer.Array.Length) + if ((uint)buffer.Offset > (uint)buffer.Array.Length) { throw new ArgumentOutOfRangeException(nameof(buffer.Offset)); } - if ((uint)buffer.Count > buffer.Array.Length - buffer.Offset) + if ((uint)buffer.Count > (uint)(buffer.Array.Length - buffer.Offset)) { throw new ArgumentOutOfRangeException(nameof(buffer.Count)); } diff --git a/src/libraries/System.Net.Sockets/src/System/Net/Sockets/Socket.cs b/src/libraries/System.Net.Sockets/src/System/Net/Sockets/Socket.cs index 1496abb598a8b2..c3a94d426a503d 100644 --- a/src/libraries/System.Net.Sockets/src/System/Net/Sockets/Socket.cs +++ b/src/libraries/System.Net.Sockets/src/System/Net/Sockets/Socket.cs @@ -1198,19 +1198,7 @@ public int Send(byte[] buffer, int offset, int size, SocketFlags socketFlags, ou { ThrowIfDisposed(); - // Validate input parameters. - if (buffer == null) - { - throw new ArgumentNullException(nameof(buffer)); - } - if (offset < 0 || offset > buffer.Length) - { - throw new ArgumentOutOfRangeException(nameof(offset)); - } - if (size < 0 || size > buffer.Length - offset) - { - throw new ArgumentOutOfRangeException(nameof(size)); - } + ValidateBufferArguments(buffer, offset, size); errorCode = SocketError.Success; ValidateBlockingMode(); @@ -1306,23 +1294,11 @@ public int SendTo(byte[] buffer, int offset, int size, SocketFlags socketFlags, { ThrowIfDisposed(); - // Validate input parameters. - if (buffer == null) - { - throw new ArgumentNullException(nameof(buffer)); - } + ValidateBufferArguments(buffer, offset, size); if (remoteEP == null) { throw new ArgumentNullException(nameof(remoteEP)); } - if (offset < 0 || offset > buffer.Length) - { - throw new ArgumentOutOfRangeException(nameof(offset)); - } - if (size < 0 || size > buffer.Length - offset) - { - throw new ArgumentOutOfRangeException(nameof(size)); - } ValidateBlockingMode(); if (NetEventSource.Log.IsEnabled()) NetEventSource.Info(this, $"SRC:{LocalEndPoint} size:{size} remoteEP:{remoteEP}"); @@ -1402,21 +1378,7 @@ public int Receive(byte[] buffer, int offset, int size, SocketFlags socketFlags) public int Receive(byte[] buffer, int offset, int size, SocketFlags socketFlags, out SocketError errorCode) { ThrowIfDisposed(); - - // Validate input parameters. - if (buffer == null) - { - throw new ArgumentNullException(nameof(buffer)); - } - if (offset < 0 || offset > buffer.Length) - { - throw new ArgumentOutOfRangeException(nameof(offset)); - } - if (size < 0 || size > buffer.Length - offset) - { - throw new ArgumentOutOfRangeException(nameof(size)); - } - + ValidateBufferArguments(buffer, offset, size); ValidateBlockingMode(); if (NetEventSource.Log.IsEnabled()) NetEventSource.Info(this, $"SRC:{LocalEndPoint} DST:{RemoteEndPoint} size:{size}"); @@ -1538,10 +1500,7 @@ public int Receive(IList> buffers, SocketFlags socketFlags, o public int ReceiveMessageFrom(byte[] buffer, int offset, int size, ref SocketFlags socketFlags, ref EndPoint remoteEP, out IPPacketInformation ipPacketInformation) { ThrowIfDisposed(); - if (buffer == null) - { - throw new ArgumentNullException(nameof(buffer)); - } + ValidateBufferArguments(buffer, offset, size); if (remoteEP == null) { throw new ArgumentNullException(nameof(remoteEP)); @@ -1550,14 +1509,6 @@ public int ReceiveMessageFrom(byte[] buffer, int offset, int size, ref SocketFla { throw new ArgumentException(SR.Format(SR.net_InvalidEndPointAddressFamily, remoteEP.AddressFamily, _addressFamily), nameof(remoteEP)); } - if (offset < 0 || offset > buffer.Length) - { - throw new ArgumentOutOfRangeException(nameof(offset)); - } - if (size < 0 || size > buffer.Length - offset) - { - throw new ArgumentOutOfRangeException(nameof(size)); - } if (_rightEndPoint == null) { throw new InvalidOperationException(SR.net_sockets_mustbind); @@ -1618,12 +1569,7 @@ public int ReceiveMessageFrom(byte[] buffer, int offset, int size, ref SocketFla public int ReceiveFrom(byte[] buffer, int offset, int size, SocketFlags socketFlags, ref EndPoint remoteEP) { ThrowIfDisposed(); - - // Validate input parameters. - if (buffer == null) - { - throw new ArgumentNullException(nameof(buffer)); - } + ValidateBufferArguments(buffer, offset, size); if (remoteEP == null) { throw new ArgumentNullException(nameof(remoteEP)); @@ -1633,14 +1579,6 @@ public int ReceiveFrom(byte[] buffer, int offset, int size, SocketFlags socketFl throw new ArgumentException(SR.Format(SR.net_InvalidEndPointAddressFamily, remoteEP.AddressFamily, _addressFamily), nameof(remoteEP)); } - if (offset < 0 || offset > buffer.Length) - { - throw new ArgumentOutOfRangeException(nameof(offset)); - } - if (size < 0 || size > buffer.Length - offset) - { - throw new ArgumentOutOfRangeException(nameof(size)); - } if (_rightEndPoint == null) { throw new InvalidOperationException(SR.net_sockets_mustbind); @@ -2166,220 +2104,62 @@ public void EndDisconnect(IAsyncResult asyncResult) } } - // Routine Description: - // - // BeginSend - Async implementation of Send call, mirrored after BeginReceive - // This routine may go pending at which time, - // but any case the callback Delegate will be called upon completion - // - // Arguments: - // - // WriteBuffer - status line that we wish to parse - // Index - Offset into WriteBuffer to begin sending from - // Size - Size of Buffer to transmit - // Callback - Delegate function that holds callback, called on completion of I/O - // State - State used to track callback, set by caller, not required - // - // Return Value: - // - // IAsyncResult - Async result used to retrieve result public IAsyncResult BeginSend(byte[] buffer, int offset, int size, SocketFlags socketFlags, AsyncCallback? callback, object? state) - { - SocketError errorCode; - IAsyncResult? result = BeginSend(buffer, offset, size, socketFlags, out errorCode, callback, state); - if (errorCode != SocketError.Success && errorCode != SocketError.IOPending) - { - throw new SocketException((int)errorCode); - } - return result!; - } - - public IAsyncResult? BeginSend(byte[] buffer, int offset, int size, SocketFlags socketFlags, out SocketError errorCode, AsyncCallback? callback, object? state) { ThrowIfDisposed(); + ValidateBufferArguments(buffer, offset, size); - // Validate input parameters. - if (buffer == null) - { - throw new ArgumentNullException(nameof(buffer)); - } - if (offset < 0 || offset > buffer.Length) - { - throw new ArgumentOutOfRangeException(nameof(offset)); - } - if (size < 0 || size > buffer.Length - offset) - { - throw new ArgumentOutOfRangeException(nameof(size)); - } - - // We need to flow the context here. But we don't need to lock the context - we don't use it until the callback. - OverlappedAsyncResult? asyncResult = new OverlappedAsyncResult(this, state, callback); - asyncResult.StartPostingAsyncOp(false); - - // Run the send with this asyncResult. - errorCode = DoBeginSend(buffer, offset, size, socketFlags, asyncResult); - - if (errorCode != SocketError.Success && errorCode != SocketError.IOPending) - { - asyncResult = null; - } - else - { - // We're not throwing, so finish the async op posting code so we can return to the user. - // If the operation already finished, the callback will be called from here. - asyncResult.FinishPostingAsyncOp(ref Caches.SendClosureCache); - } - - return asyncResult; + return TaskToApm.Begin(SendAsync(new ReadOnlyMemory(buffer, offset, size), socketFlags, default).AsTask(), callback, state); } - private SocketError DoBeginSend(byte[] buffer, int offset, int size, SocketFlags socketFlags, OverlappedAsyncResult asyncResult) + public IAsyncResult? BeginSend(byte[] buffer, int offset, int size, SocketFlags socketFlags, out SocketError errorCode, AsyncCallback? callback, object? state) { - if (NetEventSource.Log.IsEnabled()) NetEventSource.Info(this, $"SRC:{LocalEndPoint} DST:{RemoteEndPoint} size:{size} asyncResult:{asyncResult}"); - - SocketError errorCode = SocketPal.SendAsync(_handle, buffer, offset, size, socketFlags, asyncResult); - - if (NetEventSource.Log.IsEnabled()) NetEventSource.Info(this, $"SendAsync returns:{errorCode} size:{size} AsyncResult:{asyncResult}"); + ThrowIfDisposed(); + ValidateBufferArguments(buffer, offset, size); - // If the call failed, update our status - if (!CheckErrorAndUpdateStatus(errorCode)) + Task t = SendAsync(new ReadOnlyMemory(buffer, offset, size), socketFlags, default).AsTask(); + if (t.IsFaulted || t.IsCanceled) { - UpdateSendSocketErrorForDisposed(ref errorCode); + errorCode = GetSocketErrorFromFaultedTask(t); + return null; } - return errorCode; + errorCode = SocketError.Success; + return TaskToApm.Begin(t, callback, state); } public IAsyncResult BeginSend(IList> buffers, SocketFlags socketFlags, AsyncCallback? callback, object? state) - { - SocketError errorCode; - IAsyncResult? result = BeginSend(buffers, socketFlags, out errorCode, callback, state); - if (errorCode != SocketError.Success && errorCode != SocketError.IOPending) - { - throw new SocketException((int)errorCode); - } - return result!; - } - - public IAsyncResult? BeginSend(IList> buffers, SocketFlags socketFlags, out SocketError errorCode, AsyncCallback? callback, object? state) { ThrowIfDisposed(); - // Validate input parameters. - if (buffers == null) - { - throw new ArgumentNullException(nameof(buffers)); - } - - if (buffers.Count == 0) - { - throw new ArgumentException(SR.Format(SR.net_sockets_zerolist, nameof(buffers)), nameof(buffers)); - } - - // We need to flow the context here. But we don't need to lock the context - we don't use it until the callback. - OverlappedAsyncResult? asyncResult = new OverlappedAsyncResult(this, state, callback); - asyncResult.StartPostingAsyncOp(false); - - // Run the send with this asyncResult. - errorCode = DoBeginSend(buffers, socketFlags, asyncResult); - - // We're not throwing, so finish the async op posting code so we can return to the user. - // If the operation already finished, the callback will be called from here. - asyncResult.FinishPostingAsyncOp(ref Caches.SendClosureCache); - - if (errorCode != SocketError.Success && errorCode != SocketError.IOPending) - { - asyncResult = null; - } - - return asyncResult; + return TaskToApm.Begin(SendAsync(buffers, socketFlags), callback, state); } - private SocketError DoBeginSend(IList> buffers, SocketFlags socketFlags, OverlappedAsyncResult asyncResult) + public IAsyncResult? BeginSend(IList> buffers, SocketFlags socketFlags, out SocketError errorCode, AsyncCallback? callback, object? state) { - if (NetEventSource.Log.IsEnabled()) NetEventSource.Info(this, $"SRC:{LocalEndPoint} DST:{RemoteEndPoint} buffers:{buffers} asyncResult:{asyncResult}"); - - SocketError errorCode = SocketPal.SendAsync(_handle, buffers, socketFlags, asyncResult); - - if (NetEventSource.Log.IsEnabled()) NetEventSource.Info(this, $"SendAsync returns:{errorCode} returning AsyncResult:{asyncResult}"); + ThrowIfDisposed(); - // If the call failed, update our status - if (!CheckErrorAndUpdateStatus(errorCode)) + Task t = SendAsync(buffers, socketFlags); + if (t.IsFaulted || t.IsCanceled) { - UpdateSendSocketErrorForDisposed(ref errorCode); + errorCode = GetSocketErrorFromFaultedTask(t); + return null; } - return errorCode; + errorCode = SocketError.Success; + return TaskToApm.Begin(t, callback, state); } - // Routine Description: - // - // EndSend - Called by user code after I/O is done or the user wants to wait. - // until Async completion, needed to retrieve error result from call - // - // Arguments: - // - // AsyncResult - the AsyncResult Returned from BeginSend call - // - // Return Value: - // - // int - Number of bytes transferred public int EndSend(IAsyncResult asyncResult) - { - SocketError errorCode; - int bytesTransferred = EndSend(asyncResult, out errorCode); - if (errorCode != SocketError.Success) - { - throw new SocketException((int)errorCode); - } - return bytesTransferred; - } - - public int EndSend(IAsyncResult asyncResult, out SocketError errorCode) { ThrowIfDisposed(); - // Validate input parameters. - if (asyncResult == null) - { - throw new ArgumentNullException(nameof(asyncResult)); - } - - OverlappedAsyncResult? castedAsyncResult = asyncResult as OverlappedAsyncResult; - if (castedAsyncResult == null || castedAsyncResult.AsyncObject != this) - { - throw new ArgumentException(SR.net_io_invalidasyncresult, nameof(asyncResult)); - } - if (castedAsyncResult.EndCalled) - { - throw new InvalidOperationException(SR.Format(SR.net_io_invalidendcall, "EndSend")); - } - - int bytesTransferred = castedAsyncResult.InternalWaitForCompletionInt32Result(); - castedAsyncResult.EndCalled = true; - - if (NetEventSource.Log.IsEnabled()) NetEventSource.Info(this, $"bytesTransffered:{bytesTransferred}"); - - // Throw an appropriate SocketException if the native call failed asynchronously. - errorCode = (SocketError)castedAsyncResult.ErrorCode; - - if (errorCode != SocketError.Success) - { - UpdateSendSocketErrorForDisposed(ref errorCode); - // Update the internal state of this socket according to the error before throwing. - UpdateStatusAfterSocketError(errorCode); - if (NetEventSource.Log.IsEnabled()) NetEventSource.Error(this, new SocketException((int)errorCode)); - return 0; - } - else if (SocketsTelemetry.Log.IsEnabled()) - { - SocketsTelemetry.Log.BytesSent(bytesTransferred); - if (SocketType == SocketType.Dgram) SocketsTelemetry.Log.DatagramSent(); - } - - return bytesTransferred; + return TaskToApm.End(asyncResult); } + public int EndSend(IAsyncResult asyncResult, out SocketError errorCode) => + EndSendReceive(asyncResult, out errorCode); + public IAsyncResult BeginSendFile(string? fileName, AsyncCallback? callback, object? state) { return BeginSendFile(fileName, null, null, TransmitFileOptions.UseDefaultWorkerThread, callback, state); @@ -2434,24 +2214,11 @@ public void EndSendFile(IAsyncResult asyncResult) public IAsyncResult BeginSendTo(byte[] buffer, int offset, int size, SocketFlags socketFlags, EndPoint remoteEP, AsyncCallback? callback, object? state) { ThrowIfDisposed(); - - // Validate input parameters. - if (buffer == null) - { - throw new ArgumentNullException(nameof(buffer)); - } + ValidateBufferArguments(buffer, offset, size); if (remoteEP == null) { throw new ArgumentNullException(nameof(remoteEP)); } - if (offset < 0 || offset > buffer.Length) - { - throw new ArgumentOutOfRangeException(nameof(offset)); - } - if (size < 0 || size > buffer.Length - offset) - { - throw new ArgumentOutOfRangeException(nameof(size)); - } Internals.SocketAddress socketAddress = Serialize(ref remoteEP); @@ -2562,244 +2329,82 @@ public int EndSendTo(IAsyncResult asyncResult) return bytesTransferred; } - // Routine Description: - // - // BeginReceive - Async implementation of Recv call, - // - // Called when we want to start an async receive. - // We kick off the receive, and if it completes synchronously we'll - // call the callback. Otherwise we'll return an IASyncResult, which - // the caller can use to wait on or retrieve the final status, as needed. - // - // Uses Winsock 2 overlapped I/O. - // - // Arguments: - // - // ReadBuffer - status line that we wish to parse - // Index - Offset into ReadBuffer to begin reading from - // Size - Size of Buffer to recv - // Callback - Delegate function that holds callback, called on completion of I/O - // State - State used to track callback, set by caller, not required - // - // Return Value: - // - // IAsyncResult - Async result used to retrieve result public IAsyncResult BeginReceive(byte[] buffer, int offset, int size, SocketFlags socketFlags, AsyncCallback? callback, object? state) { - SocketError errorCode; - IAsyncResult? result = BeginReceive(buffer, offset, size, socketFlags, out errorCode, callback, state); - if (errorCode != SocketError.Success && errorCode != SocketError.IOPending) - { - throw new SocketException((int)errorCode); - } - return result!; + ThrowIfDisposed(); + ValidateBufferArguments(buffer, offset, size); + return TaskToApm.Begin(ReceiveAsync(new ArraySegment(buffer, offset, size), socketFlags, fromNetworkStream: false, default).AsTask(), callback, state); } public IAsyncResult? BeginReceive(byte[] buffer, int offset, int size, SocketFlags socketFlags, out SocketError errorCode, AsyncCallback? callback, object? state) { ThrowIfDisposed(); + ValidateBufferArguments(buffer, offset, size); + Task t = ReceiveAsync(new ArraySegment(buffer, offset, size), socketFlags, fromNetworkStream: false, default).AsTask(); - // Validate input parameters. - if (buffer == null) + if (t.IsFaulted || t.IsCanceled) { - throw new ArgumentNullException(nameof(buffer)); + errorCode = GetSocketErrorFromFaultedTask(t); + return null; } - if (offset < 0 || offset > buffer.Length) - { - throw new ArgumentOutOfRangeException(nameof(offset)); - } - if (size < 0 || size > buffer.Length - offset) - { - throw new ArgumentOutOfRangeException(nameof(size)); - } - - // We need to flow the context here. But we don't need to lock the context - we don't use it until the callback. - OverlappedAsyncResult? asyncResult = new OverlappedAsyncResult(this, state, callback); - asyncResult.StartPostingAsyncOp(false); - // Run the receive with this asyncResult. - errorCode = DoBeginReceive(buffer, offset, size, socketFlags, asyncResult); - - if (errorCode != SocketError.Success && errorCode != SocketError.IOPending) - { - asyncResult = null; - } - else - { - // We're not throwing, so finish the async op posting code so we can return to the user. - // If the operation already finished, the callback will be called from here. - asyncResult.FinishPostingAsyncOp(ref Caches.ReceiveClosureCache); - } - - return asyncResult; - } - - private SocketError DoBeginReceive(byte[] buffer, int offset, int size, SocketFlags socketFlags, OverlappedAsyncResult asyncResult) - { - if (NetEventSource.Log.IsEnabled()) NetEventSource.Info(this, $"size:{size}"); - -#if DEBUG - IntPtr lastHandle = _handle.DangerousGetHandle(); -#endif - SocketError errorCode = SocketPal.ReceiveAsync(_handle, buffer, offset, size, socketFlags, asyncResult); - - if (NetEventSource.Log.IsEnabled()) NetEventSource.Info(this, $"ReceiveAsync returns:{errorCode} returning AsyncResult:{asyncResult}"); - - UpdateReceiveSocketErrorForDisposed(ref errorCode, bytesTransferred: 0); - if (CheckErrorAndUpdateStatus(errorCode)) - { -#if DEBUG - _lastReceiveHandle = lastHandle; - _lastReceiveThread = Environment.CurrentManagedThreadId; - _lastReceiveTick = Environment.TickCount; -#endif - } - - return errorCode; + errorCode = SocketError.Success; + return TaskToApm.Begin(t, callback, state); } public IAsyncResult BeginReceive(IList> buffers, SocketFlags socketFlags, AsyncCallback? callback, object? state) { - SocketError errorCode; - IAsyncResult? result = BeginReceive(buffers, socketFlags, out errorCode, callback, state); - if (errorCode != SocketError.Success && errorCode != SocketError.IOPending) - { - throw new SocketException((int)errorCode); - } - return result!; + ThrowIfDisposed(); + return TaskToApm.Begin(ReceiveAsync(buffers, socketFlags), callback, state); } public IAsyncResult? BeginReceive(IList> buffers, SocketFlags socketFlags, out SocketError errorCode, AsyncCallback? callback, object? state) { ThrowIfDisposed(); + Task t = ReceiveAsync(buffers, socketFlags); - // Validate input parameters. - if (buffers == null) - { - throw new ArgumentNullException(nameof(buffers)); - } - - if (buffers.Count == 0) - { - throw new ArgumentException(SR.Format(SR.net_sockets_zerolist, nameof(buffers)), nameof(buffers)); - } - - // We need to flow the context here. But we don't need to lock the context - we don't use it until the callback. - OverlappedAsyncResult? asyncResult = new OverlappedAsyncResult(this, state, callback); - asyncResult.StartPostingAsyncOp(false); - - // Run the receive with this asyncResult. - errorCode = DoBeginReceive(buffers, socketFlags, asyncResult); - - if (errorCode != SocketError.Success && errorCode != SocketError.IOPending) - { - asyncResult = null; - } - else - { - // We're not throwing, so finish the async op posting code so we can return to the user. - // If the operation already finished, the callback will be called from here. - asyncResult.FinishPostingAsyncOp(ref Caches.ReceiveClosureCache); - } - - return asyncResult; - } - - private SocketError DoBeginReceive(IList> buffers, SocketFlags socketFlags, OverlappedAsyncResult asyncResult) - { -#if DEBUG - IntPtr lastHandle = _handle.DangerousGetHandle(); -#endif - SocketError errorCode = SocketPal.ReceiveAsync(_handle, buffers, socketFlags, asyncResult); - - if (NetEventSource.Log.IsEnabled()) NetEventSource.Info(this, $"ReceiveAsync returns:{errorCode} returning AsyncResult:{asyncResult}"); - - UpdateReceiveSocketErrorForDisposed(ref errorCode, bytesTransferred: 0); - if (!CheckErrorAndUpdateStatus(errorCode)) + if (t.IsFaulted || t.IsCanceled) { + errorCode = GetSocketErrorFromFaultedTask(t); + return null; } -#if DEBUG - else - { - _lastReceiveHandle = lastHandle; - _lastReceiveThread = Environment.CurrentManagedThreadId; - _lastReceiveTick = Environment.TickCount; - } -#endif - return errorCode; + errorCode = SocketError.Success; + return TaskToApm.Begin(t, callback, state); } -#if DEBUG - private IntPtr _lastReceiveHandle; - private int _lastReceiveThread; - private int _lastReceiveTick; -#endif - - // Routine Description: - // - // EndReceive - Called when I/O is done or the user wants to wait. If - // the I/O isn't done, we'll wait for it to complete, and then we'll return - // the bytes of I/O done. - // - // Arguments: - // - // AsyncResult - the AsyncResult Returned from BeginSend call - // - // Return Value: - // - // int - Number of bytes transferred public int EndReceive(IAsyncResult asyncResult) { - SocketError errorCode; - int bytesTransferred = EndReceive(asyncResult, out errorCode); - if (errorCode != SocketError.Success) - { - throw new SocketException((int)errorCode); - } - return bytesTransferred; + ThrowIfDisposed(); + return TaskToApm.End(asyncResult); } - public int EndReceive(IAsyncResult asyncResult, out SocketError errorCode) + public int EndReceive(IAsyncResult asyncResult, out SocketError errorCode) => + EndSendReceive(asyncResult, out errorCode); + + private int EndSendReceive(IAsyncResult asyncResult, out SocketError errorCode) { ThrowIfDisposed(); - // Validate input parameters. - if (asyncResult == null) + if (TaskToApm.GetTask(asyncResult) is not Task ti) { - throw new ArgumentNullException(nameof(asyncResult)); + throw new ArgumentException(null, nameof(asyncResult)); } - OverlappedAsyncResult? castedAsyncResult = asyncResult as OverlappedAsyncResult; - if (castedAsyncResult == null || castedAsyncResult.AsyncObject != this) + if (!ti.IsCompleted) { - throw new ArgumentException(SR.net_io_invalidasyncresult, nameof(asyncResult)); + // TODO https://github.com/dotnet/runtime/issues/17148: Wait without throwing + ((IAsyncResult)ti).AsyncWaitHandle.WaitOne(); } - if (castedAsyncResult.EndCalled) - { - throw new InvalidOperationException(SR.Format(SR.net_io_invalidendcall, "EndReceive")); - } - - int bytesTransferred = castedAsyncResult.InternalWaitForCompletionInt32Result(); - castedAsyncResult.EndCalled = true; - // Throw an appropriate SocketException if the native call failed asynchronously. - errorCode = (SocketError)castedAsyncResult.ErrorCode; - - UpdateReceiveSocketErrorForDisposed(ref errorCode, bytesTransferred); - if (errorCode != SocketError.Success) - { - // Update the internal state of this socket according to the error before throwing. - UpdateStatusAfterSocketError(errorCode); - if (NetEventSource.Log.IsEnabled()) NetEventSource.Error(this, new SocketException((int)errorCode)); - return 0; - } - else if (SocketsTelemetry.Log.IsEnabled()) + if (ti.IsCompletedSuccessfully) { - SocketsTelemetry.Log.BytesReceived(bytesTransferred); - if (SocketType == SocketType.Dgram) SocketsTelemetry.Log.DatagramReceived(); + errorCode = SocketError.Success; + return ti.Result; } - return bytesTransferred; + + errorCode = GetSocketErrorFromFaultedTask(ti); + return 0; } public IAsyncResult BeginReceiveMessageFrom(byte[] buffer, int offset, int size, SocketFlags socketFlags, ref EndPoint remoteEP, AsyncCallback? callback, object? state) @@ -2807,10 +2412,7 @@ public IAsyncResult BeginReceiveMessageFrom(byte[] buffer, int offset, int size, if (NetEventSource.Log.IsEnabled()) NetEventSource.Info(this, $"size:{size}"); ThrowIfDisposed(); - if (buffer == null) - { - throw new ArgumentNullException(nameof(buffer)); - } + ValidateBufferArguments(buffer, offset, size); if (remoteEP == null) { throw new ArgumentNullException(nameof(remoteEP)); @@ -2819,14 +2421,6 @@ public IAsyncResult BeginReceiveMessageFrom(byte[] buffer, int offset, int size, { throw new ArgumentException(SR.Format(SR.net_InvalidEndPointAddressFamily, remoteEP.AddressFamily, _addressFamily), nameof(remoteEP)); } - if (offset < 0 || offset > buffer.Length) - { - throw new ArgumentOutOfRangeException(nameof(offset)); - } - if (size < 0 || size > buffer.Length - offset) - { - throw new ArgumentOutOfRangeException(nameof(size)); - } if (_rightEndPoint == null) { throw new InvalidOperationException(SR.net_sockets_mustbind); @@ -2916,7 +2510,6 @@ public IAsyncResult BeginReceiveMessageFrom(byte[] buffer, int offset, int size, public int EndReceiveMessageFrom(IAsyncResult asyncResult, ref SocketFlags socketFlags, ref EndPoint endPoint, out IPPacketInformation ipPacketInformation) { - ThrowIfDisposed(); if (endPoint == null) { @@ -3008,12 +2601,7 @@ public int EndReceiveMessageFrom(IAsyncResult asyncResult, ref SocketFlags socke public IAsyncResult BeginReceiveFrom(byte[] buffer, int offset, int size, SocketFlags socketFlags, ref EndPoint remoteEP, AsyncCallback? callback, object? state) { ThrowIfDisposed(); - - // Validate input parameters. - if (buffer == null) - { - throw new ArgumentNullException(nameof(buffer)); - } + ValidateBufferArguments(buffer, offset, size); if (remoteEP == null) { throw new ArgumentNullException(nameof(remoteEP)); @@ -3022,14 +2610,6 @@ public IAsyncResult BeginReceiveFrom(byte[] buffer, int offset, int size, Socket { throw new ArgumentException(SR.Format(SR.net_InvalidEndPointAddressFamily, remoteEP.AddressFamily, _addressFamily), nameof(remoteEP)); } - if (offset < 0 || offset > buffer.Length) - { - throw new ArgumentOutOfRangeException(nameof(offset)); - } - if (size < 0 || size > buffer.Length - offset) - { - throw new ArgumentOutOfRangeException(nameof(size)); - } if (_rightEndPoint == null) { throw new InvalidOperationException(SR.net_sockets_mustbind); @@ -4549,6 +4129,25 @@ internal static void SocketListDangerousReleaseRefs(IList? socketList, ref int r } } + private static SocketError GetSocketErrorFromFaultedTask(Task t) + { + Debug.Assert(t.IsCanceled || t.IsFaulted); + + if (t.IsCanceled) + { + return SocketError.OperationAborted; + } + + Debug.Assert(t.Exception != null); + return t.Exception.InnerException switch + { + SocketException se => se.SocketErrorCode, + ObjectDisposedException => SocketError.OperationAborted, + OperationCanceledException => SocketError.OperationAborted, + _ => SocketError.SocketError + }; + } + #endregion } } diff --git a/src/libraries/System.Net.Sockets/src/System/Net/Sockets/SocketAsyncEventArgs.Windows.cs b/src/libraries/System.Net.Sockets/src/System/Net/Sockets/SocketAsyncEventArgs.Windows.cs index fa5b1b2b597c22..a4b92a40fa4a55 100644 --- a/src/libraries/System.Net.Sockets/src/System/Net/Sockets/SocketAsyncEventArgs.Windows.cs +++ b/src/libraries/System.Net.Sockets/src/System/Net/Sockets/SocketAsyncEventArgs.Windows.cs @@ -364,7 +364,7 @@ internal unsafe SocketError DoOperationReceiveSingleBuffer(SafeSocketHandle hand SocketFlags flags = _socketFlags; SocketError socketError = Interop.Winsock.WSARecv( handle, - ref wsaBuffer, + &wsaBuffer, 1, out int bytesTransferred, ref flags, @@ -598,7 +598,7 @@ internal unsafe SocketError DoOperationSendSingleBuffer(SafeSocketHandle handle, SocketError socketError = Interop.Winsock.WSASend( handle, - ref wsaBuffer, + &wsaBuffer, 1, out int bytesTransferred, _socketFlags, diff --git a/src/libraries/System.Net.Sockets/src/System/Net/Sockets/SocketPal.Unix.cs b/src/libraries/System.Net.Sockets/src/System/Net/Sockets/SocketPal.Unix.cs index 7c5e38bcde2ad6..653164730d760c 100644 --- a/src/libraries/System.Net.Sockets/src/System/Net/Sockets/SocketPal.Unix.cs +++ b/src/libraries/System.Net.Sockets/src/System/Net/Sockets/SocketPal.Unix.cs @@ -1770,28 +1770,6 @@ public static SocketError Shutdown(SafeSocketHandle handle, bool isConnected, bo return GetSocketErrorForErrorCode(err); } - public static SocketError SendAsync(SafeSocketHandle handle, byte[] buffer, int offset, int count, SocketFlags socketFlags, OverlappedAsyncResult asyncResult) - { - int bytesSent; - SocketError socketError = handle.AsyncContext.SendAsync(buffer, offset, count, socketFlags, out bytesSent, asyncResult.CompletionCallback, CancellationToken.None); - if (socketError == SocketError.Success) - { - asyncResult.CompletionCallback(bytesSent, null, 0, SocketFlags.None, SocketError.Success); - } - return socketError; - } - - public static SocketError SendAsync(SafeSocketHandle handle, IList> buffers, SocketFlags socketFlags, OverlappedAsyncResult asyncResult) - { - int bytesSent; - SocketError socketError = handle.AsyncContext.SendAsync(buffers, socketFlags, out bytesSent, asyncResult.CompletionCallback); - if (socketError == SocketError.Success) - { - asyncResult.CompletionCallback(bytesSent, null, 0, SocketFlags.None, SocketError.Success); - } - return socketError; - } - public static SocketError SendFileAsync(SafeSocketHandle handle, FileStream fileStream, Action callback) => SendFileAsync(handle, fileStream, 0, fileStream.Length, callback); @@ -1893,30 +1871,6 @@ public static SocketError SendToAsync(SafeSocketHandle handle, byte[] buffer, in return socketError; } - public static SocketError ReceiveAsync(SafeSocketHandle handle, byte[] buffer, int offset, int count, SocketFlags socketFlags, OverlappedAsyncResult asyncResult) - { - int bytesReceived; - SocketFlags receivedFlags; - SocketError socketError = handle.AsyncContext.ReceiveAsync(new Memory(buffer, offset, count), socketFlags, out bytesReceived, out receivedFlags, asyncResult.CompletionCallback, CancellationToken.None); - if (socketError == SocketError.Success) - { - asyncResult.CompletionCallback(bytesReceived, null, 0, receivedFlags, SocketError.Success); - } - return socketError; - } - - public static SocketError ReceiveAsync(SafeSocketHandle handle, IList> buffers, SocketFlags socketFlags, OverlappedAsyncResult asyncResult) - { - int bytesReceived; - SocketFlags receivedFlags; - SocketError socketError = handle.AsyncContext.ReceiveAsync(buffers, socketFlags, out bytesReceived, out receivedFlags, asyncResult.CompletionCallback); - if (socketError == SocketError.Success) - { - asyncResult.CompletionCallback(bytesReceived, null, 0, receivedFlags, SocketError.Success); - } - return socketError; - } - public static SocketError ReceiveFromAsync(SafeSocketHandle handle, byte[] buffer, int offset, int count, SocketFlags socketFlags, Internals.SocketAddress socketAddress, OverlappedAsyncResult asyncResult) { asyncResult.SocketAddress = socketAddress; diff --git a/src/libraries/System.Net.Sockets/src/System/Net/Sockets/SocketPal.Windows.cs b/src/libraries/System.Net.Sockets/src/System/Net/Sockets/SocketPal.Windows.cs index 430574a42f4559..f111ac03d4927f 100644 --- a/src/libraries/System.Net.Sockets/src/System/Net/Sockets/SocketPal.Windows.cs +++ b/src/libraries/System.Net.Sockets/src/System/Net/Sockets/SocketPal.Windows.cs @@ -1026,56 +1026,6 @@ public static SocketError Shutdown(SafeSocketHandle handle, bool isConnected, bo return err; } - public static unsafe SocketError SendAsync(SafeSocketHandle handle, byte[] buffer, int offset, int count, SocketFlags socketFlags, OverlappedAsyncResult asyncResult) - { - // Set up unmanaged structures for overlapped WSASend. - asyncResult.SetUnmanagedStructures(buffer, offset, count, null); - try - { - int bytesTransferred; - SocketError errorCode = Interop.Winsock.WSASend( - handle, - ref asyncResult._singleBuffer, - 1, // There is only ever 1 buffer being sent. - out bytesTransferred, - socketFlags, - asyncResult.DangerousOverlappedPointer, // SafeHandle was just created in SetUnmanagedStructures - IntPtr.Zero); - - return asyncResult.ProcessOverlappedResult(errorCode == SocketError.Success, bytesTransferred); - } - catch - { - asyncResult.ReleaseUnmanagedStructures(); - throw; - } - } - - public static unsafe SocketError SendAsync(SafeSocketHandle handle, IList> buffers, SocketFlags socketFlags, OverlappedAsyncResult asyncResult) - { - // Set up asyncResult for overlapped WSASend. - asyncResult.SetUnmanagedStructures(buffers); - try - { - int bytesTransferred; - SocketError errorCode = Interop.Winsock.WSASend( - handle, - asyncResult._wsaBuffers, - asyncResult._wsaBuffers!.Length, - out bytesTransferred, - socketFlags, - asyncResult.DangerousOverlappedPointer, // SafeHandle was just created in SetUnmanagedStructures - IntPtr.Zero); - - return asyncResult.ProcessOverlappedResult(errorCode == SocketError.Success, bytesTransferred); - } - catch - { - asyncResult.ReleaseUnmanagedStructures(); - throw; - } - } - // This assumes preBuffer/postBuffer are pinned already private static unsafe bool TransmitFileHelper( @@ -1175,31 +1125,6 @@ public static unsafe SocketError SendToAsync(SafeSocketHandle handle, byte[] buf } } - public static unsafe SocketError ReceiveAsync(SafeSocketHandle handle, byte[] buffer, int offset, int count, SocketFlags socketFlags, OverlappedAsyncResult asyncResult) - { - // Set up asyncResult for overlapped WSARecv. - asyncResult.SetUnmanagedStructures(buffer, offset, count, null); - try - { - int bytesTransferred; - SocketError errorCode = Interop.Winsock.WSARecv( - handle, - ref asyncResult._singleBuffer, - 1, - out bytesTransferred, - ref socketFlags, - asyncResult.DangerousOverlappedPointer, // SafeHandle was just created in SetUnmanagedStructures - IntPtr.Zero); - - return asyncResult.ProcessOverlappedResult(errorCode == SocketError.Success, bytesTransferred); - } - catch - { - asyncResult.ReleaseUnmanagedStructures(); - throw; - } - } - public static unsafe SocketError ReceiveAsync(SafeSocketHandle handle, IList> buffers, SocketFlags socketFlags, OverlappedAsyncResult asyncResult) { // Set up asyncResult for overlapped WSASend.