diff --git a/src/libraries/Common/src/Interop/Unix/System.Native/Interop.Fcntl.cs b/src/libraries/Common/src/Interop/Unix/System.Native/Interop.Fcntl.cs index cec6451b509135..0877ca707c5759 100644 --- a/src/libraries/Common/src/Interop/Unix/System.Native/Interop.Fcntl.cs +++ b/src/libraries/Common/src/Interop/Unix/System.Native/Interop.Fcntl.cs @@ -17,6 +17,9 @@ internal static partial class Fcntl [DllImport(Libraries.SystemNative, EntryPoint = "SystemNative_FcntlSetIsNonBlocking", SetLastError=true)] internal static extern int SetIsNonBlocking(SafeHandle fd, int isNonBlocking); + [DllImport(Libraries.SystemNative, EntryPoint = "SystemNative_FcntlGetIsNonBlocking", SetLastError = true)] + internal static extern int GetIsNonBlocking(SafeHandle fd, out bool isNonBlocking); + [DllImport(Libraries.SystemNative, EntryPoint = "SystemNative_FcntlSetFD", SetLastError=true)] internal static extern int SetFD(SafeHandle fd, int flags); diff --git a/src/libraries/Common/src/Interop/Unix/System.Native/Interop.GetSocketType.cs b/src/libraries/Common/src/Interop/Unix/System.Native/Interop.GetSocketType.cs new file mode 100644 index 00000000000000..f857bf32f35375 --- /dev/null +++ b/src/libraries/Common/src/Interop/Unix/System.Native/Interop.GetSocketType.cs @@ -0,0 +1,15 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. +// See the LICENSE file in the project root for more information. + +using System.Net.Sockets; +using System.Runtime.InteropServices; + +internal static partial class Interop +{ + internal static partial class Sys + { + [DllImport(Libraries.SystemNative, EntryPoint = "SystemNative_GetSocketType")] + internal static extern Error GetSocketType(SafeSocketHandle socket, out AddressFamily addressFamily, out SocketType socketType, out ProtocolType protocolType); + } +} diff --git a/src/libraries/Common/src/Interop/Unix/System.Native/Interop.Stat.cs b/src/libraries/Common/src/Interop/Unix/System.Native/Interop.Stat.cs index d06fbda71852ab..628bd9c70ce40b 100644 --- a/src/libraries/Common/src/Interop/Unix/System.Native/Interop.Stat.cs +++ b/src/libraries/Common/src/Interop/Unix/System.Native/Interop.Stat.cs @@ -55,7 +55,7 @@ internal enum FileStatusFlags } [DllImport(Libraries.SystemNative, EntryPoint = "SystemNative_FStat", SetLastError = true)] - internal static extern int FStat(SafeFileHandle fd, out FileStatus output); + internal static extern int FStat(SafeHandle fd, out FileStatus output); [DllImport(Libraries.SystemNative, EntryPoint = "SystemNative_Stat", SetLastError = true)] internal static extern int Stat(string path, out FileStatus output); diff --git a/src/libraries/Common/src/Interop/Windows/WinSock/Interop.WSADuplicateSocket.cs b/src/libraries/Common/src/Interop/Windows/WinSock/Interop.WSADuplicateSocket.cs index 23fa0328c440d3..d079c1c2509841 100644 --- a/src/libraries/Common/src/Interop/Windows/WinSock/Interop.WSADuplicateSocket.cs +++ b/src/libraries/Common/src/Interop/Windows/WinSock/Interop.WSADuplicateSocket.cs @@ -2,7 +2,6 @@ // The .NET Foundation licenses this file to you under the MIT license. // See the LICENSE file in the project root for more information. -using System; using System.Net.Sockets; using System.Runtime.InteropServices; @@ -10,42 +9,6 @@ internal static partial class Interop { internal static partial class Winsock { - [StructLayout(LayoutKind.Sequential)] - internal unsafe struct WSAPROTOCOLCHAIN - { - private const int MAX_PROTOCOL_CHAIN = 7; - - internal int ChainLen; - internal fixed uint ChainEntries[MAX_PROTOCOL_CHAIN]; - } - - [StructLayout(LayoutKind.Sequential, CharSet = CharSet.Unicode)] - internal unsafe struct WSAPROTOCOL_INFOW - { - private const int WSAPROTOCOL_LEN = 255; - - internal uint dwServiceFlags1; - internal uint dwServiceFlags2; - internal uint dwServiceFlags3; - internal uint dwServiceFlags4; - internal uint dwProviderFlags; - internal Guid ProviderId; - internal uint dwCatalogEntryId; - internal WSAPROTOCOLCHAIN ProtocolChain; - internal int iVersion; - internal AddressFamily iAddressFamily; - internal int iMaxSockAddr; - internal int iMinSockAddr; - internal SocketType iSocketType; - internal ProtocolType iProtocol; - internal int iProtocolMaxOffset; - internal int iNetworkByteOrder; - internal int iSecurityScheme; - internal uint dwMessageSize; - internal uint dwProviderReserved; - internal fixed char szProtocol[WSAPROTOCOL_LEN + 1]; - } - [DllImport(Interop.Libraries.Ws2_32, CharSet = CharSet.Unicode, SetLastError = true)] internal static extern unsafe int WSADuplicateSocket( [In] SafeSocketHandle s, diff --git a/src/libraries/Common/src/Interop/Windows/WinSock/Interop.WSAPROTOCOL_INFOW.cs b/src/libraries/Common/src/Interop/Windows/WinSock/Interop.WSAPROTOCOL_INFOW.cs new file mode 100644 index 00000000000000..502b28466fbdef --- /dev/null +++ b/src/libraries/Common/src/Interop/Windows/WinSock/Interop.WSAPROTOCOL_INFOW.cs @@ -0,0 +1,51 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. +// See the LICENSE file in the project root for more information. + +using System; +using System.Runtime.InteropServices; +using System.Net.Sockets; + +internal static partial class Interop +{ + internal static partial class Winsock + { + public const int SO_PROTOCOL_INFOW = 0x2005; + + [StructLayout(LayoutKind.Sequential, CharSet = CharSet.Unicode)] + internal unsafe struct WSAPROTOCOL_INFOW + { + private const int WSAPROTOCOL_LEN = 255; + + internal uint dwServiceFlags1; + internal uint dwServiceFlags2; + internal uint dwServiceFlags3; + internal uint dwServiceFlags4; + internal uint dwProviderFlags; + internal Guid ProviderId; + internal uint dwCatalogEntryId; + internal WSAPROTOCOLCHAIN ProtocolChain; + internal int iVersion; + internal AddressFamily iAddressFamily; + internal int iMaxSockAddr; + internal int iMinSockAddr; + internal SocketType iSocketType; + internal ProtocolType iProtocol; + internal int iProtocolMaxOffset; + internal int iNetworkByteOrder; + internal int iSecurityScheme; + internal uint dwMessageSize; + internal uint dwProviderReserved; + internal fixed char szProtocol[WSAPROTOCOL_LEN + 1]; + } + + [StructLayout(LayoutKind.Sequential)] + internal unsafe struct WSAPROTOCOLCHAIN + { + private const int MAX_PROTOCOL_CHAIN = 7; + + internal int ChainLen; + internal fixed uint ChainEntries[MAX_PROTOCOL_CHAIN]; + } + } +} diff --git a/src/libraries/Common/src/Interop/Windows/WinSock/Interop.getsockname.cs b/src/libraries/Common/src/Interop/Windows/WinSock/Interop.getsockname.cs index 4d44869467cafe..5471f62dcb7f80 100644 --- a/src/libraries/Common/src/Interop/Windows/WinSock/Interop.getsockname.cs +++ b/src/libraries/Common/src/Interop/Windows/WinSock/Interop.getsockname.cs @@ -10,9 +10,9 @@ internal static partial class Interop internal static partial class Winsock { [DllImport(Interop.Libraries.Ws2_32, SetLastError = true)] - internal static extern SocketError getsockname( - [In] SafeSocketHandle socketHandle, - [Out] byte[] socketAddress, - [In, Out] ref int socketAddressSize); + internal static extern unsafe SocketError getsockname( + SafeSocketHandle socketHandle, + byte* socketAddress, + int* socketAddressSize); } } diff --git a/src/libraries/Common/src/System/Net/ByteOrder.cs b/src/libraries/Common/src/System/Net/ByteOrder.cs index 1783debe7a4f82..5bf33658ffb11e 100644 --- a/src/libraries/Common/src/System/Net/ByteOrder.cs +++ b/src/libraries/Common/src/System/Net/ByteOrder.cs @@ -12,7 +12,7 @@ public static void HostToNetworkBytes(this ushort host, byte[] bytes, int index) bytes[index + 1] = unchecked((byte)host); } - public static ushort NetworkBytesToHostUInt16(this byte[] bytes, int index) + public static ushort NetworkBytesToHostUInt16(this ReadOnlySpan bytes, int index) { return (ushort)(((ushort)bytes[index] << 8) | (ushort)bytes[index + 1]); } diff --git a/src/libraries/Common/src/System/Net/Internals/IPEndPointExtensions.cs b/src/libraries/Common/src/System/Net/Internals/IPEndPointExtensions.cs index 4c095f7b3559a4..89becdde08ba5a 100644 --- a/src/libraries/Common/src/System/Net/Internals/IPEndPointExtensions.cs +++ b/src/libraries/Common/src/System/Net/Internals/IPEndPointExtensions.cs @@ -55,7 +55,7 @@ private static Internals.SocketAddress GetInternalSocketAddress(System.Net.Socke return result; } - private static System.Net.SocketAddress GetNetSocketAddress(Internals.SocketAddress address) + internal static System.Net.SocketAddress GetNetSocketAddress(Internals.SocketAddress address) { var result = new System.Net.SocketAddress(address.Family, address.Size); for (int index = 0; index < address.Size; index++) diff --git a/src/libraries/Common/src/System/Net/SocketAddress.cs b/src/libraries/Common/src/System/Net/SocketAddress.cs index a3a93bdb3bedba..b02d4a799d36c9 100644 --- a/src/libraries/Common/src/System/Net/SocketAddress.cs +++ b/src/libraries/Common/src/System/Net/SocketAddress.cs @@ -130,6 +130,12 @@ internal SocketAddress(IPAddress ipaddress, int port) SocketAddressPal.SetPort(Buffer, unchecked((ushort)port)); } + internal SocketAddress(AddressFamily addressFamily, ReadOnlySpan buffer) + { + Buffer = buffer.ToArray(); + InternalSize = Buffer.Length; + } + internal IPAddress GetIPAddress() { if (Family == AddressFamily.InterNetworkV6) diff --git a/src/libraries/Common/src/System/Net/SocketAddressPal.Unix.cs b/src/libraries/Common/src/System/Net/SocketAddressPal.Unix.cs index 027f050cec8557..c28b62ce70ba8a 100644 --- a/src/libraries/Common/src/System/Net/SocketAddressPal.Unix.cs +++ b/src/libraries/Common/src/System/Net/SocketAddressPal.Unix.cs @@ -52,7 +52,7 @@ private static void ThrowOnFailure(Interop.Error err) } } - public static unsafe AddressFamily GetAddressFamily(byte[] buffer) + public static unsafe AddressFamily GetAddressFamily(ReadOnlySpan buffer) { AddressFamily family; Interop.Error err; @@ -76,7 +76,7 @@ public static unsafe void SetAddressFamily(byte[] buffer, AddressFamily family) ThrowOnFailure(err); } - public static unsafe ushort GetPort(byte[] buffer) + public static unsafe ushort GetPort(ReadOnlySpan buffer) { ushort port; Interop.Error err; diff --git a/src/libraries/Common/src/System/Net/SocketAddressPal.Windows.cs b/src/libraries/Common/src/System/Net/SocketAddressPal.Windows.cs index b28bad4eeebb1f..82f0a6d6bbde75 100644 --- a/src/libraries/Common/src/System/Net/SocketAddressPal.Windows.cs +++ b/src/libraries/Common/src/System/Net/SocketAddressPal.Windows.cs @@ -11,9 +11,9 @@ internal static class SocketAddressPal public const int IPv6AddressSize = 28; public const int IPv4AddressSize = 16; - public static unsafe AddressFamily GetAddressFamily(byte[] buffer) + public static unsafe AddressFamily GetAddressFamily(ReadOnlySpan buffer) { - return (AddressFamily)BitConverter.ToInt16(buffer, 0); + return (AddressFamily)BitConverter.ToInt16(buffer); } public static unsafe void SetAddressFamily(byte[] buffer, AddressFamily family) @@ -35,7 +35,7 @@ public static unsafe void SetAddressFamily(byte[] buffer, AddressFamily family) #endif } - public static unsafe ushort GetPort(byte[] buffer) + public static unsafe ushort GetPort(ReadOnlySpan buffer) { return buffer.NetworkBytesToHostUInt16(2); } diff --git a/src/libraries/Native/Unix/System.Native/pal_io.c b/src/libraries/Native/Unix/System.Native/pal_io.c index d3690fa2f5b858..609084644c5c4c 100644 --- a/src/libraries/Native/Unix/System.Native/pal_io.c +++ b/src/libraries/Native/Unix/System.Native/pal_io.c @@ -594,6 +594,24 @@ int32_t SystemNative_FcntlSetIsNonBlocking(intptr_t fd, int32_t isNonBlocking) return fcntl(fileDescriptor, F_SETFL, flags); } +int32_t SystemNative_FcntlGetIsNonBlocking(intptr_t fd, int32_t* isNonBlocking) +{ + if (isNonBlocking == NULL) + { + return Error_EFAULT; + } + + int flags = fcntl(ToFileDescriptor(fd), F_GETFL); + if (flags == -1) + { + *isNonBlocking = 0; + return -1; + } + + *isNonBlocking = ((flags & O_NONBLOCK) == O_NONBLOCK) ? 1 : 0; + return 0; +} + int32_t SystemNative_MkDir(const char* path, int32_t mode) { int32_t result; diff --git a/src/libraries/Native/Unix/System.Native/pal_io.h b/src/libraries/Native/Unix/System.Native/pal_io.h index 136f534e163e4e..9f68aa3d62f9ea 100644 --- a/src/libraries/Native/Unix/System.Native/pal_io.h +++ b/src/libraries/Native/Unix/System.Native/pal_io.h @@ -471,6 +471,13 @@ PALEXPORT int32_t SystemNative_FcntlSetPipeSz(intptr_t fd, int32_t size); */ PALEXPORT int32_t SystemNative_FcntlSetIsNonBlocking(intptr_t fd, int32_t isNonBlocking); +/** + * Gets whether or not a file descriptor is non-blocking. + * + * Returns 0 for success, -1 for failure. Sets errno for failure. + */ +PALEXPORT int32_t SystemNative_FcntlGetIsNonBlocking(intptr_t fd, int32_t* isNonBlocking); + /** * Create a directory. Implemented as a shim to mkdir(2). * diff --git a/src/libraries/Native/Unix/System.Native/pal_networking.c b/src/libraries/Native/Unix/System.Native/pal_networking.c index 7faa21b3923ac8..08540e24b18e54 100644 --- a/src/libraries/Native/Unix/System.Native/pal_networking.c +++ b/src/libraries/Native/Unix/System.Native/pal_networking.c @@ -1524,7 +1524,6 @@ int32_t SystemNative_GetPeerName(intptr_t socket, uint8_t* socketAddress, int32_ return SystemNative_ConvertErrorPlatformToPal(errno); } - assert(addrLen <= (socklen_t)*socketAddressLen); *socketAddressLen = (int32_t)addrLen; return Error_SUCCESS; } @@ -2254,6 +2253,128 @@ static bool TryConvertProtocolTypePalToPlatform(int32_t palAddressFamily, int32_ } } +static bool TryConvertProtocolTypePlatformToPal(int32_t palAddressFamily, int platformProtocolType, int32_t* palProtocolType) +{ + assert(palProtocolType != NULL); + + switch (palAddressFamily) + { +#ifdef AF_PACKET + case AddressFamily_AF_PACKET: + // protocol is the IEEE 802.3 protocol number in network order. + *palProtocolType = platformProtocolType; + return true; +#endif +#if HAVE_LINUX_CAN_H + case AddressFamily_AF_CAN: + switch (platformProtocolType) + { + case 0: + *palProtocolType = ProtocolType_PT_UNSPECIFIED; + return true; + + case CAN_RAW: + *palProtocolType = ProtocolType_PT_RAW; + return true; + + default: + *palProtocolType = (int)platformProtocolType; + return false; + } +#endif + case AddressFamily_AF_INET: + switch (platformProtocolType) + { + case 0: + *palProtocolType = ProtocolType_PT_UNSPECIFIED; + return true; + + case IPPROTO_ICMP: + *palProtocolType = ProtocolType_PT_ICMP; + return true; + + case IPPROTO_TCP: + *palProtocolType = ProtocolType_PT_TCP; + return true; + + case IPPROTO_UDP: + *palProtocolType = ProtocolType_PT_UDP; + return true; + + case IPPROTO_IGMP: + *palProtocolType = ProtocolType_PT_IGMP; + return true; + + case IPPROTO_RAW: + *palProtocolType = ProtocolType_PT_RAW; + return true; + + default: + *palProtocolType = (int)palProtocolType; + return false; + } + + case AddressFamily_AF_INET6: + switch (platformProtocolType) + { + case 0: + *palProtocolType = ProtocolType_PT_UNSPECIFIED; + return true; + + case IPPROTO_ICMPV6: + *palProtocolType = ProtocolType_PT_ICMPV6; + return true; + + case IPPROTO_TCP: + *palProtocolType = ProtocolType_PT_TCP; + return true; + + case IPPROTO_UDP: + *palProtocolType = ProtocolType_PT_UDP; + return true; + + case IPPROTO_IGMP: + *palProtocolType = ProtocolType_PT_IGMP; + return true; + + case IPPROTO_RAW: + *palProtocolType = ProtocolType_PT_RAW; + return true; + + case IPPROTO_DSTOPTS: + *palProtocolType = ProtocolType_PT_DSTOPTS; + return true; + + case IPPROTO_NONE: + *palProtocolType = ProtocolType_PT_NONE; + return true; + + case IPPROTO_ROUTING: + *palProtocolType = ProtocolType_PT_ROUTING; + return true; + + case IPPROTO_FRAGMENT: + *palProtocolType = ProtocolType_PT_FRAGMENT; + return true; + + default: + *palProtocolType = (int)platformProtocolType; + return false; + } + + default: + switch (platformProtocolType) + { + case 0: + *palProtocolType = ProtocolType_PT_UNSPECIFIED; + return true; + default: + *palProtocolType = (int)platformProtocolType; + return false; + } + } +} + int32_t SystemNative_Socket(int32_t addressFamily, int32_t socketType, int32_t protocolType, intptr_t* createdSocket) { if (createdSocket == NULL) @@ -2297,6 +2418,48 @@ int32_t SystemNative_Socket(int32_t addressFamily, int32_t socketType, int32_t p return Error_SUCCESS; } +int32_t SystemNative_GetSocketType(intptr_t socket, int32_t* addressFamily, int32_t* socketType, int32_t* protocolType) +{ + if (addressFamily == NULL || socketType == NULL || protocolType == NULL) + { + return Error_EFAULT; + } + + int fd = ToFileDescriptor(socket); + +#ifdef SO_DOMAIN + int domainValue; + socklen_t domainLength = sizeof(int); + if (getsockopt(fd, SOL_SOCKET, SO_DOMAIN, &domainValue, &domainLength) != 0 || + !TryConvertAddressFamilyPlatformToPal((sa_family_t)domainValue, addressFamily)) +#endif + { + *addressFamily = AddressFamily_AF_UNKNOWN; + } + +#ifdef SO_TYPE + int typeValue; + socklen_t typeLength = sizeof(int); + if (getsockopt(fd, SOL_SOCKET, SO_TYPE, &typeValue, &typeLength) != 0 || + !TryConvertSocketTypePlatformToPal(typeValue, socketType)) +#endif + { + *socketType = SocketType_UNKNOWN; + } + +#ifdef SO_PROTOCOL + int protocolValue; + socklen_t protocolLength = sizeof(int); + if (getsockopt(fd, SOL_SOCKET, SO_PROTOCOL, &protocolValue, &protocolLength) != 0 || + !TryConvertProtocolTypePlatformToPal(*addressFamily, protocolValue, protocolType)) +#endif + { + *protocolType = ProtocolType_PT_UNKNOWN; + } + + return Error_SUCCESS; +} + int32_t SystemNative_GetAtOutOfBandMark(intptr_t socket, int32_t* atMark) { if (atMark == NULL) diff --git a/src/libraries/Native/Unix/System.Native/pal_networking.h b/src/libraries/Native/Unix/System.Native/pal_networking.h index dbc4ebc211066d..3dbc08932b9208 100644 --- a/src/libraries/Native/Unix/System.Native/pal_networking.h +++ b/src/libraries/Native/Unix/System.Native/pal_networking.h @@ -59,6 +59,7 @@ typedef enum */ typedef enum { + AddressFamily_AF_UNKNOWN = -1, // System.Net.AddressFamily.Unknown AddressFamily_AF_UNSPEC = 0, // System.Net.AddressFamily.Unspecified AddressFamily_AF_UNIX = 1, // System.Net.AddressFamily.Unix AddressFamily_AF_INET = 2, // System.Net.AddressFamily.InterNetwork @@ -74,6 +75,7 @@ typedef enum */ typedef enum { + SocketType_UNKNOWN = -1, // System.Net.SocketType.Unknown SocketType_SOCK_STREAM = 1, // System.Net.SocketType.Stream SocketType_SOCK_DGRAM = 2, // System.Net.SocketType.Dgram SocketType_SOCK_RAW = 3, // System.Net.SocketType.Raw @@ -88,6 +90,7 @@ typedef enum */ typedef enum { + ProtocolType_PT_UNKNOWN = -1, // System.Net.ProtocolType.Unknown ProtocolType_PT_UNSPECIFIED = 0, // System.Net.ProtocolType.Unspecified ProtocolType_PT_ICMP = 1, // System.Net.ProtocolType.Icmp ProtocolType_PT_TCP = 6, // System.Net.ProtocolType.Tcp @@ -402,6 +405,8 @@ PALEXPORT int32_t SystemNative_SetRawSockOpt( PALEXPORT int32_t SystemNative_Socket(int32_t addressFamily, int32_t socketType, int32_t protocolType, intptr_t* createdSocket); +PALEXPORT int32_t SystemNative_GetSocketType(intptr_t socket, int32_t* addressFamily, int32_t* socketType, int32_t* protocolType); + PALEXPORT int32_t SystemNative_GetAtOutOfBandMark(intptr_t socket, int32_t* available); PALEXPORT int32_t SystemNative_GetBytesAvailable(intptr_t socket, int32_t* available); diff --git a/src/libraries/System.Net.Sockets/ref/System.Net.Sockets.cs b/src/libraries/System.Net.Sockets/ref/System.Net.Sockets.cs index e72a4e22755b5a..629e16cabb2cb3 100644 --- a/src/libraries/System.Net.Sockets/ref/System.Net.Sockets.cs +++ b/src/libraries/System.Net.Sockets/ref/System.Net.Sockets.cs @@ -222,6 +222,7 @@ public SendPacketsElement(string filepath, long offset, int count, bool endOfPac } public partial class Socket : System.IDisposable { + public Socket(System.Net.Sockets.SafeSocketHandle handle) { } public Socket(System.Net.Sockets.AddressFamily addressFamily, System.Net.Sockets.SocketType socketType, System.Net.Sockets.ProtocolType protocolType) { } public Socket(System.Net.Sockets.SocketInformation socketInformation) { } public Socket(System.Net.Sockets.SocketType socketType, System.Net.Sockets.ProtocolType protocolType) { } diff --git a/src/libraries/System.Net.Sockets/src/Resources/Strings.resx b/src/libraries/System.Net.Sockets/src/Resources/Strings.resx index 3b274c84c10541..3078fabf5d5fd6 100644 --- a/src/libraries/System.Net.Sockets/src/Resources/Strings.resx +++ b/src/libraries/System.Net.Sockets/src/Resources/Strings.resx @@ -57,6 +57,9 @@ System.Resources.ResXResourceWriter, System.Windows.Forms, Version=4.0.0.0, Culture=neutral, PublicKeyToken=b77a5c561934e089 + + Invalid handle. + This protocol version is not supported. diff --git a/src/libraries/System.Net.Sockets/src/System.Net.Sockets.csproj b/src/libraries/System.Net.Sockets/src/System.Net.Sockets.csproj index 89b10587443878..3d74a783912353 100644 --- a/src/libraries/System.Net.Sockets/src/System.Net.Sockets.csproj +++ b/src/libraries/System.Net.Sockets/src/System.Net.Sockets.csproj @@ -234,6 +234,9 @@ Common\Interop\Windows\WinSock\WSABuffer.cs + + Common\Interop\Windows\WinSock\Interop.WSAPROTOCOL_INFOW.cs + Common\Interop\Windows\Interop.CancelIoEx.cs @@ -315,6 +318,9 @@ Common\Interop\Unix\System.Native\Interop.GetSocketErrorOption.cs + + Common\Interop\Unix\System.Native\Interop.GetSocketType.cs + Common\Interop\Unix\System.Native\Interop.GetSockName.cs @@ -339,6 +345,9 @@ Common\Interop\Unix\System.Native\Interop.SetReceiveTimeout.cs + + Common\Interop\Unix\Interop.Stat.cs + Common\Interop\Unix\System.Native\Interop.Listen.cs diff --git a/src/libraries/System.Net.Sockets/src/System/Net/Sockets/SafeSocketHandle.cs b/src/libraries/System.Net.Sockets/src/System/Net/Sockets/SafeSocketHandle.cs index e39df10ba1b6b3..dff1d2739b2c8c 100644 --- a/src/libraries/System.Net.Sockets/src/System/Net/Sockets/SafeSocketHandle.cs +++ b/src/libraries/System.Net.Sockets/src/System/Net/Sockets/SafeSocketHandle.cs @@ -35,14 +35,17 @@ public sealed partial class SafeSocketHandle : SafeHandleMinusOneIsInvalid public SafeSocketHandle(IntPtr preexistingHandle, bool ownsHandle) : base(ownsHandle) { + OwnsHandle = ownsHandle; SetHandleAndValid(preexistingHandle); } - private SafeSocketHandle() : base(true) { } + private SafeSocketHandle() : base(ownsHandle: true) => OwnsHandle = true; + + internal bool OwnsHandle { get; } private bool TryOwnClose() { - return Interlocked.CompareExchange(ref _ownClose, 1, 0) == 0; + return OwnsHandle && Interlocked.CompareExchange(ref _ownClose, 1, 0) == 0; } private volatile bool _released; diff --git a/src/libraries/System.Net.Sockets/src/System/Net/Sockets/Socket.Unix.cs b/src/libraries/System.Net.Sockets/src/System/Net/Sockets/Socket.Unix.cs index 804b9d35343e33..d5092637f52447 100644 --- a/src/libraries/System.Net.Sockets/src/System/Net/Sockets/Socket.Unix.cs +++ b/src/libraries/System.Net.Sockets/src/System/Net/Sockets/Socket.Unix.cs @@ -47,6 +47,36 @@ partial void ValidateForMultiConnect(bool isMultiEndpoint) Debug.Assert(!_handle.LastConnectFailed); } + private static unsafe void LoadSocketTypeFromHandle( + SafeSocketHandle handle, out AddressFamily addressFamily, out SocketType socketType, out ProtocolType protocolType, out bool blocking) + { + // Validate that the supplied handle is indeed a socket. + if (Interop.Sys.FStat(handle, out Interop.Sys.FileStatus stat) == -1 || + (stat.Mode & Interop.Sys.FileTypes.S_IFSOCK) != Interop.Sys.FileTypes.S_IFSOCK) + { + throw new SocketException((int)SocketError.NotSocket); + } + + // On Linux, GetSocketType will be able to query SO_DOMAIN, SO_TYPE, and SO_PROTOCOL to get the + // address family, socket type, and protocol type, respectively. On macOS, this will only succeed + // in getting the socket type, and the others will be unknown. Subsequently the Socket ctor + // can use getsockname to retrieve the address family as part of trying to get the local end point. + Interop.Error e = Interop.Sys.GetSocketType(handle, out addressFamily, out socketType, out protocolType); + Debug.Assert(e == Interop.Error.SUCCESS, e.ToString()); + + // Get whether the socket is in non-blocking mode. On Unix, we automatically put the underlying + // Socket into non-blocking mode whenever an async method is first invoked on the instance, but we + // maintain a shadow bool that maintains the Socket.Blocking value set by the developer. Because + // we're querying the underlying socket here, and don't have access to the original Socket instance + // (if there even was one... the Socket(SafeSocketHandle) ctor is likely being used because there + // wasn't one, Socket.Blocking will end up reflecting the actual state of the socket even if the + // developer didn't set Blocking = false. + bool nonBlocking; + int rv = Interop.Sys.Fcntl.GetIsNonBlocking(handle, out nonBlocking); + blocking = !nonBlocking; + Debug.Assert(rv == 0 || blocking, e.ToString()); // ignore failures + } + internal void ReplaceHandleIfNecessaryAfterFailedConnect() { if (!_handle.LastConnectFailed) diff --git a/src/libraries/System.Net.Sockets/src/System/Net/Sockets/Socket.Windows.cs b/src/libraries/System.Net.Sockets/src/System/Net/Sockets/Socket.Windows.cs index d6060321e5d273..2034cab6efb063 100644 --- a/src/libraries/System.Net.Sockets/src/System/Net/Sockets/Socket.Windows.cs +++ b/src/libraries/System.Net.Sockets/src/System/Net/Sockets/Socket.Windows.cs @@ -7,6 +7,7 @@ using System.Diagnostics; using System.IO; using System.Runtime.CompilerServices; +using System.Runtime.InteropServices; using System.Threading; namespace System.Net.Sockets @@ -56,7 +57,14 @@ public Socket(SocketInformation socketInformation) IPEndPoint ep = new IPEndPoint(tempAddress, 0); Internals.SocketAddress socketAddress = IPEndPointExtensions.Serialize(ep); - errorCode = SocketPal.GetSockName(_handle, socketAddress.Buffer, ref socketAddress.InternalSize); + unsafe + { + fixed (byte* bufferPtr = socketAddress.Buffer) + fixed (int* sizePtr = &socketAddress.InternalSize) + { + errorCode = SocketPal.GetSockName(_handle, bufferPtr, sizePtr); + } + } if (errorCode == SocketError.Success) { @@ -76,6 +84,28 @@ public Socket(SocketInformation socketInformation) if (NetEventSource.IsEnabled) NetEventSource.Exit(this); } + private unsafe void LoadSocketTypeFromHandle( + SafeSocketHandle handle, out AddressFamily addressFamily, out SocketType socketType, out ProtocolType protocolType, out bool blocking) + { + Interop.Winsock.WSAPROTOCOL_INFOW info = default; + int optionLength = sizeof(Interop.Winsock.WSAPROTOCOL_INFOW); + + // Get the address family, socket type, and protocol type from the socket. + if (Interop.Winsock.getsockopt(handle, SocketOptionLevel.Socket, (SocketOptionName)Interop.Winsock.SO_PROTOCOL_INFOW, (byte*)&info, ref optionLength) == SocketError.SocketError) + { + throw new SocketException((int)SocketPal.GetLastSocketError()); + } + + addressFamily = info.iAddressFamily; + socketType = info.iSocketType; + protocolType = info.iProtocol; + + // There's no API to retrieve this (WSAIsBlocking isn't supported any more). Assume it's blocking, but we might be wrong. + // This affects the result of querying Socket.Blocking, which will mostly only affect user code that happens to query + // that property, though there are a few places we check it internally, e.g. as part of NetworkStream argument validation. + blocking = true; + } + public SocketInformation DuplicateAndClose(int targetProcessId) { if (NetEventSource.IsEnabled) NetEventSource.Enter(this, targetProcessId); diff --git a/src/libraries/System.Net.Sockets/src/System/Net/Sockets/Socket.cs b/src/libraries/System.Net.Sockets/src/System/Net/Sockets/Socket.cs index 5dc340ca928110..925e4e89c0f06f 100644 --- a/src/libraries/System.Net.Sockets/src/System/Net/Sockets/Socket.cs +++ b/src/libraries/System.Net.Sockets/src/System/Net/Sockets/Socket.cs @@ -107,23 +107,146 @@ public Socket(AddressFamily addressFamily, SocketType socketType, ProtocolType p if (NetEventSource.IsEnabled) NetEventSource.Exit(this); } - // Called by the class to create a socket to accept an incoming request. - private Socket(SafeSocketHandle fd) + /// Initializes a new instance of the class for the specified socket handle. + /// The socket handle for the socket that the object will encapsulate. + /// is null. + /// is invalid. + /// is not a socket or information about the socket could not be accessed. + /// + /// This method populates the instance with data gathered from the supplied . + /// Different operating systems provide varying levels of support for querying a socket handle or file descriptor for its + /// properties and configuration, which means some of the public APIs on the resulting instance may + /// differ based on operating system, such as and . + /// + public Socket(SafeSocketHandle handle) : + this(ValidateHandle(handle), loadPropertiesFromHandle: true) { - // NOTE: If this ctor is ever made public/protected, this check will need - // to be converted into a runtime exception. - Debug.Assert(fd != null && !fd.IsInvalid); + } - if (NetEventSource.IsEnabled) NetEventSource.Enter(this); + private unsafe Socket(SafeSocketHandle handle, bool loadPropertiesFromHandle) + { InitializeSockets(); - _handle = fd; + _handle = handle; + _addressFamily = AddressFamily.Unknown; + _socketType = SocketType.Unknown; + _protocolType = ProtocolType.Unknown; - _addressFamily = Sockets.AddressFamily.Unknown; - _socketType = Sockets.SocketType.Unknown; - _protocolType = Sockets.ProtocolType.Unknown; - if (NetEventSource.IsEnabled) NetEventSource.Exit(this); + if (!loadPropertiesFromHandle) + { + return; + } + + try + { + // Get properties like address family and blocking mode from the OS. + LoadSocketTypeFromHandle(handle, out _addressFamily, out _socketType, out _protocolType, out _willBlockInternal); + + // Determine whether the socket is in listening mode. + _isListening = + SocketPal.GetSockOpt(_handle, SocketOptionLevel.Socket, SocketOptionName.AcceptConnection, out int isListening) == SocketError.Success && + isListening != 0; + + // Try to get the address of the socket. + Span buffer = stackalloc byte[512]; // arbitrary high limit that should suffice for almost all scenarios + int bufferLength = buffer.Length; + fixed (byte* bufferPtr = buffer) + { + if (SocketPal.GetSockName(handle, bufferPtr, &bufferLength) != SocketError.Success) + { + return; + } + } + + if (bufferLength > buffer.Length) + { + buffer = new byte[buffer.Length]; + fixed (byte* bufferPtr = buffer) + { + if (SocketPal.GetSockName(handle, bufferPtr, &bufferLength) != SocketError.Success || + bufferLength > buffer.Length) + { + return; + } + } + } + + buffer = buffer.Slice(0, bufferLength); + if (_addressFamily == AddressFamily.Unknown) + { + _addressFamily = SocketAddressPal.GetAddressFamily(buffer); + } +#if DEBUG + else + { + Debug.Assert(_addressFamily == SocketAddressPal.GetAddressFamily(buffer)); + } +#endif + + // Try to get the local end point. That will in turn enable the remote + // end point to be retrieved on-demand when the property is accessed. + Internals.SocketAddress? socketAddress = null; + switch (_addressFamily) + { + case AddressFamily.InterNetwork: + _rightEndPoint = new IPEndPoint( + new IPAddress((long)SocketAddressPal.GetIPv4Address(buffer) & 0x0FFFFFFFF), + SocketAddressPal.GetPort(buffer)); + break; + + case AddressFamily.InterNetworkV6: + Span address = stackalloc byte[IPAddressParserStatics.IPv6AddressBytes]; + SocketAddressPal.GetIPv6Address(buffer, address, out uint scope); + _rightEndPoint = new IPEndPoint( + new IPAddress(address, scope), + SocketAddressPal.GetPort(buffer)); + break; + + case AddressFamily.Unix: + socketAddress = new Internals.SocketAddress(_addressFamily, buffer); + _rightEndPoint = new UnixDomainSocketEndPoint(IPEndPointExtensions.GetNetSocketAddress(socketAddress)); + break; + } + + // Try to determine if we're connected, based on querying for a peer, just as we would in RemoteEndPoint, + // but ignoring any failures; this is best-effort (RemoteEndPoint also does a catch-all around the Create call). + if (_rightEndPoint != null) + { + try + { + socketAddress ??= new Internals.SocketAddress(_addressFamily, buffer); + if (SocketPal.GetPeerName(_handle, socketAddress.Buffer, ref socketAddress.InternalSize) != SocketError.Success) + { + return; + } + + if (socketAddress.InternalSize > socketAddress.Buffer.Length) + { + socketAddress.Buffer = new byte[socketAddress.InternalSize]; + if (SocketPal.GetPeerName(_handle, socketAddress.Buffer, ref socketAddress.InternalSize) != SocketError.Success) + { + return; + } + } + + _remoteEndPoint = _rightEndPoint.Create(socketAddress); + _isConnected = true; + } + catch { } + } + } + catch + { + _handle = null!; + GC.SuppressFinalize(this); + throw; + } } + + private static SafeSocketHandle ValidateHandle(SafeSocketHandle handle) => + handle is null ? throw new ArgumentNullException(nameof(handle)) : + handle.IsInvalid ? throw new ArgumentException(SR.Arg_InvalidHandle, nameof(handle)) : + handle; #endregion #region Properties @@ -209,15 +332,18 @@ public EndPoint? LocalEndPoint Internals.SocketAddress socketAddress = IPEndPointExtensions.Serialize(_rightEndPoint); - // This may throw ObjectDisposedException. - SocketError errorCode = SocketPal.GetSockName( - _handle, - socketAddress.Buffer, - ref socketAddress.InternalSize); - - if (errorCode != SocketError.Success) + unsafe { - UpdateStatusAfterSocketErrorAndThrowException(errorCode); + fixed (byte* buffer = socketAddress.Buffer) + fixed (int* bufferSize = &socketAddress.InternalSize) + { + // This may throw ObjectDisposedException. + SocketError errorCode = SocketPal.GetSockName(_handle, buffer, bufferSize); + if (errorCode != SocketError.Success) + { + UpdateStatusAfterSocketErrorAndThrowException(errorCode); + } + } } return _rightEndPoint.Create(socketAddress); @@ -4361,6 +4487,13 @@ protected virtual void Dispose(bool disposing) SetToDisconnected(); + // If the safe handle doesn't own the underlying handle, we're done. + SafeSocketHandle handle = _handle; + if (handle != null && !handle.OwnsHandle) + { + return; + } + // Close the handle in one of several ways depending on the timeout. // Ignore ObjectDisposedException just in case the handle somehow gets disposed elsewhere. try @@ -4961,7 +5094,7 @@ internal Socket CreateAcceptSocket(SafeSocketHandle fd, EndPoint remoteEP) { // Internal state of the socket is inherited from listener. Debug.Assert(fd != null && !fd.IsInvalid); - Socket socket = new Socket(fd); + Socket socket = new Socket(fd, loadPropertiesFromHandle: false); return UpdateAcceptSocket(socket, remoteEP); } diff --git a/src/libraries/System.Net.Sockets/src/System/Net/Sockets/SocketPal.Unix.cs b/src/libraries/System.Net.Sockets/src/System/Net/Sockets/SocketPal.Unix.cs index b012817e5cbd2e..3c43b02163d1eb 100644 --- a/src/libraries/System.Net.Sockets/src/System/Net/Sockets/SocketPal.Unix.cs +++ b/src/libraries/System.Net.Sockets/src/System/Net/Sockets/SocketPal.Unix.cs @@ -837,16 +837,9 @@ public static SocketError SetBlocking(SafeSocketHandle handle, bool shouldBlock, return SocketError.Success; } - public static unsafe SocketError GetSockName(SafeSocketHandle handle, byte[] buffer, ref int nameLen) + public static unsafe SocketError GetSockName(SafeSocketHandle handle, byte* buffer, int* nameLen) { - Interop.Error err; - int addrLen = nameLen; - fixed (byte* rawBuffer = buffer) - { - err = Interop.Sys.GetSockName(handle, rawBuffer, &addrLen); - } - - nameLen = addrLen; + Interop.Error err = Interop.Sys.GetSockName(handle, buffer, nameLen); return err == Interop.Error.SUCCESS ? SocketError.Success : GetSocketErrorForErrorCode(err); } diff --git a/src/libraries/System.Net.Sockets/src/System/Net/Sockets/SocketPal.Windows.cs b/src/libraries/System.Net.Sockets/src/System/Net/Sockets/SocketPal.Windows.cs index db5f880bbe4194..c3e113b0841249 100644 --- a/src/libraries/System.Net.Sockets/src/System/Net/Sockets/SocketPal.Windows.cs +++ b/src/libraries/System.Net.Sockets/src/System/Net/Sockets/SocketPal.Windows.cs @@ -138,9 +138,9 @@ public static SocketError SetBlocking(SafeSocketHandle handle, bool shouldBlock, return errorCode; } - public static SocketError GetSockName(SafeSocketHandle handle, byte[] buffer, ref int nameLen) + public static unsafe SocketError GetSockName(SafeSocketHandle handle, byte* buffer, int* nameLen) { - SocketError errorCode = Interop.Winsock.getsockname(handle, buffer, ref nameLen); + SocketError errorCode = Interop.Winsock.getsockname(handle, buffer, nameLen); return errorCode == SocketError.SocketError ? GetLastSocketError() : SocketError.Success; } diff --git a/src/libraries/System.Net.Sockets/tests/FunctionalTests/CreateSocketTests.cs b/src/libraries/System.Net.Sockets/tests/FunctionalTests/CreateSocketTests.cs index 9fc617fd20bc9e..86e1b2dea2fc6d 100644 --- a/src/libraries/System.Net.Sockets/tests/FunctionalTests/CreateSocketTests.cs +++ b/src/libraries/System.Net.Sockets/tests/FunctionalTests/CreateSocketTests.cs @@ -2,8 +2,10 @@ // The .NET Foundation licenses this file to you under the MIT license. // See the LICENSE file in the project root for more information. +using System.Diagnostics; using System.IO; using System.IO.Pipes; +using System.Runtime.InteropServices; using System.Threading.Tasks; using Microsoft.DotNet.RemoteExecutor; using Xunit; @@ -233,5 +235,236 @@ public void Ctor_Netcoreapp_Success(AddressFamily addressFamily) } s.Close(); } + + [Fact] + public void Ctor_SafeHandle_Invalid_ThrowsException() + { + AssertExtensions.Throws("handle", () => new Socket(null)); + AssertExtensions.Throws("handle", () => new Socket(new SafeSocketHandle((IntPtr)(-1), false))); + + using (var pipe = new AnonymousPipeServerStream()) + { + SocketException se = Assert.Throws(() => new Socket(new SafeSocketHandle(pipe.ClientSafePipeHandle.DangerousGetHandle(), false))); + Assert.Equal(SocketError.NotSocket, se.SocketErrorCode); + } + } + + [Theory] + [InlineData(AddressFamily.ControllerAreaNetwork, SocketType.Raw, ProtocolType.Unspecified)] + [InlineData(AddressFamily.InterNetwork, SocketType.Dgram, ProtocolType.Udp)] + [InlineData(AddressFamily.InterNetwork, SocketType.Stream, ProtocolType.Tcp)] + [InlineData(AddressFamily.InterNetwork, SocketType.Raw, ProtocolType.Unspecified)] + [InlineData(AddressFamily.InterNetworkV6, SocketType.Dgram, ProtocolType.Udp)] + [InlineData(AddressFamily.InterNetworkV6, SocketType.Stream, ProtocolType.Tcp)] + [InlineData(AddressFamily.InterNetworkV6, SocketType.Raw, ProtocolType.Unspecified)] + [InlineData(AddressFamily.Packet, SocketType.Raw, ProtocolType.Raw)] + [InlineData(AddressFamily.Unix, SocketType.Stream, ProtocolType.Unspecified)] + public void Ctor_SafeHandle_BasicPropertiesPropagate_Success(AddressFamily addressFamily, SocketType socketType, ProtocolType protocolType) + { + Socket tmpOrig; + try + { + tmpOrig = new Socket(addressFamily, socketType, protocolType); + } + catch (SocketException e) when ( + e.SocketErrorCode == SocketError.AccessDenied || + e.SocketErrorCode == SocketError.ProtocolNotSupported || + e.SocketErrorCode == SocketError.AddressFamilyNotSupported) + { + // We can't test this combination on this platform. + return; + } + + using Socket orig = tmpOrig; + using var copy = new Socket(orig.SafeHandle); + + Assert.False(orig.Connected); + Assert.False(copy.Connected); + + Assert.Null(orig.LocalEndPoint); + Assert.Null(orig.RemoteEndPoint); + Assert.False(orig.IsBound); + if (copy.IsBound) + { + // On Unix, we may successfully obtain an (empty) local end point, even though Bind wasn't called. + Debug.Assert(!RuntimeInformation.IsOSPlatform(OSPlatform.Windows)); + if (RuntimeInformation.IsOSPlatform(OSPlatform.Linux)) // OSX gets some strange results in some cases, e.g. "@\0\0\0\0\0\0\0\0\0\0\0\0\0" for a UDS + { + switch (addressFamily) + { + case AddressFamily.InterNetwork: + Assert.Equal(new IPEndPoint(IPAddress.Any, 0), copy.LocalEndPoint); + break; + + case AddressFamily.InterNetworkV6: + Assert.Equal(new IPEndPoint(IPAddress.IPv6Any, 0), copy.LocalEndPoint); + break; + + case AddressFamily.Unix: + Assert.IsType(copy.LocalEndPoint); + Assert.Equal("", copy.LocalEndPoint.ToString()); + break; + + default: + Assert.Null(copy.LocalEndPoint); + break; + } + } + } + else + { + Assert.Equal(orig.LocalEndPoint, copy.LocalEndPoint); + Assert.Equal(orig.LocalEndPoint, copy.RemoteEndPoint); + } + + Assert.Equal(addressFamily, orig.AddressFamily); + Assert.Equal(socketType, orig.SocketType); + Assert.Equal(protocolType, orig.ProtocolType); + + Assert.Equal(addressFamily, copy.AddressFamily); + Assert.Equal(socketType, copy.SocketType); + Assert.True(copy.ProtocolType == orig.ProtocolType || copy.ProtocolType == ProtocolType.Unknown, $"Expected: {protocolType} or Unknown, Actual: {copy.ProtocolType}"); + + Assert.True(orig.Blocking); + Assert.True(copy.Blocking); + + if (orig.AddressFamily == copy.AddressFamily) + { + AssertEqualOrSameException(() => orig.DontFragment, () => copy.DontFragment); + AssertEqualOrSameException(() => orig.MulticastLoopback, () => copy.MulticastLoopback); + AssertEqualOrSameException(() => orig.Ttl, () => copy.Ttl); + } + + AssertEqualOrSameException(() => orig.EnableBroadcast, () => copy.EnableBroadcast); + AssertEqualOrSameException(() => orig.LingerState.Enabled, () => copy.LingerState.Enabled); + AssertEqualOrSameException(() => orig.LingerState.LingerTime, () => copy.LingerState.LingerTime); + AssertEqualOrSameException(() => orig.NoDelay, () => copy.NoDelay); + + Assert.Equal(orig.Available, copy.Available); + Assert.Equal(orig.ExclusiveAddressUse, copy.ExclusiveAddressUse); + Assert.Equal(orig.Handle, copy.Handle); + Assert.Equal(orig.ReceiveBufferSize, copy.ReceiveBufferSize); + Assert.Equal(orig.ReceiveTimeout, copy.ReceiveTimeout); + Assert.Equal(orig.SendBufferSize, copy.SendBufferSize); + Assert.Equal(orig.SendTimeout, copy.SendTimeout); + Assert.Equal(orig.UseOnlyOverlappedIO, copy.UseOnlyOverlappedIO); + } + + [Theory] + [InlineData(AddressFamily.InterNetwork, SocketType.Stream, ProtocolType.Tcp)] + [InlineData(AddressFamily.InterNetworkV6, SocketType.Stream, ProtocolType.Tcp)] + public async Task Ctor_SafeHandle_Tcp_SendReceive_Success(AddressFamily addressFamily, SocketType socketType, ProtocolType protocolType) + { + using var orig = new Socket(addressFamily, socketType, protocolType); + using var listener = new Socket(addressFamily, socketType, protocolType); + listener.Bind(new IPEndPoint(addressFamily == AddressFamily.InterNetwork ? IPAddress.Loopback : IPAddress.IPv6Loopback, 0)); + listener.Listen(1); + await orig.ConnectAsync(listener.LocalEndPoint); + using var server = await listener.AcceptAsync(); + + using var client = new Socket(orig.SafeHandle); + + Assert.True(client.Connected); + Assert.Equal(orig.AddressFamily, client.AddressFamily); + Assert.Equal(orig.SocketType, client.SocketType); + Assert.True(client.ProtocolType == orig.ProtocolType || client.ProtocolType == ProtocolType.Unknown, $"Expected: {protocolType} or Unknown, Actual: {client.ProtocolType}"); + + // Validate accessing end points + Assert.Equal(orig.LocalEndPoint, client.LocalEndPoint); + Assert.Equal(orig.RemoteEndPoint, client.RemoteEndPoint); + + // Validating accessing other properties + Assert.Equal(orig.Available, client.Available); + Assert.True(orig.Blocking); + Assert.True(client.Blocking); + AssertEqualOrSameException(() => orig.DontFragment, () => client.DontFragment); + AssertEqualOrSameException(() => orig.EnableBroadcast, () => client.EnableBroadcast); + Assert.Equal(orig.ExclusiveAddressUse, client.ExclusiveAddressUse); + Assert.Equal(orig.Handle, client.Handle); + Assert.Equal(orig.IsBound, client.IsBound); + Assert.Equal(orig.LingerState.Enabled, client.LingerState.Enabled); + Assert.Equal(orig.LingerState.LingerTime, client.LingerState.LingerTime); + AssertEqualOrSameException(() => orig.MulticastLoopback, () => client.MulticastLoopback); + Assert.Equal(orig.NoDelay, client.NoDelay); + Assert.Equal(orig.ReceiveBufferSize, client.ReceiveBufferSize); + Assert.Equal(orig.ReceiveTimeout, client.ReceiveTimeout); + Assert.Equal(orig.SendBufferSize, client.SendBufferSize); + Assert.Equal(orig.SendTimeout, client.SendTimeout); + Assert.Equal(orig.Ttl, client.Ttl); + Assert.Equal(orig.UseOnlyOverlappedIO, client.UseOnlyOverlappedIO); + + // Validate setting various properties on the new instance and seeing them roundtrip back to the original. + client.ReceiveTimeout = 42; + Assert.Equal(client.ReceiveTimeout, orig.ReceiveTimeout); + + // Validate sending and receiving + Assert.Equal(1, await client.SendAsync(new byte[1] { 42 }, SocketFlags.None)); + var buffer = new byte[1]; + Assert.Equal(1, await server.ReceiveAsync(buffer, SocketFlags.None)); + Assert.Equal(42, buffer[0]); + + Assert.Equal(1, await server.SendAsync(new byte[1] { 42 }, SocketFlags.None)); + buffer[0] = 0; + Assert.Equal(1, await client.ReceiveAsync(buffer, SocketFlags.None)); + Assert.Equal(42, buffer[0]); + } + + [PlatformSpecific(TestPlatforms.Windows | TestPlatforms.Linux)] // OSX/FreeBSD doesn't support SO_ACCEPTCONN, so we can't query for whether a socket is listening + [Theory] + [InlineData(false)] + [InlineData(true)] + public async Task Ctor_SafeHandle_Listening_Success(bool shareSafeHandle) + { + using var listener = new Socket(AddressFamily.InterNetwork, SocketType.Stream, ProtocolType.Tcp); + listener.Bind(new IPEndPoint(IPAddress.Loopback, 0)); + listener.Listen(); + Assert.Equal(1, listener.GetSocketOption(SocketOptionLevel.Socket, SocketOptionName.AcceptConnection)); + + using var listenerCopy = new Socket(shareSafeHandle ? listener.SafeHandle : new SafeSocketHandle(listener.Handle, ownsHandle: false)); + Assert.Equal(1, listenerCopy.GetSocketOption(SocketOptionLevel.Socket, SocketOptionName.AcceptConnection)); + + Assert.Equal(listener.AddressFamily, listenerCopy.AddressFamily); + Assert.Equal(listener.Handle, listenerCopy.Handle); + Assert.Equal(listener.IsBound, listenerCopy.IsBound); + Assert.Equal(listener.LocalEndPoint, listener.LocalEndPoint); + Assert.True(listenerCopy.ProtocolType == listener.ProtocolType || listenerCopy.ProtocolType == ProtocolType.Unknown, $"Expected: {listener.ProtocolType} or Unknown, Actual: {listenerCopy.ProtocolType}"); + Assert.Equal(listener.SocketType, listenerCopy.SocketType); + + foreach (Socket listenerSocket in new[] { listener, listenerCopy }) + { + using (var client1 = new Socket(AddressFamily.InterNetwork, SocketType.Stream, ProtocolType.Tcp)) + { + Task connect1 = client1.ConnectAsync(listenerSocket.LocalEndPoint); + using (Socket server1 = listenerSocket.Accept()) + { + await connect1; + server1.Send(new byte[] { 42 }); + Assert.Equal(1, client1.Receive(new byte[1])); + } + } + } + } + + private static void AssertEqualOrSameException(Func expected, Func actual) + { + T r1 = default, r2 = default; + Exception e1 = null, e2 = null; + + try { r1 = expected(); } + catch (Exception e) { e1 = e; }; + + try { r2 = actual(); } + catch (Exception e) { e2 = e; }; + + Assert.Equal(e1 is null, e2 is null); + if (e1 is null) + { + Assert.Equal(r1, r2); + } + else + { + Assert.Equal(e1.GetType(), e2.GetType()); + } + } } } diff --git a/src/libraries/System.Net.Sockets/tests/FunctionalTests/SendReceive.cs b/src/libraries/System.Net.Sockets/tests/FunctionalTests/SendReceive.cs index 3725c4ce24692b..68fcd850ce537e 100644 --- a/src/libraries/System.Net.Sockets/tests/FunctionalTests/SendReceive.cs +++ b/src/libraries/System.Net.Sockets/tests/FunctionalTests/SendReceive.cs @@ -44,11 +44,16 @@ public async Task InvalidArguments_Throws(int? length, int offset, int count) } } + public static IEnumerable LoopbackWithBool => + from addr in Loopbacks + from b in new[] { false, true } + select new object[] { addr[0], b }; + [ActiveIssue("https://github.com/dotnet/runtime/issues/1712")] [OuterLoop] [Theory] - [MemberData(nameof(Loopbacks))] - public async Task SendToRecvFrom_Datagram_UDP(IPAddress loopbackAddress) + [MemberData(nameof(LoopbackWithBool))] + public async Task SendToRecvFrom_Datagram_UDP(IPAddress loopbackAddress, bool useClone) { IPAddress leftAddress = loopbackAddress, rightAddress = loopbackAddress; @@ -57,11 +62,13 @@ public async Task SendToRecvFrom_Datagram_UDP(IPAddress loopbackAddress) const int AckTimeout = 10000; const int TestTimeout = 30000; - var left = new Socket(leftAddress.AddressFamily, SocketType.Dgram, ProtocolType.Udp); - left.BindToAnonymousPort(leftAddress); + using var origLeft = new Socket(leftAddress.AddressFamily, SocketType.Dgram, ProtocolType.Udp); + using var origRight = new Socket(rightAddress.AddressFamily, SocketType.Dgram, ProtocolType.Udp); + origLeft.BindToAnonymousPort(leftAddress); + origRight.BindToAnonymousPort(rightAddress); - var right = new Socket(rightAddress.AddressFamily, SocketType.Dgram, ProtocolType.Udp); - right.BindToAnonymousPort(rightAddress); + using var left = useClone ? new Socket(origLeft.SafeHandle) : origLeft; + using var right = useClone ? new Socket(origRight.SafeHandle) : origRight; var leftEndpoint = (IPEndPoint)left.LocalEndPoint; var rightEndpoint = (IPEndPoint)right.LocalEndPoint; @@ -74,25 +81,22 @@ public async Task SendToRecvFrom_Datagram_UDP(IPAddress loopbackAddress) var receivedChecksums = new uint?[DatagramsToSend]; Task leftThread = Task.Run(async () => { - using (left) + EndPoint remote = leftEndpoint.Create(leftEndpoint.Serialize()); + var recvBuffer = new byte[DatagramSize]; + for (int i = 0; i < DatagramsToSend; i++) { - EndPoint remote = leftEndpoint.Create(leftEndpoint.Serialize()); - var recvBuffer = new byte[DatagramSize]; - for (int i = 0; i < DatagramsToSend; i++) - { - SocketReceiveFromResult result = await ReceiveFromAsync( - left, new ArraySegment(recvBuffer), remote); - Assert.Equal(DatagramSize, result.ReceivedBytes); - Assert.Equal(rightEndpoint, result.RemoteEndPoint); - - int datagramId = recvBuffer[0]; - Assert.Null(receivedChecksums[datagramId]); - receivedChecksums[datagramId] = Fletcher32.Checksum(recvBuffer, 0, result.ReceivedBytes); - - receiverAck.Release(); - bool gotAck = await senderAck.WaitAsync(TestTimeout); - Assert.True(gotAck, $"{DateTime.Now}: Timeout waiting {TestTimeout} for senderAck in iteration {i}"); - } + SocketReceiveFromResult result = await ReceiveFromAsync( + left, new ArraySegment(recvBuffer), remote); + Assert.Equal(DatagramSize, result.ReceivedBytes); + Assert.Equal(rightEndpoint, result.RemoteEndPoint); + + int datagramId = recvBuffer[0]; + Assert.Null(receivedChecksums[datagramId]); + receivedChecksums[datagramId] = Fletcher32.Checksum(recvBuffer, 0, result.ReceivedBytes); + + receiverAck.Release(); + bool gotAck = await senderAck.WaitAsync(TestTimeout); + Assert.True(gotAck, $"{DateTime.Now}: Timeout waiting {TestTimeout} for senderAck in iteration {i}"); } }); diff --git a/src/libraries/System.Net.Sockets/tests/FunctionalTests/UnixDomainSocketTest.cs b/src/libraries/System.Net.Sockets/tests/FunctionalTests/UnixDomainSocketTest.cs index ddead9a754b4a6..3e5e7d9d64e417 100644 --- a/src/libraries/System.Net.Sockets/tests/FunctionalTests/UnixDomainSocketTest.cs +++ b/src/libraries/System.Net.Sockets/tests/FunctionalTests/UnixDomainSocketTest.cs @@ -147,6 +147,51 @@ public void Socket_SendReceive_Success() } } + [ConditionalFact(nameof(PlatformSupportsUnixDomainSockets))] + public void Socket_SendReceive_Clone_Success() + { + string path = GetRandomNonExistingFilePath(); + var endPoint = new UnixDomainSocketEndPoint(path); + try + { + using var server = new Socket(AddressFamily.Unix, SocketType.Stream, ProtocolType.Unspecified); + using var client = new Socket(AddressFamily.Unix, SocketType.Stream, ProtocolType.Unspecified); + { + server.Bind(endPoint); + server.Listen(1); + client.Connect(endPoint); + + using (Socket accepted = server.Accept()) + { + using var clientClone = new Socket(client.SafeHandle); + using var acceptedClone = new Socket(accepted.SafeHandle); + + Assert.Equal(client.LocalEndPoint.ToString(), clientClone.LocalEndPoint.ToString()); + Assert.Equal(client.RemoteEndPoint.ToString(), clientClone.RemoteEndPoint.ToString()); + Assert.Equal(accepted.LocalEndPoint.ToString(), acceptedClone.LocalEndPoint.ToString()); + Assert.Equal(accepted.RemoteEndPoint.ToString(), acceptedClone.RemoteEndPoint.ToString()); + + var data = new byte[1]; + for (int i = 0; i < 10; i++) + { + data[0] = (byte)i; + + acceptedClone.Send(data); + data[0] = 0; + + Assert.Equal(1, clientClone.Receive(data)); + Assert.Equal(i, data[0]); + } + } + } + } + finally + { + try { File.Delete(path); } + catch { } + } + } + [ConditionalFact(nameof(PlatformSupportsUnixDomainSockets))] public async Task Socket_SendReceiveAsync_Success() {