Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@ internal static class MsQuicAddressHelpers
internal const ushort IPv4 = 2;
internal const ushort IPv6 = 23;

internal static unsafe IPEndPoint INetToIPEndPoint(SOCKADDR_INET inetAddress)
internal static unsafe IPEndPoint INetToIPEndPoint(ref SOCKADDR_INET inetAddress)
{
if (inetAddress.si_family == IPv4)
{
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ private unsafe MsQuicApi()

try
{
uint status = Interop.MsQuic.MsQuicOpen(version: 1, out registration);
uint status = Interop.MsQuic.MsQuicOpen(out registration);
if (!MsQuicStatusHelper.SuccessfulStatusCode(status))
{
throw new NotSupportedException(SR.net_quic_notsupported);
Expand Down Expand Up @@ -123,7 +123,13 @@ private unsafe MsQuicApi()
Marshal.GetDelegateForFunctionPointer<MsQuicNativeMethods.GetParamDelegate>(
nativeRegistration.GetParam);

RegistrationOpenDelegate(Encoding.UTF8.GetBytes("SystemNetQuic"), out IntPtr ctx);
var registrationConfig = new MsQuicNativeMethods.RegistrationConfig
{
AppName = "SystemNetQuic",
ExecutionProfile = QUIC_EXECUTION_PROFILE.QUIC_EXECUTION_PROFILE_LOW_LATENCY
};

RegistrationOpenDelegate(ref registrationConfig, out IntPtr ctx);
_registrationContext = ctx;
}

Expand Down Expand Up @@ -312,15 +318,26 @@ void SecCfgCreateCallbackHandler(
return secConfig;
}

public IntPtr SessionOpen(byte[] alpn)
public unsafe IntPtr SessionOpen(byte[] alpn)
{
IntPtr sessionPtr = IntPtr.Zero;
uint status;

uint status = SessionOpenDelegate(
_registrationContext,
alpn,
IntPtr.Zero,
ref sessionPtr);
fixed (byte* pAlpn = alpn)
{
var alpnBuffer = new MsQuicNativeMethods.QuicBuffer
{
Length = (uint)alpn.Length,
Buffer = pAlpn
};

status = SessionOpenDelegate(
_registrationContext,
&alpnBuffer,
1,
IntPtr.Zero,
ref sessionPtr);
}

QuicExceptionHelpers.ThrowIfFailed(status, "Could not open session.");

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@ public IntPtr ConnectionOpen(QuicClientConnectionOptions options)

QuicExceptionHelpers.ThrowIfFailed(MsQuicApi.Api.ConnectionOpenDelegate(
_nativeObjPtr,
MsQuicConnection.NativeCallbackHandler,
MsQuicConnection.s_connectionDelegate,
IntPtr.Zero,
out IntPtr connectionPtr),
"Could not open the connection.");
Expand Down Expand Up @@ -83,15 +83,15 @@ public void Dispose()

public void SetPeerBiDirectionalStreamCount(ushort count)
{
SetUshortParamter(QUIC_PARAM_SESSION.PEER_BIDI_STREAM_COUNT, count);
SetUshortParameter(QUIC_PARAM_SESSION.PEER_BIDI_STREAM_COUNT, count);
}

public void SetPeerUnidirectionalStreamCount(ushort count)
{
SetUshortParamter(QUIC_PARAM_SESSION.PEER_UNIDI_STREAM_COUNT, count);
SetUshortParameter(QUIC_PARAM_SESSION.PEER_UNIDI_STREAM_COUNT, count);
}

private unsafe void SetUshortParamter(QUIC_PARAM_SESSION param, ushort count)
private unsafe void SetUshortParameter(QUIC_PARAM_SESSION param, ushort count)
{
var buffer = new MsQuicNativeMethods.QuicBuffer()
{
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2,9 +2,11 @@
// The .NET Foundation licenses this file to you under the MIT license.

#nullable enable
using System.Diagnostics;
using System.IO;
using System.Net.Quic.Implementations.MsQuic.Internal;
using System.Net.Security;
using System.Runtime.ExceptionServices;
using System.Runtime.InteropServices;
using System.Security.Cryptography.X509Certificates;
using System.Threading;
Expand All @@ -16,7 +18,7 @@ namespace System.Net.Quic.Implementations.MsQuic
{
internal sealed class MsQuicConnection : QuicConnectionProvider
{
private MsQuicSession? _session;
private readonly MsQuicSession? _session;

// Pointer to the underlying connection
// TODO replace all IntPtr with SafeHandles
Expand All @@ -26,15 +28,17 @@ internal sealed class MsQuicConnection : QuicConnectionProvider
private GCHandle _handle;

// Delegate that wraps the static function that will be called when receiving an event.
// TODO investigate if the delegate can be static instead.
private ConnectionCallbackDelegate? _connectionDelegate;
internal static readonly ConnectionCallbackDelegate s_connectionDelegate = new ConnectionCallbackDelegate(NativeCallbackHandler);

// Endpoint to either connect to or the endpoint already accepted.
private IPEndPoint? _localEndPoint;
private readonly IPEndPoint _remoteEndPoint;

private readonly ResettableCompletionSource<uint> _connectTcs = new ResettableCompletionSource<uint>();
private readonly ResettableCompletionSource<uint> _shutdownTcs = new ResettableCompletionSource<uint>();
private SslApplicationProtocol _negotiatedAlpnProtocol;

// TODO: only allocate these when there is an outstanding connect/shutdown.
private readonly TaskCompletionSource<uint> _connectTcs = new TaskCompletionSource<uint>();
private readonly TaskCompletionSource<uint> _shutdownTcs = new TaskCompletionSource<uint>();

private bool _disposed;
private bool _connected;
Expand All @@ -54,6 +58,7 @@ public MsQuicConnection(IPEndPoint localEndPoint, IPEndPoint remoteEndPoint, Int
_localEndPoint = localEndPoint;
_remoteEndPoint = remoteEndPoint;
_ptr = nativeObjPtr;
_connected = true;

SetCallbackHandler();
SetIdleTimeout(TimeSpan.FromSeconds(120));
Expand Down Expand Up @@ -89,125 +94,96 @@ internal async ValueTask SetSecurityConfigForConnection(X509Certificate cert, st

internal override IPEndPoint RemoteEndPoint => new IPEndPoint(_remoteEndPoint.Address, _remoteEndPoint.Port);

internal override SslApplicationProtocol NegotiatedApplicationProtocol => throw new NotImplementedException();
internal override SslApplicationProtocol NegotiatedApplicationProtocol => _negotiatedAlpnProtocol;

internal override bool Connected => _connected;

internal uint HandleEvent(ref ConnectionEvent connectionEvent)
{
uint status = MsQuicStatusCodes.Success;
try
{
switch (connectionEvent.Type)
{
// Connection is connected, can start to create streams.
case QUIC_CONNECTION_EVENT.CONNECTED:
{
status = HandleEventConnected(
connectionEvent);
}
break;

// Connection is being closed by the transport
return HandleEventConnected(ref connectionEvent);
case QUIC_CONNECTION_EVENT.SHUTDOWN_INITIATED_BY_TRANSPORT:
{
status = HandleEventShutdownInitiatedByTransport(
connectionEvent);
}
break;

// Connection is being closed by the peer
return HandleEventShutdownInitiatedByTransport(ref connectionEvent);
case QUIC_CONNECTION_EVENT.SHUTDOWN_INITIATED_BY_PEER:
{
status = HandleEventShutdownInitiatedByPeer(
connectionEvent);
}
break;

// Connection has been shutdown
return HandleEventShutdownInitiatedByPeer(ref connectionEvent);
case QUIC_CONNECTION_EVENT.SHUTDOWN_COMPLETE:
{
status = HandleEventShutdownComplete(
connectionEvent);
}
break;

return HandleEventShutdownComplete(ref connectionEvent);
case QUIC_CONNECTION_EVENT.PEER_STREAM_STARTED:
{
status = HandleEventNewStream(
connectionEvent);
}
break;

return HandleEventNewStream(ref connectionEvent);
case QUIC_CONNECTION_EVENT.STREAMS_AVAILABLE:
{
status = HandleEventStreamsAvailable(
connectionEvent);
}
break;

return HandleEventStreamsAvailable(ref connectionEvent);
default:
break;
return MsQuicStatusCodes.Success;
}
}
catch (Exception)
catch (Exception ex)
{
// TODO we may want to either add a debug assert here or return specific error codes
// based on the exception caught.
if (NetEventSource.Log.IsEnabled())
{
NetEventSource.Error(this, $"Exception occurred during connection callback: {ex.Message}");
}

// TODO: trigger an exception on any outstanding async calls.

return MsQuicStatusCodes.InternalError;
}

return status;
}

private uint HandleEventConnected(ConnectionEvent connectionEvent)
private uint HandleEventConnected(ref ConnectionEvent connectionEvent)
{
SOCKADDR_INET inetAddress = MsQuicParameterHelpers.GetINetParam(MsQuicApi.Api, _ptr, (uint)QUIC_PARAM_LEVEL.CONNECTION, (uint)QUIC_PARAM_CONN.LOCAL_ADDRESS);
_localEndPoint = MsQuicAddressHelpers.INetToIPEndPoint(inetAddress);
if (!_connected)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can this ever happen when _connected is true? Should we assert and/or throw?

Copy link
Contributor Author

@scalablecory scalablecory Aug 6, 2020

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This happens on the listener side -- MsQuic fires the connected event immediately after it fires the new connection accepted event.

{
// _connected will already be true for connections accepted from a listener.

_connected = true;
// I don't believe we need to lock here because
// handle event connected will not be called at the same time as
// handle event shutdown initiated by transport
_connectTcs.Complete(MsQuicStatusCodes.Success);
SOCKADDR_INET inetAddress = MsQuicParameterHelpers.GetINetParam(MsQuicApi.Api, _ptr, (uint)QUIC_PARAM_LEVEL.CONNECTION, (uint)QUIC_PARAM_CONN.LOCAL_ADDRESS);
_localEndPoint = MsQuicAddressHelpers.INetToIPEndPoint(ref inetAddress);

SetNegotiatedAlpn(connectionEvent.Data.Connected.NegotiatedAlpn, connectionEvent.Data.Connected.NegotiatedAlpnLength);

_connected = true;
_connectTcs.SetResult(MsQuicStatusCodes.Success);
}

return MsQuicStatusCodes.Success;
}

private uint HandleEventShutdownInitiatedByTransport(ConnectionEvent connectionEvent)
private uint HandleEventShutdownInitiatedByTransport(ref ConnectionEvent connectionEvent)
{
if (!_connected)
{
_connectTcs.CompleteException(new IOException("Connection has been shutdown."));
_connectTcs.SetException(ExceptionDispatchInfo.SetCurrentStackTrace(new IOException("Connection has been shutdown.")));
}

_acceptQueue.Writer.Complete();


return MsQuicStatusCodes.Success;
}

private uint HandleEventShutdownInitiatedByPeer(ConnectionEvent connectionEvent)
private uint HandleEventShutdownInitiatedByPeer(ref ConnectionEvent connectionEvent)
{
_abortErrorCode = connectionEvent.Data.ShutdownBeginPeer.ErrorCode;
_abortErrorCode = connectionEvent.Data.ShutdownInitiatedByPeer.ErrorCode;
_acceptQueue.Writer.Complete();
return MsQuicStatusCodes.Success;
}

private uint HandleEventShutdownComplete(ConnectionEvent connectionEvent)
private uint HandleEventShutdownComplete(ref ConnectionEvent connectionEvent)
{
_shutdownTcs.Complete(MsQuicStatusCodes.Success);
_shutdownTcs.SetResult(MsQuicStatusCodes.Success);
return MsQuicStatusCodes.Success;
}

private uint HandleEventNewStream(ConnectionEvent connectionEvent)
private uint HandleEventNewStream(ref ConnectionEvent connectionEvent)
{
MsQuicStream msQuicStream = new MsQuicStream(this, connectionEvent.StreamFlags, connectionEvent.Data.NewStream.Stream, inbound: true);
MsQuicStream msQuicStream = new MsQuicStream(this, connectionEvent.StreamFlags, connectionEvent.Data.StreamStarted.Stream, inbound: true);
_acceptQueue.Writer.TryWrite(msQuicStream);
return MsQuicStatusCodes.Success;
}

private uint HandleEventStreamsAvailable(ConnectionEvent connectionEvent)
private uint HandleEventStreamsAvailable(ref ConnectionEvent connectionEvent)
{
return MsQuicStatusCodes.Success;
}
Expand Down Expand Up @@ -275,7 +251,7 @@ internal override ValueTask ConnectAsync(CancellationToken cancellationToken = d
(ushort)_remoteEndPoint.Port),
"Failed to connect to peer.");

return _connectTcs.GetTypelessValueTask();
return new ValueTask(_connectTcs.Task);
}

private MsQuicStream StreamOpen(
Expand All @@ -286,7 +262,7 @@ private MsQuicStream StreamOpen(
MsQuicApi.Api.StreamOpenDelegate(
_ptr,
(uint)flags,
MsQuicStream.NativeCallbackHandler,
MsQuicStream.s_streamDelegate,
IntPtr.Zero,
out streamPtr),
"Failed to open stream to peer.");
Expand All @@ -296,11 +272,12 @@ private MsQuicStream StreamOpen(

private void SetCallbackHandler()
{
Debug.Assert(!_handle.IsAllocated);
_handle = GCHandle.Alloc(this);
_connectionDelegate = new ConnectionCallbackDelegate(NativeCallbackHandler);

MsQuicApi.Api.SetCallbackHandlerDelegate(
_ptr,
_connectionDelegate,
s_connectionDelegate,
GCHandle.ToIntPtr(_handle));
}

Expand All @@ -314,10 +291,20 @@ private ValueTask ShutdownAsync(
ErrorCode);
QuicExceptionHelpers.ThrowIfFailed(status, "Failed to shutdown connection.");

return _shutdownTcs.GetTypelessValueTask();
return new ValueTask(_shutdownTcs.Task);
}

internal void SetNegotiatedAlpn(IntPtr alpn, int alpnLength)
{
if (alpn != IntPtr.Zero && alpnLength != 0)
{
var buffer = new byte[alpnLength];
Marshal.Copy(alpn, buffer, 0, alpnLength);
_negotiatedAlpnProtocol = new SslApplicationProtocol(buffer);
}
}

internal static uint NativeCallbackHandler(
private static uint NativeCallbackHandler(
IntPtr connection,
IntPtr context,
ref ConnectionEvent connectionEventStruct)
Expand Down
Loading