diff --git a/src/libraries/System.Net.Sockets/ref/System.Net.Sockets.cs b/src/libraries/System.Net.Sockets/ref/System.Net.Sockets.cs index 4a909e43a46300..3faafa799c5852 100644 --- a/src/libraries/System.Net.Sockets/ref/System.Net.Sockets.cs +++ b/src/libraries/System.Net.Sockets/ref/System.Net.Sockets.cs @@ -388,6 +388,7 @@ public void Listen(int backlog) { } public System.Threading.Tasks.ValueTask ReceiveFromAsync(System.Memory buffer, System.Net.Sockets.SocketFlags socketFlags, System.Net.EndPoint remoteEndPoint, System.Threading.CancellationToken cancellationToken = default(System.Threading.CancellationToken)) { throw null; } public bool ReceiveFromAsync(System.Net.Sockets.SocketAsyncEventArgs e) { throw null; } public int ReceiveMessageFrom(byte[] buffer, int offset, int size, ref System.Net.Sockets.SocketFlags socketFlags, ref System.Net.EndPoint remoteEP, out System.Net.Sockets.IPPacketInformation ipPacketInformation) { throw null; } + public int ReceiveMessageFrom(System.Span buffer, ref System.Net.Sockets.SocketFlags socketFlags, ref System.Net.EndPoint remoteEP, out System.Net.Sockets.IPPacketInformation ipPacketInformation) { throw null; } public System.Threading.Tasks.Task ReceiveMessageFromAsync(System.ArraySegment buffer, System.Net.Sockets.SocketFlags socketFlags, System.Net.EndPoint remoteEndPoint) { throw null; } public System.Threading.Tasks.ValueTask ReceiveMessageFromAsync(System.Memory buffer, System.Net.Sockets.SocketFlags socketFlags, System.Net.EndPoint remoteEndPoint, System.Threading.CancellationToken cancellationToken = default) { throw null; } public bool ReceiveMessageFromAsync(System.Net.Sockets.SocketAsyncEventArgs e) { throw null; } diff --git a/src/libraries/System.Net.Sockets/src/System/Net/Sockets/Socket.cs b/src/libraries/System.Net.Sockets/src/System/Net/Sockets/Socket.cs index fdd1055aaef557..d82934aadc5cda 100644 --- a/src/libraries/System.Net.Sockets/src/System/Net/Sockets/Socket.cs +++ b/src/libraries/System.Net.Sockets/src/System/Net/Sockets/Socket.cs @@ -1603,6 +1603,100 @@ public int ReceiveMessageFrom(byte[] buffer, int offset, int size, ref SocketFla return bytesTransferred; } + /// + /// Receives the specified number of bytes of data into the specified location of the data buffer, + /// using the specified , and stores the endpoint and packet information. + /// + /// + /// An of type that is the storage location for received data. + /// + /// + /// A bitwise combination of the values. + /// + /// + /// An , passed by reference, that represents the remote server. + /// + /// + /// An holding address and interface information. + /// + /// + /// The number of bytes received. + /// + /// The object has been closed. + /// The remoteEP is null. + /// The of the used in + /// + /// needs to match the of the used in SendTo. + /// + /// The object is not in blocking mode and cannot accept this synchronous call. + /// You must call the Bind method before performing this operation. + public int ReceiveMessageFrom(Span buffer, ref SocketFlags socketFlags, ref EndPoint remoteEP, out IPPacketInformation ipPacketInformation) + { + ThrowIfDisposed(); + + if (remoteEP == null) + { + throw new ArgumentNullException(nameof(remoteEP)); + } + if (!CanTryAddressFamily(remoteEP.AddressFamily)) + { + throw new ArgumentException(SR.Format(SR.net_InvalidEndPointAddressFamily, remoteEP.AddressFamily, _addressFamily), nameof(remoteEP)); + } + if (_rightEndPoint == null) + { + throw new InvalidOperationException(SR.net_sockets_mustbind); + } + + SocketPal.CheckDualModeReceiveSupport(this); + ValidateBlockingMode(); + + // We don't do a CAS demand here because the contents of remoteEP aren't used by + // WSARecvMsg; all that matters is that we generate a unique-to-this-call SocketAddress + // with the right address family. + EndPoint endPointSnapshot = remoteEP; + Internals.SocketAddress socketAddress = Serialize(ref endPointSnapshot); + + // Save a copy of the original EndPoint. + Internals.SocketAddress socketAddressOriginal = IPEndPointExtensions.Serialize(endPointSnapshot); + + SetReceivingPacketInformation(); + + Internals.SocketAddress receiveAddress; + int bytesTransferred; + SocketError errorCode = SocketPal.ReceiveMessageFrom(this, _handle, buffer, ref socketFlags, socketAddress, out receiveAddress, out ipPacketInformation, out bytesTransferred); + + UpdateReceiveSocketErrorForDisposed(ref errorCode, bytesTransferred); + // Throw an appropriate SocketException if the native call fails. + if (errorCode != SocketError.Success && errorCode != SocketError.MessageSize) + { + UpdateStatusAfterSocketErrorAndThrowException(errorCode); + } + else if (SocketsTelemetry.Log.IsEnabled()) + { + SocketsTelemetry.Log.BytesReceived(bytesTransferred); + if (errorCode == SocketError.Success && SocketType == SocketType.Dgram) SocketsTelemetry.Log.DatagramReceived(); + } + + if (!socketAddressOriginal.Equals(receiveAddress)) + { + try + { + remoteEP = endPointSnapshot.Create(receiveAddress); + } + catch + { + } + if (_rightEndPoint == null) + { + // Save a copy of the EndPoint so we can use it for Create(). + _rightEndPoint = endPointSnapshot; + } + } + + if (NetEventSource.Log.IsEnabled()) NetEventSource.Error(this, errorCode); + return bytesTransferred; + } + // Receives a datagram into a specific location in the data buffer and stores // the end point. public int ReceiveFrom(byte[] buffer, int offset, int size, SocketFlags socketFlags, ref EndPoint remoteEP) diff --git a/src/libraries/System.Net.Sockets/src/System/Net/Sockets/SocketAsyncContext.Unix.cs b/src/libraries/System.Net.Sockets/src/System/Net/Sockets/SocketAsyncContext.Unix.cs index 2f01b6b543ceeb..4d15435d3567ef 100644 --- a/src/libraries/System.Net.Sockets/src/System/Net/Sockets/SocketAsyncContext.Unix.cs +++ b/src/libraries/System.Net.Sockets/src/System/Net/Sockets/SocketAsyncContext.Unix.cs @@ -565,6 +565,31 @@ public override void InvokeCallback(bool allowPooling) => Callback!(BytesTransferred, SocketAddress!, SocketAddressLen, ReceivedFlags, IPPacketInformation, ErrorCode); } + private sealed unsafe class BufferPtrReceiveMessageFromOperation : ReadOperation + { + public byte* BufferPtr; + public int Length; + public SocketFlags Flags; + public int BytesTransferred; + public SocketFlags ReceivedFlags; + + public bool IsIPv4; + public bool IsIPv6; + public IPPacketInformation IPPacketInformation; + + public BufferPtrReceiveMessageFromOperation(SocketAsyncContext context) : base(context) { } + + protected sealed override void Abort() { } + + public Action? Callback { get; set; } + + protected override bool DoTryComplete(SocketAsyncContext context) => + SocketPal.TryCompleteReceiveMessageFrom(context._socket, new Span(BufferPtr, Length), null, Flags, SocketAddress!, ref SocketAddressLen, IsIPv4, IsIPv6, out BytesTransferred, out ReceivedFlags, out IPPacketInformation, out ErrorCode); + + public override void InvokeCallback(bool allowPooling) => + Callback!(BytesTransferred, SocketAddress!, SocketAddressLen, ReceivedFlags, IPPacketInformation, ErrorCode); + } + private sealed class AcceptOperation : ReadOperation { public IntPtr AcceptedFileDescriptor; @@ -1696,7 +1721,7 @@ public SocketError ReceiveFromAsync(IList> buffers, SocketFla } public SocketError ReceiveMessageFrom( - Memory buffer, IList>? buffers, ref SocketFlags flags, byte[] socketAddress, ref int socketAddressLen, bool isIPv4, bool isIPv6, int timeout, out IPPacketInformation ipPacketInformation, out int bytesReceived) + Memory buffer, ref SocketFlags flags, byte[] socketAddress, ref int socketAddressLen, bool isIPv4, bool isIPv6, int timeout, out IPPacketInformation ipPacketInformation, out int bytesReceived) { Debug.Assert(timeout == -1 || timeout > 0, $"Unexpected timeout: {timeout}"); @@ -1704,7 +1729,7 @@ public SocketError ReceiveMessageFrom( SocketError errorCode; int observedSequenceNumber; if (_receiveQueue.IsReady(this, out observedSequenceNumber) && - (SocketPal.TryCompleteReceiveMessageFrom(_socket, buffer.Span, buffers, flags, socketAddress, ref socketAddressLen, isIPv4, isIPv6, out bytesReceived, out receivedFlags, out ipPacketInformation, out errorCode) || + (SocketPal.TryCompleteReceiveMessageFrom(_socket, buffer.Span, null, flags, socketAddress, ref socketAddressLen, isIPv4, isIPv6, out bytesReceived, out receivedFlags, out ipPacketInformation, out errorCode) || !ShouldRetrySyncOperation(out errorCode))) { flags = receivedFlags; @@ -1714,7 +1739,7 @@ public SocketError ReceiveMessageFrom( var operation = new ReceiveMessageFromOperation(this) { Buffer = buffer, - Buffers = buffers, + Buffers = null, Flags = flags, SocketAddress = socketAddress, SocketAddressLen = socketAddressLen, @@ -1731,6 +1756,45 @@ public SocketError ReceiveMessageFrom( return operation.ErrorCode; } + public unsafe SocketError ReceiveMessageFrom( + Span buffer, ref SocketFlags flags, byte[] socketAddress, ref int socketAddressLen, bool isIPv4, bool isIPv6, int timeout, out IPPacketInformation ipPacketInformation, out int bytesReceived) + { + Debug.Assert(timeout == -1 || timeout > 0, $"Unexpected timeout: {timeout}"); + + SocketFlags receivedFlags; + SocketError errorCode; + int observedSequenceNumber; + if (_receiveQueue.IsReady(this, out observedSequenceNumber) && + (SocketPal.TryCompleteReceiveMessageFrom(_socket, buffer, null, flags, socketAddress, ref socketAddressLen, isIPv4, isIPv6, out bytesReceived, out receivedFlags, out ipPacketInformation, out errorCode) || + !ShouldRetrySyncOperation(out errorCode))) + { + flags = receivedFlags; + return errorCode; + } + + fixed (byte* bufferPtr = &MemoryMarshal.GetReference(buffer)) + { + var operation = new BufferPtrReceiveMessageFromOperation(this) + { + BufferPtr = bufferPtr, + Length = buffer.Length, + Flags = flags, + SocketAddress = socketAddress, + SocketAddressLen = socketAddressLen, + IsIPv4 = isIPv4, + IsIPv6 = isIPv6, + }; + + PerformSyncOperation(ref _receiveQueue, operation, timeout, observedSequenceNumber); + + socketAddressLen = operation.SocketAddressLen; + flags = operation.ReceivedFlags; + ipPacketInformation = operation.IPPacketInformation; + bytesReceived = operation.BytesTransferred; + return operation.ErrorCode; + } + } + public SocketError ReceiveMessageFromAsync(Memory buffer, IList>? buffers, SocketFlags flags, byte[] socketAddress, ref int socketAddressLen, bool isIPv4, bool isIPv6, out int bytesReceived, out SocketFlags receivedFlags, out IPPacketInformation ipPacketInformation, Action callback, CancellationToken cancellationToken = default) { SetNonBlocking(); diff --git a/src/libraries/System.Net.Sockets/src/System/Net/Sockets/SocketPal.Unix.cs b/src/libraries/System.Net.Sockets/src/System/Net/Sockets/SocketPal.Unix.cs index 019bd5f165db77..8b1f87071194ab 100644 --- a/src/libraries/System.Net.Sockets/src/System/Net/Sockets/SocketPal.Unix.cs +++ b/src/libraries/System.Net.Sockets/src/System/Net/Sockets/SocketPal.Unix.cs @@ -1172,7 +1172,7 @@ public static SocketError ReceiveMessageFrom(Socket socket, SafeSocketHandle han SocketError errorCode; if (!handle.IsNonBlocking) { - errorCode = handle.AsyncContext.ReceiveMessageFrom(new Memory(buffer, offset, count), null, ref socketFlags, socketAddressBuffer, ref socketAddressLen, isIPv4, isIPv6, handle.ReceiveTimeout, out ipPacketInformation, out bytesTransferred); + errorCode = handle.AsyncContext.ReceiveMessageFrom(new Memory(buffer, offset, count), ref socketFlags, socketAddressBuffer, ref socketAddressLen, isIPv4, isIPv6, handle.ReceiveTimeout, out ipPacketInformation, out bytesTransferred); } else { @@ -1187,6 +1187,33 @@ public static SocketError ReceiveMessageFrom(Socket socket, SafeSocketHandle han return errorCode; } + + public static SocketError ReceiveMessageFrom(Socket socket, SafeSocketHandle handle, Span buffer, ref SocketFlags socketFlags, Internals.SocketAddress socketAddress, out Internals.SocketAddress receiveAddress, out IPPacketInformation ipPacketInformation, out int bytesTransferred) + { + byte[] socketAddressBuffer = socketAddress.Buffer; + int socketAddressLen = socketAddress.Size; + + bool isIPv4, isIPv6; + Socket.GetIPProtocolInformation(socket.AddressFamily, socketAddress, out isIPv4, out isIPv6); + + SocketError errorCode; + if (!handle.IsNonBlocking) + { + errorCode = handle.AsyncContext.ReceiveMessageFrom(buffer, ref socketFlags, socketAddressBuffer, ref socketAddressLen, isIPv4, isIPv6, handle.ReceiveTimeout, out ipPacketInformation, out bytesTransferred); + } + else + { + if (!TryCompleteReceiveMessageFrom(handle, buffer, null, socketFlags, socketAddressBuffer, ref socketAddressLen, isIPv4, isIPv6, out bytesTransferred, out socketFlags, out ipPacketInformation, out errorCode)) + { + errorCode = SocketError.WouldBlock; + } + } + + socketAddress.InternalSize = socketAddressLen; + receiveAddress = socketAddress; + return errorCode; + } + public static SocketError ReceiveFrom(SafeSocketHandle handle, byte[] buffer, int offset, int count, SocketFlags socketFlags, byte[] socketAddress, ref int socketAddressLen, out int bytesTransferred) { if (!handle.IsNonBlocking) diff --git a/src/libraries/System.Net.Sockets/src/System/Net/Sockets/SocketPal.Windows.cs b/src/libraries/System.Net.Sockets/src/System/Net/Sockets/SocketPal.Windows.cs index ab701286e435c6..09858566bc28eb 100644 --- a/src/libraries/System.Net.Sockets/src/System/Net/Sockets/SocketPal.Windows.cs +++ b/src/libraries/System.Net.Sockets/src/System/Net/Sockets/SocketPal.Windows.cs @@ -451,6 +451,11 @@ public static unsafe IPPacketInformation GetIPPacketInformation(Interop.Winsock. } public static unsafe SocketError ReceiveMessageFrom(Socket socket, SafeSocketHandle handle, byte[] buffer, int offset, int size, ref SocketFlags socketFlags, Internals.SocketAddress socketAddress, out Internals.SocketAddress receiveAddress, out IPPacketInformation ipPacketInformation, out int bytesTransferred) + { + return ReceiveMessageFrom(socket, handle, new Span(buffer, offset, size), ref socketFlags, socketAddress, out receiveAddress, out ipPacketInformation, out bytesTransferred); + } + + public static unsafe SocketError ReceiveMessageFrom(Socket socket, SafeSocketHandle handle, Span buffer, ref SocketFlags socketFlags, Internals.SocketAddress socketAddress, out Internals.SocketAddress receiveAddress, out IPPacketInformation ipPacketInformation, out int bytesTransferred) { bool ipv4, ipv6; Socket.GetIPProtocolInformation(socket.AddressFamily, socketAddress, out ipv4, out ipv6); @@ -458,7 +463,7 @@ public static unsafe SocketError ReceiveMessageFrom(Socket socket, SafeSocketHan bytesTransferred = 0; receiveAddress = socketAddress; ipPacketInformation = default(IPPacketInformation); - fixed (byte* ptrBuffer = buffer) + fixed (byte* bufferPtr = &MemoryMarshal.GetReference(buffer)) fixed (byte* ptrSocketAddress = socketAddress.Buffer) { Interop.Winsock.WSAMsg wsaMsg; @@ -467,8 +472,8 @@ public static unsafe SocketError ReceiveMessageFrom(Socket socket, SafeSocketHan wsaMsg.flags = socketFlags; WSABuffer wsaBuffer; - wsaBuffer.Length = size; - wsaBuffer.Pointer = (IntPtr)(ptrBuffer + offset); + wsaBuffer.Length = buffer.Length; + wsaBuffer.Pointer = (IntPtr)bufferPtr; wsaMsg.buffers = (IntPtr)(&wsaBuffer); wsaMsg.count = 1; diff --git a/src/libraries/System.Net.Sockets/tests/FunctionalTests/ReceiveMessageFrom.cs b/src/libraries/System.Net.Sockets/tests/FunctionalTests/ReceiveMessageFrom.cs index f948b1be0c6d52..49d95c395a65a7 100644 --- a/src/libraries/System.Net.Sockets/tests/FunctionalTests/ReceiveMessageFrom.cs +++ b/src/libraries/System.Net.Sockets/tests/FunctionalTests/ReceiveMessageFrom.cs @@ -54,6 +54,8 @@ public async Task ReceiveSent_TCP_Success(bool ipv6) [InlineData(true)] public async Task ReceiveSentMessages_UDP_Success(bool ipv4) { + // [ActiveIssue("https://github.com/dotnet/runtime/issues/47637")] + int Offset = UsesSync || !PlatformDetection.IsWindows ? 10 : 0; const int DatagramSize = 256; const int DatagramsToSend = 16; @@ -69,7 +71,9 @@ public async Task ReceiveSentMessages_UDP_Success(bool ipv4) sender.BindToAnonymousPort(address); byte[] sendBuffer = new byte[DatagramSize]; - byte[] receiveBuffer = new byte[DatagramSize]; + var receiveInternalBuffer = new byte[DatagramSize + Offset]; + var emptyBuffer = new byte[Offset]; + ArraySegment receiveBuffer = new ArraySegment(receiveInternalBuffer, Offset, DatagramSize); Random rnd = new Random(0); IPEndPoint remoteEp = new IPEndPoint(ipv4 ? IPAddress.Any : IPAddress.IPv6Any, 0); @@ -83,7 +87,8 @@ public async Task ReceiveSentMessages_UDP_Success(bool ipv4) IPPacketInformation packetInformation = result.PacketInformation; Assert.Equal(DatagramSize, result.ReceivedBytes); - AssertExtensions.SequenceEqual(sendBuffer, receiveBuffer); + AssertExtensions.SequenceEqual(emptyBuffer, new ReadOnlySpan(receiveInternalBuffer, 0, Offset)); + AssertExtensions.SequenceEqual(sendBuffer, new ReadOnlySpan(receiveInternalBuffer, Offset, DatagramSize)); Assert.Equal(sender.LocalEndPoint, result.RemoteEndPoint); Assert.Equal(((IPEndPoint)sender.LocalEndPoint).Address, packetInformation.Address); } diff --git a/src/libraries/System.Net.Sockets/tests/FunctionalTests/SocketTestHelper.cs b/src/libraries/System.Net.Sockets/tests/FunctionalTests/SocketTestHelper.cs index a401f50d27f03d..615cca73aca0f1 100644 --- a/src/libraries/System.Net.Sockets/tests/FunctionalTests/SocketTestHelper.cs +++ b/src/libraries/System.Net.Sockets/tests/FunctionalTests/SocketTestHelper.cs @@ -182,7 +182,7 @@ public override Task ReceiveMessageFromAsync(Soc tcs.TrySetResult(result); } catch (Exception e) { tcs.TrySetException(e); } - + }, null); return tcs.Task; } @@ -440,6 +440,20 @@ public override Task ReceiveAsync(Socket s, ArraySegment buffer) => Task.Run(() => s.Receive((Span)buffer, SocketFlags.None)); public override Task SendAsync(Socket s, ArraySegment buffer) => Task.Run(() => s.Send((ReadOnlySpan)buffer, SocketFlags.None)); + public override Task ReceiveMessageFromAsync(Socket s, ArraySegment buffer, EndPoint endPoint) => + Task.Run(() => + { + SocketFlags socketFlags = SocketFlags.None; + IPPacketInformation ipPacketInformation; + int received = s.ReceiveMessageFrom((Span)buffer, ref socketFlags, ref endPoint, out ipPacketInformation); + return new SocketReceiveMessageFromResult + { + ReceivedBytes = received, + SocketFlags = socketFlags, + RemoteEndPoint = endPoint, + PacketInformation = ipPacketInformation + }; + }); public override Task SendFileAsync(Socket s, string fileName, ArraySegment preBuffer, ArraySegment postBuffer, TransmitFileOptions flags) => Task.Run(() => s.SendFile(fileName, preBuffer, postBuffer, flags)); public override bool UsesSync => true; diff --git a/src/libraries/System.Net.Sockets/tests/FunctionalTests/System.Net.Sockets.Tests.csproj b/src/libraries/System.Net.Sockets/tests/FunctionalTests/System.Net.Sockets.Tests.csproj index b0c0abb4a03377..ffd474d4bcdb52 100644 --- a/src/libraries/System.Net.Sockets/tests/FunctionalTests/System.Net.Sockets.Tests.csproj +++ b/src/libraries/System.Net.Sockets/tests/FunctionalTests/System.Net.Sockets.Tests.csproj @@ -1,4 +1,4 @@ - + true true