Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions src/libraries/System.Net.Sockets/ref/System.Net.Sockets.cs
Original file line number Diff line number Diff line change
Expand Up @@ -338,6 +338,7 @@ public void Connect(string host, int port) { }
public static bool ConnectAsync(System.Net.Sockets.SocketType socketType, System.Net.Sockets.ProtocolType protocolType, System.Net.Sockets.SocketAsyncEventArgs e) { throw null; }
public void Disconnect(bool reuseSocket) { }
public bool DisconnectAsync(System.Net.Sockets.SocketAsyncEventArgs e) { throw null; }
public System.Threading.Tasks.ValueTask DisconnectAsync(bool reuseSocket, System.Threading.CancellationToken cancellationToken = default) { throw null; }
public void Dispose() { }
protected virtual void Dispose(bool disposing) { }
[System.Runtime.Versioning.SupportedOSPlatformAttribute("windows")]
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,6 @@
<Compile Include="System\Net\Sockets\UdpReceiveResult.cs" />
<Compile Include="System\Net\Sockets\AcceptOverlappedAsyncResult.cs" />
<Compile Include="System\Net\Sockets\BaseOverlappedAsyncResult.cs" />
<Compile Include="System\Net\Sockets\DisconnectOverlappedAsyncResult.cs" />
<Compile Include="System\Net\Sockets\UnixDomainSocketEndPoint.cs" />
<!-- Logging -->
<Compile Include="$(CommonPath)System\Net\Logging\NetEventSource.Common.cs"
Expand Down Expand Up @@ -187,7 +186,6 @@
<ItemGroup Condition="'$(TargetsUnix)' == 'true'">
<Compile Include="System\Net\Sockets\AcceptOverlappedAsyncResult.Unix.cs" />
<Compile Include="System\Net\Sockets\BaseOverlappedAsyncResult.Unix.cs" />
<Compile Include="System\Net\Sockets\DisconnectOverlappedAsyncResult.Unix.cs" />
<Compile Include="System\Net\Sockets\SafeSocketHandle.Unix.cs" />
<Compile Include="System\Net\Sockets\Socket.Unix.cs" />
<Compile Include="System\Net\Sockets\SocketAsyncContext.Unix.cs" />
Expand Down

This file was deleted.

This file was deleted.

Original file line number Diff line number Diff line change
Expand Up @@ -263,6 +263,29 @@ public ValueTask ConnectAsync(string host, int port, CancellationToken cancellat
return ConnectAsync(ep, cancellationToken);
}

/// <summary>
/// Disconnects a connected socket from the remote host.
/// </summary>
/// <param name="reuseSocket">Indicates whether the socket should be available for reuse after disconnect.</param>
/// <param name="cancellationToken">A cancellation token that can be used to cancel the asynchronous operation.</param>
/// <returns>An asynchronous task that completes when the socket is disconnected.</returns>
public ValueTask DisconnectAsync(bool reuseSocket, CancellationToken cancellationToken = default)
{
if (cancellationToken.IsCancellationRequested)
{
return ValueTask.FromCanceled(cancellationToken);
}

AwaitableSocketAsyncEventArgs saea =
Interlocked.Exchange(ref _singleBufferSendEventArgs, null) ??
new AwaitableSocketAsyncEventArgs(this, isReceiveForCaching: false);

saea.DisconnectReuseSocket = reuseSocket;
saea.WrapExceptionsForNetworkStream = false;

return saea.DisconnectAsync(this, cancellationToken);
}

/// <summary>
/// Receives data from a connected socket.
/// </summary>
Expand Down Expand Up @@ -1028,6 +1051,25 @@ public ValueTask ConnectAsync(Socket socket)
ValueTask.FromException(CreateException(error));
}

public ValueTask DisconnectAsync(Socket socket, CancellationToken cancellationToken)
{
Debug.Assert(Volatile.Read(ref _continuation) == null, $"Expected null continuation to indicate reserved for use");

if (socket.DisconnectAsync(this, cancellationToken))
{
_cancellationToken = cancellationToken;
return new ValueTask(this, _token);
}

SocketError error = SocketError;

Release();

return error == SocketError.Success ?
ValueTask.CompletedTask :
ValueTask.FromException(CreateException(error));
}

/// <summary>Gets the status of the operation.</summary>
public ValueTaskSourceStatus GetStatus(short token)
{
Expand Down
119 changes: 36 additions & 83 deletions src/libraries/System.Net.Sockets/src/System/Net/Sockets/Socket.cs
Original file line number Diff line number Diff line change
Expand Up @@ -73,7 +73,6 @@ private sealed class CacheSet
private int _closeTimeout = Socket.DefaultCloseTimeout;
private int _disposed; // 0 == false, anything else == true

#region Constructors
public Socket(SocketType socketType, ProtocolType protocolType)
: this(OSSupportsIPv6 ? AddressFamily.InterNetworkV6 : AddressFamily.InterNetwork, socketType, protocolType)
{
Expand Down Expand Up @@ -242,9 +241,10 @@ private static SafeSocketHandle ValidateHandle(SafeSocketHandle handle) =>
handle is null ? throw new ArgumentNullException(nameof(handle)) :
handle.IsInvalid ? throw new ArgumentException(SR.Arg_InvalidHandle, nameof(handle)) :
handle;
#endregion

#region Properties
//
// Properties
//

// The CLR allows configuration of these properties, separately from whether the OS supports IPv4/6. We
// do not provide these config options, so SupportsIPvX === OSSupportsIPvX.
Expand Down Expand Up @@ -761,9 +761,10 @@ internal bool CanTryAddressFamily(AddressFamily family)
{
return (family == _addressFamily) || (family == AddressFamily.InterNetwork && IsDualMode);
}
#endregion

#region Public Methods
//
// Public Methods
//

// Associates a socket with an end point.
public void Bind(EndPoint localEP)
Expand Down Expand Up @@ -2116,43 +2117,14 @@ public IAsyncResult BeginConnect(IPAddress address, int port, AsyncCallback? req
public IAsyncResult BeginConnect(IPAddress[] addresses, int port, AsyncCallback? requestCallback, object? state) =>
TaskToApm.Begin(ConnectAsync(addresses, port), requestCallback, state);

public IAsyncResult BeginDisconnect(bool reuseSocket, AsyncCallback? callback, object? state)
public void EndConnect(IAsyncResult asyncResult)
{
ThrowIfDisposed();

// Start context-flowing op. No need to lock - we don't use the context till the callback.
DisconnectOverlappedAsyncResult asyncResult = new DisconnectOverlappedAsyncResult(this, state, callback);
asyncResult.StartPostingAsyncOp(false);

// Post the disconnect.
DoBeginDisconnect(reuseSocket, asyncResult);

// Finish flowing (or call the callback), and return.
asyncResult.FinishPostingAsyncOp();
return asyncResult;
TaskToApm.End(asyncResult);
}

private void DoBeginDisconnect(bool reuseSocket, DisconnectOverlappedAsyncResult asyncResult)
{
SocketError errorCode = SocketError.Success;

errorCode = SocketPal.DisconnectAsync(this, _handle, reuseSocket, asyncResult);

if (errorCode == SocketError.Success)
{
SetToDisconnected();
_remoteEndPoint = null;
_localEndPoint = null;
}

if (NetEventSource.Log.IsEnabled()) NetEventSource.Info(this, $"UnsafeNclNativeMethods.OSSOCK.DisConnectEx returns:{errorCode}");

// If the call failed, update our status and throw
if (!CheckErrorAndUpdateStatus(errorCode))
{
throw new SocketException((int)errorCode);
}
}
public IAsyncResult BeginDisconnect(bool reuseSocket, AsyncCallback? callback, object? state) =>
TaskToApmBeginWithSyncExceptions(DisconnectAsync(reuseSocket).AsTask(), callback, state);

public void Disconnect(bool reuseSocket)
{
Expand All @@ -2175,47 +2147,12 @@ public void Disconnect(bool reuseSocket)
_localEndPoint = null;
}

public void EndConnect(IAsyncResult asyncResult)
public void EndDisconnect(IAsyncResult asyncResult)
{
ThrowIfDisposed();
TaskToApm.End(asyncResult);
}

public void EndDisconnect(IAsyncResult asyncResult)
{
ThrowIfDisposed();

if (asyncResult == null)
{
throw new ArgumentNullException(nameof(asyncResult));
}

//get async result and check for errors
LazyAsyncResult? castedAsyncResult = asyncResult as LazyAsyncResult;
if (castedAsyncResult == null || castedAsyncResult.AsyncObject != this)
{
throw new ArgumentException(SR.net_io_invalidasyncresult, nameof(asyncResult));
}
if (castedAsyncResult.EndCalled)
{
throw new InvalidOperationException(SR.Format(SR.net_io_invalidendcall, nameof(EndDisconnect)));
}

//wait for completion if it hasn't occurred
castedAsyncResult.InternalWaitForCompletion();
castedAsyncResult.EndCalled = true;

if (NetEventSource.Log.IsEnabled()) NetEventSource.Info(this);

//
// if the asynchronous native call failed asynchronously
// we'll throw a SocketException
//
if ((SocketError)castedAsyncResult.ErrorCode != SocketError.Success)
{
UpdateStatusAfterSocketErrorAndThrowException((SocketError)castedAsyncResult.ErrorCode);
}
}

public IAsyncResult BeginSend(byte[] buffer, int offset, int size, SocketFlags socketFlags, AsyncCallback? callback, object? state)
{
Expand Down Expand Up @@ -2668,7 +2605,10 @@ public void Shutdown(SocketShutdown how)
InternalSetBlocking(_willBlockInternal);
}

#region Async methods
//
// Async methods
//

public bool AcceptAsync(SocketAsyncEventArgs e)
{
ThrowIfDisposed();
Expand Down Expand Up @@ -2889,7 +2829,9 @@ public static void CancelConnectAsync(SocketAsyncEventArgs e)
e.CancelConnectAsync();
}

public bool DisconnectAsync(SocketAsyncEventArgs e)
public bool DisconnectAsync(SocketAsyncEventArgs e) => DisconnectAsync(e, default);

private bool DisconnectAsync(SocketAsyncEventArgs e, CancellationToken cancellationToken)
{
// Throw if socket disposed
ThrowIfDisposed();
Expand All @@ -2904,7 +2846,7 @@ public bool DisconnectAsync(SocketAsyncEventArgs e)
SocketError socketError = SocketError.Success;
try
{
socketError = e.DoOperationDisconnect(this, _handle);
socketError = e.DoOperationDisconnect(this, _handle, cancellationToken);
}
catch
{
Expand Down Expand Up @@ -3155,10 +3097,10 @@ private bool SendToAsync(SocketAsyncEventArgs e, CancellationToken cancellationT

return socketError == SocketError.IOPending;
}
#endregion
#endregion

#region Internal and private properties
//
// Internal and private properties
//

private CacheSet Caches
{
Expand All @@ -3174,9 +3116,10 @@ private CacheSet Caches
}

internal bool Disposed => _disposed != 0;
#endregion

#region Internal and private methods
//
// Internal and private methods
//

internal static void GetIPProtocolInformation(AddressFamily addressFamily, Internals.SocketAddress socketAddress, out bool isIPv4, out bool isIPv6)
{
Expand Down Expand Up @@ -3889,6 +3832,16 @@ private static SocketError GetSocketErrorFromFaultedTask(Task t)
};
}

#endregion
// Helper to maintain existing behavior of Socket APM methods to throw synchronously from Begin*.
private static IAsyncResult TaskToApmBeginWithSyncExceptions(Task task, AsyncCallback? callback, object? state)
{
if (task.IsFaulted)
{
task.GetAwaiter().GetResult();
Debug.Fail("Task faulted but GetResult did not throw???");
}

return TaskToApm.Begin(task, callback, state);
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -93,7 +93,7 @@ internal unsafe SocketError DoOperationConnect(Socket socket, SafeSocketHandle h
return socketError;
}

internal SocketError DoOperationDisconnect(Socket socket, SafeSocketHandle handle)
internal SocketError DoOperationDisconnect(Socket socket, SafeSocketHandle handle, CancellationToken cancellationToken)
{
SocketError socketError = SocketPal.Disconnect(socket, handle, _disconnectReuseSocket);
FinishOperationSync(socketError, 0, SocketFlags.None);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -364,8 +364,11 @@ internal unsafe SocketError DoOperationConnectEx(Socket socket, SafeSocketHandle
}
}

internal unsafe SocketError DoOperationDisconnect(Socket socket, SafeSocketHandle handle)
internal unsafe SocketError DoOperationDisconnect(Socket socket, SafeSocketHandle handle, CancellationToken cancellationToken)
{
// Note: CancellationToken is ignored for now.
// See https://github.com/dotnet/runtime/issues/51452

NativeOverlapped* overlapped = AllocateNativeOverlapped();
try
{
Expand Down Expand Up @@ -1188,6 +1191,7 @@ private unsafe SocketError FinishOperationConnect()
private void CompleteCore()
{
_strongThisRef.Value = null; // null out this reference from the overlapped so this isn't kept alive artificially

if (_singleBufferHandleState != SingleBufferHandleState.None)
{
// If the state isn't None, then either it's Set, in which case there's state to cleanup,
Expand All @@ -1213,6 +1217,8 @@ void CompleteCoreSpin()
sw.SpinOnce();
}

Debug.Assert(_singleBufferHandleState == SingleBufferHandleState.Set);

// Remove any cancellation registration. First dispose the registration
// to ensure that cancellation will either never fine or will have completed
// firing before we continue. Only then can we safely null out the overlapped.
Expand All @@ -1223,6 +1229,8 @@ void CompleteCoreSpin()
}

// Release any GC handles.
Debug.Assert(_singleBufferHandleState == SingleBufferHandleState.Set);

if (_singleBufferHandleState == SingleBufferHandleState.Set)
{
_singleBufferHandleState = SingleBufferHandleState.None;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1976,13 +1976,6 @@ public static SocketError AcceptAsync(Socket socket, SafeSocketHandle handle, Sa
return socketError;
}

internal static SocketError DisconnectAsync(Socket socket, SafeSocketHandle handle, bool reuseSocket, DisconnectOverlappedAsyncResult asyncResult)
{
SocketError socketError = Disconnect(socket, handle, reuseSocket);
asyncResult.PostCompletion(socketError);
return socketError;
}

internal static SocketError Disconnect(Socket socket, SafeSocketHandle handle, bool reuseSocket)
{
handle.SetToDisconnected();
Expand Down
Loading