Skip to content
This repository was archived by the owner on Dec 18, 2018. It is now read-only.
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,8 @@ public class FilteredStreamAdapter
private readonly Stream _filteredStream;
private readonly Stream _socketInputStream;
private readonly IKestrelTrace _log;
private readonly MemoryPool2 _memory;
private MemoryPoolBlock2 _block;

public FilteredStreamAdapter(
Stream filteredStream,
Expand All @@ -28,30 +30,35 @@ public FilteredStreamAdapter(
_log = logger;
_filteredStream = filteredStream;
_socketInputStream = new SocketInputStream(SocketInput);
_memory = memory;
}

public SocketInput SocketInput { get; private set; }

var block = memory.Lease();
public ISocketOutput SocketOutput { get; private set; }

public void ReadInput()
{
_block = _memory.Lease();
// Use pooled block for copy
_filteredStream.CopyToAsync(_socketInputStream, block).ContinueWith((task, state) =>
_filteredStream.CopyToAsync(_socketInputStream, _block).ContinueWith((task, state) =>
{
var returnedBlock = task.Result;
returnedBlock.Pool.Return(returnedBlock);

((FilteredStreamAdapter)state).OnStreamClose(task);
}, this);
}

public SocketInput SocketInput { get; private set; }

public ISocketOutput SocketOutput { get; private set; }

private void OnStreamClose(Task copyAsyncTask)
{
_memory.Return(_block);

if (copyAsyncTask.IsFaulted)
{
SocketInput.AbortAwaiting();
_log.LogError(0, copyAsyncTask.Exception, "FilteredStreamAdapter.CopyToAsync");
}
else if (copyAsyncTask.IsCanceled)
{
SocketInput.AbortAwaiting();
_log.LogError("FilteredStreamAdapter.CopyToAsync canceled.");
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -9,15 +9,13 @@ namespace Microsoft.AspNetCore.Server.Kestrel.Filter
{
public static class StreamExtensions
{
public static async Task<MemoryPoolBlock2> CopyToAsync(this Stream source, Stream destination, MemoryPoolBlock2 block)
public static async Task CopyToAsync(this Stream source, Stream destination, MemoryPoolBlock2 block)
{
int bytesRead;
while ((bytesRead = await source.ReadAsync(block.Array, block.Data.Offset, block.Data.Count)) != 0)
{
await destination.WriteAsync(block.Array, block.Data.Offset, bytesRead);
}

return block;
}
}
}
156 changes: 118 additions & 38 deletions src/Microsoft.AspNetCore.Server.Kestrel/Http/Connection.cs
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
using System;
using System.Net;
using System.Threading;
using System.Threading.Tasks;
using Microsoft.AspNetCore.Server.Kestrel.Filter;
using Microsoft.AspNetCore.Server.Kestrel.Infrastructure;
using Microsoft.AspNetCore.Server.Kestrel.Networking;
Expand Down Expand Up @@ -31,13 +32,15 @@ public class Connection : ConnectionContext, IConnectionControl

private readonly object _stateLock = new object();
private ConnectionState _connectionState;
private TaskCompletionSource<object> _socketClosedTcs;

private IPEndPoint _remoteEndPoint;
private IPEndPoint _localEndPoint;

public Connection(ListenerContext context, UvStreamHandle socket) : base(context)
{
_socket = socket;
socket.Connection = this;
ConnectionControl = this;

_connectionId = Interlocked.Increment(ref _lastConnectionId);
Expand All @@ -46,6 +49,11 @@ public Connection(ListenerContext context, UvStreamHandle socket) : base(context
_rawSocketOutput = new SocketOutput(Thread, _socket, Memory2, this, _connectionId, Log, ThreadPool, WriteReqPool);
}

// Internal for testing
internal Connection()
{
}

public void Start()
{
Log.ConnectionStart(_connectionId);
Expand All @@ -61,13 +69,23 @@ public void Start()
}

// Don't initialize _frame until SocketInput and SocketOutput are set to their final values.
if (ConnectionFilter == null)
if (ServerInformation.ConnectionFilter == null)
{
SocketInput = _rawSocketInput;
SocketOutput = _rawSocketOutput;
lock (_stateLock)
{
if (_connectionState != ConnectionState.CreatingFrame)
{
throw new InvalidOperationException("Invalid connection state: " + _connectionState);
}

_connectionState = ConnectionState.Open;

_frame = CreateFrame();
_frame.Start();
SocketInput = _rawSocketInput;
SocketOutput = _rawSocketOutput;

_frame = CreateFrame();
_frame.Start();
}
}
else
{
Expand All @@ -81,7 +99,7 @@ public void Start()

try
{
ConnectionFilter.OnConnectionAsync(_filterContext).ContinueWith((task, state) =>
ServerInformation.ConnectionFilter.OnConnectionAsync(_filterContext).ContinueWith((task, state) =>
{
var connection = (Connection)state;

Expand Down Expand Up @@ -109,37 +127,105 @@ public void Start()
}
}

public virtual void Abort()
public Task StopAsync()
{
if (_frame != null)
lock (_stateLock)
{
// Frame.Abort calls user code while this method is always
// called from a libuv thread.
System.Threading.ThreadPool.QueueUserWorkItem(state =>
switch (_connectionState)
{
var connection = (Connection)state;
connection._frame.Abort();
}, this);
case ConnectionState.SocketClosed:
return TaskUtilities.CompletedTask;
case ConnectionState.CreatingFrame:
_connectionState = ConnectionState.ToDisconnect;
break;
case ConnectionState.Open:
_frame.Stop();
SocketInput.CompleteAwaiting();
break;
}

_socketClosedTcs = new TaskCompletionSource<object>();
return _socketClosedTcs.Task;
}
}

private void ApplyConnectionFilter()
public virtual void Abort()
{
if (_filterContext.Connection != _libuvStream)
lock (_stateLock)
{
var filteredStreamAdapter = new FilteredStreamAdapter(_filterContext.Connection, Memory2, Log, ThreadPool);
if (_connectionState == ConnectionState.CreatingFrame)
{
_connectionState = ConnectionState.ToDisconnect;
}
else
{
// Frame.Abort calls user code while this method is always
// called from a libuv thread.
System.Threading.ThreadPool.QueueUserWorkItem(state =>
{
var connection = (Connection)state;
connection._frame.Abort();
}, this);
}
}
}

SocketInput = filteredStreamAdapter.SocketInput;
SocketOutput = filteredStreamAdapter.SocketOutput;
// Called on Libuv thread
public virtual void OnSocketClosed()
{
_rawSocketInput.Dispose();

// If a connection filter was applied there will be two SocketInputs.
// If a connection filter failed, SocketInput will be null.
if (SocketInput != null && SocketInput != _rawSocketInput)
{
SocketInput.Dispose();
}
else

lock (_stateLock)
{
SocketInput = _rawSocketInput;
SocketOutput = _rawSocketOutput;
_connectionState = ConnectionState.SocketClosed;

if (_socketClosedTcs != null)
{
// This is always waited on synchronously, so it's safe to
// call on the libuv thread.
_socketClosedTcs.TrySetResult(null);
}
}
}

private void ApplyConnectionFilter()
{
lock (_stateLock)
{
if (_connectionState == ConnectionState.CreatingFrame)
{
_connectionState = ConnectionState.Open;

if (_filterContext.Connection != _libuvStream)
{
var filteredStreamAdapter = new FilteredStreamAdapter(_filterContext.Connection, Memory2, Log, ThreadPool);

SocketInput = filteredStreamAdapter.SocketInput;
SocketOutput = filteredStreamAdapter.SocketOutput;

_frame = CreateFrame();
_frame.Start();
filteredStreamAdapter.ReadInput();
}
else
{
SocketInput = _rawSocketInput;
SocketOutput = _rawSocketOutput;
}

_frame = CreateFrame();
_frame.Start();
}
else
{
ConnectionControl.End(ProduceEndType.SocketDisconnect);
}
}
}

private static Libuv.uv_buf_t AllocCallback(UvStreamHandle handle, int suggestedSize, object state)
Expand Down Expand Up @@ -215,16 +301,6 @@ void IConnectionControl.End(ProduceEndType endType)
{
switch (endType)
{
case ProduceEndType.SocketShutdownSend:
if (_connectionState != ConnectionState.Open)
{
return;
}
_connectionState = ConnectionState.Shutdown;

Log.ConnectionWriteFin(_connectionId);
_rawSocketOutput.End(endType);
break;
case ProduceEndType.ConnectionKeepAlive:
if (_connectionState != ConnectionState.Open)
{
Expand All @@ -233,12 +309,14 @@ void IConnectionControl.End(ProduceEndType endType)

Log.ConnectionKeepAlive(_connectionId);
break;
case ProduceEndType.SocketShutdown:
case ProduceEndType.SocketDisconnect:
if (_connectionState == ConnectionState.Disconnected)
if (_connectionState == ConnectionState.Disconnecting ||
_connectionState == ConnectionState.SocketClosed)
{
return;
}
_connectionState = ConnectionState.Disconnected;
_connectionState = ConnectionState.Disconnecting;

Log.ConnectionDisconnect(_connectionId);
_rawSocketOutput.End(endType);
Expand All @@ -249,9 +327,11 @@ void IConnectionControl.End(ProduceEndType endType)

private enum ConnectionState
{
CreatingFrame,
ToDisconnect,
Open,
Shutdown,
Disconnected
Disconnecting,
SocketClosed
}
}
}
53 changes: 53 additions & 0 deletions src/Microsoft.AspNetCore.Server.Kestrel/Http/ConnectionManager.cs
Original file line number Diff line number Diff line change
@@ -0,0 +1,53 @@
// Copyright (c) .NET Foundation. All rights reserved.
// Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information.

using System;
using System.Collections.Generic;
using System.Threading.Tasks;
using Microsoft.AspNetCore.Server.Kestrel.Networking;

namespace Microsoft.AspNetCore.Server.Kestrel.Http
{
public class ConnectionManager
{
private KestrelThread _thread;
private List<Task> _connectionStopTasks;

public ConnectionManager(KestrelThread thread)
{
_thread = thread;
}

// This must be called on the libuv event loop
public void WalkConnectionsAndClose()
{
if (_connectionStopTasks != null)
{
throw new InvalidOperationException($"{nameof(WalkConnectionsAndClose)} cannot be called twice.");
}

_connectionStopTasks = new List<Task>();

_thread.Walk(ptr =>
{
var handle = UvMemory.FromIntPtr<UvHandle>(ptr);
var connection = (handle as UvStreamHandle)?.Connection;

if (connection != null)
{
_connectionStopTasks.Add(connection.StopAsync());
}
});
}

public Task WaitForConnectionCloseAsync()
{
if (_connectionStopTasks == null)
{
throw new InvalidOperationException($"{nameof(WalkConnectionsAndClose)} must be called first.");
}

return Task.WhenAll(_connectionStopTasks);
}
}
}
Loading