From 76992612b7c8e60bba89722805ba61e991d210db Mon Sep 17 00:00:00 2001 From: Brennan Date: Wed, 26 Jan 2022 09:39:03 -0800 Subject: [PATCH 01/11] [SignalR] Client return results --- .../csharp/Client.Core/src/HubConnection.cs | 75 ++- .../src/HubConnectionExtensions.OnResult.cs | 455 ++++++++++++++++++ .../src/HubConnectionExtensions.cs | 20 +- .../Client.Core/src/PublicAPI.Unshipped.txt | 20 + .../HubConnectionTests.Extensions.cs | 304 ++++++++++++ .../UnitTests/HubConnectionTests.Protocol.cs | 181 +++++++ .../java/com/microsoft/signalr/Action.java | 2 +- .../com/microsoft/signalr/ActionBase.java | 4 + .../com/microsoft/signalr/CallbackMap.java | 2 +- .../signalr/ClientResultMessage.java | 47 ++ .../java/com/microsoft/signalr/Function.java | 5 + .../com/microsoft/signalr/HubConnection.java | 77 ++- .../com/microsoft/signalr/HubMessageType.java | 1 + .../microsoft/signalr/InvocationHandler.java | 14 +- .../com/microsoft/signalr/sample/Chat.java | 4 + .../clients/ts/FunctionalTests/Startup.cs | 14 + .../FunctionalTests/ts/HubConnectionTests.ts | 54 +++ .../clients/ts/signalr/src/HubConnection.ts | 51 +- .../ts/signalr/tests/HubConnection.test.ts | 284 +++++++++++ .../src/Protocol/JsonHubProtocol.cs | 36 +- .../Protocol/MessagePackHubProtocolWorker.cs | 13 +- .../src/Protocol/NewtonsoftJsonHubProtocol.cs | 31 +- .../common/Shared/ClientResultsManager.cs | 196 ++++++++ .../src/Protocol/CompletionMessage.cs | 2 +- .../SignalR.Common/src/Protocol/RawResult.cs | 35 ++ .../src/PublicAPI.Unshipped.txt | 3 + .../Protocol/JsonHubProtocolTestsBase.cs | 79 +++ .../Protocol/MessagePackHubProtocolTests.cs | 63 +++ .../DefaultHubDispatcherBenchmark.cs | 6 +- .../Microbenchmarks/RedisProtocolBenchmark.cs | 8 +- src/SignalR/samples/ClientSample/HubSample.cs | 124 +++-- .../samples/SignalRSamples/Hubs/Chat.cs | 9 +- .../samples/SignalRSamples/Hubs/GameHub.cs | 24 + src/SignalR/samples/SignalRSamples/Program.cs | 11 +- src/SignalR/samples/SignalRSamples/Startup.cs | 67 ++- .../samples/SignalRSamples/wwwroot/hubs.html | 12 +- .../server/Core/src/ClientProxyExtensions.cs | 209 ++++++++ .../Core/src/DefaultHubLifetimeManager.cs | 70 +++ .../server/Core/src/HubConnectionHandler.cs | 5 +- .../server/Core/src/HubLifetimeManager.cs | 42 ++ .../server/Core/src/IHubCallerClients.cs | 13 +- src/SignalR/server/Core/src/IHubClients.cs | 11 +- src/SignalR/server/Core/src/IHubClients`T.cs | 11 +- .../server/Core/src/ISingleClientProxy.cs | 21 + .../Core/src/Internal/DefaultHubDispatcher.cs | 24 +- .../Core/src/Internal/HubCallerClients.cs | 33 +- .../server/Core/src/Internal/HubClients.cs | 7 +- .../Core/src/Internal/HubConnectionBinder.cs | 10 +- .../server/Core/src/Internal/Proxies.cs | 22 + .../Core/src/Internal/TypedClientBuilder.cs | 48 +- .../Core/src/Internal/TypedHubClients.cs | 4 +- .../Microsoft.AspNetCore.SignalR.Core.csproj | 1 + .../server/Core/src/PublicAPI.Unshipped.txt | 22 + .../HubConnectionHandlerTestUtils/Hubs.cs | 6 + .../HubConnectionHandlerTests.ClientResult.cs | 106 ++++ .../SignalR/test/HubConnectionHandlerTests.cs | 4 +- .../test/Internal/TypedClientBuilderTests.cs | 118 ++++- .../src/HubLifetimeManagerTestBase.cs | 193 ++++++++ .../src/ScaleoutHubLifetimeManagerTests.cs | 111 +++++ .../src/Internal/RedisChannels.cs | 11 + .../src/Internal/RedisInvocation.cs | 9 +- .../src/Internal/RedisProtocol.cs | 76 ++- .../src/Internal/RedisReturnResult.cs | 38 ++ ...pNetCore.SignalR.StackExchangeRedis.csproj | 1 + .../src/PublicAPI.Unshipped.txt | 3 + .../src/RedisHubLifetimeManager.cs | 140 +++++- .../test/RedisProtocolTests.cs | 76 ++- 67 files changed, 3622 insertions(+), 156 deletions(-) create mode 100644 src/SignalR/clients/csharp/Client.Core/src/HubConnectionExtensions.OnResult.cs create mode 100644 src/SignalR/clients/java/signalr/core/src/main/java/com/microsoft/signalr/ClientResultMessage.java create mode 100644 src/SignalR/clients/java/signalr/core/src/main/java/com/microsoft/signalr/Function.java create mode 100644 src/SignalR/common/Shared/ClientResultsManager.cs create mode 100644 src/SignalR/common/SignalR.Common/src/Protocol/RawResult.cs create mode 100644 src/SignalR/samples/SignalRSamples/Hubs/GameHub.cs create mode 100644 src/SignalR/server/Core/src/ISingleClientProxy.cs create mode 100644 src/SignalR/server/SignalR/test/HubConnectionHandlerTests.ClientResult.cs create mode 100644 src/SignalR/server/StackExchangeRedis/src/Internal/RedisReturnResult.cs diff --git a/src/SignalR/clients/csharp/Client.Core/src/HubConnection.cs b/src/SignalR/clients/csharp/Client.Core/src/HubConnection.cs index 46101e410213..42d4707e434c 100644 --- a/src/SignalR/clients/csharp/Client.Core/src/HubConnection.cs +++ b/src/SignalR/clients/csharp/Client.Core/src/HubConnection.cs @@ -312,6 +312,35 @@ public virtual async ValueTask DisposeAsync() } } + /// + /// + /// + /// + /// + /// + /// + /// + public virtual IDisposable On(string methodName, Type[] parameterTypes, Func> handler, object state) + { + Log.RegisteringHandler(_logger, methodName); + + CheckDisposed(); + + // It's OK to be disposed while registering a callback, we'll just never call the callback anyway (as with all the callbacks registered before disposal). + var invocationHandler = new InvocationHandler(parameterTypes, handler, state); + var invocationList = _handlers.AddOrUpdate(methodName, _ => new InvocationHandlerList(invocationHandler), + (_, invocations) => + { + lock (invocations) + { + invocations.Add(invocationHandler); + } + return invocations; + }); + + return new Subscription(invocationHandler, invocationList); + } + // If the registered callback blocks it can cause the client to stop receiving messages. If you need to block, get off the current thread first. /// /// Registers a handler that will be invoked when the hub method with the specified method name is invoked. @@ -988,26 +1017,64 @@ private async Task SendWithLock(ConnectionState expectedConnectionState, HubMess return null; } - private async Task DispatchInvocationAsync(InvocationMessage invocation) + private async Task DispatchInvocationAsync(InvocationMessage invocation, ConnectionState connectionState) { + var expectsResult = !string.IsNullOrEmpty(invocation.InvocationId); // Find the handler if (!_handlers.TryGetValue(invocation.Target, out var invocationHandlerList)) { Log.MissingHandler(_logger, invocation.Target); + if (expectsResult) + { + await SendWithLock(connectionState, CompletionMessage.WithError(invocation.InvocationId!, "Client didn't provide a result."), cancellationToken: default).ConfigureAwait(false); + } return; } // Grabbing the current handlers var copiedHandlers = invocationHandlerList.GetHandlers(); + object? result = null; + Exception? resultException = null; + var hasResult = false; foreach (var handler in copiedHandlers) { try { - await handler.InvokeAsync(invocation.Arguments).ConfigureAwait(false); + var task = handler.InvokeAsync(invocation.Arguments); + if (handler.HasResult && task is Task resultTask) + { + hasResult = true; + result = await resultTask.ConfigureAwait(false); + // ignore previous results' exception, we prefer last .On handler for results + resultException = null; + } + else + { + await task.ConfigureAwait(false); + } } catch (Exception ex) { Log.ErrorInvokingClientSideMethod(_logger, invocation.Target, ex); + if (handler.HasResult) + { + resultException = ex; + } + } + } + if (expectsResult) + { + if (resultException is not null) + { + await SendWithLock(connectionState, CompletionMessage.WithError(invocation.InvocationId!, resultException.Message), cancellationToken: default).ConfigureAwait(false); + } + else if (!hasResult) + { + await SendWithLock(connectionState, CompletionMessage.WithError(invocation.InvocationId!, "Client didn't provide a result."), cancellationToken: default).ConfigureAwait(false); + } + else + { + await SendWithLock(connectionState, CompletionMessage.WithResult(invocation.InvocationId!, result), cancellationToken: default).ConfigureAwait(false); } } } @@ -1178,7 +1245,7 @@ async Task StartProcessingInvocationMessages(ChannelReader in { while (invocationMessageChannelReader.TryRead(out var invocationMessage)) { - await DispatchInvocationAsync(invocationMessage).ConfigureAwait(false); + await DispatchInvocationAsync(invocationMessage, connectionState).ConfigureAwait(false); } } } @@ -1663,6 +1730,7 @@ internal void Remove(InvocationHandler handler) private readonly struct InvocationHandler { public Type[] ParameterTypes { get; } + public bool HasResult => _callback.Method.ReturnType == typeof(Task); private readonly Func _callback; private readonly object _state; @@ -1671,6 +1739,7 @@ public InvocationHandler(Type[] parameterTypes, Func ca _callback = callback; ParameterTypes = parameterTypes; _state = state; + } public Task InvokeAsync(object?[] parameters) diff --git a/src/SignalR/clients/csharp/Client.Core/src/HubConnectionExtensions.OnResult.cs b/src/SignalR/clients/csharp/Client.Core/src/HubConnectionExtensions.OnResult.cs new file mode 100644 index 000000000000..098fcd1bcbe3 --- /dev/null +++ b/src/SignalR/clients/csharp/Client.Core/src/HubConnectionExtensions.OnResult.cs @@ -0,0 +1,455 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. + +using System; +using System.Collections.Generic; +using System.Linq; +using System.Text; +using System.Threading.Tasks; + +namespace Microsoft.AspNetCore.SignalR.Client; +public static partial class HubConnectionExtensions +{ + private static IDisposable On(this HubConnection hubConnection, string methodName, Type[] parameterTypes, Func handler) + { + return hubConnection.On(methodName, parameterTypes, static (parameters, state) => + { + var currentHandler = (Func)state; + return Task.FromResult(currentHandler(parameters)); + }, handler); + } + + /// + /// + /// + /// + /// + /// + /// + /// + /// + public static IDisposable On(this HubConnection hubConnection, string methodName, Func> handler) + { + if (hubConnection == null) + { + throw new ArgumentNullException(nameof(hubConnection)); + } + + return hubConnection.On(methodName, Type.EmptyTypes, args => handler()); + } + + /// + /// + /// + /// + /// + /// + /// + /// + /// + public static IDisposable On(this HubConnection hubConnection, string methodName, Func handler) + { + if (hubConnection == null) + { + throw new ArgumentNullException(nameof(hubConnection)); + } + + return hubConnection.On(methodName, Type.EmptyTypes, args => handler()); + } + + /// + /// + /// + /// + /// + /// + /// + /// + /// + /// + public static IDisposable On(this HubConnection hubConnection, string methodName, Func handler) + { + if (hubConnection == null) + { + throw new ArgumentNullException(nameof(hubConnection)); + } + + return hubConnection.On(methodName, + new[] { typeof(T1) }, + args => handler((T1)args[0]!)); + } + + /// + /// + /// + /// + /// /// + /// + /// + /// + /// + /// + /// + public static IDisposable On(this HubConnection hubConnection, string methodName, Func handler) + { + if (hubConnection == null) + { + throw new ArgumentNullException(nameof(hubConnection)); + } + + return hubConnection.On(methodName, + new[] { typeof(T1), typeof(T2) }, + args => handler((T1)args[0]!, (T2)args[1]!)); + } + + /// + /// Registers a handler that will be invoked when the hub method with the specified method name is invoked. + /// + /// The first argument type. + /// The second argument type. + /// The third argument type. + /// + /// The hub connection. + /// The name of the hub method to define. + /// The handler that will be raised when the hub method is invoked. + /// A subscription that can be disposed to unsubscribe from the hub method. + public static IDisposable On(this HubConnection hubConnection, string methodName, Func handler) + { + if (hubConnection == null) + { + throw new ArgumentNullException(nameof(hubConnection)); + } + + return hubConnection.On(methodName, + new[] { typeof(T1), typeof(T2), typeof(T3) }, + args => handler((T1)args[0]!, (T2)args[1]!, (T3)args[2]!)); + } + + /// + /// Registers a handler that will be invoked when the hub method with the specified method name is invoked. + /// + /// The first argument type. + /// The second argument type. + /// The third argument type. + /// The fourth argument type. + /// + /// The hub connection. + /// The name of the hub method to define. + /// The handler that will be raised when the hub method is invoked. + /// A subscription that can be disposed to unsubscribe from the hub method. + public static IDisposable On(this HubConnection hubConnection, string methodName, Func handler) + { + if (hubConnection == null) + { + throw new ArgumentNullException(nameof(hubConnection)); + } + + return hubConnection.On(methodName, + new[] { typeof(T1), typeof(T2), typeof(T3), typeof(T4) }, + args => handler((T1)args[0]!, (T2)args[1]!, (T3)args[2]!, (T4)args[3]!)); + } + + /// + /// Registers a handler that will be invoked when the hub method with the specified method name is invoked. + /// + /// The first argument type. + /// The second argument type. + /// The third argument type. + /// The fourth argument type. + /// The fifth argument type. + /// + /// The hub connection. + /// The name of the hub method to define. + /// The handler that will be raised when the hub method is invoked. + /// A subscription that can be disposed to unsubscribe from the hub method. + public static IDisposable On(this HubConnection hubConnection, string methodName, Func handler) + { + if (hubConnection == null) + { + throw new ArgumentNullException(nameof(hubConnection)); + } + + return hubConnection.On(methodName, + new[] { typeof(T1), typeof(T2), typeof(T3), typeof(T4), typeof(T5) }, + args => handler((T1)args[0]!, (T2)args[1]!, (T3)args[2]!, (T4)args[3]!, (T5)args[4]!)); + } + + /// + /// Registers a handler that will be invoked when the hub method with the specified method name is invoked. + /// + /// The first argument type. + /// The second argument type. + /// The third argument type. + /// The fourth argument type. + /// The fifth argument type. + /// The sixth argument type. + /// + /// The hub connection. + /// The name of the hub method to define. + /// The handler that will be raised when the hub method is invoked. + /// A subscription that can be disposed to unsubscribe from the hub method. + public static IDisposable On(this HubConnection hubConnection, string methodName, Func handler) + { + if (hubConnection == null) + { + throw new ArgumentNullException(nameof(hubConnection)); + } + + return hubConnection.On(methodName, + new[] { typeof(T1), typeof(T2), typeof(T3), typeof(T4), typeof(T5), typeof(T6) }, + args => handler((T1)args[0]!, (T2)args[1]!, (T3)args[2]!, (T4)args[3]!, (T5)args[4]!, (T6)args[5]!)); + } + + /// + /// Registers a handler that will be invoked when the hub method with the specified method name is invoked. + /// + /// The first argument type. + /// The second argument type. + /// The third argument type. + /// The fourth argument type. + /// The fifth argument type. + /// The sixth argument type. + /// The seventh argument type. + /// + /// The hub connection. + /// The name of the hub method to define. + /// The handler that will be raised when the hub method is invoked. + /// A subscription that can be disposed to unsubscribe from the hub method. + public static IDisposable On(this HubConnection hubConnection, string methodName, Func handler) + { + if (hubConnection == null) + { + throw new ArgumentNullException(nameof(hubConnection)); + } + + return hubConnection.On(methodName, + new[] { typeof(T1), typeof(T2), typeof(T3), typeof(T4), typeof(T5), typeof(T6), typeof(T7) }, + args => handler((T1)args[0]!, (T2)args[1]!, (T3)args[2]!, (T4)args[3]!, (T5)args[4]!, (T6)args[5]!, (T7)args[6]!)); + } + + /// + /// Registers a handler that will be invoked when the hub method with the specified method name is invoked. + /// + /// The first argument type. + /// The second argument type. + /// The third argument type. + /// The fourth argument type. + /// The fifth argument type. + /// The sixth argument type. + /// The seventh argument type. + /// The eighth argument type. + /// + /// The hub connection. + /// The name of the hub method to define. + /// The handler that will be raised when the hub method is invoked. + /// A subscription that can be disposed to unsubscribe from the hub method. + public static IDisposable On(this HubConnection hubConnection, string methodName, Func handler) + { + if (hubConnection == null) + { + throw new ArgumentNullException(nameof(hubConnection)); + } + + return hubConnection.On(methodName, + new[] { typeof(T1), typeof(T2), typeof(T3), typeof(T4), typeof(T5), typeof(T6), typeof(T7), typeof(T8) }, + args => handler((T1)args[0]!, (T2)args[1]!, (T3)args[2]!, (T4)args[3]!, (T5)args[4]!, (T6)args[5]!, (T7)args[6]!, (T8)args[7]!)); + } + + /// + /// + /// + /// + /// + /// + /// + /// + /// + /// + public static IDisposable On(this HubConnection hubConnection, string methodName, Func> handler) + { + if (hubConnection == null) + { + throw new ArgumentNullException(nameof(hubConnection)); + } + + return hubConnection.On(methodName, + new[] { typeof(T1) }, + args => handler((T1)args[0]!)); + } + + /// + /// + /// + /// + /// + /// + /// + /// + /// + /// + /// + public static IDisposable On(this HubConnection hubConnection, string methodName, Func> handler) + { + if (hubConnection == null) + { + throw new ArgumentNullException(nameof(hubConnection)); + } + + return hubConnection.On(methodName, + new[] { typeof(T1), typeof(T2) }, + args => handler((T1)args[0]!, (T2)args[1]!)); + } + + /// + /// Registers a handler that will be invoked when the hub method with the specified method name is invoked. + /// + /// The first argument type. + /// The second argument type. + /// The third argument type. + /// + /// The hub connection. + /// The name of the hub method to define. + /// The handler that will be raised when the hub method is invoked. + /// A subscription that can be disposed to unsubscribe from the hub method. + public static IDisposable On(this HubConnection hubConnection, string methodName, Func> handler) + { + if (hubConnection == null) + { + throw new ArgumentNullException(nameof(hubConnection)); + } + + return hubConnection.On(methodName, + new[] { typeof(T1), typeof(T2), typeof(T3) }, + args => handler((T1)args[0]!, (T2)args[1]!, (T3)args[2]!)); + } + + /// + /// Registers a handler that will be invoked when the hub method with the specified method name is invoked. + /// + /// The first argument type. + /// The second argument type. + /// The third argument type. + /// The fourth argument type. + /// + /// The hub connection. + /// The name of the hub method to define. + /// The handler that will be raised when the hub method is invoked. + /// A subscription that can be disposed to unsubscribe from the hub method. + public static IDisposable On(this HubConnection hubConnection, string methodName, Func> handler) + { + if (hubConnection == null) + { + throw new ArgumentNullException(nameof(hubConnection)); + } + + return hubConnection.On(methodName, + new[] { typeof(T1), typeof(T2), typeof(T3), typeof(T4) }, + args => handler((T1)args[0]!, (T2)args[1]!, (T3)args[2]!, (T4)args[3]!)); + } + + /// + /// Registers a handler that will be invoked when the hub method with the specified method name is invoked. + /// + /// The first argument type. + /// The second argument type. + /// The third argument type. + /// The fourth argument type. + /// The fifth argument type. + /// + /// The hub connection. + /// The name of the hub method to define. + /// The handler that will be raised when the hub method is invoked. + /// A subscription that can be disposed to unsubscribe from the hub method. + public static IDisposable On(this HubConnection hubConnection, string methodName, Func> handler) + { + if (hubConnection == null) + { + throw new ArgumentNullException(nameof(hubConnection)); + } + + return hubConnection.On(methodName, + new[] { typeof(T1), typeof(T2), typeof(T3), typeof(T4), typeof(T5) }, + args => handler((T1)args[0]!, (T2)args[1]!, (T3)args[2]!, (T4)args[3]!, (T5)args[4]!)); + } + + /// + /// Registers a handler that will be invoked when the hub method with the specified method name is invoked. + /// + /// The first argument type. + /// The second argument type. + /// The third argument type. + /// The fourth argument type. + /// The fifth argument type. + /// The sixth argument type. + /// + /// The hub connection. + /// The name of the hub method to define. + /// The handler that will be raised when the hub method is invoked. + /// A subscription that can be disposed to unsubscribe from the hub method. + public static IDisposable On(this HubConnection hubConnection, string methodName, Func> handler) + { + if (hubConnection == null) + { + throw new ArgumentNullException(nameof(hubConnection)); + } + + return hubConnection.On(methodName, + new[] { typeof(T1), typeof(T2), typeof(T3), typeof(T4), typeof(T5), typeof(T6) }, + args => handler((T1)args[0]!, (T2)args[1]!, (T3)args[2]!, (T4)args[3]!, (T5)args[4]!, (T6)args[5]!)); + } + + /// + /// Registers a handler that will be invoked when the hub method with the specified method name is invoked. + /// + /// The first argument type. + /// The second argument type. + /// The third argument type. + /// The fourth argument type. + /// The fifth argument type. + /// The sixth argument type. + /// The seventh argument type. + /// + /// The hub connection. + /// The name of the hub method to define. + /// The handler that will be raised when the hub method is invoked. + /// A subscription that can be disposed to unsubscribe from the hub method. + public static IDisposable On(this HubConnection hubConnection, string methodName, Func> handler) + { + if (hubConnection == null) + { + throw new ArgumentNullException(nameof(hubConnection)); + } + + return hubConnection.On(methodName, + new[] { typeof(T1), typeof(T2), typeof(T3), typeof(T4), typeof(T5), typeof(T6), typeof(T7) }, + args => handler((T1)args[0]!, (T2)args[1]!, (T3)args[2]!, (T4)args[3]!, (T5)args[4]!, (T6)args[5]!, (T7)args[6]!)); + } + + /// + /// Registers a handler that will be invoked when the hub method with the specified method name is invoked. + /// + /// The first argument type. + /// The second argument type. + /// The third argument type. + /// The fourth argument type. + /// The fifth argument type. + /// The sixth argument type. + /// The seventh argument type. + /// The eighth argument type. + /// + /// The hub connection. + /// The name of the hub method to define. + /// The handler that will be raised when the hub method is invoked. + /// A subscription that can be disposed to unsubscribe from the hub method. + public static IDisposable On(this HubConnection hubConnection, string methodName, Func> handler) + { + if (hubConnection == null) + { + throw new ArgumentNullException(nameof(hubConnection)); + } + + return hubConnection.On(methodName, + new[] { typeof(T1), typeof(T2), typeof(T3), typeof(T4), typeof(T5), typeof(T6), typeof(T7), typeof(T8) }, + args => handler((T1)args[0]!, (T2)args[1]!, (T3)args[2]!, (T4)args[3]!, (T5)args[4]!, (T6)args[5]!, (T7)args[6]!, (T8)args[7]!)); + } +} diff --git a/src/SignalR/clients/csharp/Client.Core/src/HubConnectionExtensions.cs b/src/SignalR/clients/csharp/Client.Core/src/HubConnectionExtensions.cs index e8afe03949ef..1da69401294b 100644 --- a/src/SignalR/clients/csharp/Client.Core/src/HubConnectionExtensions.cs +++ b/src/SignalR/clients/csharp/Client.Core/src/HubConnectionExtensions.cs @@ -13,7 +13,7 @@ public static partial class HubConnectionExtensions { private static IDisposable On(this HubConnection hubConnection, string methodName, Type[] parameterTypes, Action handler) { - return hubConnection.On(methodName, parameterTypes, (parameters, state) => + return hubConnection.On(methodName, parameterTypes, static (parameters, state) => { var currentHandler = (Action)state; currentHandler(parameters); @@ -243,6 +243,24 @@ public static IDisposable On(this HubConnection hubConnection, string methodName }, handler); } + /// + /// + /// + /// + /// + /// + /// + /// + /// + public static IDisposable On(this HubConnection hubConnection, string methodName, Type[] parameterTypes, Func> handler) + { + return hubConnection.On(methodName, parameterTypes, async (parameters, state) => + { + var currentHandler = (Func>)state; + return await currentHandler(parameters).ConfigureAwait(false); + }, handler); + } + /// /// Registers a handler that will be invoked when the hub method with the specified method name is invoked. /// diff --git a/src/SignalR/clients/csharp/Client.Core/src/PublicAPI.Unshipped.txt b/src/SignalR/clients/csharp/Client.Core/src/PublicAPI.Unshipped.txt index 7dc5c58110bf..40077fde5c02 100644 --- a/src/SignalR/clients/csharp/Client.Core/src/PublicAPI.Unshipped.txt +++ b/src/SignalR/clients/csharp/Client.Core/src/PublicAPI.Unshipped.txt @@ -1 +1,21 @@ #nullable enable +static Microsoft.AspNetCore.SignalR.Client.HubConnectionExtensions.On(this Microsoft.AspNetCore.SignalR.Client.HubConnection! hubConnection, string! methodName, System.Func!>! handler) -> System.IDisposable! +static Microsoft.AspNetCore.SignalR.Client.HubConnectionExtensions.On(this Microsoft.AspNetCore.SignalR.Client.HubConnection! hubConnection, string! methodName, System.Func! handler) -> System.IDisposable! +static Microsoft.AspNetCore.SignalR.Client.HubConnectionExtensions.On(this Microsoft.AspNetCore.SignalR.Client.HubConnection! hubConnection, string! methodName, System.Func!>! handler) -> System.IDisposable! +static Microsoft.AspNetCore.SignalR.Client.HubConnectionExtensions.On(this Microsoft.AspNetCore.SignalR.Client.HubConnection! hubConnection, string! methodName, System.Func! handler) -> System.IDisposable! +static Microsoft.AspNetCore.SignalR.Client.HubConnectionExtensions.On(this Microsoft.AspNetCore.SignalR.Client.HubConnection! hubConnection, string! methodName, System.Func!>! handler) -> System.IDisposable! +static Microsoft.AspNetCore.SignalR.Client.HubConnectionExtensions.On(this Microsoft.AspNetCore.SignalR.Client.HubConnection! hubConnection, string! methodName, System.Func! handler) -> System.IDisposable! +static Microsoft.AspNetCore.SignalR.Client.HubConnectionExtensions.On(this Microsoft.AspNetCore.SignalR.Client.HubConnection! hubConnection, string! methodName, System.Func!>! handler) -> System.IDisposable! +static Microsoft.AspNetCore.SignalR.Client.HubConnectionExtensions.On(this Microsoft.AspNetCore.SignalR.Client.HubConnection! hubConnection, string! methodName, System.Func! handler) -> System.IDisposable! +static Microsoft.AspNetCore.SignalR.Client.HubConnectionExtensions.On(this Microsoft.AspNetCore.SignalR.Client.HubConnection! hubConnection, string! methodName, System.Func!>! handler) -> System.IDisposable! +static Microsoft.AspNetCore.SignalR.Client.HubConnectionExtensions.On(this Microsoft.AspNetCore.SignalR.Client.HubConnection! hubConnection, string! methodName, System.Func! handler) -> System.IDisposable! +static Microsoft.AspNetCore.SignalR.Client.HubConnectionExtensions.On(this Microsoft.AspNetCore.SignalR.Client.HubConnection! hubConnection, string! methodName, System.Func!>! handler) -> System.IDisposable! +static Microsoft.AspNetCore.SignalR.Client.HubConnectionExtensions.On(this Microsoft.AspNetCore.SignalR.Client.HubConnection! hubConnection, string! methodName, System.Func! handler) -> System.IDisposable! +static Microsoft.AspNetCore.SignalR.Client.HubConnectionExtensions.On(this Microsoft.AspNetCore.SignalR.Client.HubConnection! hubConnection, string! methodName, System.Func!>! handler) -> System.IDisposable! +static Microsoft.AspNetCore.SignalR.Client.HubConnectionExtensions.On(this Microsoft.AspNetCore.SignalR.Client.HubConnection! hubConnection, string! methodName, System.Func! handler) -> System.IDisposable! +static Microsoft.AspNetCore.SignalR.Client.HubConnectionExtensions.On(this Microsoft.AspNetCore.SignalR.Client.HubConnection! hubConnection, string! methodName, System.Func!>! handler) -> System.IDisposable! +static Microsoft.AspNetCore.SignalR.Client.HubConnectionExtensions.On(this Microsoft.AspNetCore.SignalR.Client.HubConnection! hubConnection, string! methodName, System.Func! handler) -> System.IDisposable! +static Microsoft.AspNetCore.SignalR.Client.HubConnectionExtensions.On(this Microsoft.AspNetCore.SignalR.Client.HubConnection! hubConnection, string! methodName, System.Func!>! handler) -> System.IDisposable! +static Microsoft.AspNetCore.SignalR.Client.HubConnectionExtensions.On(this Microsoft.AspNetCore.SignalR.Client.HubConnection! hubConnection, string! methodName, System.Func! handler) -> System.IDisposable! +static Microsoft.AspNetCore.SignalR.Client.HubConnectionExtensions.On(this Microsoft.AspNetCore.SignalR.Client.HubConnection! hubConnection, string! methodName, System.Type![]! parameterTypes, System.Func!>! handler) -> System.IDisposable! +virtual Microsoft.AspNetCore.SignalR.Client.HubConnection.On(string! methodName, System.Type![]! parameterTypes, System.Func!>! handler, object! state) -> System.IDisposable! diff --git a/src/SignalR/clients/csharp/Client/test/UnitTests/HubConnectionTests.Extensions.cs b/src/SignalR/clients/csharp/Client/test/UnitTests/HubConnectionTests.Extensions.cs index 7f60b7791161..1ae2bf44df54 100644 --- a/src/SignalR/clients/csharp/Client/test/UnitTests/HubConnectionTests.Extensions.cs +++ b/src/SignalR/clients/csharp/Client/test/UnitTests/HubConnectionTests.Extensions.cs @@ -4,6 +4,7 @@ using System; using System.Threading.Tasks; using Microsoft.AspNetCore.Testing; +using Newtonsoft.Json.Linq; using Xunit; namespace Microsoft.AspNetCore.SignalR.Client.Tests; @@ -396,5 +397,308 @@ await connection.ReceiveJsonMessage( await hubConnection.DisposeAsync().DefaultTimeout(); } } + + [Fact] + public async Task OnWithResult() + { + var returnValue = 46; + var result = await InvokeOnWithResult( + (hubConnection, tcs) => hubConnection.On("Foo", + () => + { + tcs.SetResult(new object[0]); + return returnValue; + }), + new object[0]); + Assert.Equal(returnValue, result); + } + + [Fact] + public async Task OnAsyncWithResult() + { + var returnValue = 1220; + var result = await InvokeOnWithResult( + (hubConnection, tcs) => hubConnection.On("Foo", + async () => + { + tcs.SetResult(new object[0]); + await Task.CompletedTask; + return returnValue; + }), + new object[0]); + Assert.Equal(returnValue, result); + } + + [Fact] + public async Task OnT1WithResult() + { + var returnValue = "buffalo"; + var result = await InvokeOnWithResult( + (hubConnection, tcs) => hubConnection.On("Foo", + r => + { + tcs.SetResult(new object[] { r }); + return returnValue; + }), + new object[] { 42 }); + Assert.Equal(returnValue, result); + } + + [Fact] + public async Task OnT1AsyncWithResult() + { + var returnValue = 2; + var result = await InvokeOnWithResult( + (hubConnection, tcs) => hubConnection.On("Foo", + async r => + { + tcs.SetResult(new object[] { r }); + await Task.CompletedTask; + return returnValue; + }), + new object[] { 42 }); + + Assert.Equal(returnValue, result); + } + + [Fact] + public async Task OnT2WithResult() + { + var returnValue = "ret"; + var result = await InvokeOnWithResult( + (hubConnection, tcs) => hubConnection.On("Foo", + (r1, r2) => + { + tcs.SetResult(new object[] { r1, r2 }); + return returnValue; + }), + new object[] { 42, "abc" }); + Assert.Equal(returnValue, result); + } + + [Fact] + public async Task OnT2AsyncWithResult() + { + var returnResult = 928; + var result = await InvokeOnWithResult( + (hubConnection, tcs) => hubConnection.On("Foo", + (r1, r2) => + { + tcs.SetResult(new object[] { r1, r2 }); + return Task.FromResult(returnResult); + }), + new object[] { 42, "abc" }); + Assert.Equal(returnResult, result); + } + + [Fact] + public async Task OnT3WithResult() + { + var returnValue = "bob"; + var result = await InvokeOnWithResult( + (hubConnection, tcs) => hubConnection.On("Foo", + (r1, r2, r3) => + { + tcs.SetResult(new object[] { r1, r2, r3 }); + return returnValue; + }), + new object[] { 42, "abc", 24.0f }); + Assert.Equal(returnValue, result); + } + + [Fact] + public async Task OnT3AsyncWithResult() + { + var returnResult = "random"; + var result = await InvokeOnWithResult( + (hubConnection, tcs) => hubConnection.On("Foo", + (r1, r2, r3) => + { + tcs.SetResult(new object[] { r1, r2, r3 }); + return Task.FromResult(returnResult); + }), + new object[] { 42, "abc", 24.0f }); + Assert.Equal(returnResult, result); + } + + [Fact] + public async Task OnT4WithResult() + { + var returnResult = 233; + var result = await InvokeOnWithResult( + (hubConnection, tcs) => hubConnection.On("Foo", + (r1, r2, r3, r4) => + { + tcs.SetResult(new object[] { r1, r2, r3, r4 }); + return returnResult; + }), + new object[] { 42, "abc", 24.0f, 10d }); + Assert.Equal(returnResult, result); + } + + [Fact] + public async Task OnT4AsyncWithResult() + { + var returnResult = "alphabet"; + var result = await InvokeOnWithResult( + (hubConnection, tcs) => hubConnection.On("Foo", + (r1, r2, r3, r4) => + { + tcs.SetResult(new object[] { r1, r2, r3, r4 }); + return returnResult; + }), + new object[] { 42, "abc", 24.0f, 10d }); + Assert.Equal(returnResult, result); + } + + [Fact] + public async Task OnT5WithResult() + { + var returnResult = 3004; + var result = await InvokeOnWithResult( + (hubConnection, tcs) => hubConnection.On("Foo", + (r1, r2, r3, r4, r5) => + { + tcs.SetResult(new object[] { r1, r2, r3, r4, r5 }); + return returnResult; + }), + new object[] { 42, "abc", 24.0f, 10d, "123" }); + Assert.Equal(returnResult, result); + } + + [Fact] + public async Task OnT5AsyncWithResult() + { + var returnResult = "alphabet"; + var result = await InvokeOnWithResult( + (hubConnection, tcs) => hubConnection.On("Foo", + (r1, r2, r3, r4, r5) => + { + tcs.SetResult(new object[] { r1, r2, r3, r4, r5 }); + return returnResult; + }), + new object[] { 42, "abc", 24.0f, 10d, "123" }); + Assert.Equal(returnResult, result); + } + + [Fact] + public async Task OnT6WithResult() + { + var returnResult = "alphabet"; + var result = await InvokeOnWithResult( + (hubConnection, tcs) => hubConnection.On("Foo", + (r1, r2, r3, r4, r5, r6) => + { + tcs.SetResult(new object[] { r1, r2, r3, r4, r5, r6 }); + return returnResult; + }), + new object[] { 42, "abc", 24.0f, 10d, "123", 24 }); + Assert.Equal(returnResult, result); + } + + [Fact] + public async Task OnT6AsyncWithResult() + { + var returnResult = "alphabet"; + var result = await InvokeOnWithResult( + (hubConnection, tcs) => hubConnection.On("Foo", + (r1, r2, r3, r4, r5, r6) => + { + tcs.SetResult(new object[] { r1, r2, r3, r4, r5, r6 }); + return returnResult; + }), + new object[] { 42, "abc", 24.0f, 10d, "123", 24 }); + Assert.Equal(returnResult, result); + } + + [Fact] + public async Task OnT7WithResult() + { + var returnResult = 100; + var result = await InvokeOnWithResult( + (hubConnection, tcs) => hubConnection.On("Foo", + (r1, r2, r3, r4, r5, r6, r7) => + { + tcs.SetResult(new object[] { r1, r2, r3, r4, r5, r6, r7 }); + return returnResult; + }), + new object[] { 42, "abc", 24.0f, 10d, "123", 24, 'c' }); + Assert.Equal(returnResult, result); + } + + [Fact] + public async Task OnT7AsyncWithResult() + { + var returnResult = "alphabet"; + var result = await InvokeOnWithResult( + (hubConnection, tcs) => hubConnection.On("Foo", + (r1, r2, r3, r4, r5, r6, r7) => + { + tcs.SetResult(new object[] { r1, r2, r3, r4, r5, r6, r7 }); + return returnResult; + }), + new object[] { 42, "abc", 24.0f, 10d, "123", 24, 'c' }); + Assert.Equal(returnResult, result); + } + + [Fact] + public async Task OnT8WithResult() + { + var returnResult = 102; + var result = await InvokeOnWithResult( + (hubConnection, tcs) => hubConnection.On("Foo", + (r1, r2, r3, r4, r5, r6, r7, r8) => + { + tcs.SetResult(new object[] { r1, r2, r3, r4, r5, r6, r7, r8 }); + return returnResult; + }), + new object[] { 42, "abc", 24.0f, 10d, "123", 24, 'c', "XYZ" }); + Assert.Equal(returnResult, result); + } + + [Fact] + public async Task OnT8AsyncWithResult() + { + var returnResult = "alphabet"; + var result = await InvokeOnWithResult( + (hubConnection, tcs) => hubConnection.On("Foo", + (r1, r2, r3, r4, r5, r6, r7, r8) => + { + tcs.SetResult(new object[] { r1, r2, r3, r4, r5, r6, r7, r8 }); + return returnResult; + }), + new object[] { 42, "abc", 24.0f, 10d, "123", 24, 'c', "XYZ" }); + Assert.Equal(returnResult, result); + } + + private async Task InvokeOnWithResult(Action> onAction, object[] args) + { + var connection = new TestConnection(); + var hubConnection = CreateHubConnection(connection); + var handlerTcs = new TaskCompletionSource(); + try + { + onAction(hubConnection, handlerTcs); + await hubConnection.StartAsync(); + + await connection.ReceiveJsonMessage( + new + { + invocationId = "1", + type = 1, + target = "Foo", + arguments = args + }).DefaultTimeout(); + + await handlerTcs.Task.DefaultTimeout(); + var json = await connection.ReadSentJsonAsync(); + var result = json["result"]; + return result; + } + finally + { + await hubConnection.DisposeAsync().DefaultTimeout(); + } + } } } diff --git a/src/SignalR/clients/csharp/Client/test/UnitTests/HubConnectionTests.Protocol.cs b/src/SignalR/clients/csharp/Client/test/UnitTests/HubConnectionTests.Protocol.cs index 98078fb198ef..eaed0dab2778 100644 --- a/src/SignalR/clients/csharp/Client/test/UnitTests/HubConnectionTests.Protocol.cs +++ b/src/SignalR/clients/csharp/Client/test/UnitTests/HubConnectionTests.Protocol.cs @@ -676,5 +676,186 @@ public async Task ClientWithInherentKeepAliveDoesNotPing() await connection.DisposeAsync().DefaultTimeout(); } } + + [Fact] + public async Task ClientCanReturnResult() + { + var connection = new TestConnection(); + var hubConnection = CreateHubConnection(connection); + try + { + await hubConnection.StartAsync().DefaultTimeout(); + + hubConnection.On("Result", () => 10); + + await connection.ReceiveTextAsync("{\"type\":1,\"invocationId\":\"1\",\"target\":\"Result\",\"arguments\":[]}\u001e").DefaultTimeout(); + + var invokeMessage = await connection.ReadSentTextMessageAsync().DefaultTimeout(); + + Assert.Equal("{\"type\":3,\"invocationId\":\"1\",\"result\":10}", invokeMessage); + } + finally + { + await hubConnection.DisposeAsync().DefaultTimeout(); + await connection.DisposeAsync().DefaultTimeout(); + } + } + + [Fact] + public async Task ClientReturnResultUsesLastResult() + { + var connection = new TestConnection(); + var hubConnection = CreateHubConnection(connection); + try + { + await hubConnection.StartAsync().DefaultTimeout(); + + hubConnection.On("Result", () => 10); + hubConnection.On("Result", () => 11); + hubConnection.On("Result", () => 14); + hubConnection.On("Result", () => 3); + + await connection.ReceiveTextAsync("{\"type\":1,\"invocationId\":\"1\",\"target\":\"Result\",\"arguments\":[]}\u001e").DefaultTimeout(); + + var invokeMessage = await connection.ReadSentTextMessageAsync().DefaultTimeout(); + + Assert.Equal("{\"type\":3,\"invocationId\":\"1\",\"result\":3}", invokeMessage); + } + finally + { + await hubConnection.DisposeAsync().DefaultTimeout(); + await connection.DisposeAsync().DefaultTimeout(); + } + } + + [Fact] + public async Task ClientReturnHandlerCanMixWithNonReturnHandler() + { + var connection = new TestConnection(); + var hubConnection = CreateHubConnection(connection); + try + { + await hubConnection.StartAsync().DefaultTimeout(); + + var tcs = new TaskCompletionSource(TaskCreationOptions.RunContinuationsAsynchronously); + hubConnection.On("Result", () => 40); + hubConnection.On("Result", tcs.SetResult); + + await connection.ReceiveTextAsync("{\"type\":1,\"invocationId\":\"1\",\"target\":\"Result\",\"arguments\":[]}\u001e").DefaultTimeout(); + + var invokeMessage = await connection.ReadSentTextMessageAsync().DefaultTimeout(); + + Assert.Equal("{\"type\":3,\"invocationId\":\"1\",\"result\":40}", invokeMessage); + await tcs.Task.DefaultTimeout(); + } + finally + { + await hubConnection.DisposeAsync().DefaultTimeout(); + await connection.DisposeAsync().DefaultTimeout(); + } + } + + [Fact] + public async Task ClientCanThrowErrorResult() + { + var connection = new TestConnection(); + var hubConnection = CreateHubConnection(connection); + try + { + await hubConnection.StartAsync().DefaultTimeout(); + + hubConnection.On("Result", int () => + { + throw new Exception("error from client"); + }); + + await connection.ReceiveTextAsync("{\"type\":1,\"invocationId\":\"1\",\"target\":\"Result\",\"arguments\":[]}\u001e").DefaultTimeout(); + + var invokeMessage = await connection.ReadSentTextMessageAsync().DefaultTimeout(); + + Assert.Equal("{\"type\":3,\"invocationId\":\"1\",\"error\":\"error from client\"}", invokeMessage); + } + finally + { + await hubConnection.DisposeAsync().DefaultTimeout(); + await connection.DisposeAsync().DefaultTimeout(); + } + } + + [Fact] + public async Task ClientResultIgnoresErrorWhenLastHandlerSuccessful() + { + var connection = new TestConnection(); + var hubConnection = CreateHubConnection(connection); + try + { + await hubConnection.StartAsync().DefaultTimeout(); + + hubConnection.On("Result", int () => + { + throw new Exception("error from client"); + }); + + hubConnection.On("Result", () => 20); + + await connection.ReceiveTextAsync("{\"type\":1,\"invocationId\":\"1\",\"target\":\"Result\",\"arguments\":[]}\u001e").DefaultTimeout(); + + var invokeMessage = await connection.ReadSentTextMessageAsync().DefaultTimeout(); + + Assert.Equal("{\"type\":3,\"invocationId\":\"1\",\"result\":20}", invokeMessage); + } + finally + { + await hubConnection.DisposeAsync().DefaultTimeout(); + await connection.DisposeAsync().DefaultTimeout(); + } + } + + [Fact] + public async Task ClientResultReturnsErrorIfNoHandlerFromClient() + { + var connection = new TestConnection(); + var hubConnection = CreateHubConnection(connection); + try + { + await hubConnection.StartAsync().DefaultTimeout(); + + await connection.ReceiveTextAsync("{\"type\":1,\"invocationId\":\"1\",\"target\":\"Result\",\"arguments\":[]}\u001e").DefaultTimeout(); + + var invokeMessage = await connection.ReadSentTextMessageAsync().DefaultTimeout(); + + Assert.Equal("{\"type\":3,\"invocationId\":\"1\",\"error\":\"Client didn't provide a result.\"}", invokeMessage); + } + finally + { + await hubConnection.DisposeAsync().DefaultTimeout(); + await connection.DisposeAsync().DefaultTimeout(); + } + } + + [Fact] + public async Task ClientResultReturnsErrorIfNoResultFromClient() + { + var connection = new TestConnection(); + var hubConnection = CreateHubConnection(connection); + try + { + await hubConnection.StartAsync().DefaultTimeout(); + + // No result provided + hubConnection.On("Result", () => { }); + + await connection.ReceiveTextAsync("{\"type\":1,\"invocationId\":\"1\",\"target\":\"Result\",\"arguments\":[]}\u001e").DefaultTimeout(); + + var invokeMessage = await connection.ReadSentTextMessageAsync().DefaultTimeout(); + + Assert.Equal("{\"type\":3,\"invocationId\":\"1\",\"error\":\"Client didn't provide a result.\"}", invokeMessage); + } + finally + { + await hubConnection.DisposeAsync().DefaultTimeout(); + await connection.DisposeAsync().DefaultTimeout(); + } + } } } diff --git a/src/SignalR/clients/java/signalr/core/src/main/java/com/microsoft/signalr/Action.java b/src/SignalR/clients/java/signalr/core/src/main/java/com/microsoft/signalr/Action.java index fd1216e15431..653ab10e69e1 100644 --- a/src/SignalR/clients/java/signalr/core/src/main/java/com/microsoft/signalr/Action.java +++ b/src/SignalR/clients/java/signalr/core/src/main/java/com/microsoft/signalr/Action.java @@ -10,4 +10,4 @@ public interface Action { // We can't use the @FunctionalInterface annotation because it's only // available on Android API Level 24 and above. void invoke(); -} +} \ No newline at end of file diff --git a/src/SignalR/clients/java/signalr/core/src/main/java/com/microsoft/signalr/ActionBase.java b/src/SignalR/clients/java/signalr/core/src/main/java/com/microsoft/signalr/ActionBase.java index e24630b53d23..4e5fcf1b5d9f 100644 --- a/src/SignalR/clients/java/signalr/core/src/main/java/com/microsoft/signalr/ActionBase.java +++ b/src/SignalR/clients/java/signalr/core/src/main/java/com/microsoft/signalr/ActionBase.java @@ -8,3 +8,7 @@ interface ActionBase { // available on Android API Level 24 and above. void invoke(Object ... params); } + +interface FunctionBase { + Object invoke(Object ... params); +} \ No newline at end of file diff --git a/src/SignalR/clients/java/signalr/core/src/main/java/com/microsoft/signalr/CallbackMap.java b/src/SignalR/clients/java/signalr/core/src/main/java/com/microsoft/signalr/CallbackMap.java index 2a7013cc5dfb..6b6fd69c467c 100644 --- a/src/SignalR/clients/java/signalr/core/src/main/java/com/microsoft/signalr/CallbackMap.java +++ b/src/SignalR/clients/java/signalr/core/src/main/java/com/microsoft/signalr/CallbackMap.java @@ -14,7 +14,7 @@ class CallbackMap { private final Map> handlers = new HashMap<>(); private final ReentrantLock lock = new ReentrantLock(); - public InvocationHandler put(String target, ActionBase action, Type... types) { + public InvocationHandler put(String target, FunctionBase action, Type... types) { try { lock.lock(); InvocationHandler handler = new InvocationHandler(action, types); diff --git a/src/SignalR/clients/java/signalr/core/src/main/java/com/microsoft/signalr/ClientResultMessage.java b/src/SignalR/clients/java/signalr/core/src/main/java/com/microsoft/signalr/ClientResultMessage.java new file mode 100644 index 000000000000..2046fb54863d --- /dev/null +++ b/src/SignalR/clients/java/signalr/core/src/main/java/com/microsoft/signalr/ClientResultMessage.java @@ -0,0 +1,47 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. + +package com.microsoft.signalr; + +import java.util.Map; + +public final class ClientResultMessage extends HubMessage { + private final int type = HubMessageType.CLIENT_RESULT.value; + private Map headers; + private final String invocationId; + private final Object result; + private final String error; + + public ClientResultMessage(Map headers, String invocationId, Object result, String error) { + if (headers != null && !headers.isEmpty()) { + this.headers = headers; + } + if (error != null && result != null) { + throw new IllegalArgumentException("Expected either 'error' or 'result' to be provided, but not both."); + } + this.invocationId = invocationId; + this.result = result; + this.error = error; + } + + public Map getHeaders() { + return headers; + } + + public Object getResult() { + return result; + } + + public String getError() { + return error; + } + + public String getInvocationId() { + return invocationId; + } + + @Override + public HubMessageType getMessageType() { + return HubMessageType.values()[type - 1]; + } +} diff --git a/src/SignalR/clients/java/signalr/core/src/main/java/com/microsoft/signalr/Function.java b/src/SignalR/clients/java/signalr/core/src/main/java/com/microsoft/signalr/Function.java new file mode 100644 index 000000000000..dd56f1e0f4c1 --- /dev/null +++ b/src/SignalR/clients/java/signalr/core/src/main/java/com/microsoft/signalr/Function.java @@ -0,0 +1,5 @@ +package com.microsoft.signalr; + +public interface Function { + Object invoke(); +} diff --git a/src/SignalR/clients/java/signalr/core/src/main/java/com/microsoft/signalr/HubConnection.java b/src/SignalR/clients/java/signalr/core/src/main/java/com/microsoft/signalr/HubConnection.java index 1e2ce1bc5a5d..4d31b0a0a496 100644 --- a/src/SignalR/clients/java/signalr/core/src/main/java/com/microsoft/signalr/HubConnection.java +++ b/src/SignalR/clients/java/signalr/core/src/main/java/com/microsoft/signalr/HubConnection.java @@ -470,14 +470,21 @@ private void ReceiveLoop(ByteBuffer payload) InvocationBindingFailureMessage msg = (InvocationBindingFailureMessage)message; logger.error("Failed to bind arguments received in invocation '{}' of '{}'.", msg.getInvocationId(), msg.getTarget(), msg.getException()); break; + case STREAM_BINDING_FAILURE: + StreamBindingFailureMessage streamBindingFailure = (StreamBindingFailureMessage)message; + logger.error("Failed to bind arguments received in invocation '{}'.", streamBindingFailure.getInvocationId(), streamBindingFailure.getException()); + break; case INVOCATION: - InvocationMessage invocationMessage = (InvocationMessage) message; List handlers = this.handlers.get(invocationMessage.getTarget()); if (handlers != null) { for (InvocationHandler handler : handlers) { try { - handler.getAction().invoke(invocationMessage.getArguments()); + Object result = handler.getAction().invoke(invocationMessage.getArguments()); + logger.error("{}", result); + if (result != null) { + this.sendHubMessageWithLock(new ClientResultMessage(null, invocationMessage.getInvocationId(), new CompletionMessage(null, invocationMessage.getInvocationId(), result, null), null)); + } } catch (Exception e) { logger.error("Invoking client side method '{}' failed:", invocationMessage.getTarget(), e); } @@ -513,6 +520,7 @@ private void ReceiveLoop(ByteBuffer payload) streamInvocationRequest.addItem(streamItem); break; + case CLIENT_RESULT: case STREAM_INVOCATION: case CANCEL_INVOCATION: logger.error("This client does not support {} messages.", message.getMessageType()); @@ -868,7 +876,17 @@ public void onClosed(OnClosedCallback callback) { * @return A {@link Subscription} that can be disposed to unsubscribe from the hub method. */ public Subscription on(String target, Action callback) { - ActionBase action = args -> callback.invoke(); + FunctionBase action = args -> { + callback.invoke(); + return null; + }; + return registerHandler(target, action); + } + + public Subscription on(String target, Function callback) { + FunctionBase action = args -> { + return callback.invoke(); + }; return registerHandler(target, action); } @@ -883,9 +901,11 @@ public Subscription on(String target, Action callback) { * @return A {@link Subscription} that can be disposed to unsubscribe from the hub method. */ public Subscription on(String target, Action1 callback, Class param1) { - ActionBase action = params -> callback.invoke(Utils.cast(param1, params[0])); + FunctionBase action = params -> { + callback.invoke(Utils.cast(param1, params[0])); + return null; + }; return registerHandler(target, action, param1); - } /** @@ -901,8 +921,9 @@ public Subscription on(String target, Action1 callback, Class param * @return A {@link Subscription} that can be disposed to unsubscribe from the hub method. */ public Subscription on(String target, Action2 callback, Class param1, Class param2) { - ActionBase action = params -> { + FunctionBase action = params -> { callback.invoke(Utils.cast(param1, params[0]), Utils.cast(param2, params[1])); + return null; }; return registerHandler(target, action, param1, param2); } @@ -923,8 +944,9 @@ public Subscription on(String target, Action2 callback, Class Subscription on(String target, Action3 callback, Class param1, Class param2, Class param3) { - ActionBase action = params -> { + FunctionBase action = params -> { callback.invoke(Utils.cast(param1, params[0]), Utils.cast(param2, params[1]), Utils.cast(param3, params[2])); + return null; }; return registerHandler(target, action, param1, param2, param3); } @@ -947,9 +969,10 @@ public Subscription on(String target, Action3 callback, */ public Subscription on(String target, Action4 callback, Class param1, Class param2, Class param3, Class param4) { - ActionBase action = params -> { + FunctionBase action = params -> { callback.invoke(Utils.cast(param1, params[0]), Utils.cast(param2, params[1]), Utils.cast(param3, params[2]), Utils.cast(param4, params[3])); + return null; }; return registerHandler(target, action, param1, param2, param3, param4); } @@ -974,9 +997,10 @@ public Subscription on(String target, Action4 c */ public Subscription on(String target, Action5 callback, Class param1, Class param2, Class param3, Class param4, Class param5) { - ActionBase action = params -> { + FunctionBase action = params -> { callback.invoke(Utils.cast(param1, params[0]), Utils.cast(param2, params[1]), Utils.cast(param3, params[2]), Utils.cast(param4, params[3]), Utils.cast(param5, params[4])); + return null; }; return registerHandler(target, action, param1, param2, param3, param4, param5); } @@ -1003,9 +1027,10 @@ public Subscription on(String target, Action5 Subscription on(String target, Action6 callback, Class param1, Class param2, Class param3, Class param4, Class param5, Class param6) { - ActionBase action = params -> { + FunctionBase action = params -> { callback.invoke(Utils.cast(param1, params[0]), Utils.cast(param2, params[1]), Utils.cast(param3, params[2]), Utils.cast(param4, params[3]), Utils.cast(param5, params[4]), Utils.cast(param6, params[5])); + return null; }; return registerHandler(target, action, param1, param2, param3, param4, param5, param6); } @@ -1034,9 +1059,10 @@ public Subscription on(String target, Action6 Subscription on(String target, Action7 callback, Class param1, Class param2, Class param3, Class param4, Class param5, Class param6, Class param7) { - ActionBase action = params -> { + FunctionBase action = params -> { callback.invoke(Utils.cast(param1, params[0]), Utils.cast(param2, params[1]), Utils.cast(param3, params[2]), Utils.cast(param4, params[3]), Utils.cast(param5, params[4]), Utils.cast(param6, params[5]), Utils.cast(param7, params[6])); + return null; }; return registerHandler(target, action, param1, param2, param3, param4, param5, param6, param7); } @@ -1067,10 +1093,11 @@ public Subscription on(String target, Action7 Subscription on(String target, Action8 callback, Class param1, Class param2, Class param3, Class param4, Class param5, Class param6, Class param7, Class param8) { - ActionBase action = params -> { + FunctionBase action = params -> { callback.invoke(Utils.cast(param1, params[0]), Utils.cast(param2, params[1]), Utils.cast(param3, params[2]), Utils.cast(param4, params[3]), Utils.cast(param5, params[4]), Utils.cast(param6, params[5]), Utils.cast(param7, params[6]), Utils.cast(param8, params[7])); + return null; }; return registerHandler(target, action, param1, param2, param3, param4, param5, param6, param7, param8); } @@ -1087,8 +1114,9 @@ public Subscription on(String target, Action8 Subscription on(String target, Action1 callback, Type param1) { - ActionBase action = params -> { + FunctionBase action = params -> { callback.invoke(Utils.cast(param1, params[0])); + return null; }; return registerHandler(target, action, param1); } @@ -1107,8 +1135,9 @@ public Subscription on(String target, Action1 callback, Type param1) { * @return A {@link Subscription} that can be disposed to unsubscribe from the hub method. */ public Subscription on(String target, Action2 callback, Type param1, Type param2) { - ActionBase action = params -> { + FunctionBase action = params -> { callback.invoke(Utils.cast(param1, params[0]), Utils.cast(param2, params[1])); + return null; }; return registerHandler(target, action, param1, param2); } @@ -1130,8 +1159,9 @@ public Subscription on(String target, Action2 callback, Type pa */ public Subscription on(String target, Action3 callback, Type param1, Type param2, Type param3) { - ActionBase action = params -> { + FunctionBase action = params -> { callback.invoke(Utils.cast(param1, params[0]), Utils.cast(param2, params[1]), Utils.cast(param3, params[2])); + return null; }; return registerHandler(target, action, param1, param2, param3); } @@ -1155,9 +1185,10 @@ public Subscription on(String target, Action3 callback, */ public Subscription on(String target, Action4 callback, Type param1, Type param2, Type param3, Type param4) { - ActionBase action = params -> { + FunctionBase action = params -> { callback.invoke(Utils.cast(param1, params[0]), Utils.cast(param2, params[1]), Utils.cast(param3, params[2]), Utils.cast(param4, params[3])); + return null; }; return registerHandler(target, action, param1, param2, param3, param4); } @@ -1183,9 +1214,10 @@ public Subscription on(String target, Action4 c */ public Subscription on(String target, Action5 callback, Type param1, Type param2, Type param3, Type param4, Type param5) { - ActionBase action = params -> { + FunctionBase action = params -> { callback.invoke(Utils.cast(param1, params[0]), Utils.cast(param2, params[1]), Utils.cast(param3, params[2]), Utils.cast(param4, params[3]), Utils.cast(param5, params[4])); + return null; }; return registerHandler(target, action, param1, param2, param3, param4, param5); } @@ -1213,9 +1245,10 @@ public Subscription on(String target, Action5 Subscription on(String target, Action6 callback, Type param1, Type param2, Type param3, Type param4, Type param5, Type param6) { - ActionBase action = params -> { + FunctionBase action = params -> { callback.invoke(Utils.cast(param1, params[0]), Utils.cast(param2, params[1]), Utils.cast(param3, params[2]), Utils.cast(param4, params[3]), Utils.cast(param5, params[4]), Utils.cast(param6, params[5])); + return null; }; return registerHandler(target, action, param1, param2, param3, param4, param5, param6); } @@ -1245,9 +1278,10 @@ public Subscription on(String target, Action6 Subscription on(String target, Action7 callback, Type param1, Type param2, Type param3, Type param4, Type param5, Type param6, Type param7) { - ActionBase action = params -> { + FunctionBase action = params -> { callback.invoke(Utils.cast(param1, params[0]), Utils.cast(param2, params[1]), Utils.cast(param3, params[2]), Utils.cast(param4, params[3]), Utils.cast(param5, params[4]), Utils.cast(param6, params[5]), Utils.cast(param7, params[6])); + return null; }; return registerHandler(target, action, param1, param2, param3, param4, param5, param6, param7); } @@ -1280,15 +1314,16 @@ public Subscription on(String target, Action7 Subscription on(String target, Action8 callback, Type param1, Type param2, Type param3, Type param4, Type param5, Type param6, Type param7, Type param8) { - ActionBase action = params -> { + FunctionBase action = params -> { callback.invoke(Utils.cast(param1, params[0]), Utils.cast(param2, params[1]), Utils.cast(param3, params[2]), Utils.cast(param4, params[3]), Utils.cast(param5, params[4]), Utils.cast(param6, params[5]), Utils.cast(param7, params[6]), Utils.cast(param8, params[7])); + return null; }; return registerHandler(target, action, param1, param2, param3, param4, param5, param6, param7, param8); } - private Subscription registerHandler(String target, ActionBase action, Type... types) { + private Subscription registerHandler(String target, FunctionBase action, Type... types) { InvocationHandler handler = handlers.put(target, action, types); logger.debug("Registering handler for client method: '{}'.", target); return new Subscription(handlers, handler, target); diff --git a/src/SignalR/clients/java/signalr/core/src/main/java/com/microsoft/signalr/HubMessageType.java b/src/SignalR/clients/java/signalr/core/src/main/java/com/microsoft/signalr/HubMessageType.java index 23201c0c0d8a..d191f9f7917c 100644 --- a/src/SignalR/clients/java/signalr/core/src/main/java/com/microsoft/signalr/HubMessageType.java +++ b/src/SignalR/clients/java/signalr/core/src/main/java/com/microsoft/signalr/HubMessageType.java @@ -11,6 +11,7 @@ public enum HubMessageType { CANCEL_INVOCATION(5), PING(6), CLOSE(7), + CLIENT_RESULT(8), INVOCATION_BINDING_FAILURE(-1), STREAM_BINDING_FAILURE(-2); diff --git a/src/SignalR/clients/java/signalr/core/src/main/java/com/microsoft/signalr/InvocationHandler.java b/src/SignalR/clients/java/signalr/core/src/main/java/com/microsoft/signalr/InvocationHandler.java index 98fa53aaf365..a4632c85cced 100644 --- a/src/SignalR/clients/java/signalr/core/src/main/java/com/microsoft/signalr/InvocationHandler.java +++ b/src/SignalR/clients/java/signalr/core/src/main/java/com/microsoft/signalr/InvocationHandler.java @@ -9,18 +9,26 @@ class InvocationHandler { private final List types; - private final ActionBase action; + private final Object action; + private final Boolean hasResult; + + InvocationHandler(FunctionBase action, Type... types) { + this.action = action; + this.types = Arrays.asList(types); + this.hasResult = false; + } InvocationHandler(ActionBase action, Type... types) { this.action = action; this.types = Arrays.asList(types); + this.hasResult = true; } public List getTypes() { return types; } - public ActionBase getAction() { - return action; + public FunctionBase getAction() { + return (FunctionBase)action; } } \ No newline at end of file diff --git a/src/SignalR/clients/java/signalr/test/src/main/java/com/microsoft/signalr/sample/Chat.java b/src/SignalR/clients/java/signalr/test/src/main/java/com/microsoft/signalr/sample/Chat.java index f4c0850c0fb9..cdb411427eb8 100644 --- a/src/SignalR/clients/java/signalr/test/src/main/java/com/microsoft/signalr/sample/Chat.java +++ b/src/SignalR/clients/java/signalr/test/src/main/java/com/microsoft/signalr/sample/Chat.java @@ -15,6 +15,10 @@ public static void main(final String[] args) throws Exception { final String input = reader.nextLine(); try (HubConnection hubConnection = HubConnectionBuilder.create(input).build()) { + hubConnection.on("F", () -> { + return 2; + }); + hubConnection.on("Send", (message) -> { System.out.println(message); }, String.class); diff --git a/src/SignalR/clients/ts/FunctionalTests/Startup.cs b/src/SignalR/clients/ts/FunctionalTests/Startup.cs index 509c7378da60..ef67be4b91a6 100644 --- a/src/SignalR/clients/ts/FunctionalTests/Startup.cs +++ b/src/SignalR/clients/ts/FunctionalTests/Startup.cs @@ -1,6 +1,7 @@ // Licensed to the .NET Foundation under one or more agreements. // The .NET Foundation licenses this file to you under the MIT license. +using System.Globalization; using System.IdentityModel.Tokens.Jwt; using System.Reflection; using System.Security.Claims; @@ -223,6 +224,19 @@ public void Configure(IApplicationBuilder app, IWebHostEnvironment env, ILogger< return context.Response.WriteAsync(GenerateJwtToken()); }); + endpoints.MapGet("/clientresult/{id}", async (IHubContext hubContext, string id) => + { + try + { + var result = await hubContext.Clients.Single(id).InvokeAsync("Result"); + return result.ToString(CultureInfo.InvariantCulture); + } + catch (Exception ex) + { + return ex.Message; + } + }); + endpoints.MapGet("/deployment", context => { var attributes = Assembly.GetAssembly(typeof(Startup)).GetCustomAttributes(); diff --git a/src/SignalR/clients/ts/FunctionalTests/ts/HubConnectionTests.ts b/src/SignalR/clients/ts/FunctionalTests/ts/HubConnectionTests.ts index c463fb7096b6..2b9d6efacbd9 100644 --- a/src/SignalR/clients/ts/FunctionalTests/ts/HubConnectionTests.ts +++ b/src/SignalR/clients/ts/FunctionalTests/ts/HubConnectionTests.ts @@ -513,6 +513,60 @@ describe("hubConnection", () => { await hubConnection.stop(); } }); + + it("can return result to server", async () => { + const hubConnection = getConnectionBuilder(transportType, undefined, { httpClient }) + .withHubProtocol(protocol) + .build(); + + hubConnection.on("Result", () => { + return 10; + }); + + await hubConnection.start(); + + const response = await httpClient.get(ENDPOINT_BASE_URL + `/clientresult/${hubConnection.connectionId}`); + + expect(response.content).toEqual("10"); + + await hubConnection.stop(); + }); + + it("can throw result to server", async () => { + const hubConnection = getConnectionBuilder(transportType, undefined, { httpClient }) + .withHubProtocol(protocol) + .build(); + + hubConnection.on("Result", () => { + throw new Error("from callback"); + }); + + try { + await hubConnection.start(); + + const response = await httpClient.get(ENDPOINT_BASE_URL + `/clientresult/${hubConnection.connectionId}`); + + expect(response.content).toEqual("Error: from callback"); + } finally { + await hubConnection.stop(); + } + }); + + it("returns result error to server when no result given", async () => { + const hubConnection = getConnectionBuilder(transportType, undefined, { httpClient }) + .withHubProtocol(protocol) + .build(); + + try { + await hubConnection.start(); + + const response = await httpClient.get(ENDPOINT_BASE_URL + `/clientresult/${hubConnection.connectionId}`); + + expect(response.content).toEqual("Client didn't provide a result."); + } finally { + await hubConnection.stop(); + } + }); }); }); diff --git a/src/SignalR/clients/ts/signalr/src/HubConnection.ts b/src/SignalR/clients/ts/signalr/src/HubConnection.ts index 0338521956c5..4dce394014ec 100644 --- a/src/SignalR/clients/ts/signalr/src/HubConnection.ts +++ b/src/SignalR/clients/ts/signalr/src/HubConnection.ts @@ -39,7 +39,7 @@ export class HubConnection { private _protocol: IHubProtocol; private _handshakeProtocol: HandshakeProtocol; private _callbacks: { [invocationId: string]: (invocationEvent: StreamItemMessage | CompletionMessage | null, error?: Error) => void }; - private _methods: { [name: string]: ((...args: any[]) => void)[] }; + private _methods: { [name: string]: (((...args: any[]) => void) | ((...args: any[]) => any))[] }; private _invocationId: number; private _closedCallbacks: ((error?: Error) => void)[]; @@ -443,6 +443,7 @@ export class HubConnection { * @param {string} methodName The name of the hub method to define. * @param {Function} newMethod The handler that will be raised when the hub method is invoked. */ + public on(methodName: string, newMethod: (...args: any[]) => any): void public on(methodName: string, newMethod: (...args: any[]) => void): void { if (!methodName || !newMethod) { return; @@ -546,6 +547,7 @@ export class HubConnection { for (const message of messages) { switch (message.type) { case MessageType.Invocation: + // eslint-disable-next-line @typescript-eslint/no-floating-promises this._invokeClientMethod(message); break; case MessageType.StreamItem: @@ -672,26 +674,47 @@ export class HubConnection { this.connection.stop(new Error("Server timeout elapsed without receiving a message from the server.")); } - private _invokeClientMethod(invocationMessage: InvocationMessage) { + private async _invokeClientMethod(invocationMessage: InvocationMessage) { const methods = this._methods[invocationMessage.target.toLowerCase()]; if (methods) { const methodsCopy = methods.slice(); - try { - methodsCopy.forEach((m) => m.apply(this, invocationMessage.arguments)); - } catch (e) { - this._logger.log(LogLevel.Error, `A callback for the method ${invocationMessage.target.toLowerCase()} threw error '${e}'.`); - } if (invocationMessage.invocationId) { - // This is not supported in v1. So we return an error to avoid blocking the server waiting for the response. - const message = "Server requested a response, which is not supported in this version of the client."; - this._logger.log(LogLevel.Error, message); - - // We don't want to wait on the stop itself. - this._stopPromise = this._stopInternal(new Error(message)); + let res; + let exception; + for (const m of methodsCopy) { + try { + if (res) { + this._logger.log(LogLevel.Warning, `Result already provided for '${invocationMessage.target.toLowerCase()}' only last one will be sent.`); + } + res = await m.apply(this, invocationMessage.arguments); + exception = undefined; + } catch (e) { + exception = e; + this._logger.log(LogLevel.Error, `A callback for the method '${invocationMessage.target.toLowerCase()}' threw error '${e}'.`); + } + } + if (exception) { + await this._sendWithProtocol(this._createCompletionMessage(invocationMessage.invocationId, `${exception}`, null)); + } + else if (res !== undefined) { + await this._sendWithProtocol(this._createCompletionMessage(invocationMessage.invocationId, null, res)); + } else { + await this._sendWithProtocol(this._createCompletionMessage(invocationMessage.invocationId, "Client didn't provide a result.", null)); + } + } else { + try { + methodsCopy.forEach((m) => m.apply(this, invocationMessage.arguments)); + } catch (e) { + this._logger.log(LogLevel.Error, `A callback for the method '${invocationMessage.target.toLowerCase()}' threw error '${e}'.`); + } } } else { - this._logger.log(LogLevel.Warning, `No client method with the name '${invocationMessage.target}' found.`); + this._logger.log(LogLevel.Warning, `No client method with the name '${invocationMessage.target.toLowerCase()}' found.`); + + if (invocationMessage.invocationId) { + await this._sendWithProtocol(this._createCompletionMessage(invocationMessage.invocationId, "Client didn't provide a result.", null)); + } } } diff --git a/src/SignalR/clients/ts/signalr/tests/HubConnection.test.ts b/src/SignalR/clients/ts/signalr/tests/HubConnection.test.ts index 18b5a6de3c09..bc8b54aed788 100644 --- a/src/SignalR/clients/ts/signalr/tests/HubConnection.test.ts +++ b/src/SignalR/clients/ts/signalr/tests/HubConnection.test.ts @@ -1048,6 +1048,290 @@ describe("HubConnection", () => { } }); }); + + it("can return result from callback", async () => { + await VerifyLogger.run(async (logger) => { + const connection = new TestConnection(); + const hubConnection = createHubConnection(connection, logger); + try { + await hubConnection.start(); + + hubConnection.on("message", () => 10); + + connection.receive({ + arguments: [], + invocationId: "1", + nonblocking: true, + target: "message", + type: MessageType.Invocation, + }); + + // nothing to wait on and the code is all synchronous, but because of how JS and async works we need to trigger + // async here to guarantee the sent message is written + await delayUntil(1); + + expect(connection.parsedSentData.length).toEqual(2); + expect(connection.parsedSentData[1].type).toEqual(3); + expect(connection.parsedSentData[1].result).toEqual(10); + expect(connection.parsedSentData[1].invocationId).toEqual("1"); + } finally { + await hubConnection.stop(); + } + }); + }); + + it("can return null result from callback", async () => { + await VerifyLogger.run(async (logger) => { + const connection = new TestConnection(); + const hubConnection = createHubConnection(connection, logger); + try { + await hubConnection.start(); + + hubConnection.on("message", () => null); + + connection.receive({ + arguments: [], + invocationId: "1", + nonblocking: true, + target: "message", + type: MessageType.Invocation, + }); + + // nothing to wait on and the code is all synchronous, but because of how JS and async works we need to trigger + // async here to guarantee the sent message is written + await delayUntil(1); + + expect(connection.parsedSentData.length).toEqual(2); + expect(connection.parsedSentData[1].type).toEqual(3); + expect(connection.parsedSentData[1].result).toBeNull(); + expect(connection.parsedSentData[1].invocationId).toEqual("1"); + } finally { + await hubConnection.stop(); + } + }); + }); + + it("can return task result from callback", async () => { + await VerifyLogger.run(async (logger) => { + const connection = new TestConnection(); + const hubConnection = createHubConnection(connection, logger); + try { + await hubConnection.start(); + + const p = new PromiseSource(); + hubConnection.on("message", () => p); + + connection.receive({ + arguments: [], + invocationId: "1", + nonblocking: true, + target: "message", + type: MessageType.Invocation, + }); + p.resolve(13); + + // nothing to wait on and the code is all synchronous, but because of how JS and async works we need to trigger + // async here to guarantee the sent message is written + await delayUntil(1); + + expect(connection.parsedSentData.length).toEqual(2); + expect(connection.parsedSentData[1].type).toEqual(3); + expect(connection.parsedSentData[1].result).toEqual(13); + expect(connection.parsedSentData[1].invocationId).toEqual("1"); + } finally { + await hubConnection.stop(); + } + }); + }); + + it("can throw from callback when expecting result", async () => { + await VerifyLogger.run(async (logger) => { + const connection = new TestConnection(); + const hubConnection = createHubConnection(connection, logger); + try { + await hubConnection.start(); + + hubConnection.on("message", () => { throw new Error("from callback"); }); + + connection.receive({ + arguments: [], + invocationId: "1", + nonblocking: true, + target: "message", + type: MessageType.Invocation, + }); + + // nothing to wait on and the code is all synchronous, but because of how JS and async works we need to trigger + // async here to guarantee the sent message is written + await delayUntil(1); + + expect(connection.parsedSentData.length).toEqual(2); + expect(connection.parsedSentData[1].type).toEqual(3); + expect(connection.parsedSentData[1].error).toEqual("Error: from callback"); + expect(connection.parsedSentData[1].invocationId).toEqual("1"); + } finally { + await hubConnection.stop(); + } + }, "A callback for the method 'message' threw error 'Error: from callback'."); + }); + + it("multiple results only sends last one", async () => { + await VerifyLogger.run(async (logger) => { + const connection = new TestConnection(); + const hubConnection = createHubConnection(connection, logger); + try { + await hubConnection.start(); + + hubConnection.on("message", () => 3); + hubConnection.on("message", () => 4); + + connection.receive({ + arguments: [], + invocationId: "1", + nonblocking: true, + target: "message", + type: MessageType.Invocation, + }); + + // nothing to wait on and the code is all synchronous, but because of how JS and async works we need to trigger + // async here to guarantee the sent message is written + await delayUntil(1); + + expect(connection.parsedSentData.length).toEqual(2); + expect(connection.parsedSentData[1].type).toEqual(3); + expect(connection.parsedSentData[1].result).toEqual(4); + expect(connection.parsedSentData[1].invocationId).toEqual("1"); + } finally { + await hubConnection.stop(); + } + }); + }); + + it("multiple result handlers error from last one sent", async () => { + await VerifyLogger.run(async (logger) => { + const connection = new TestConnection(); + const hubConnection = createHubConnection(connection, logger); + try { + await hubConnection.start(); + + hubConnection.on("message", () => 3); + hubConnection.on("message", () => { throw new Error("from callback"); }); + + connection.receive({ + arguments: [], + invocationId: "1", + nonblocking: true, + target: "message", + type: MessageType.Invocation, + }); + + // nothing to wait on and the code is all synchronous, but because of how JS and async works we need to trigger + // async here to guarantee the sent message is written + await delayUntil(1); + + expect(connection.parsedSentData.length).toEqual(2); + expect(connection.parsedSentData[1].type).toEqual(3); + expect(connection.parsedSentData[1].error).toEqual("Error: from callback"); + expect(connection.parsedSentData[1].result).toBeUndefined(); + expect(connection.parsedSentData[1].invocationId).toEqual("1"); + } finally { + await hubConnection.stop(); + } + }, "A callback for the method 'message' threw error 'Error: from callback'."); + }); + + it("multiple result handlers ignore error if last one has result", async () => { + await VerifyLogger.run(async (logger) => { + const connection = new TestConnection(); + const hubConnection = createHubConnection(connection, logger); + try { + await hubConnection.start(); + + hubConnection.on("message", () => { throw new Error("from callback"); }); + hubConnection.on("message", () => 3); + + connection.receive({ + arguments: [], + invocationId: "1", + nonblocking: true, + target: "message", + type: MessageType.Invocation, + }); + + // nothing to wait on and the code is all synchronous, but because of how JS and async works we need to trigger + // async here to guarantee the sent message is written + await delayUntil(1); + + expect(connection.parsedSentData.length).toEqual(2); + expect(connection.parsedSentData[1].type).toEqual(3); + expect(connection.parsedSentData[1].result).toEqual(3); + expect(connection.parsedSentData[1].error).toBeUndefined(); + expect(connection.parsedSentData[1].invocationId).toEqual("1"); + } finally { + await hubConnection.stop(); + } + }, "A callback for the method 'message' threw error 'Error: from callback'."); + }); + + it("sends completion error if return result expected but not returned", async () => { + await VerifyLogger.run(async (logger) => { + const connection = new TestConnection(); + const hubConnection = createHubConnection(connection, logger); + try { + await hubConnection.start(); + + hubConnection.on("message", () => {}); + + connection.receive({ + arguments: [], + invocationId: "1", + nonblocking: true, + target: "message", + type: MessageType.Invocation, + }); + + // nothing to wait on and the code is all synchronous, but because of how JS and async works we need to trigger + // async here to guarantee the sent message is written + await delayUntil(1); + + expect(connection.parsedSentData.length).toEqual(2); + expect(connection.parsedSentData[1].type).toEqual(3); + expect(connection.parsedSentData[1].error).toEqual("Client didn't provide a result."); + expect(connection.parsedSentData[1].invocationId).toEqual("1"); + } finally { + await hubConnection.stop(); + } + }); + }); + + it("sends completion error if return result expected but no handlers", async () => { + await VerifyLogger.run(async (logger) => { + const connection = new TestConnection(); + const hubConnection = createHubConnection(connection, logger); + try { + await hubConnection.start(); + + connection.receive({ + arguments: [], + invocationId: "1", + nonblocking: true, + target: "message", + type: MessageType.Invocation, + }); + + // nothing to wait on and the code is all synchronous, but because of how JS and async works we need to trigger + // async here to guarantee the sent message is written + await delayUntil(1); + + expect(connection.parsedSentData.length).toEqual(2); + expect(connection.parsedSentData[1].type).toEqual(3); + expect(connection.parsedSentData[1].error).toEqual("Client didn't provide a result."); + expect(connection.parsedSentData[1].invocationId).toEqual("1"); + } finally { + await hubConnection.stop(); + } + }); + }); }); describe("stream", () => { diff --git a/src/SignalR/common/Protocols.Json/src/Protocol/JsonHubProtocol.cs b/src/SignalR/common/Protocols.Json/src/Protocol/JsonHubProtocol.cs index f35cba60c611..b21a008ffadc 100644 --- a/src/SignalR/common/Protocols.Json/src/Protocol/JsonHubProtocol.cs +++ b/src/SignalR/common/Protocols.Json/src/Protocol/JsonHubProtocol.cs @@ -200,8 +200,6 @@ public ReadOnlyMemory GetMessageBytes(HubMessage message) { hasResult = true; - reader.CheckRead(); - if (string.IsNullOrEmpty(invocationId)) { // If we don't have an invocation id then we need to value copy the reader so we can parse it later @@ -211,9 +209,8 @@ public ReadOnlyMemory GetMessageBytes(HubMessage message) } else { - // If we have an invocation id already we can parse the end result var returnType = binder.GetReturnType(invocationId); - result = BindType(ref reader, returnType); + result = BindType(ref reader, input, returnType); } } else if (reader.ValueTextEquals(ItemPropertyNameBytes.EncodedUtf8Bytes)) @@ -391,7 +388,7 @@ public ReadOnlyMemory GetMessageBytes(HubMessage message) if (hasResultToken) { var returnType = binder.GetReturnType(invocationId); - result = BindType(ref resultToken, returnType); + result = BindType(ref resultToken, input, returnType); } message = BindCompletionMessage(invocationId, error, result, hasResult); @@ -537,7 +534,16 @@ private void WriteCompletionMessage(CompletionMessage message, Utf8JsonWriter wr } else { - JsonSerializer.Serialize(writer, message.Result, message.Result.GetType(), _payloadSerializerOptions); + var resultType = message.Result.GetType(); + if (resultType == typeof(RawResult)) + { + Debug.Assert(((RawResult)message.Result).RawSerializedData.IsSingleSegment); + writer.WriteRawValue(((RawResult)message.Result).RawSerializedData.First.Span, skipInputValidation: true); + } + else + { + JsonSerializer.Serialize(writer, message.Result, resultType, _payloadSerializerOptions); + } } } } @@ -724,6 +730,24 @@ private static HubMessage BindInvocationMessage(string? invocationId, string tar return new InvocationMessage(invocationId, target, arguments, streamIds); } + private object? BindType(ref Utf8JsonReader reader, ReadOnlySequence input, Type type) + { + if (type == typeof(RawResult)) + { + var start = reader.BytesConsumed; + reader.Skip(); + var end = reader.BytesConsumed; + var sequence = input.Slice(start, end - start); + // Technically we could pass the sequence without copying into a new array + // but in the future we could break this if we dispatched the CompletionMessage and the underlying Pipe read would be advanced + var arr = new byte[sequence.Length]; + sequence.CopyTo(arr); + // REVIEW: We can make this type do the copying which would allow us to rent from the ArrayPool + return new RawResult(new ReadOnlySequence(arr)); + } + return BindType(ref reader, type); + } + private object? BindType(ref Utf8JsonReader reader, Type type) { return JsonSerializer.Deserialize(ref reader, type, _payloadSerializerOptions); diff --git a/src/SignalR/common/Protocols.MessagePack/src/Protocol/MessagePackHubProtocolWorker.cs b/src/SignalR/common/Protocols.MessagePack/src/Protocol/MessagePackHubProtocolWorker.cs index b12baeb3b14c..a38cf1851611 100644 --- a/src/SignalR/common/Protocols.MessagePack/src/Protocol/MessagePackHubProtocolWorker.cs +++ b/src/SignalR/common/Protocols.MessagePack/src/Protocol/MessagePackHubProtocolWorker.cs @@ -162,7 +162,14 @@ private CompletionMessage CreateCompletionMessage(ref MessagePackReader reader, break; case NonVoidResult: var itemType = binder.GetReturnType(invocationId); - result = DeserializeObject(ref reader, itemType, "argument"); + if (itemType == typeof(RawResult)) + { + result = new RawResult(reader.ReadRaw()); + } + else + { + result = DeserializeObject(ref reader, itemType, "argument"); + } hasResult = true; break; case VoidResult: @@ -434,6 +441,10 @@ private void WriteArgument(object? argument, ref MessagePackWriter writer) { writer.WriteNil(); } + else if (argument.GetType() == typeof(RawResult)) + { + writer.WriteRaw(((RawResult)argument).RawSerializedData); + } else { Serialize(ref writer, argument.GetType(), argument); diff --git a/src/SignalR/common/Protocols.NewtonsoftJson/src/Protocol/NewtonsoftJsonHubProtocol.cs b/src/SignalR/common/Protocols.NewtonsoftJson/src/Protocol/NewtonsoftJsonHubProtocol.cs index 694831ae5e8c..c0516dba9cb4 100644 --- a/src/SignalR/common/Protocols.NewtonsoftJson/src/Protocol/NewtonsoftJsonHubProtocol.cs +++ b/src/SignalR/common/Protocols.NewtonsoftJson/src/Protocol/NewtonsoftJsonHubProtocol.cs @@ -8,6 +8,7 @@ using System.Diagnostics.CodeAnalysis; using System.IO; using System.Runtime.ExceptionServices; +using System.Text; using Microsoft.AspNetCore.Connections; using Microsoft.AspNetCore.Internal; using Microsoft.AspNetCore.SignalR.Internal; @@ -214,7 +215,16 @@ public ReadOnlyMemory GetMessageBytes(HubMessage message) throw new JsonReaderException("Unexpected end when reading JSON"); } - result = PayloadSerializer.Deserialize(reader, returnType); + if (returnType == typeof(RawResult)) + { + var token = JToken.Load(reader); + var str = token.ToString(Formatting.None); + result = new RawResult(new ReadOnlySequence(Encoding.UTF8.GetBytes(str))); + } + else + { + result = PayloadSerializer.Deserialize(reader, returnType); + } } break; case ItemPropertyName: @@ -388,7 +398,15 @@ public ReadOnlyMemory GetMessageBytes(HubMessage message) if (resultToken != null) { var returnType = binder.GetReturnType(invocationId); - result = resultToken.ToObject(returnType, PayloadSerializer); + if (returnType == typeof(RawResult)) + { + var str = resultToken.ToString(Formatting.None); + result = new RawResult(new ReadOnlySequence(Encoding.UTF8.GetBytes(str))); + } + else + { + result = resultToken.ToObject(returnType, PayloadSerializer); + } } message = BindCompletionMessage(invocationId, error, result, hasResult); @@ -531,7 +549,14 @@ private void WriteCompletionMessage(CompletionMessage message, JsonTextWriter wr else if (message.HasResult) { writer.WritePropertyName(ResultPropertyName); - PayloadSerializer.Serialize(writer, message.Result); + if (message.Result?.GetType() == typeof(RawResult)) + { + writer.WriteRawValue(Encoding.UTF8.GetString(((RawResult)message.Result).RawSerializedData.ToArray())); + } + else + { + PayloadSerializer.Serialize(writer, message.Result); + } } } diff --git a/src/SignalR/common/Shared/ClientResultsManager.cs b/src/SignalR/common/Shared/ClientResultsManager.cs new file mode 100644 index 000000000000..352acd308124 --- /dev/null +++ b/src/SignalR/common/Shared/ClientResultsManager.cs @@ -0,0 +1,196 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. + +using System.Collections.Concurrent; +using System.Diagnostics; +using System.Diagnostics.CodeAnalysis; +using System.Linq; +using Microsoft.AspNetCore.SignalR.Protocol; + +namespace Microsoft.AspNetCore.SignalR.Internal; + +internal class ClientResultsManager : IInvocationBinder +{ + private readonly ConcurrentDictionary Completion)> _pendingInvocations = new(); + + public Task AddInvocation(string connectionId, string invocationId, CancellationToken cancellationToken) + { + var tcs = new TaskCompletionSourceWithCancellation(this, connectionId, invocationId, cancellationToken); + _pendingInvocations.TryAdd(invocationId, (typeof(T), connectionId, completionMessage => + { + if (completionMessage.HasResult) + { + tcs.SetResult((T)completionMessage.Result); + } + else + { + tcs.SetException(new Exception(completionMessage.Error)); + } + return Task.CompletedTask; + } + )); + + return tcs.Task; + } + + public void AddInvocation(string invocationId, (Type Type, string ConnectionId, Func Completion) invocationInfo) + { + _pendingInvocations.TryAdd(invocationId, invocationInfo); + } + + public Task TryCompleteResult(string connectionId, CompletionMessage message) + { + if (_pendingInvocations.TryGetValue(message.InvocationId!, out var item)) + { + if (item.ConnectionId != connectionId) + { + throw new Exception("wrong ID"); + } + + // if false the connection disconnected right after the above TryGetValue + // or someone else completed the invocation (likely a bad client) + // we'll ignore both cases + if (_pendingInvocations.Remove(message.InvocationId!, out _)) + { + return item.Completion(message); + } + } + else + { + // connection was disconnected or someone else completed the invocation + } + return Task.CompletedTask; + } + + public (Type Type, string ConnectionId, Func Completion) RemoveInvocation(string invocationId) + { + _pendingInvocations.Remove(invocationId, out var item); + return item; + } + + public void CleanupConnection(string connectionId, List? pendingTasks) + { + var invocationIds = _pendingInvocations.Where(x => x.Value.ConnectionId == connectionId).Select(x => x.Key); + foreach (var id in invocationIds) + { + if (_pendingInvocations.Remove(id, out var item)) + { + var task = item.Completion(CompletionMessage.WithError(id, "Connection disconnected")); + if (!task.IsCompletedSuccessfully) + { + if (pendingTasks is null) + { + pendingTasks = new List(); + } + pendingTasks.Add(task); + } + } + } + } + + public bool TryGetType(string invocationId, [NotNullWhen(true)] out Type? type) + { + if (_pendingInvocations.TryGetValue(invocationId, out var item)) + { + type = item.Type; + return true; + } + type = null; + return false; + } + + public Type GetReturnType(string invocationId) + { + if (TryGetType(invocationId, out var type)) + { + return type; + } + throw new InvalidOperationException(); + } + + public IReadOnlyList GetParameterTypes(string methodName) + { + throw new NotImplementedException(); + } + + public Type GetStreamItemType(string streamId) + { + throw new NotImplementedException(); + } + + private sealed class TaskCompletionSourceWithCancellation : TaskCompletionSource + { + private readonly ClientResultsManager _clientResultsManager; + private readonly string _connectionId; + private readonly string _invocationId; + private readonly CancellationTokenRegistration _tokenRegistration; + + public TaskCompletionSourceWithCancellation(ClientResultsManager clientResultsManager, string connectionId, string invocationId, + CancellationToken cancellationToken) + : base(TaskCreationOptions.RunContinuationsAsynchronously) + { + _clientResultsManager = clientResultsManager; + _connectionId = connectionId; + _invocationId = invocationId; + + if (cancellationToken.CanBeCanceled) + { + _tokenRegistration = cancellationToken.UnsafeRegister(static o => + { + var tcs = (TaskCompletionSourceWithCancellation)o!; + tcs.SetCanceled(); + }, this); + } + } + + public new void SetCanceled() + { + // TODO: RedisHubLifetimeManager will want to notify the other server (if there is one) about the cancellation + // so it can clean up state and potentially forward that info to the connection + _ = _clientResultsManager.TryCompleteResult(_connectionId, CompletionMessage.WithError(_invocationId, "Canceled")); + } + + public new void SetResult(T result) + { + base.SetResult(result); + _tokenRegistration.Dispose(); + } + + public new void SetException(Exception exception) + { + base.SetException(exception); + _tokenRegistration.Dispose(); + } + +#pragma warning disable IDE0060 // Remove unused parameter + // Just making sure we don't accidentally call one of these without knowing + public static new void SetCanceled(CancellationToken cancellationToken) => Debug.Assert(false); + public static new void SetException(IEnumerable exceptions) => Debug.Assert(false); + public static new bool TrySetCanceled() + { + Debug.Assert(false); + return false; + } + public static new bool TrySetCanceled(CancellationToken cancellationToken) + { + Debug.Assert(false); + return false; + } + public static new bool TrySetException(IEnumerable exceptions) + { + Debug.Assert(false); + return false; + } + public static new bool TrySetException(Exception exception) + { + Debug.Assert(false); + return false; + } + public static new bool TrySetResult(T result) + { + Debug.Assert(false); + return false; + } +#pragma warning restore IDE0060 // Remove unused parameter + } +} diff --git a/src/SignalR/common/SignalR.Common/src/Protocol/CompletionMessage.cs b/src/SignalR/common/SignalR.Common/src/Protocol/CompletionMessage.cs index 1645e7f3ed5e..440e431c6bb8 100644 --- a/src/SignalR/common/SignalR.Common/src/Protocol/CompletionMessage.cs +++ b/src/SignalR/common/SignalR.Common/src/Protocol/CompletionMessage.cs @@ -36,7 +36,7 @@ public class CompletionMessage : HubInvocationMessage public CompletionMessage(string invocationId, string? error, object? result, bool hasResult) : base(invocationId) { - if (error != null && result != null) + if (error is not null && hasResult) { throw new ArgumentException($"Expected either '{nameof(error)}' or '{nameof(result)}' to be provided, but not both"); } diff --git a/src/SignalR/common/SignalR.Common/src/Protocol/RawResult.cs b/src/SignalR/common/SignalR.Common/src/Protocol/RawResult.cs new file mode 100644 index 000000000000..8aecffb30967 --- /dev/null +++ b/src/SignalR/common/SignalR.Common/src/Protocol/RawResult.cs @@ -0,0 +1,35 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. + +using System; +using System.Buffers; +using System.Collections.Generic; +using System.Linq; +using System.Text; +using System.Threading.Tasks; + +namespace Microsoft.AspNetCore.SignalR.Protocol; + +/// +/// Type returned to implementations to let them know the object being deserialized should be +/// stored as raw serialized bytes in the format of the protocol being used. +/// +/// +/// In Json that would mean storing the bytes of {"prop":10} as an example. +/// +public class RawResult +{ + /// + /// + /// + /// + public RawResult(ReadOnlySequence rawBytes) + { + RawSerializedData = rawBytes; + } + + /// + /// + /// + public ReadOnlySequence RawSerializedData { get; private set; } +} diff --git a/src/SignalR/common/SignalR.Common/src/PublicAPI.Unshipped.txt b/src/SignalR/common/SignalR.Common/src/PublicAPI.Unshipped.txt index 7dc5c58110bf..76fd06c813c2 100644 --- a/src/SignalR/common/SignalR.Common/src/PublicAPI.Unshipped.txt +++ b/src/SignalR/common/SignalR.Common/src/PublicAPI.Unshipped.txt @@ -1 +1,4 @@ #nullable enable +Microsoft.AspNetCore.SignalR.Protocol.RawResult +Microsoft.AspNetCore.SignalR.Protocol.RawResult.RawResult(System.Buffers.ReadOnlySequence rawBytes) -> void +Microsoft.AspNetCore.SignalR.Protocol.RawResult.RawSerializedData.get -> System.Buffers.ReadOnlySequence diff --git a/src/SignalR/common/SignalR.Common/test/Internal/Protocol/JsonHubProtocolTestsBase.cs b/src/SignalR/common/SignalR.Common/test/Internal/Protocol/JsonHubProtocolTestsBase.cs index 66fd2de266e1..deccbb188e14 100644 --- a/src/SignalR/common/SignalR.Common/test/Internal/Protocol/JsonHubProtocolTestsBase.cs +++ b/src/SignalR/common/SignalR.Common/test/Internal/Protocol/JsonHubProtocolTestsBase.cs @@ -387,6 +387,67 @@ public void VerifyMessageSize(string testDataName) } } + public static IDictionary ClientResultData => new[] + { + new ClientResultTestData("SimpleResult", "{\"type\":3,\"invocationId\":\"1\",\"result\":45}", typeof(int), 45), + new ClientResultTestData("SimpleResult_InvocationIdLast", "{\"type\":3,\"result\":45,\"invocationId\":\"1\"}", typeof(int), 45), + new ClientResultTestData("MissingResult", "{\"type\":3,\"invocationId\":\"1\"}", typeof(int), null), + + new ClientResultTestData("ComplexResult", "{\"type\":3,\"invocationId\":\"1\",\"result\":{\"stringProp\":\"test\",\"doubleProp\":1.1,\"intProp\":0,\"dateTimeProp\":\"0001-01-01T00:00:00\",\"nullProp\":null,\"byteArrProp\":\"AgQG\"}}", typeof(CustomObject), + new CustomObject() + { + ByteArrProp = new byte[] { 2, 4, 6 }, + IntProp = default, + DoubleProp = 1.1, + StringProp = "test", + DateTimeProp = default + }), + new ClientResultTestData("ComplexResult_InvocationIdLast", "{\"type\":3,\"result\":{\"stringProp\":\"test\",\"doubleProp\":1.1,\"intProp\":0,\"dateTimeProp\":\"0001-01-01T00:00:00\",\"nullProp\":null,\"byteArrProp\":\"AgQG\"},\"invocationId\":\"1\"}", typeof(CustomObject), + new CustomObject() + { + ByteArrProp = new byte[] { 2, 4, 6 }, + IntProp = default, + DoubleProp = 1.1, + StringProp = "test", + DateTimeProp = default + }), + }.ToDictionary(t => t.Name); + + public static IEnumerable ClientResultDataNames => ClientResultData.Keys.Select(name => new object[] { name }); + + [Theory] + [MemberData(nameof(ClientResultDataNames))] + public void RawResultRoundTripsProperly(string testDataName) + { + var testData = ClientResultData[testDataName]; + + var binder = new TestBinder(null, typeof(RawResult)); + var input = Frame(testData.Message); + var data = new ReadOnlySequence(Encoding.UTF8.GetBytes(input)); + Assert.True(JsonHubProtocol.TryParseMessage(ref data, binder, out var message)); + var completion = Assert.IsType(message); + + var writer = MemoryBufferWriter.Get(); + try + { + // WriteMessage should handle RawResult as Raw Json and write it properly + JsonHubProtocol.WriteMessage(completion, writer); + + // Now we check if the Raw Json was written properly and can be read using the expected type + binder = new TestBinder(null, testData.ResultType); + var written = writer.ToArray(); + data = new ReadOnlySequence(written); + Assert.True(JsonHubProtocol.TryParseMessage(ref data, binder, out message)); + + completion = Assert.IsType(message); + Assert.Equal(testData.Result, completion.Result); + } + finally + { + MemoryBufferWriter.Return(writer); + } + } + public static string Frame(string input) { var data = Encoding.UTF8.GetBytes(input); @@ -436,4 +497,22 @@ public MessageSizeTestData(string name, HubMessage message, int size) public override string ToString() => Name; } + + public class ClientResultTestData + { + public string Name { get; } + public string Message { get; } + public Type ResultType { get; } + public object Result { get; } + + public ClientResultTestData(string name, string message, Type resultType, object result) + { + Name = name; + Message = message; + ResultType = resultType; + Result = result; + } + + public override string ToString() => Name; + } } diff --git a/src/SignalR/common/SignalR.Common/test/Internal/Protocol/MessagePackHubProtocolTests.cs b/src/SignalR/common/SignalR.Common/test/Internal/Protocol/MessagePackHubProtocolTests.cs index b1f6acab656b..978d729cc4d5 100644 --- a/src/SignalR/common/SignalR.Common/test/Internal/Protocol/MessagePackHubProtocolTests.cs +++ b/src/SignalR/common/SignalR.Common/test/Internal/Protocol/MessagePackHubProtocolTests.cs @@ -203,4 +203,67 @@ public void WriteMessages(string testDataName) TestWriteMessages(testData); } + + public static IDictionary ClientResultData => new[] + { + new ClientResultTestData("SimpleResult", "lQOAo3h5egMq", typeof(int), 42), + new ClientResultTestData("NullResult", "lQOAo3h5egPA", typeof(CustomObject), null), + + new ClientResultTestData("ComplexResult", "lQOAo3h5egOGqlN0cmluZ1Byb3CoU2lnbmFsUiGqRG91YmxlUHJvcMtAGSH7VELPEqdJbnRQcm9wKqxEYXRlVGltZVByb3DW/1jsHICoTnVsbFByb3DAq0J5dGVBcnJQcm9wxAMBAgM=", typeof(CustomObject), + new CustomObject()), + }.ToDictionary(t => t.Name); + + public static IEnumerable ClientResultDataNames => ClientResultData.Keys.Select(name => new object[] { name }); + + [Theory] + [MemberData(nameof(ClientResultDataNames))] + public void RawResultRoundTripsProperly(string testDataName) + { + var testData = ClientResultData[testDataName]; + var bytes = Convert.FromBase64String(testData.Message); + + var binder = new TestBinder(null, typeof(RawResult)); + var input = Frame(bytes); + var data = new ReadOnlySequence(input); + Assert.True(HubProtocol.TryParseMessage(ref data, binder, out var message)); + var completion = Assert.IsType(message); + + var writer = MemoryBufferWriter.Get(); + try + { + // WriteMessage should handle RawResult as Raw Json and write it properly + HubProtocol.WriteMessage(completion, writer); + + // Now we check if the Raw Json was written properly and can be read using the expected type + binder = new TestBinder(null, testData.ResultType); + var written = writer.ToArray(); + data = new ReadOnlySequence(written); + Assert.True(HubProtocol.TryParseMessage(ref data, binder, out message)); + + completion = Assert.IsType(message); + Assert.Equal(testData.Result, completion.Result); + } + finally + { + MemoryBufferWriter.Return(writer); + } + } + + public class ClientResultTestData + { + public string Name { get; } + public string Message { get; } + public Type ResultType { get; } + public object Result { get; } + + public ClientResultTestData(string name, string message, Type resultType, object result) + { + Name = name; + Message = message; + ResultType = resultType; + Result = result; + } + + public override string ToString() => Name; + } } diff --git a/src/SignalR/perf/Microbenchmarks/DefaultHubDispatcherBenchmark.cs b/src/SignalR/perf/Microbenchmarks/DefaultHubDispatcherBenchmark.cs index 896088fe63c5..bdbf8443d752 100644 --- a/src/SignalR/perf/Microbenchmarks/DefaultHubDispatcherBenchmark.cs +++ b/src/SignalR/perf/Microbenchmarks/DefaultHubDispatcherBenchmark.cs @@ -29,13 +29,15 @@ public void GlobalSetup() var serviceScopeFactory = provider.GetService(); + var hubLifetimeManager = new DefaultHubLifetimeManager(NullLogger>.Instance); _dispatcher = new DefaultHubDispatcher( serviceScopeFactory, - new HubContext(new DefaultHubLifetimeManager(NullLogger>.Instance)), + new HubContext(hubLifetimeManager), enableDetailedErrors: false, disableImplicitFromServiceParameters: true, new Logger>(NullLoggerFactory.Instance), - hubFilters: null); + hubFilters: null, + hubLifetimeManager); var pair = DuplexPipe.CreateConnectionPair(PipeOptions.Default, PipeOptions.Default); var connection = new DefaultConnectionContext(Guid.NewGuid().ToString(), pair.Application, pair.Transport); diff --git a/src/SignalR/perf/Microbenchmarks/RedisProtocolBenchmark.cs b/src/SignalR/perf/Microbenchmarks/RedisProtocolBenchmark.cs index 5f92a6dddb53..98c38dbd390b 100644 --- a/src/SignalR/perf/Microbenchmarks/RedisProtocolBenchmark.cs +++ b/src/SignalR/perf/Microbenchmarks/RedisProtocolBenchmark.cs @@ -45,8 +45,8 @@ public void GlobalSetup() _writtenAck = RedisProtocol.WriteAck(42); _writtenGroupCommand = RedisProtocol.WriteGroupCommand(_groupCommand); _writtenInvocationNoExclusions = _protocol.WriteInvocation(_methodName, _args, null); - _writtenInvocationSmallExclusions = _protocol.WriteInvocation(_methodName, _args, _excludedConnectionIdsSmall); - _writtenInvocationLargeExclusions = _protocol.WriteInvocation(_methodName, _args, _excludedConnectionIdsLarge); + _writtenInvocationSmallExclusions = _protocol.WriteInvocation(_methodName, _args, excludedConnectionIds: _excludedConnectionIdsSmall); + _writtenInvocationLargeExclusions = _protocol.WriteInvocation(_methodName, _args, excludedConnectionIds: _excludedConnectionIdsLarge); } [Benchmark] @@ -70,13 +70,13 @@ public void WriteInvocationNoExclusions() [Benchmark] public void WriteInvocationSmallExclusions() { - _protocol.WriteInvocation(_methodName, _args, _excludedConnectionIdsSmall); + _protocol.WriteInvocation(_methodName, _args, excludedConnectionIds: _excludedConnectionIdsSmall); } [Benchmark] public void WriteInvocationLargeExclusions() { - _protocol.WriteInvocation(_methodName, _args, _excludedConnectionIdsLarge); + _protocol.WriteInvocation(_methodName, _args, excludedConnectionIds: _excludedConnectionIdsLarge); } [Benchmark] diff --git a/src/SignalR/samples/ClientSample/HubSample.cs b/src/SignalR/samples/ClientSample/HubSample.cs index f77b9c3a9a8d..5b69699c2f64 100644 --- a/src/SignalR/samples/ClientSample/HubSample.cs +++ b/src/SignalR/samples/ClientSample/HubSample.cs @@ -37,10 +37,10 @@ public static async Task ExecuteAsync(string baseUrl) logging.AddConsole(); }); - connectionBuilder.Services.Configure(options => - { - options.MinLevel = LogLevel.Trace; - }); + //connectionBuilder.Services.Configure(options => + //{ + // options.MinLevel = LogLevel.Trace; + //}); if (uri.Scheme == "net.tcp") { @@ -66,7 +66,51 @@ public static async Task ExecuteAsync(string baseUrl) }; // Set up handler - connection.On("Send", Console.WriteLine); + connection.On("GetNumber", () => + { + Console.WriteLine("Provide an integer:"); + return Task.FromResult(int.Parse(Console.ReadLine(), System.Globalization.NumberFormatInfo.InvariantInfo)); + }); + + connection.On("g", (string s, int r) => + { + return Task.FromResult(1); + }); + + connection.On("g", () => + { + return 1; + }); + + connection.On("g", (string s) => + { + return 1; + }); + + connection.On("g", async (string s) => + { + await Task.CompletedTask; + return 1; + }); + + connection.On("g", async () => + { + await Task.CompletedTask; + return 1; + }); + + connection.On("g", async (string s, int r) => + { + await Task.CompletedTask; + return 1; + }); + + connection.On("g", (string s, int r) => + { + return Task.FromResult(1); + }); + + connection.On("Result", r => Console.WriteLine($"Result: {r}")); connection.Closed += e => { @@ -81,39 +125,49 @@ public static async Task ExecuteAsync(string baseUrl) return 0; } + await connection.SendAsync("AddPlayer"); + Console.WriteLine("Connected to {0}", uri); + Console.WriteLine(connection.ConnectionId); + + var wait = new TaskCompletionSource(); + closedTokenSource.Token.Register(() => + { + wait.SetResult(null); + }); + await wait.Task; // Handle the connected connection - while (true) - { - // If the underlying connection closes while waiting for user input, the user will not observe - // the connection close aside from "Connection closed..." being printed to the console. That's - // because cancelling Console.ReadLine() is a royal pain. - var line = Console.ReadLine(); - - if (line == null || closedTokenSource.Token.IsCancellationRequested) - { - Console.WriteLine("Exiting..."); - break; - } - - try - { - await connection.InvokeAsync("Send", line); - } - catch when (closedTokenSource.IsCancellationRequested) - { - // We're shutting down the client - Console.WriteLine("Failed to send '{0}' because the CancelKeyPress event fired first. Exiting...", line); - break; - } - catch (Exception ex) - { - // Send could have failed because the connection closed - // Continue to loop because we should be reconnecting. - Console.WriteLine(ex); - } - } + //while (true) + //{ + // // If the underlying connection closes while waiting for user input, the user will not observe + // // the connection close aside from "Connection closed..." being printed to the console. That's + // // because cancelling Console.ReadLine() is a royal pain. + // var line = Console.ReadLine(); + + // if (line == null || closedTokenSource.Token.IsCancellationRequested) + // { + // Console.WriteLine("Exiting..."); + // break; + // } + + // try + // { + // await connection.InvokeAsync("Send", line); + // } + // catch when (closedTokenSource.IsCancellationRequested) + // { + // // We're shutting down the client + // Console.WriteLine("Failed to send '{0}' because the CancelKeyPress event fired first. Exiting...", line); + // break; + // } + // catch (Exception ex) + // { + // // Send could have failed because the connection closed + // // Continue to loop because we should be reconnecting. + // Console.WriteLine(ex); + // } + //} } finally { diff --git a/src/SignalR/samples/SignalRSamples/Hubs/Chat.cs b/src/SignalR/samples/SignalRSamples/Hubs/Chat.cs index 47484e2cf3ef..d1b7aa157075 100644 --- a/src/SignalR/samples/SignalRSamples/Hubs/Chat.cs +++ b/src/SignalR/samples/SignalRSamples/Hubs/Chat.cs @@ -19,9 +19,14 @@ public override Task OnDisconnectedAsync(Exception exception) return Clients.All.SendAsync("Send", $"{name} left the chat"); } - public Task Send(string name, string message) + public async Task Send(string name, string message) { - return Clients.All.SendAsync("Send", $"{name}: {message}"); + var c = Clients.Single(Context.ConnectionId); + _ = Task.Run(async () => + { + var i = await c.InvokeAsync("F"); + }); + await Clients.All.SendAsync("Send", $"{name}: {message}"); } public Task SendToOthers(string name, string message) diff --git a/src/SignalR/samples/SignalRSamples/Hubs/GameHub.cs b/src/SignalR/samples/SignalRSamples/Hubs/GameHub.cs new file mode 100644 index 000000000000..9d6f100d0b0b --- /dev/null +++ b/src/SignalR/samples/SignalRSamples/Hubs/GameHub.cs @@ -0,0 +1,24 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. + +using Microsoft.AspNetCore.SignalR; + +namespace SignalRSamples.Hubs; + +public class GameHub : Hub +{ + private readonly Game _game; + + public GameHub(Game game) + { + _game = game; + } + + public Task AddPlayer() + { + //_ = await Clients.Caller.InvokeAsync("GetNumber"); + //Clients.Caller.InvokeClientAsync(); + _game.AddPlayer(Context.ConnectionId); + return Task.CompletedTask; + } +} diff --git a/src/SignalR/samples/SignalRSamples/Program.cs b/src/SignalR/samples/SignalRSamples/Program.cs index 92e56f306369..3675e5c32b37 100644 --- a/src/SignalR/samples/SignalRSamples/Program.cs +++ b/src/SignalR/samples/SignalRSamples/Program.cs @@ -25,17 +25,18 @@ public static Task Main(string[] args) { factory.AddConfiguration(c.Configuration.GetSection("Logging")); factory.AddConsole(); + //factory.SetMinimumLevel(LogLevel.Trace); }) .UseKestrel(options => { // Default port - options.ListenLocalhost(5000); + options.ListenAnyIP(0); // Hub bound to TCP end point - options.Listen(IPAddress.Any, 9001, builder => - { - builder.UseHub(); - }); + //options.Listen(IPAddress.Any, 9001, builder => + //{ + // builder.UseHub(); + //}); }) .UseContentRoot(Directory.GetCurrentDirectory()) .UseIISIntegration() diff --git a/src/SignalR/samples/SignalRSamples/Startup.cs b/src/SignalR/samples/SignalRSamples/Startup.cs index 5a3d67e481c3..fa48bbce1251 100644 --- a/src/SignalR/samples/SignalRSamples/Startup.cs +++ b/src/SignalR/samples/SignalRSamples/Startup.cs @@ -3,6 +3,7 @@ using System.Reflection; using System.Text.Json; +using Microsoft.AspNetCore.SignalR; using SignalRSamples.ConnectionHandlers; using SignalRSamples.Hubs; @@ -18,9 +19,11 @@ public void ConfigureServices(IServiceCollection services) { services.AddConnections(); - services.AddSignalR() - .AddMessagePackProtocol(); - //.AddStackExchangeRedis(); + services.AddSignalR(o => o.MaximumParallelInvocationsPerClient = 2) + .AddMessagePackProtocol() + .AddStackExchangeRedis(); + + services.AddSingleton(); } // This method gets called by the runtime. Use this method to configure the HTTP request pipeline. @@ -39,11 +42,17 @@ public void Configure(IApplicationBuilder app, IWebHostEnvironment env) app.UseEndpoints(endpoints => { + endpoints.MapGet("/start", (Game game, string connection1, string connection2) => + { + _ = game.GameLoop(connection1, connection2); + }); + endpoints.MapHub("/dynamic"); endpoints.MapHub("/default"); endpoints.MapHub("/streaming"); endpoints.MapHub("/uploading"); endpoints.MapHub("/hubT"); + endpoints.MapHub("/game"); endpoints.MapConnectionHandler("/chat"); @@ -79,3 +88,55 @@ public void Configure(IApplicationBuilder app, IWebHostEnvironment env) }); } } + +// Nurdle +public class Game +{ + public string Player1Id { get; set; } + public string Player2Id { get; set; } + + private readonly IHubContext _hubContext; + + public Game(IHubContext hubContext) + { + _hubContext = hubContext; + } + + public void AddPlayer(string Id) + { + if (string.IsNullOrEmpty(Player1Id)) + { + Player1Id = Id; + } + else + { + Player2Id = Id; + } + } + + public async Task GameLoop(string connection1, string connection2) + { + var randomAnswer = Random.Shared.Next(2, 10); + var res = 0; + + do + { + await Task.Delay(1000); + var task1 = _hubContext.Clients.Single(connection1).InvokeAsync("GetNumber"); + var task2 = _hubContext.Clients.Single(connection2).InvokeAsync("GetNumber"); + res = (await task1) + (await task2); + + if (res < randomAnswer) + { + await _hubContext.Clients.Clients(connection1, connection2).SendAsync("Result", $"Guessed {res} which is too low"); + } + else if (res > randomAnswer) + { + await _hubContext.Clients.Clients(connection1, connection2).SendAsync("Result", $"Guessed {res} which is too high"); + } + } + while (res != randomAnswer); + + await _hubContext.Clients.Clients(connection1, connection2).SendAsync("Result", $"Guessed {res} which is correct!"); + } +} diff --git a/src/SignalR/samples/SignalRSamples/wwwroot/hubs.html b/src/SignalR/samples/SignalRSamples/wwwroot/hubs.html index a7a18a02450c..efdca8d3d367 100644 --- a/src/SignalR/samples/SignalRSamples/wwwroot/hubs.html +++ b/src/SignalR/samples/SignalRSamples/wwwroot/hubs.html @@ -146,7 +146,7 @@

Group Actions

return; } - let hubRoute = hubTypeDropdown.value || "default"; + let hubRoute = "game"; let protocol = protocolDropdown.value === "msgpack" ? new signalR.protocols.msgpack.MessagePackHubProtocol() : new signalR.JsonHubProtocol(); @@ -174,6 +174,16 @@

Group Actions

addLine('message-list', msg); }); + connection.on('F', function () { + return new Promise((resolve, reject) => { + setTimeout(() => resolve(2), 5000); + }); + }); + + connection.on('GetNumber', function () { + return 2; + }); + connection.onclose(function (e) { if (e) { addLine('message-list', 'Connection closed with error: ' + e, 'red'); diff --git a/src/SignalR/server/Core/src/ClientProxyExtensions.cs b/src/SignalR/server/Core/src/ClientProxyExtensions.cs index 6b23cf59b3c2..fe01852358d3 100644 --- a/src/SignalR/server/Core/src/ClientProxyExtensions.cs +++ b/src/SignalR/server/Core/src/ClientProxyExtensions.cs @@ -218,4 +218,213 @@ public static Task SendAsync(this IClientProxy clientProxy, string method, objec { return clientProxy.SendCoreAsync(method, new[] { arg1, arg2, arg3, arg4, arg5, arg6, arg7, arg8, arg9, arg10 }, cancellationToken); } + + /// + /// Invokes a method on the connection(s) represented by the instance. + /// Does not wait for a response from the receiver. + /// + /// The + /// The name of the method to invoke. + /// The token to monitor for cancellation requests. The default value is . + /// A that represents the asynchronous invoke. + [SuppressMessage("ApiDesign", "RS0026:Do not add multiple overloads with optional parameters", Justification = "Required to maintain compatibility")] + public static Task InvokeAsync(this ISingleClientProxy clientProxy, string method, CancellationToken cancellationToken = default) + { + return clientProxy.InvokeCoreAsync(method, Array.Empty(), cancellationToken); + } + + /// + /// Invokes a method on the connection(s) represented by the instance. + /// Does not wait for a response from the receiver. + /// + /// The + /// The name of the method to invoke. + /// The first argument. + /// The token to monitor for cancellation requests. The default value is . + /// A that represents the asynchronous invoke. + [SuppressMessage("ApiDesign", "RS0026:Do not add multiple overloads with optional parameters", Justification = "Required to maintain compatibility")] + public static Task InvokeAsync(this ISingleClientProxy clientProxy, string method, object? arg1, CancellationToken cancellationToken = default) + { + return clientProxy.InvokeCoreAsync(method, new[] { arg1 }, cancellationToken); + } + + /// + /// Invokes a method on the connection(s) represented by the instance. + /// Does not wait for a response from the receiver. + /// + /// The + /// The name of the method to invoke. + /// The first argument. + /// The second argument. + /// The token to monitor for cancellation requests. The default value is . + /// A that represents the asynchronous invoke. + [SuppressMessage("ApiDesign", "RS0026:Do not add multiple overloads with optional parameters", Justification = "Required to maintain compatibility")] + public static Task InvokeAsync(this ISingleClientProxy clientProxy, string method, object? arg1, object? arg2, CancellationToken cancellationToken = default) + { + return clientProxy.InvokeCoreAsync(method, new[] { arg1, arg2 }, cancellationToken); + } + + /// + /// Invokes a method on the connection(s) represented by the instance. + /// Does not wait for a response from the receiver. + /// + /// The + /// The name of the method to invoke. + /// The first argument. + /// The second argument. + /// The third argument. + /// The token to monitor for cancellation requests. The default value is . + /// A that represents the asynchronous invoke. + [SuppressMessage("ApiDesign", "RS0026:Do not add multiple overloads with optional parameters", Justification = "Required to maintain compatibility")] + public static Task InvokeAsync(this ISingleClientProxy clientProxy, string method, object? arg1, object? arg2, object? arg3, CancellationToken cancellationToken = default) + { + return clientProxy.InvokeCoreAsync(method, new[] { arg1, arg2, arg3 }, cancellationToken); + } + + /// + /// Invokes a method on the connection(s) represented by the instance. + /// Does not wait for a response from the receiver. + /// + /// The + /// The name of the method to invoke. + /// The first argument. + /// The second argument. + /// The third argument. + /// The fourth argument. + /// The token to monitor for cancellation requests. The default value is . + /// A that represents the asynchronous invoke. + [SuppressMessage("ApiDesign", "RS0026:Do not add multiple overloads with optional parameters", Justification = "Required to maintain compatibility")] + public static Task InvokeAsync(this ISingleClientProxy clientProxy, string method, object? arg1, object? arg2, object? arg3, object? arg4, CancellationToken cancellationToken = default) + { + return clientProxy.InvokeCoreAsync(method, new[] { arg1, arg2, arg3, arg4 }, cancellationToken); + } + + /// + /// Invokes a method on the connection(s) represented by the instance. + /// Does not wait for a response from the receiver. + /// + /// The + /// The name of the method to invoke. + /// The first argument. + /// The second argument. + /// The third argument. + /// The fourth argument. + /// The fifth argument. + /// The token to monitor for cancellation requests. The default value is . + /// A that represents the asynchronous invoke. + [SuppressMessage("ApiDesign", "RS0026:Do not add multiple overloads with optional parameters", Justification = "Required to maintain compatibility")] + public static Task InvokeAsync(this ISingleClientProxy clientProxy, string method, object? arg1, object? arg2, object? arg3, object? arg4, object? arg5, CancellationToken cancellationToken = default) + { + return clientProxy.InvokeCoreAsync(method, new[] { arg1, arg2, arg3, arg4, arg5 }, cancellationToken); + } + + /// + /// Invokes a method on the connection(s) represented by the instance. + /// Does not wait for a response from the receiver. + /// + /// The + /// The name of the method to invoke. + /// The first argument. + /// The second argument. + /// The third argument. + /// The fourth argument. + /// The fifth argument. + /// The sixth argument. + /// The token to monitor for cancellation requests. The default value is . + /// A that represents the asynchronous invoke. + [SuppressMessage("ApiDesign", "RS0026:Do not add multiple overloads with optional parameters", Justification = "Required to maintain compatibility")] + public static Task InvokeAsync(this ISingleClientProxy clientProxy, string method, object? arg1, object? arg2, object? arg3, object? arg4, object? arg5, object? arg6, CancellationToken cancellationToken = default) + { + return clientProxy.InvokeCoreAsync(method, new[] { arg1, arg2, arg3, arg4, arg5, arg6 }, cancellationToken); + } + + /// + /// Invokes a method on the connection(s) represented by the instance. + /// Does not wait for a response from the receiver. + /// + /// The + /// The name of the method to invoke. + /// The first argument. + /// The second argument. + /// The third argument. + /// The fourth argument. + /// The fifth argument. + /// The sixth argument. + /// The seventh argument. + /// The token to monitor for cancellation requests. The default value is . + /// A that represents the asynchronous invoke. + [SuppressMessage("ApiDesign", "RS0026:Do not add multiple overloads with optional parameters", Justification = "Required to maintain compatibility")] + public static Task InvokeAsync(this ISingleClientProxy clientProxy, string method, object? arg1, object? arg2, object? arg3, object? arg4, object? arg5, object? arg6, object? arg7, CancellationToken cancellationToken = default) + { + return clientProxy.InvokeCoreAsync(method, new[] { arg1, arg2, arg3, arg4, arg5, arg6, arg7 }, cancellationToken); + } + + /// + /// Invokes a method on the connection(s) represented by the instance. + /// Does not wait for a response from the receiver. + /// + /// The + /// The name of the method to invoke. + /// The first argument. + /// The second argument. + /// The third argument. + /// The fourth argument. + /// The fifth argument. + /// The sixth argument. + /// The seventh argument. + /// The eigth argument. + /// The token to monitor for cancellation requests. The default value is . + /// A that represents the asynchronous invoke. + [SuppressMessage("ApiDesign", "RS0026:Do not add multiple overloads with optional parameters", Justification = "Required to maintain compatibility")] + public static Task InvokeAsync(this ISingleClientProxy clientProxy, string method, object? arg1, object? arg2, object? arg3, object? arg4, object? arg5, object? arg6, object? arg7, object? arg8, CancellationToken cancellationToken = default) + { + return clientProxy.InvokeCoreAsync(method, new[] { arg1, arg2, arg3, arg4, arg5, arg6, arg7, arg8 }, cancellationToken); + } + + /// + /// Invokes a method on the connection(s) represented by the instance. + /// Does not wait for a response from the receiver. + /// + /// The + /// The name of the method to invoke. + /// The first argument. + /// The second argument. + /// The third argument. + /// The fourth argument. + /// The fifth argument. + /// The sixth argument. + /// The seventh argument. + /// The eigth argument. + /// The ninth argument. + /// The token to monitor for cancellation requests. The default value is . + /// A that represents the asynchronous invoke. + [SuppressMessage("ApiDesign", "RS0026:Do not add multiple overloads with optional parameters", Justification = "Required to maintain compatibility")] + public static Task InvokeAsync(this ISingleClientProxy clientProxy, string method, object? arg1, object? arg2, object? arg3, object? arg4, object? arg5, object? arg6, object? arg7, object? arg8, object? arg9, CancellationToken cancellationToken = default) + { + return clientProxy.InvokeCoreAsync(method, new[] { arg1, arg2, arg3, arg4, arg5, arg6, arg7, arg8, arg9 }, cancellationToken); + } + + /// + /// Invokes a method on the connection(s) represented by the instance. + /// Does not wait for a response from the receiver. + /// + /// The + /// The name of the method to invoke. + /// The first argument. + /// The second argument. + /// The third argument. + /// The fourth argument. + /// The fifth argument. + /// The sixth argument. + /// The seventh argument. + /// The eigth argument. + /// The ninth argument. + /// The tenth argument. + /// The token to monitor for cancellation requests. The default value is . + /// A that represents the asynchronous invoke. + [SuppressMessage("ApiDesign", "RS0026:Do not add multiple overloads with optional parameters", Justification = "Required to maintain compatibility")] + public static Task InvokeAsync(this ISingleClientProxy clientProxy, string method, object? arg1, object? arg2, object? arg3, object? arg4, object? arg5, object? arg6, object? arg7, object? arg8, object? arg9, object? arg10, CancellationToken cancellationToken = default) + { + return clientProxy.InvokeCoreAsync(method, new[] { arg1, arg2, arg3, arg4, arg5, arg6, arg7, arg8, arg9, arg10 }, cancellationToken); + } } diff --git a/src/SignalR/server/Core/src/DefaultHubLifetimeManager.cs b/src/SignalR/server/Core/src/DefaultHubLifetimeManager.cs index 18e4891a16b9..53474856632c 100644 --- a/src/SignalR/server/Core/src/DefaultHubLifetimeManager.cs +++ b/src/SignalR/server/Core/src/DefaultHubLifetimeManager.cs @@ -2,6 +2,9 @@ // The .NET Foundation licenses this file to you under the MIT license. using System.Collections.Concurrent; +using System.Diagnostics; +using System.Diagnostics.CodeAnalysis; +using System.Globalization; using System.Linq; using Microsoft.AspNetCore.SignalR.Internal; using Microsoft.AspNetCore.SignalR.Protocol; @@ -17,6 +20,8 @@ public class DefaultHubLifetimeManager : HubLifetimeManager where TH private readonly HubConnectionStore _connections = new HubConnectionStore(); private readonly HubGroupList _groups = new HubGroupList(); private readonly ILogger _logger; + private readonly ClientResultsManager _clientResultsManager = new(); + private ulong _lastInvocationId; /// /// Initializes a new instance of the class. @@ -294,6 +299,12 @@ public override Task OnDisconnectedAsync(HubConnectionContext connection) { _connections.Remove(connection); _groups.RemoveDisconnectedConnection(connection.ConnectionId); + + List? pendingTasks = null; + _clientResultsManager.CleanupConnection(connection.ConnectionId, pendingTasks); + // Completions should be synchronous for DefaultHubLifetimeManager + Debug.Assert(pendingTasks is null); + return Task.CompletedTask; } @@ -314,4 +325,63 @@ public override Task SendUsersAsync(IReadOnlyList userIds, string method { return SendToAllConnections(methodName, args, (connection, state) => ((IReadOnlyList)state!).Contains(connection.UserIdentifier), userIds, cancellationToken); } + + /// + public override async Task InvokeConnectionAsync(string connectionId, string methodName, object?[] args, CancellationToken cancellationToken = default) + { + if (connectionId == null) + { + throw new ArgumentNullException(nameof(connectionId)); + } + + var connection = _connections[connectionId]; + + if (connection == null) + { + throw new InvalidOperationException("Connection does not exist."); + } + + var invocationId = Interlocked.Increment(ref _lastInvocationId).ToString(NumberFormatInfo.InvariantInfo); + var task = _clientResultsManager.AddInvocation(connectionId, invocationId, cancellationToken); + // Connection disconnected while adding invocation + // we need to try to remove it here to avoid the task hanging if the add happened after the connection cleanup + if (connection.ConnectionAborted.IsCancellationRequested) + { + await _clientResultsManager.TryCompleteResult(connectionId, CompletionMessage.WithError(invocationId, "Connection disconnected")); + return await task; + } + + try + { + // We're sending to a single connection + // Write message directly to connection without caching it in memory + var message = new InvocationMessage(invocationId, methodName, args); + + await connection.WriteAsync(message, cancellationToken).AsTask(); + } + catch + { + _clientResultsManager.RemoveInvocation(invocationId); + throw; + } + + return await task; + } + + /// + public override Task SetConnectionResultAsync(string connectionId, CompletionMessage result) + { + return _clientResultsManager.TryCompleteResult(connectionId, result); + } + + /// + public override bool TryGetReturnType(string invocationId, [NotNullWhen(true)] out Type? type) + { + if (_clientResultsManager.TryGetType(invocationId, out type)) + { + return true; + } + type = null; + return false; + } } diff --git a/src/SignalR/server/Core/src/HubConnectionHandler.cs b/src/SignalR/server/Core/src/HubConnectionHandler.cs index dfdcbe3ce822..3bb5566e6a7a 100644 --- a/src/SignalR/server/Core/src/HubConnectionHandler.cs +++ b/src/SignalR/server/Core/src/HubConnectionHandler.cs @@ -96,7 +96,8 @@ IServiceScopeFactory serviceScopeFactory _enableDetailedErrors, disableImplicitFromServiceParameters, new Logger>(loggerFactory), - hubFilters); + hubFilters, + lifetimeManager); } /// @@ -240,7 +241,7 @@ private async Task DispatchMessagesAsync(HubConnectionContext connection) var protocol = connection.Protocol; connection.BeginClientTimeout(); - var binder = new HubConnectionBinder(_dispatcher, connection); + var binder = new HubConnectionBinder(_dispatcher, _lifetimeManager, connection); while (true) { diff --git a/src/SignalR/server/Core/src/HubLifetimeManager.cs b/src/SignalR/server/Core/src/HubLifetimeManager.cs index 140e0aea1672..c5050dbfde57 100644 --- a/src/SignalR/server/Core/src/HubLifetimeManager.cs +++ b/src/SignalR/server/Core/src/HubLifetimeManager.cs @@ -1,6 +1,9 @@ // Licensed to the .NET Foundation under one or more agreements. // The .NET Foundation licenses this file to you under the MIT license. +using System.Diagnostics.CodeAnalysis; +using Microsoft.AspNetCore.SignalR.Protocol; + namespace Microsoft.AspNetCore.SignalR; /// @@ -131,4 +134,43 @@ public abstract class HubLifetimeManager where THub : Hub /// The token to monitor for cancellation requests. The default value is . /// A that represents the asynchronous remove. public abstract Task RemoveFromGroupAsync(string connectionId, string groupName, CancellationToken cancellationToken = default); + + /// + /// + /// + /// + /// + /// + /// + /// + /// + /// + public virtual Task InvokeConnectionAsync(string connectionId, string methodName, object?[] args, CancellationToken cancellationToken = default) + { + throw new NotImplementedException(); + } + + /// + /// + /// + /// + /// + /// + /// + public virtual Task SetConnectionResultAsync(string connectionId, CompletionMessage result) + { + throw new NotImplementedException(); + } + + /// + /// + /// + /// + /// + /// + public virtual bool TryGetReturnType(string invocationId, [NotNullWhen(true)] out Type? type) + { + type = null; + return false; + } } diff --git a/src/SignalR/server/Core/src/IHubCallerClients.cs b/src/SignalR/server/Core/src/IHubCallerClients.cs index 0a9fdaa6b404..2fb87c207552 100644 --- a/src/SignalR/server/Core/src/IHubCallerClients.cs +++ b/src/SignalR/server/Core/src/IHubCallerClients.cs @@ -1,4 +1,4 @@ -// Licensed to the .NET Foundation under one or more agreements. +// Licensed to the .NET Foundation under one or more agreements. // The .NET Foundation licenses this file to you under the MIT license. namespace Microsoft.AspNetCore.SignalR; @@ -6,4 +6,13 @@ namespace Microsoft.AspNetCore.SignalR; /// /// A clients caller abstraction for a hub. /// -public interface IHubCallerClients : IHubCallerClients { } +public interface IHubCallerClients : IHubCallerClients +{ + /// + /// + /// + /// + /// + /// + new ISingleClientProxy Single(string connectionId) => throw new NotImplementedException(); +} diff --git a/src/SignalR/server/Core/src/IHubClients.cs b/src/SignalR/server/Core/src/IHubClients.cs index 1f4299a83d2e..06ecf8a606cb 100644 --- a/src/SignalR/server/Core/src/IHubClients.cs +++ b/src/SignalR/server/Core/src/IHubClients.cs @@ -6,4 +6,13 @@ namespace Microsoft.AspNetCore.SignalR; /// /// An abstraction that provides access to client connections. /// -public interface IHubClients : IHubClients { } +public interface IHubClients : IHubClients +{ + /// + /// + /// + /// + /// + /// + new ISingleClientProxy Single(string connectionId) => throw new NotImplementedException(); +} diff --git a/src/SignalR/server/Core/src/IHubClients`T.cs b/src/SignalR/server/Core/src/IHubClients`T.cs index 9f792d51a224..0479d9fb87e6 100644 --- a/src/SignalR/server/Core/src/IHubClients`T.cs +++ b/src/SignalR/server/Core/src/IHubClients`T.cs @@ -1,4 +1,4 @@ -// Licensed to the .NET Foundation under one or more agreements. +// Licensed to the .NET Foundation under one or more agreements. // The .NET Foundation licenses this file to you under the MIT license. namespace Microsoft.AspNetCore.SignalR; @@ -9,6 +9,14 @@ namespace Microsoft.AspNetCore.SignalR; /// The client invoker type. public interface IHubClients { + /// + /// + /// + /// + /// + /// + T Single(string connectionId) => throw new NotImplementedException(); + /// /// Gets a that can be used to invoke methods on all clients connected to the hub. /// @@ -72,4 +80,3 @@ public interface IHubClients /// A client caller. T Users(IReadOnlyList userIds); } - diff --git a/src/SignalR/server/Core/src/ISingleClientProxy.cs b/src/SignalR/server/Core/src/ISingleClientProxy.cs new file mode 100644 index 000000000000..f077baeb08a7 --- /dev/null +++ b/src/SignalR/server/Core/src/ISingleClientProxy.cs @@ -0,0 +1,21 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. + +namespace Microsoft.AspNetCore.SignalR; + +/// +/// +/// +public interface ISingleClientProxy : IClientProxy +{ + /// + /// + /// + /// + /// + /// + /// + /// + /// + Task InvokeCoreAsync(string method, object?[] args, CancellationToken cancellationToken = default); +} diff --git a/src/SignalR/server/Core/src/Internal/DefaultHubDispatcher.cs b/src/SignalR/server/Core/src/Internal/DefaultHubDispatcher.cs index b3f51f7fe4bb..15eb40a80aa0 100644 --- a/src/SignalR/server/Core/src/Internal/DefaultHubDispatcher.cs +++ b/src/SignalR/server/Core/src/Internal/DefaultHubDispatcher.cs @@ -26,14 +26,16 @@ internal partial class DefaultHubDispatcher : HubDispatcher where TH private readonly Func>? _invokeMiddleware; private readonly Func? _onConnectedMiddleware; private readonly Func? _onDisconnectedMiddleware; + private readonly HubLifetimeManager _hubLifetimeManager; public DefaultHubDispatcher(IServiceScopeFactory serviceScopeFactory, IHubContext hubContext, bool enableDetailedErrors, - bool disableImplicitFromServiceParameters, ILogger> logger, List? hubFilters) + bool disableImplicitFromServiceParameters, ILogger> logger, List? hubFilters, HubLifetimeManager lifetimeManager) { _serviceScopeFactory = serviceScopeFactory; _hubContext = hubContext; _enableDetailedErrors = enableDetailedErrors; _logger = logger; + _hubLifetimeManager = lifetimeManager; DiscoverHubMethods(disableImplicitFromServiceParameters); var count = hubFilters?.Count ?? 0; @@ -70,7 +72,7 @@ public DefaultHubDispatcher(IServiceScopeFactory serviceScopeFactory, IHubContex public override async Task OnConnectedAsync(HubConnectionContext connection) { await using var scope = _serviceScopeFactory.CreateAsyncScope(); - connection.HubCallerClients = new HubCallerClients(_hubContext.Clients, connection.ConnectionId); + connection.HubCallerClients = new HubCallerClients(_hubContext.Clients, connection.ConnectionId, connection.ActiveInvocationLimit is not null); var hubActivator = scope.ServiceProvider.GetRequiredService>(); var hub = hubActivator.Create(); @@ -165,19 +167,29 @@ public override Task DispatchMessageAsync(HubConnectionContext connection, HubMe case StreamItemMessage streamItem: return ProcessStreamItem(connection, streamItem); - case CompletionMessage streamCompleteMessage: + case CompletionMessage completionMessage: // closes channels, removes from Lookup dict // user's method can see the channel is complete and begin wrapping up - if (connection.StreamTracker.TryComplete(streamCompleteMessage)) + if (connection.StreamTracker.TryComplete(completionMessage)) { - Log.CompletingStream(_logger, streamCompleteMessage); + Log.CompletingStream(_logger, completionMessage); + } + // TODO: this relies on the lifetime manager keeping state for the return type after deserializing the message, is that ok? + // InvocationId is always required on CompletionMessage, it's nullable because of the base type + else if (_hubLifetimeManager.TryGetReturnType(completionMessage.InvocationId!, out _)) + { + _hubLifetimeManager.SetConnectionResultAsync(connection.ConnectionId, completionMessage); } else { + // TODO: Retire this log and replace with a more generic one Log.UnexpectedStreamCompletion(_logger); } break; + //case ClientResultMessage clientResultMessage: + // return _hubLifetimeManager.SetClientResult(clientResultMessage); + // Other kind of message we weren't expecting default: Log.UnsupportedMessageReceived(_logger, hubMessage.GetType().FullName!); @@ -247,7 +259,7 @@ private Task ProcessInvocation(HubConnectionContext connection, bool isStreamCall = descriptor.StreamingParameters != null; if (connection.ActiveInvocationLimit != null && !isStreamCall && !isStreamResponse) { - return connection.ActiveInvocationLimit.RunAsync(state => + return connection.ActiveInvocationLimit.RunAsync(static state => { var (dispatcher, descriptor, connection, invocationMessage) = state; return dispatcher.Invoke(descriptor, connection, invocationMessage, isStreamResponse: false, isStreamCall: false); diff --git a/src/SignalR/server/Core/src/Internal/HubCallerClients.cs b/src/SignalR/server/Core/src/Internal/HubCallerClients.cs index a07c83c373af..6d3a31332f30 100644 --- a/src/SignalR/server/Core/src/Internal/HubCallerClients.cs +++ b/src/SignalR/server/Core/src/Internal/HubCallerClients.cs @@ -8,12 +8,43 @@ internal class HubCallerClients : IHubCallerClients private readonly string _connectionId; private readonly IHubClients _hubClients; private readonly string[] _currentConnectionId; + private readonly bool _parallelEnabled; - public HubCallerClients(IHubClients hubClients, string connectionId) + public HubCallerClients(IHubClients hubClients, string connectionId, bool parallelEnabled) { _connectionId = connectionId; _hubClients = hubClients; _currentConnectionId = new[] { _connectionId }; + _parallelEnabled = parallelEnabled; + } + + private class NotParallelSingleClientProxy : ISingleClientProxy + { + private readonly ISingleClientProxy _proxy; + + public NotParallelSingleClientProxy(ISingleClientProxy hubClients) + { + _proxy = hubClients; + } + + public Task InvokeCoreAsync(string method, object?[] args, CancellationToken cancellationToken = default) + { + throw new InvalidOperationException("Client results inside a Hub method requires HubOptions.MaximumParallelInvocationsPerClient to be greater than 1."); + } + + public Task SendCoreAsync(string method, object?[] args, CancellationToken cancellationToken = default) + { + return _proxy.SendCoreAsync(method, args, cancellationToken); + } + } + + public ISingleClientProxy Single(string connectionId) + { + if (!_parallelEnabled) + { + return new NotParallelSingleClientProxy(_hubClients.Single(connectionId)); + } + return _hubClients.Single(connectionId); } public IClientProxy Caller => _hubClients.Client(_connectionId); diff --git a/src/SignalR/server/Core/src/Internal/HubClients.cs b/src/SignalR/server/Core/src/Internal/HubClients.cs index 0920e40ad6cf..203b3fe4ff92 100644 --- a/src/SignalR/server/Core/src/Internal/HubClients.cs +++ b/src/SignalR/server/Core/src/Internal/HubClients.cs @@ -1,4 +1,4 @@ -// Licensed to the .NET Foundation under one or more agreements. +// Licensed to the .NET Foundation under one or more agreements. // The .NET Foundation licenses this file to you under the MIT license. namespace Microsoft.AspNetCore.SignalR.Internal; @@ -13,6 +13,11 @@ public HubClients(HubLifetimeManager lifetimeManager) All = new AllClientProxy(_lifetimeManager); } + public ISingleClientProxy Single(string connectionId) + { + return new SingleClientProxyWithInvoke(_lifetimeManager, connectionId); + } + public IClientProxy All { get; } public IClientProxy AllExcept(IReadOnlyList excludedConnectionIds) diff --git a/src/SignalR/server/Core/src/Internal/HubConnectionBinder.cs b/src/SignalR/server/Core/src/Internal/HubConnectionBinder.cs index 2f69c7827329..dfde6a11d070 100644 --- a/src/SignalR/server/Core/src/Internal/HubConnectionBinder.cs +++ b/src/SignalR/server/Core/src/Internal/HubConnectionBinder.cs @@ -7,11 +7,13 @@ internal class HubConnectionBinder : IInvocationBinder where THub : Hub { private readonly HubDispatcher _dispatcher; private readonly HubConnectionContext _connection; + private readonly HubLifetimeManager _hubLifetimeManager; - public HubConnectionBinder(HubDispatcher dispatcher, HubConnectionContext connection) + public HubConnectionBinder(HubDispatcher dispatcher, HubLifetimeManager lifetimeManager, HubConnectionContext connection) { _dispatcher = dispatcher; _connection = connection; + _hubLifetimeManager = lifetimeManager; } public IReadOnlyList GetParameterTypes(string methodName) @@ -21,7 +23,11 @@ public IReadOnlyList GetParameterTypes(string methodName) public Type GetReturnType(string invocationId) { - return typeof(object); + if (_hubLifetimeManager.TryGetReturnType(invocationId, out var type)) + { + return type; + } + throw new InvalidOperationException("Unknown invocation ID."); } public Type GetStreamItemType(string streamId) diff --git a/src/SignalR/server/Core/src/Internal/Proxies.cs b/src/SignalR/server/Core/src/Internal/Proxies.cs index 951b8c7c8941..7dead0caf95a 100644 --- a/src/SignalR/server/Core/src/Internal/Proxies.cs +++ b/src/SignalR/server/Core/src/Internal/Proxies.cs @@ -155,3 +155,25 @@ public Task SendCoreAsync(string method, object?[] args, CancellationToken cance return _lifetimeManager.SendConnectionsAsync(_connectionIds, method, args, cancellationToken); } } + +internal class SingleClientProxyWithInvoke : ISingleClientProxy where THub : Hub +{ + private readonly string _connectionId; + private readonly HubLifetimeManager _lifetimeManager; + + public SingleClientProxyWithInvoke(HubLifetimeManager lifetimeManager, string connectionId) + { + _lifetimeManager = lifetimeManager; + _connectionId = connectionId; + } + + public Task SendCoreAsync(string method, object?[] args, CancellationToken cancellationToken = default) + { + return _lifetimeManager.SendConnectionAsync(_connectionId, method, args, cancellationToken); + } + + public Task InvokeCoreAsync(string method, object?[] args, CancellationToken cancellationToken = default) + { + return _lifetimeManager.InvokeConnectionAsync(_connectionId, method, args ?? Array.Empty(), cancellationToken); + } +} diff --git a/src/SignalR/server/Core/src/Internal/TypedClientBuilder.cs b/src/SignalR/server/Core/src/Internal/TypedClientBuilder.cs index 182153f31f47..ec8971c4134d 100644 --- a/src/SignalR/server/Core/src/Internal/TypedClientBuilder.cs +++ b/src/SignalR/server/Core/src/Internal/TypedClientBuilder.cs @@ -113,12 +113,25 @@ private static void BuildMethod(TypeBuilder type, MethodInfo interfaceMethodInfo var parameters = interfaceMethodInfo.GetParameters(); var paramTypes = parameters.Select(param => param.ParameterType).ToArray(); + var returnType = interfaceMethodInfo.ReturnType; + bool isInvoke = returnType != typeof(Task); var methodBuilder = type.DefineMethod(interfaceMethodInfo.Name, methodAttributes); - var invokeMethod = typeof(IClientProxy).GetMethod( - nameof(IClientProxy.SendCoreAsync), BindingFlags.Instance | BindingFlags.Public | BindingFlags.NonPublic, null, - new[] { typeof(string), typeof(object[]), typeof(CancellationToken) }, null)!; + MethodInfo invokeMethod; + if (isInvoke) + { + invokeMethod = typeof(ISingleClientProxy).GetMethod( + nameof(ISingleClientProxy.InvokeCoreAsync), BindingFlags.Instance | BindingFlags.Public | BindingFlags.NonPublic, null, + new[] { typeof(string), typeof(object[]), typeof(CancellationToken) }, null)! + .MakeGenericMethod(returnType.GenericTypeArguments); + } + else + { + invokeMethod = typeof(IClientProxy).GetMethod( + nameof(IClientProxy.SendCoreAsync), BindingFlags.Instance | BindingFlags.Public | BindingFlags.NonPublic, null, + new[] { typeof(string), typeof(object[]), typeof(CancellationToken) }, null)!; + } methodBuilder.SetReturnType(interfaceMethodInfo.ReturnType); methodBuilder.SetParameters(paramTypes); @@ -153,6 +166,25 @@ private static void BuildMethod(TypeBuilder type, MethodInfo interfaceMethodInfo generator.Emit(OpCodes.Ldarg_0); generator.Emit(OpCodes.Ldfld, proxyField); + var notTypeLabel = generator.DefineLabel(); + if (isInvoke) + { + var singleClientProxyType = typeof(ISingleClientProxy); + /* + if (_proxy is ISingleClientProxy singleClientProxy) + { + return ((ISingleClientProxy)_proxy).InvokeAsync(methodName, args, cancellationToken); + } + throw new InvalidOperationException("InvokeAsync only works with Single clients."); + */ + generator.Emit(OpCodes.Isinst, singleClientProxyType); + generator.Emit(OpCodes.Brfalse_S, notTypeLabel); + + generator.Emit(OpCodes.Ldarg_0); + generator.Emit(OpCodes.Ldfld, proxyField); + generator.Emit(OpCodes.Castclass, singleClientProxyType); + } + // The first argument to IClientProxy.SendCoreAsync is this method's name generator.Emit(OpCodes.Ldstr, methodName); @@ -189,6 +221,12 @@ private static void BuildMethod(TypeBuilder type, MethodInfo interfaceMethodInfo generator.Emit(OpCodes.Callvirt, invokeMethod); generator.Emit(OpCodes.Ret); // Return the Task returned by 'invokeMethod' + + // Used by InvokeAsync to check if it's being called with ISingleClientProxy otherwise throws + generator.MarkLabel(notTypeLabel); + generator.Emit(OpCodes.Ldstr, "InvokeAsync only works with Single clients."); + generator.Emit(OpCodes.Newobj, typeof(InvalidOperationException).GetConstructor(new Type[] { typeof(string) })!); + generator.Emit(OpCodes.Throw); } private static void BuildFactoryMethod(TypeBuilder type, ConstructorInfo ctor) @@ -232,10 +270,10 @@ private static void VerifyInterface(Type interfaceType) private static void VerifyMethod(MethodInfo interfaceMethod) { - if (interfaceMethod.ReturnType != typeof(Task)) + if (interfaceMethod.ReturnType != typeof(Task) && interfaceMethod.ReturnType?.BaseType != typeof(Task)) { throw new InvalidOperationException( - $"Cannot generate proxy implementation for '{typeof(T).FullName}.{interfaceMethod.Name}'. All client proxy methods must return '{typeof(Task).FullName}'."); + $"Cannot generate proxy implementation for '{typeof(T).FullName}.{interfaceMethod.Name}'. All client proxy methods must return '{typeof(Task).FullName}' or '{typeof(Task).FullName}'."); } foreach (var parameter in interfaceMethod.GetParameters()) diff --git a/src/SignalR/server/Core/src/Internal/TypedHubClients.cs b/src/SignalR/server/Core/src/Internal/TypedHubClients.cs index 5bc8ce693ff7..0bffb3d83abb 100644 --- a/src/SignalR/server/Core/src/Internal/TypedHubClients.cs +++ b/src/SignalR/server/Core/src/Internal/TypedHubClients.cs @@ -1,4 +1,4 @@ -// Licensed to the .NET Foundation under one or more agreements. +// Licensed to the .NET Foundation under one or more agreements. // The .NET Foundation licenses this file to you under the MIT license. namespace Microsoft.AspNetCore.SignalR.Internal; @@ -12,6 +12,8 @@ public TypedHubClients(IHubCallerClients dynamicContext) _hubClients = dynamicContext; } + public T Single(string connectionId) => TypedClientBuilder.Build(_hubClients.Single(connectionId)); + public T All => TypedClientBuilder.Build(_hubClients.All); public T Caller => TypedClientBuilder.Build(_hubClients.Caller); diff --git a/src/SignalR/server/Core/src/Microsoft.AspNetCore.SignalR.Core.csproj b/src/SignalR/server/Core/src/Microsoft.AspNetCore.SignalR.Core.csproj index c5ae1cb3718f..cf269845f355 100644 --- a/src/SignalR/server/Core/src/Microsoft.AspNetCore.SignalR.Core.csproj +++ b/src/SignalR/server/Core/src/Microsoft.AspNetCore.SignalR.Core.csproj @@ -17,6 +17,7 @@ + diff --git a/src/SignalR/server/Core/src/PublicAPI.Unshipped.txt b/src/SignalR/server/Core/src/PublicAPI.Unshipped.txt index eb9071344d2a..a3686554ff72 100644 --- a/src/SignalR/server/Core/src/PublicAPI.Unshipped.txt +++ b/src/SignalR/server/Core/src/PublicAPI.Unshipped.txt @@ -5,3 +5,25 @@ Microsoft.AspNetCore.SignalR.HubOptions.DisableImplicitFromServicesParameters.se Microsoft.AspNetCore.SignalR.HubConnectionHandler.HubConnectionHandler(Microsoft.AspNetCore.SignalR.HubLifetimeManager! lifetimeManager, Microsoft.AspNetCore.SignalR.IHubProtocolResolver! protocolResolver, Microsoft.Extensions.Options.IOptions! globalHubOptions, Microsoft.Extensions.Options.IOptions!>! hubOptions, Microsoft.Extensions.Logging.ILoggerFactory! loggerFactory, Microsoft.AspNetCore.SignalR.IUserIdProvider! userIdProvider, Microsoft.Extensions.DependencyInjection.IServiceScopeFactory! serviceScopeFactory) -> void *REMOVED*~Microsoft.AspNetCore.SignalR.HubOptionsSetup.HubOptionsSetup(Microsoft.Extensions.Options.IOptions! options) -> void Microsoft.AspNetCore.SignalR.HubOptionsSetup.HubOptionsSetup(Microsoft.Extensions.Options.IOptions! options) -> void +Microsoft.AspNetCore.SignalR.IHubCallerClients.Single(string! connectionId) -> Microsoft.AspNetCore.SignalR.ISingleClientProxy! +Microsoft.AspNetCore.SignalR.IHubClients.Single(string! connectionId) -> Microsoft.AspNetCore.SignalR.ISingleClientProxy! +Microsoft.AspNetCore.SignalR.IHubClients.Single(string! connectionId) -> T +Microsoft.AspNetCore.SignalR.ISingleClientProxy +Microsoft.AspNetCore.SignalR.ISingleClientProxy.InvokeCoreAsync(string! method, object?[]! args, System.Threading.CancellationToken cancellationToken = default(System.Threading.CancellationToken)) -> System.Threading.Tasks.Task! +override Microsoft.AspNetCore.SignalR.DefaultHubLifetimeManager.InvokeConnectionAsync(string! connectionId, string! methodName, object?[]! args, System.Threading.CancellationToken cancellationToken = default(System.Threading.CancellationToken)) -> System.Threading.Tasks.Task! +override Microsoft.AspNetCore.SignalR.DefaultHubLifetimeManager.SetConnectionResultAsync(string! connectionId, Microsoft.AspNetCore.SignalR.Protocol.CompletionMessage! result) -> System.Threading.Tasks.Task! +override Microsoft.AspNetCore.SignalR.DefaultHubLifetimeManager.TryGetReturnType(string! invocationId, out System.Type? type) -> bool +static Microsoft.AspNetCore.SignalR.ClientProxyExtensions.InvokeAsync(this Microsoft.AspNetCore.SignalR.ISingleClientProxy! clientProxy, string! method, System.Threading.CancellationToken cancellationToken = default(System.Threading.CancellationToken)) -> System.Threading.Tasks.Task! +static Microsoft.AspNetCore.SignalR.ClientProxyExtensions.InvokeAsync(this Microsoft.AspNetCore.SignalR.ISingleClientProxy! clientProxy, string! method, object? arg1, System.Threading.CancellationToken cancellationToken = default(System.Threading.CancellationToken)) -> System.Threading.Tasks.Task! +static Microsoft.AspNetCore.SignalR.ClientProxyExtensions.InvokeAsync(this Microsoft.AspNetCore.SignalR.ISingleClientProxy! clientProxy, string! method, object? arg1, object? arg2, System.Threading.CancellationToken cancellationToken = default(System.Threading.CancellationToken)) -> System.Threading.Tasks.Task! +static Microsoft.AspNetCore.SignalR.ClientProxyExtensions.InvokeAsync(this Microsoft.AspNetCore.SignalR.ISingleClientProxy! clientProxy, string! method, object? arg1, object? arg2, object? arg3, System.Threading.CancellationToken cancellationToken = default(System.Threading.CancellationToken)) -> System.Threading.Tasks.Task! +static Microsoft.AspNetCore.SignalR.ClientProxyExtensions.InvokeAsync(this Microsoft.AspNetCore.SignalR.ISingleClientProxy! clientProxy, string! method, object? arg1, object? arg2, object? arg3, object? arg4, System.Threading.CancellationToken cancellationToken = default(System.Threading.CancellationToken)) -> System.Threading.Tasks.Task! +static Microsoft.AspNetCore.SignalR.ClientProxyExtensions.InvokeAsync(this Microsoft.AspNetCore.SignalR.ISingleClientProxy! clientProxy, string! method, object? arg1, object? arg2, object? arg3, object? arg4, object? arg5, System.Threading.CancellationToken cancellationToken = default(System.Threading.CancellationToken)) -> System.Threading.Tasks.Task! +static Microsoft.AspNetCore.SignalR.ClientProxyExtensions.InvokeAsync(this Microsoft.AspNetCore.SignalR.ISingleClientProxy! clientProxy, string! method, object? arg1, object? arg2, object? arg3, object? arg4, object? arg5, object? arg6, System.Threading.CancellationToken cancellationToken = default(System.Threading.CancellationToken)) -> System.Threading.Tasks.Task! +static Microsoft.AspNetCore.SignalR.ClientProxyExtensions.InvokeAsync(this Microsoft.AspNetCore.SignalR.ISingleClientProxy! clientProxy, string! method, object? arg1, object? arg2, object? arg3, object? arg4, object? arg5, object? arg6, object? arg7, System.Threading.CancellationToken cancellationToken = default(System.Threading.CancellationToken)) -> System.Threading.Tasks.Task! +static Microsoft.AspNetCore.SignalR.ClientProxyExtensions.InvokeAsync(this Microsoft.AspNetCore.SignalR.ISingleClientProxy! clientProxy, string! method, object? arg1, object? arg2, object? arg3, object? arg4, object? arg5, object? arg6, object? arg7, object? arg8, System.Threading.CancellationToken cancellationToken = default(System.Threading.CancellationToken)) -> System.Threading.Tasks.Task! +static Microsoft.AspNetCore.SignalR.ClientProxyExtensions.InvokeAsync(this Microsoft.AspNetCore.SignalR.ISingleClientProxy! clientProxy, string! method, object? arg1, object? arg2, object? arg3, object? arg4, object? arg5, object? arg6, object? arg7, object? arg8, object? arg9, System.Threading.CancellationToken cancellationToken = default(System.Threading.CancellationToken)) -> System.Threading.Tasks.Task! +static Microsoft.AspNetCore.SignalR.ClientProxyExtensions.InvokeAsync(this Microsoft.AspNetCore.SignalR.ISingleClientProxy! clientProxy, string! method, object? arg1, object? arg2, object? arg3, object? arg4, object? arg5, object? arg6, object? arg7, object? arg8, object? arg9, object? arg10, System.Threading.CancellationToken cancellationToken = default(System.Threading.CancellationToken)) -> System.Threading.Tasks.Task! +virtual Microsoft.AspNetCore.SignalR.HubLifetimeManager.InvokeConnectionAsync(string! connectionId, string! methodName, object?[]! args, System.Threading.CancellationToken cancellationToken = default(System.Threading.CancellationToken)) -> System.Threading.Tasks.Task! +virtual Microsoft.AspNetCore.SignalR.HubLifetimeManager.SetConnectionResultAsync(string! connectionId, Microsoft.AspNetCore.SignalR.Protocol.CompletionMessage! result) -> System.Threading.Tasks.Task! +virtual Microsoft.AspNetCore.SignalR.HubLifetimeManager.TryGetReturnType(string! invocationId, out System.Type? type) -> bool diff --git a/src/SignalR/server/SignalR/test/HubConnectionHandlerTestUtils/Hubs.cs b/src/SignalR/server/SignalR/test/HubConnectionHandlerTestUtils/Hubs.cs index d4b6a7df5f35..db7359370a33 100644 --- a/src/SignalR/server/SignalR/test/HubConnectionHandlerTestUtils/Hubs.cs +++ b/src/SignalR/server/SignalR/test/HubConnectionHandlerTestUtils/Hubs.cs @@ -332,6 +332,12 @@ public async Task BlockingMethod() await tcs.Task; } + + public async Task GetClientResult(int num) + { + var sum = await Clients.Single(Context.ConnectionId).InvokeAsync("Sum", num); + return sum; + } } internal class SelfRef diff --git a/src/SignalR/server/SignalR/test/HubConnectionHandlerTests.ClientResult.cs b/src/SignalR/server/SignalR/test/HubConnectionHandlerTests.ClientResult.cs new file mode 100644 index 000000000000..b179f92acbc5 --- /dev/null +++ b/src/SignalR/server/SignalR/test/HubConnectionHandlerTests.ClientResult.cs @@ -0,0 +1,106 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. + +using Microsoft.AspNetCore.SignalR.Protocol; +using Microsoft.AspNetCore.Testing; +using Microsoft.Extensions.DependencyInjection; + +namespace Microsoft.AspNetCore.SignalR.Tests; + +public partial class HubConnectionHandlerTests +{ + [Fact] + public async Task CanReturnClientResultToHub() + { + using (StartVerifiableLog()) + { + var serviceProvider = HubConnectionHandlerTestUtils.CreateServiceProvider(builder => + { + // Waiting for a client result blocks the hub dispatcher pipeline, need to allow multiple invocations + builder.AddSignalR(o => o.MaximumParallelInvocationsPerClient = 2); + }, LoggerFactory); + var connectionHandler = serviceProvider.GetService>(); + + using (var client = new TestClient()) + { + var connectionHandlerTask = await client.ConnectAsync(connectionHandler).DefaultTimeout(); + + var invocationId = await client.SendHubMessageAsync(new InvocationMessage("1", nameof(MethodHub.GetClientResult), new object[] { 5 })).DefaultTimeout(); + + // Hub asks client for a result, this is an invocation message with an ID + var invocationMessage = Assert.IsType(await client.ReadAsync().DefaultTimeout()); + Assert.NotNull(invocationMessage.InvocationId); + var res = 4 + ((long)invocationMessage.Arguments[0]); + await client.SendHubMessageAsync(CompletionMessage.WithResult(invocationMessage.InvocationId, res)).DefaultTimeout(); + + var completion = Assert.IsType(await client.ReadAsync().DefaultTimeout()); + Assert.Equal(9L, completion.Result); + Assert.Equal(invocationId, completion.InvocationId); + } + } + } + + [Fact] + public async Task CanReturnClientResultErrorToHub() + { + using (StartVerifiableLog(write => write.EventId.Name == "FailedInvokingHubMethod")) + { + var serviceProvider = HubConnectionHandlerTestUtils.CreateServiceProvider(builder => + { + // Waiting for a client result blocks the hub dispatcher pipeline, need to allow multiple invocations + builder.AddSignalR(o => + { + o.MaximumParallelInvocationsPerClient = 2; + o.EnableDetailedErrors = true; + }); + }, LoggerFactory); + var connectionHandler = serviceProvider.GetService>(); + + using (var client = new TestClient()) + { + var connectionHandlerTask = await client.ConnectAsync(connectionHandler).DefaultTimeout(); + + var invocationId = await client.SendHubMessageAsync(new InvocationMessage("1", nameof(MethodHub.GetClientResult), new object[] { 5 })).DefaultTimeout(); + + // Hub asks client for a result, this is an invocation message with an ID + var invocationMessage = Assert.IsType(await client.ReadAsync().DefaultTimeout()); + Assert.NotNull(invocationMessage.InvocationId); + await client.SendHubMessageAsync(CompletionMessage.WithError(invocationMessage.InvocationId, "Client error")).DefaultTimeout(); + + var completion = Assert.IsType(await client.ReadAsync().DefaultTimeout()); + Assert.Equal("An unexpected error occurred invoking 'GetClientResult' on the server. Exception: Client error", completion.Error); + Assert.Equal(invocationId, completion.InvocationId); + } + } + } + + [Fact] + public async Task ThrowsWhenParallelHubInvokesNotEnabled() + { + using (StartVerifiableLog(write => write.EventId.Name == "FailedInvokingHubMethod")) + { + var serviceProvider = HubConnectionHandlerTestUtils.CreateServiceProvider(builder => + { + builder.AddSignalR(o => + { + o.MaximumParallelInvocationsPerClient = 1; + o.EnableDetailedErrors = true; + }); + }, LoggerFactory); + var connectionHandler = serviceProvider.GetService>(); + + using (var client = new TestClient()) + { + var connectionHandlerTask = await client.ConnectAsync(connectionHandler).DefaultTimeout(); + + var invocationId = await client.SendHubMessageAsync(new InvocationMessage("1", nameof(MethodHub.GetClientResult), new object[] { 5 })).DefaultTimeout(); + + // Hub asks client for a result, this is an invocation message with an ID + var completionMessage = Assert.IsType(await client.ReadAsync().DefaultTimeout()); + Assert.Equal(invocationId, completionMessage.InvocationId); + Assert.Equal("An unexpected error occurred invoking 'GetClientResult' on the server. InvalidOperationException: Client results inside a Hub method requires HubOptions.MaximumParallelInvocationsPerClient to be greater than 1.", + completionMessage.Error); + } + } + } +} diff --git a/src/SignalR/server/SignalR/test/HubConnectionHandlerTests.cs b/src/SignalR/server/SignalR/test/HubConnectionHandlerTests.cs index a8b675bdb3dc..f09d5cb7d416 100644 --- a/src/SignalR/server/SignalR/test/HubConnectionHandlerTests.cs +++ b/src/SignalR/server/SignalR/test/HubConnectionHandlerTests.cs @@ -30,7 +30,7 @@ namespace Microsoft.AspNetCore.SignalR.Tests; -public class HubConnectionHandlerTests : VerifiableLoggedTest +public partial class HubConnectionHandlerTests : VerifiableLoggedTest { [Fact] [LogLevel(LogLevel.Trace)] @@ -240,7 +240,7 @@ public void FailsToLoadInvalidTypedHubClient() var serviceProvider = HubConnectionHandlerTestUtils.CreateServiceProvider(null, LoggerFactory); var ex = Assert.Throws(() => serviceProvider.GetRequiredService>()); - Assert.Equal($"Cannot generate proxy implementation for '{typeof(IVoidReturningTypedHubClient).FullName}.{nameof(IVoidReturningTypedHubClient.Send)}'. All client proxy methods must return '{typeof(Task).FullName}'.", ex.Message); + Assert.Equal($"Cannot generate proxy implementation for '{typeof(IVoidReturningTypedHubClient).FullName}.{nameof(IVoidReturningTypedHubClient.Send)}'. All client proxy methods must return '{typeof(Task).FullName}' or 'System.Threading.Tasks.Task'.", ex.Message); } } diff --git a/src/SignalR/server/SignalR/test/Internal/TypedClientBuilderTests.cs b/src/SignalR/server/SignalR/test/Internal/TypedClientBuilderTests.cs index 95688baadb57..93c730c8ffa7 100644 --- a/src/SignalR/server/SignalR/test/Internal/TypedClientBuilderTests.cs +++ b/src/SignalR/server/SignalR/test/Internal/TypedClientBuilderTests.cs @@ -1,13 +1,8 @@ // Licensed to the .NET Foundation under one or more agreements. // The .NET Foundation licenses this file to you under the MIT license. -using System; -using System.Collections.Generic; -using System.Threading; -using System.Threading.Tasks; using Microsoft.AspNetCore.SignalR.Internal; using Microsoft.AspNetCore.Testing; -using Xunit; namespace Microsoft.AspNetCore.SignalR.Tests.Internal; @@ -162,7 +157,7 @@ public void ThrowsIfInterfaceHasVoidReturningMethod() { var clientProxy = new MockProxy(); var ex = Assert.Throws(() => TypedClientBuilder.Build(clientProxy)); - Assert.Equal($"Cannot generate proxy implementation for '{typeof(IVoidMethodClient).FullName}.{nameof(IVoidMethodClient.Method)}'. All client proxy methods must return '{typeof(Task).FullName}'.", ex.Message); + Assert.Equal($"Cannot generate proxy implementation for '{typeof(IVoidMethodClient).FullName}.{nameof(IVoidMethodClient.Method)}'. All client proxy methods must return '{typeof(Task).FullName}' or '{typeof(Task).FullName}'.", ex.Message); } [Fact] @@ -170,7 +165,7 @@ public void ThrowsIfInterfaceHasNonTaskReturns() { var clientProxy = new MockProxy(); var ex = Assert.Throws(() => TypedClientBuilder.Build(clientProxy)); - Assert.Equal($"Cannot generate proxy implementation for '{typeof(IStringMethodClient).FullName}.{nameof(IStringMethodClient.Method)}'. All client proxy methods must return '{typeof(Task).FullName}'.", ex.Message); + Assert.Equal($"Cannot generate proxy implementation for '{typeof(IStringMethodClient).FullName}.{nameof(IStringMethodClient.Method)}'. All client proxy methods must return '{typeof(Task).FullName}' or '{typeof(Task).FullName}'.", ex.Message); } [Fact] @@ -207,9 +202,86 @@ public void ThrowsIfInterfaceHasEvents() Assert.Equal("Type must not contain events.", ex.Message); } + [Fact] + public async Task ProducesImplementationThatProxiesMethodsToISingleClientProxyAsync() + { + var clientProxy = new MockSingleClientProxy(); + var typedProxy = TypedClientBuilder.Build(clientProxy); + + var objArg = new object(); + var task = typedProxy.GetValue(1008, objArg, "test"); + Assert.False(task.IsCompleted); + + Assert.Collection(clientProxy.Sends, + send => + { + Assert.Equal("GetValue", send.Method); + Assert.Collection(send.Arguments, + arg1 => Assert.Equal(1008, arg1), + arg2 => Assert.Same(objArg, arg2), + arg3 => Assert.Same("test", arg3)); + Assert.Equal(CancellationToken.None, send.CancellationToken); + send.Complete(); + }); + + var result = await task.DefaultTimeout(); + Assert.Equal(default(int), result); + } + + [Fact] + public async Task ThrowsIfReturnMethodUsedWithoutSingleClientProxy() + { + var clientProxy = new MockProxy(); + var typedProxy = TypedClientBuilder.Build(clientProxy); + + var objArg = new object(); + var ex = await Assert.ThrowsAsync(() => typedProxy.GetValue(102, objArg, "test")).DefaultTimeout(); + Assert.Equal("InvokeAsync only works with Single clients.", ex.Message); + + Assert.Empty(clientProxy.Sends); + } + + [Fact] + public async Task ResultMethodSupportsCancellationToken() + { + var clientProxy = new MockSingleClientProxy(); + var typedProxy = TypedClientBuilder.Build(clientProxy); + CancellationTokenSource cts1 = new CancellationTokenSource(); + var task1 = typedProxy.MethodReturning("foo", cts1.Token); + Assert.False(task1.IsCompleted); + + CancellationTokenSource cts2 = new CancellationTokenSource(); + var task2 = typedProxy.NoArgumentMethodReturning(cts2.Token); + Assert.False(task2.IsCompleted); + + Assert.Collection(clientProxy.Sends, + send1 => + { + Assert.Equal("MethodReturning", send1.Method); + Assert.Single(send1.Arguments); + Assert.Collection(send1.Arguments, + arg1 => Assert.Equal("foo", arg1)); + Assert.Equal(cts1.Token, send1.CancellationToken); + send1.Complete(); + }, + send2 => + { + Assert.Equal("NoArgumentMethodReturning", send2.Method); + Assert.Empty(send2.Arguments); + Assert.Equal(cts2.Token, send2.CancellationToken); + send2.Complete(); + }); + + var result = await task1.DefaultTimeout(); + Assert.Equal(default(string), result); + var result2 = await task2.DefaultTimeout(); + Assert.Equal(default(int), result2); + } + public interface ITestClient { Task Method(string arg1, int arg2, object arg3); + Task GetValue(int arg1, object arg2, string arg3); } public interface IRenamedTestClient @@ -247,6 +319,9 @@ public interface ICancellationTokenMethod { Task Method(string foo, CancellationToken cancellationToken); Task NoArgumentMethod(CancellationToken cancellationToken); + + Task NoArgumentMethodReturning(CancellationToken cancellationToken); + Task MethodReturning(string foo, CancellationToken cancellationToken); } public interface IPropertiesClient @@ -259,6 +334,11 @@ public interface IEventsClient event EventHandler Event; } + public interface ITestReturnValueClient + { + Task GetValue(); + } + private class MockProxy : IClientProxy { public IList Sends { get; } = new List(); @@ -273,6 +353,30 @@ public Task SendCoreAsync(string method, object[] args, CancellationToken cancel } } + private class MockSingleClientProxy : ISingleClientProxy + { + public IList Sends { get; } = new List(); + + public Task SendCoreAsync(string method, object[] args, CancellationToken cancellationToken) + { + var tcs = new TaskCompletionSource(); + + Sends.Add(new SendContext(method, args, cancellationToken, tcs)); + + return tcs.Task; + } + + public async Task InvokeCoreAsync(string method, object[] args, CancellationToken cancellationToken = default) + { + var tcs = new TaskCompletionSource(); + + Sends.Add(new SendContext(method, args, cancellationToken, tcs)); + + await tcs.Task; + return default(T); + } + } + private struct SendContext { private readonly TaskCompletionSource _tcs; diff --git a/src/SignalR/server/Specification.Tests/src/HubLifetimeManagerTestBase.cs b/src/SignalR/server/Specification.Tests/src/HubLifetimeManagerTestBase.cs index 5a64beb0e792..b66595b6aa60 100644 --- a/src/SignalR/server/Specification.Tests/src/HubLifetimeManagerTestBase.cs +++ b/src/SignalR/server/Specification.Tests/src/HubLifetimeManagerTestBase.cs @@ -170,4 +170,197 @@ public async Task SendConnectionAsyncWritesToConnectionOutput() Assert.Equal("World", (string)message.Arguments[0]); } } + + /// + /// Specification test for SignalR HubLifetimeManager. + /// + /// A representing the asynchronous completion of the test. + [Fact] + public async Task CanProcessClientReturnResult() + { + var manager = CreateNewHubLifetimeManager(); + + using (var client1 = new TestClient()) + { + var connection1 = HubConnectionContextUtils.Create(client1.Connection); + + await manager.OnConnectedAsync(connection1).DefaultTimeout(); + + var resultTask = manager.InvokeConnectionAsync(connection1.ConnectionId, "Result", new object[] { "test" }); + var invocation = Assert.IsType(await client1.ReadAsync().DefaultTimeout()); + Assert.NotNull(invocation.InvocationId); + Assert.Equal("test", invocation.Arguments[0]); + + await manager.SetConnectionResultAsync(connection1.ConnectionId, CompletionMessage.WithResult(invocation.InvocationId, 10)).DefaultTimeout(); + + var res = await resultTask.DefaultTimeout(); + Assert.Equal(10L, res); + } + } + + /// + /// Specification test for SignalR HubLifetimeManager. + /// + /// A representing the asynchronous completion of the test. + [Fact] + public async Task CanProcessClientReturnErrorResult() + { + var manager = CreateNewHubLifetimeManager(); + + using (var client1 = new TestClient()) + { + var connection1 = HubConnectionContextUtils.Create(client1.Connection); + + await manager.OnConnectedAsync(connection1).DefaultTimeout(); + + var resultTask = manager.InvokeConnectionAsync(connection1.ConnectionId, "Result", new object[] { "test" }); + var invocation = Assert.IsType(await client1.ReadAsync().DefaultTimeout()); + Assert.NotNull(invocation.InvocationId); + Assert.Equal("test", invocation.Arguments[0]); + + await manager.SetConnectionResultAsync(connection1.ConnectionId, CompletionMessage.WithError(invocation.InvocationId, "Error from client")).DefaultTimeout(); + + var ex = await Assert.ThrowsAsync(() => resultTask).DefaultTimeout(); + Assert.Equal("Error from client", ex.Message); + } + } + + /// + /// Specification test for SignalR HubLifetimeManager. + /// + /// A representing the asynchronous completion of the test. + [Fact] + public async Task ExceptionWhenIncorrectClientCompletesClientResult() + { + var manager = CreateNewHubLifetimeManager(); + + using (var client1 = new TestClient()) + using (var client2 = new TestClient()) + { + var connection1 = HubConnectionContextUtils.Create(client1.Connection); + var connection2 = HubConnectionContextUtils.Create(client2.Connection); + + await manager.OnConnectedAsync(connection1).DefaultTimeout(); + await manager.OnConnectedAsync(connection2).DefaultTimeout(); + + var resultTask = manager.InvokeConnectionAsync(connection1.ConnectionId, "Result", new object[] { "test" }); + var invocation = Assert.IsType(await client1.ReadAsync().DefaultTimeout()); + Assert.NotNull(invocation.InvocationId); + Assert.Equal("test", invocation.Arguments[0]); + + var ex = await Assert.ThrowsAsync(() => + manager.SetConnectionResultAsync(connection2.ConnectionId, CompletionMessage.WithError(invocation.InvocationId, "Error from client"))).DefaultTimeout(); + + Assert.Equal("wrong ID", ex.Message); + + // Internal state for invocation isn't affected by wrong client, check that we can still complete the invocation + await manager.SetConnectionResultAsync(connection1.ConnectionId, CompletionMessage.WithResult(invocation.InvocationId, 10)).DefaultTimeout(); + + var res = await resultTask.DefaultTimeout(); + Assert.Equal(10L, res); + } + } + + /// + /// Specification test for SignalR HubLifetimeManager. + /// + /// A representing the asynchronous completion of the test. + [Fact] + public async Task ConnectionIDNotPresentWhenInvokingClientResult() + { + var manager1 = CreateNewHubLifetimeManager(); + + using (var client1 = new TestClient()) + { + var connection1 = HubConnectionContextUtils.Create(client1.Connection); + + await manager1.OnConnectedAsync(connection1).DefaultTimeout(); + + // No client with this ID + await Assert.ThrowsAsync(() => manager1.InvokeConnectionAsync("none", "Result", new object[] { "test" })).DefaultTimeout(); + } + } + + /// + /// Specification test for SignalR HubLifetimeManager. + /// + /// A representing the asynchronous completion of the test. + [Fact] + public async Task InvokesForMultipleClientsDoNotCollide() + { + var manager1 = CreateNewHubLifetimeManager(); + + using (var client1 = new TestClient()) + using (var client2 = new TestClient()) + { + var connection1 = HubConnectionContextUtils.Create(client1.Connection); + var connection2 = HubConnectionContextUtils.Create(client2.Connection); + + await manager1.OnConnectedAsync(connection1).DefaultTimeout(); + await manager1.OnConnectedAsync(connection2).DefaultTimeout(); + + var invoke1 = manager1.InvokeConnectionAsync(connection1.ConnectionId, "Result", new object[] { "test" }); + var invoke2 = manager1.InvokeConnectionAsync(connection2.ConnectionId, "Result", new object[] { "test" }); + + var invocation1 = Assert.IsType(await client1.ReadAsync().DefaultTimeout()); + var invocation2 = Assert.IsType(await client2.ReadAsync().DefaultTimeout()); + await manager1.SetConnectionResultAsync(connection2.ConnectionId, CompletionMessage.WithError(invocation2.InvocationId, "error")); + + await Assert.ThrowsAnyAsync(() => invoke2).DefaultTimeout(); + Assert.False(invoke1.IsCompleted); + + await manager1.SetConnectionResultAsync(connection1.ConnectionId, CompletionMessage.WithResult(invocation1.InvocationId, 3)); + Assert.Equal(3, await invoke1.DefaultTimeout()); + } + } + + /// + /// Specification test for SignalR HubLifetimeManager. + /// + /// A representing the asynchronous completion of the test. + [Fact] + public async Task ClientDisconnectsWithoutCompletingClientResult() + { + var manager1 = CreateNewHubLifetimeManager(); + + using (var client1 = new TestClient()) + { + var connection1 = HubConnectionContextUtils.Create(client1.Connection); + + await manager1.OnConnectedAsync(connection1).DefaultTimeout(); + + var invoke1 = manager1.InvokeConnectionAsync(connection1.ConnectionId, "Result", new object[] { "test" }); + + await manager1.OnDisconnectedAsync(connection1).DefaultTimeout(); + + await Assert.ThrowsAsync(() => invoke1).DefaultTimeout(); + } + } + + /// + /// Specification test for SignalR HubLifetimeManager. + /// + /// A representing the asynchronous completion of the test. + [Fact] + public async Task CanCancelClientResult() + { + var manager1 = CreateNewHubLifetimeManager(); + + using (var client1 = new TestClient()) + { + var connection1 = HubConnectionContextUtils.Create(client1.Connection); + + await manager1.OnConnectedAsync(connection1).DefaultTimeout(); + + var cts = new CancellationTokenSource(); + var invoke1 = manager1.InvokeConnectionAsync(connection1.ConnectionId, "Result", new object[] { "test" }, cts.Token); + var invocation1 = Assert.IsType(await client1.ReadAsync().DefaultTimeout()); + cts.Cancel(); + + await Assert.ThrowsAsync(() => invoke1).DefaultTimeout(); + + // Noop, just checking that it doesn't throw. This could be caused by an inflight response from a client while the server cancels the token + await manager1.SetConnectionResultAsync(connection1.ConnectionId, CompletionMessage.WithResult(invocation1.InvocationId, 1)); + } + } } diff --git a/src/SignalR/server/Specification.Tests/src/ScaleoutHubLifetimeManagerTests.cs b/src/SignalR/server/Specification.Tests/src/ScaleoutHubLifetimeManagerTests.cs index 509d09d477e2..d7dae9b73ad1 100644 --- a/src/SignalR/server/Specification.Tests/src/ScaleoutHubLifetimeManagerTests.cs +++ b/src/SignalR/server/Specification.Tests/src/ScaleoutHubLifetimeManagerTests.cs @@ -463,4 +463,115 @@ public async Task StillSubscribedToUserAfterOneOfMultipleConnectionsAssociatedWi await AssertMessageAsync(client2); } } + + /// + /// Specification test for SignalR HubLifetimeManager. + /// + /// A representing the asynchronous completion of the test. + [Fact] + public async Task CanProcessClientReturnResultAcrossServers() + { + var backplane = CreateBackplane(); + var manager1 = CreateNewHubLifetimeManager(backplane); + var manager2 = CreateNewHubLifetimeManager(backplane); + + using (var client1 = new TestClient()) + { + var connection1 = HubConnectionContextUtils.Create(client1.Connection); + + await manager1.OnConnectedAsync(connection1).DefaultTimeout(); + + // Server2 asks for a result from client1 on Server1 + var resultTask = manager2.InvokeConnectionAsync(connection1.ConnectionId, "Result", new object[] { "test" }); + var invocation = Assert.IsType(await client1.ReadAsync().DefaultTimeout()); + Assert.NotNull(invocation.InvocationId); + Assert.Equal("test", invocation.Arguments[0]); + + // Server1 gets the result from client1 and forwards to Server2 + await manager1.SetConnectionResultAsync(connection1.ConnectionId, CompletionMessage.WithResult(invocation.InvocationId, 10)).DefaultTimeout(); + + var res = await resultTask.DefaultTimeout(); + Assert.Equal(10L, res); + } + } + + /// + /// Specification test for SignalR HubLifetimeManager. + /// + /// A representing the asynchronous completion of the test. + [Fact] + public async Task CanProcessClientReturnErrorResultAcrossServers() + { + var backplane = CreateBackplane(); + var manager1 = CreateNewHubLifetimeManager(backplane); + var manager2 = CreateNewHubLifetimeManager(backplane); + + using (var client1 = new TestClient()) + { + var connection1 = HubConnectionContextUtils.Create(client1.Connection); + + await manager1.OnConnectedAsync(connection1).DefaultTimeout(); + + // Server2 asks for a result from client1 on Server1 + var resultTask = manager2.InvokeConnectionAsync(connection1.ConnectionId, "Result", new object[] { "test" }); + var invocation = Assert.IsType(await client1.ReadAsync().DefaultTimeout()); + Assert.NotNull(invocation.InvocationId); + Assert.Equal("test", invocation.Arguments[0]); + + // Server1 gets the result from client1 and forwards to Server2 + await manager1.SetConnectionResultAsync(connection1.ConnectionId, CompletionMessage.WithError(invocation.InvocationId, "Error from client")).DefaultTimeout(); + + var ex = await Assert.ThrowsAsync(() => resultTask).DefaultTimeout(); + Assert.Equal("Error from client", ex.Message); + } + } + + /// + /// Specification test for SignalR HubLifetimeManager. + /// + /// A representing the asynchronous completion of the test. + [Fact] + public async Task ConnectionIDNotPresentMultiServerWhenInvokingClientResult() + { + var backplane = CreateBackplane(); + var manager1 = CreateNewHubLifetimeManager(backplane); + var manager2 = CreateNewHubLifetimeManager(backplane); + + using (var client1 = new TestClient()) + { + var connection1 = HubConnectionContextUtils.Create(client1.Connection); + + await manager1.OnConnectedAsync(connection1).DefaultTimeout(); + + // No client on any backplanes with this ID + await Assert.ThrowsAsync(() => manager1.InvokeConnectionAsync("none", "Result", new object[] { "test" })).DefaultTimeout(); + } + } + + /// + /// Specification test for SignalR HubLifetimeManager. + /// + /// A representing the asynchronous completion of the test. + [Fact] + public async Task ClientDisconnectsWithoutCompletingClientResultOnSecondServer() + { + var backplane = CreateBackplane(); + var manager1 = CreateNewHubLifetimeManager(backplane); + var manager2 = CreateNewHubLifetimeManager(backplane); + + using (var client1 = new TestClient()) + { + var connection1 = HubConnectionContextUtils.Create(client1.Connection); + + await manager2.OnConnectedAsync(connection1).DefaultTimeout(); + + var invoke1 = manager1.InvokeConnectionAsync(connection1.ConnectionId, "Result", new object[] { "test" }); + var invocation = Assert.IsType(await client1.ReadAsync().DefaultTimeout()); + + await manager2.OnDisconnectedAsync(connection1).DefaultTimeout(); + + // Server should propogate connection closure so task isn't blocked + await Assert.ThrowsAsync(() => invoke1).DefaultTimeout(); + } + } } diff --git a/src/SignalR/server/StackExchangeRedis/src/Internal/RedisChannels.cs b/src/SignalR/server/StackExchangeRedis/src/Internal/RedisChannels.cs index 8f306f3082b6..eea3d9499247 100644 --- a/src/SignalR/server/StackExchangeRedis/src/Internal/RedisChannels.cs +++ b/src/SignalR/server/StackExchangeRedis/src/Internal/RedisChannels.cs @@ -71,4 +71,15 @@ public string Ack(string serverName) { return _prefix + ":internal:ack:" + serverName; } + + /// + /// Gets the name of the client return results channel for the specified server. + /// + /// The name of the server to get the client return results channel for. + /// + [MethodImpl(MethodImplOptions.AggressiveInlining)] + public string ReturnResults(string serverName) + { + return _prefix + ":internal:return:" + serverName; + } } diff --git a/src/SignalR/server/StackExchangeRedis/src/Internal/RedisInvocation.cs b/src/SignalR/server/StackExchangeRedis/src/Internal/RedisInvocation.cs index 8474a8ed780f..1f441873d6f1 100644 --- a/src/SignalR/server/StackExchangeRedis/src/Internal/RedisInvocation.cs +++ b/src/SignalR/server/StackExchangeRedis/src/Internal/RedisInvocation.cs @@ -16,9 +16,16 @@ internal readonly struct RedisInvocation /// public SerializedHubMessage Message { get; } - public RedisInvocation(SerializedHubMessage message, IReadOnlyList? excludedConnectionIds) + public string? ReturnChannel { get; } + + public string? InvocationId { get; } + + public RedisInvocation(SerializedHubMessage message, IReadOnlyList? excludedConnectionIds, + string? invocationId = null, string? returnChannel = null) { Message = message; ExcludedConnectionIds = excludedConnectionIds; + ReturnChannel = returnChannel; + InvocationId = invocationId; } } diff --git a/src/SignalR/server/StackExchangeRedis/src/Internal/RedisProtocol.cs b/src/SignalR/server/StackExchangeRedis/src/Internal/RedisProtocol.cs index 719413a08136..557be073b5b9 100644 --- a/src/SignalR/server/StackExchangeRedis/src/Internal/RedisProtocol.cs +++ b/src/SignalR/server/StackExchangeRedis/src/Internal/RedisProtocol.cs @@ -25,20 +25,25 @@ public RedisProtocol(DefaultHubMessageSerializer messageSerializer) // * Invocations are sent to the All, Group, Connection and User channels // * Group Commands are sent to the GroupManagement channel // * Acks are sent to the Acknowledgement channel. + // * Completion messages (client results) are sent to the server specific Result channel // * See the Write[type] methods for a description of the protocol for each in-depth. // * The "Variable length integer" is the length-prefixing format used by BinaryReader/BinaryWriter: // * https://docs.microsoft.com/dotnet/api/system.io.binarywriter.write?view=netcore-2.2 // * The "Length prefixed string" is the string format used by BinaryReader/BinaryWriter: // * A 7-bit variable length integer encodes the length in bytes, followed by the encoded string in UTF-8. - public byte[] WriteInvocation(string methodName, object?[] args) => - WriteInvocation(methodName, args, excludedConnectionIds: null); + //public byte[] WriteInvocation(string methodName, object?[] args, string? invocationId = null) => + // WriteInvocation(methodName, args, invocationId, excludedConnectionIds: null); - public byte[] WriteInvocation(string methodName, object?[] args, IReadOnlyList? excludedConnectionIds) + public byte[] WriteInvocation(string methodName, object?[] args, string? invocationId = null, + IReadOnlyList? excludedConnectionIds = null, string? returnChannel = null) { // Written as a MessagePack 'arr' containing at least these items: // * A MessagePack 'arr' of 'str's representing the excluded ids // * [The output of WriteSerializedHubMessage, which is an 'arr'] + // For invocations expecting a result + // * InvocationID + // * Redis return channel // Any additional items are discarded. var memoryBufferWriter = MemoryBufferWriter.Get(); @@ -46,7 +51,16 @@ public byte[] WriteInvocation(string methodName, object?[] args, IReadOnlyList 0) { writer.WriteArrayHeader(excludedConnectionIds.Count); @@ -60,7 +74,7 @@ public byte[] WriteInvocation(string methodName, object?[] args, IReadOnlyList completionMessage, string protocolName) + { + // Written as a MessagePack 'arr' containing at least these items: + // * A 'str': The name of the HubProtocol used for the serialization of the Completion Message + // * [A serialized Completion Message which is a 'bin'] + // Any additional items are discarded. + + var memoryBufferWriter = MemoryBufferWriter.Get(); + try + { + var writer = new MessagePackWriter(memoryBufferWriter); + + writer.WriteArrayHeader(2); + writer.Write(protocolName); + writer.Write(completionMessage); + + writer.Flush(); + + return memoryBufferWriter.ToArray(); + } + finally + { + MemoryBufferWriter.Return(memoryBufferWriter); + } + } + public static RedisInvocation ReadInvocation(ReadOnlyMemory data) { // See WriteInvocation for the format var reader = new MessagePackReader(data); - ValidateArraySize(ref reader, 2, "Invocation"); + var length = ValidateArraySize(ref reader, 2, "Invocation"); + + string? returnChannel = null; + string? invocationId = null; + if (length > 3) + { + invocationId = reader.ReadString(); + returnChannel = reader.ReadString(); + } // Read excluded Ids IReadOnlyList? excludedConnectionIds = null; @@ -147,7 +195,7 @@ public static RedisInvocation ReadInvocation(ReadOnlyMemory data) // Read payload var message = ReadSerializedHubMessage(ref reader); - return new RedisInvocation(message, excludedConnectionIds); + return new RedisInvocation(message, excludedConnectionIds, invocationId, returnChannel); } public static RedisGroupCommand ReadGroupCommand(ReadOnlyMemory data) @@ -209,7 +257,18 @@ public static SerializedHubMessage ReadSerializedHubMessage(ref MessagePackReade return new SerializedHubMessage(serializations); } - private static void ValidateArraySize(ref MessagePackReader reader, int expectedLength, string messageType) + public static RedisCompletion ReadCompletion(ReadOnlyMemory data) + { + // See WriteCompletionMessage for the format + var reader = new MessagePackReader(data); + ValidateArraySize(ref reader, 2, "CompletionMessage"); + + var protocolName = reader.ReadString(); + var ros = reader.ReadBytes(); + return new RedisCompletion(protocolName, ros ?? new ReadOnlySequence()); + } + + private static int ValidateArraySize(ref MessagePackReader reader, int expectedLength, string messageType) { var length = reader.ReadArrayHeader(); @@ -217,5 +276,6 @@ private static void ValidateArraySize(ref MessagePackReader reader, int expected { throw new InvalidDataException($"Insufficient items in {messageType} array."); } + return length; } } diff --git a/src/SignalR/server/StackExchangeRedis/src/Internal/RedisReturnResult.cs b/src/SignalR/server/StackExchangeRedis/src/Internal/RedisReturnResult.cs new file mode 100644 index 000000000000..85a718fc58a4 --- /dev/null +++ b/src/SignalR/server/StackExchangeRedis/src/Internal/RedisReturnResult.cs @@ -0,0 +1,38 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. + +using System.Buffers; + +namespace Microsoft.AspNetCore.SignalR.StackExchangeRedis.Internal; + +internal readonly struct RedisReturnResult +{ + /// + /// Gets the message serialization cache containing serialized payloads for the message. + /// + public object? Result { get; } + + public string InvocationId { get; } + + public RedisReturnResult(string invocationId, object? result) + { + InvocationId = invocationId; + Result = result; + } +} + +internal readonly struct RedisCompletion +{ + /// + /// Gets the message serialization cache containing serialized payloads for the message. + /// + public ReadOnlySequence CompletionMessage { get; } + + public string ProtocolName { get; } + + public RedisCompletion(string protocolName, ReadOnlySequence completionMessage) + { + ProtocolName = protocolName; + CompletionMessage = completionMessage; + } +} diff --git a/src/SignalR/server/StackExchangeRedis/src/Microsoft.AspNetCore.SignalR.StackExchangeRedis.csproj b/src/SignalR/server/StackExchangeRedis/src/Microsoft.AspNetCore.SignalR.StackExchangeRedis.csproj index ae0ff95ad8cb..a9c070a090fc 100644 --- a/src/SignalR/server/StackExchangeRedis/src/Microsoft.AspNetCore.SignalR.StackExchangeRedis.csproj +++ b/src/SignalR/server/StackExchangeRedis/src/Microsoft.AspNetCore.SignalR.StackExchangeRedis.csproj @@ -9,6 +9,7 @@ + diff --git a/src/SignalR/server/StackExchangeRedis/src/PublicAPI.Unshipped.txt b/src/SignalR/server/StackExchangeRedis/src/PublicAPI.Unshipped.txt index 68e96ee62e38..bdac4aa9fd92 100644 --- a/src/SignalR/server/StackExchangeRedis/src/PublicAPI.Unshipped.txt +++ b/src/SignalR/server/StackExchangeRedis/src/PublicAPI.Unshipped.txt @@ -3,3 +3,6 @@ *REMOVED*~Microsoft.AspNetCore.SignalR.StackExchangeRedis.RedisHubLifetimeManager.RedisHubLifetimeManager(Microsoft.Extensions.Logging.ILogger!>! logger, Microsoft.Extensions.Options.IOptions! options, Microsoft.AspNetCore.SignalR.IHubProtocolResolver! hubProtocolResolver, Microsoft.Extensions.Options.IOptions? globalHubOptions, Microsoft.Extensions.Options.IOptions!>? hubOptions) -> void Microsoft.AspNetCore.SignalR.StackExchangeRedis.RedisHubLifetimeManager.RedisHubLifetimeManager(Microsoft.Extensions.Logging.ILogger!>! logger, Microsoft.Extensions.Options.IOptions! options, Microsoft.AspNetCore.SignalR.IHubProtocolResolver! hubProtocolResolver) -> void Microsoft.AspNetCore.SignalR.StackExchangeRedis.RedisHubLifetimeManager.RedisHubLifetimeManager(Microsoft.Extensions.Logging.ILogger!>! logger, Microsoft.Extensions.Options.IOptions! options, Microsoft.AspNetCore.SignalR.IHubProtocolResolver! hubProtocolResolver, Microsoft.Extensions.Options.IOptions? globalHubOptions, Microsoft.Extensions.Options.IOptions!>? hubOptions) -> void +override Microsoft.AspNetCore.SignalR.StackExchangeRedis.RedisHubLifetimeManager.InvokeConnectionAsync(string! connectionId, string! methodName, object?[]! args, System.Threading.CancellationToken cancellationToken = default(System.Threading.CancellationToken)) -> System.Threading.Tasks.Task! +override Microsoft.AspNetCore.SignalR.StackExchangeRedis.RedisHubLifetimeManager.SetConnectionResultAsync(string! connectionId, Microsoft.AspNetCore.SignalR.Protocol.CompletionMessage! result) -> System.Threading.Tasks.Task! +override Microsoft.AspNetCore.SignalR.StackExchangeRedis.RedisHubLifetimeManager.TryGetReturnType(string! invocationId, out System.Type? type) -> bool diff --git a/src/SignalR/server/StackExchangeRedis/src/RedisHubLifetimeManager.cs b/src/SignalR/server/StackExchangeRedis/src/RedisHubLifetimeManager.cs index f3533789a336..eb916a810c7a 100644 --- a/src/SignalR/server/StackExchangeRedis/src/RedisHubLifetimeManager.cs +++ b/src/SignalR/server/StackExchangeRedis/src/RedisHubLifetimeManager.cs @@ -1,6 +1,9 @@ // Licensed to the .NET Foundation under one or more agreements. // The .NET Foundation licenses this file to you under the MIT license. +using System.Buffers; +using System.Diagnostics.CodeAnalysis; +using System.Globalization; using System.Linq; using System.Text; using Microsoft.AspNetCore.Http.Features; @@ -30,9 +33,12 @@ public class RedisHubLifetimeManager : HubLifetimeManager, IDisposab private readonly string _serverName = GenerateServerName(); private readonly RedisProtocol _protocol; private readonly SemaphoreSlim _connectionLock = new SemaphoreSlim(1); + private readonly IHubProtocolResolver _hubProtocolResolver; + private readonly ClientResultsManager _clientResultsManager = new(); private readonly AckHandler _ackHandler; - private int _internalId; + private int _internalAckId; + private ulong _lastInvocationId; /// /// Constructs the with types from Dependency Injection. @@ -61,6 +67,7 @@ public RedisHubLifetimeManager(ILogger> logger, IOptions? globalHubOptions, IOptions>? hubOptions) { + _hubProtocolResolver = hubProtocolResolver; _logger = logger; _options = options.Value; _ackHandler = new AckHandler(); @@ -131,6 +138,8 @@ public override Task OnDisconnectedAsync(HubConnectionContext connection) tasks.Add(RemoveUserAsync(connection)); } + _clientResultsManager.CleanupConnection(connection.ConnectionId, tasks); + return Task.WhenAll(tasks); } @@ -144,7 +153,7 @@ public override Task SendAllAsync(string methodName, object?[] args, Cancellatio /// public override Task SendAllExceptAsync(string methodName, object?[] args, IReadOnlyList excludedConnectionIds, CancellationToken cancellationToken = default) { - var message = _protocol.WriteInvocation(methodName, args, excludedConnectionIds); + var message = _protocol.WriteInvocation(methodName, args, excludedConnectionIds: excludedConnectionIds); return PublishAsync(_channels.All, message); } @@ -188,7 +197,7 @@ public override Task SendGroupExceptAsync(string groupName, string methodName, o throw new ArgumentNullException(nameof(groupName)); } - var message = _protocol.WriteInvocation(methodName, args, excludedConnectionIds); + var message = _protocol.WriteInvocation(methodName, args, excludedConnectionIds: excludedConnectionIds); return PublishAsync(_channels.Group(groupName), message); } @@ -306,11 +315,11 @@ public override Task SendUsersAsync(IReadOnlyList userIds, string method return Task.CompletedTask; } - private async Task PublishAsync(string channel, byte[] payload) + private async Task PublishAsync(string channel, byte[] payload) { await EnsureRedisServerConnection(); RedisLog.PublishToChannel(_logger, channel); - await _bus!.PublishAsync(channel, payload); + return await _bus!.PublishAsync(channel, payload); } private Task AddGroupAsyncCore(HubConnectionContext connection, string groupName) @@ -358,7 +367,7 @@ await _groups.RemoveSubscriptionAsync(groupChannel, connection, channelName => private async Task SendGroupActionAndWaitForAck(string connectionId, string groupName, GroupAction action) { - var id = Interlocked.Increment(ref _internalId); + var id = Interlocked.Increment(ref _internalAckId); var ack = _ackHandler.CreateAck(id); // Send Add/Remove Group to other servers and wait for an ack or timeout var message = RedisProtocol.WriteGroupCommand(new RedisGroupCommand(id, _serverName, action, groupName, connectionId)); @@ -388,6 +397,74 @@ public void Dispose() _ackHandler.Dispose(); } + /// + public override async Task InvokeConnectionAsync(string connectionId, string methodName, object?[] args, CancellationToken cancellationToken = default) + { + // send thing + if (connectionId == null) + { + throw new ArgumentNullException(nameof(connectionId)); + } + + var connection = _connections[connectionId]; + + var invocationId = Interlocked.Increment(ref _lastInvocationId).ToString(NumberFormatInfo.InvariantInfo); + var task = _clientResultsManager.AddInvocation(connectionId, invocationId, cancellationToken); + + try + { + if (connection == null) + { + // TODO: Need to handle other server going away while waiting for connection result + var m = _protocol.WriteInvocation(methodName, args, invocationId, returnChannel: _channels.ReturnResults(_serverName)); + var received = await PublishAsync(_channels.Connection(connectionId), m); + if (received < 1) + { + throw new InvalidOperationException("Connection does not exist."); + } + } + else + { + // Connection disconnected while adding invocation + // we need to try to remove it here to avoid the task hanging if the add happened after the connection cleanup + if (connection.ConnectionAborted.IsCancellationRequested) + { + await _clientResultsManager.TryCompleteResult(connectionId, CompletionMessage.WithError(invocationId, "Connection disconnected")); + return await task; + } + + // We're sending to a single connection + // Write message directly to connection without caching it in memory + var message = new InvocationMessage(invocationId, methodName, args); + + await connection.WriteAsync(message, cancellationToken).AsTask(); + } + } + catch + { + _clientResultsManager.RemoveInvocation(invocationId); + throw; + } + + return await task; + } + + /// + public override Task SetConnectionResultAsync(string connectionId, CompletionMessage result) + { + return _clientResultsManager.TryCompleteResult(connectionId, result); + } + + /// + public override bool TryGetReturnType(string invocationId, [NotNullWhen(true)] out Type? type) + { + if (_clientResultsManager.TryGetType(invocationId, out type)) + { + return true; + } + return false; + } + private async Task SubscribeToAll() { RedisLog.Subscribing(_logger, _channels.All); @@ -476,6 +553,32 @@ private async Task SubscribeToConnection(HubConnectionContext connection) channel.OnMessage(channelMessage => { var invocation = RedisProtocol.ReadInvocation((byte[])channelMessage.Message); + if (!string.IsNullOrEmpty(invocation.InvocationId)) + { + _clientResultsManager.AddInvocation(invocation.InvocationId, (typeof(RawResult), connection.ConnectionId, completionMessage => + { + var protocolName = connection.Protocol.Name; + var memoryBufferWriter = AspNetCore.Internal.MemoryBufferWriter.Get(); + try + { + connection.Protocol.WriteMessage(completionMessage, memoryBufferWriter); + // TODO: we can avoid this ToArray call + var message = RedisProtocol.WriteCompletionMessage(new ReadOnlySequence(memoryBufferWriter.ToArray()), protocolName); + return PublishAsync(invocation.ReturnChannel!, message); + } + finally + { + memoryBufferWriter.Dispose(); + } + } + )); + // Connection disconnected while adding invocation + // we need to try to remove it here to avoid the task hanging if the add happened after the connection cleanup + if (connection.ConnectionAborted.IsCancellationRequested) + { + return _clientResultsManager.TryCompleteResult(connection.ConnectionId, CompletionMessage.WithError(invocation.InvocationId, "Connection disconnected")); + } + } return connection.WriteAsync(invocation.Message).AsTask(); }); } @@ -540,6 +643,30 @@ private async Task SubscribeToGroupAsync(string groupChannel, HubConnectionStore }); } + private async Task SubscribeToReturnResultsAsync() + { + var channel = await _bus!.SubscribeAsync(_channels.ReturnResults(_serverName)); + channel.OnMessage((channelMessage) => + { + var completion = RedisProtocol.ReadCompletion(channelMessage.Message); + var protocol = _hubProtocolResolver.AllProtocols.Where(p => p.Name.Equals(completion.ProtocolName)).First(); + var ros = completion.CompletionMessage; + protocol.TryParseMessage(ref ros, _clientResultsManager, out var hubMessage); + + var invocationInfo = _clientResultsManager.RemoveInvocation(((CompletionMessage)hubMessage!).InvocationId!); + invocationInfo.Completion((CompletionMessage)hubMessage); + }); + } + + //private CompletionMessage GetCompletionMessage(ReadOnlyMemory rawMessage, string protocolName) + //{ + // var protocol = _hubProtocolResolver.AllProtocols.First(p => p.Name.Equals(protocolName)); + // var serialized = new System.Buffers.ReadOnlySequence(rawMessage); + // protocol.TryParseMessage(ref serialized, this, out var message); + // Debug.Assert(message is CompletionMessage); + // return (CompletionMessage)message; + //} + private async Task EnsureRedisServerConnection() { if (_redisServerConnection == null) @@ -589,6 +716,7 @@ private async Task EnsureRedisServerConnection() await SubscribeToAll(); await SubscribeToGroupManagementChannel(); await SubscribeToAckChannel(); + await SubscribeToReturnResultsAsync(); } } finally diff --git a/src/SignalR/server/StackExchangeRedis/test/RedisProtocolTests.cs b/src/SignalR/server/StackExchangeRedis/test/RedisProtocolTests.cs index dac4a79894c3..b87ee91e564d 100644 --- a/src/SignalR/server/StackExchangeRedis/test/RedisProtocolTests.cs +++ b/src/SignalR/server/StackExchangeRedis/test/RedisProtocolTests.cs @@ -2,8 +2,10 @@ // The .NET Foundation licenses this file to you under the MIT license. using System; +using System.Buffers; using System.Collections.Generic; using System.Linq; +using System.Text; using Microsoft.AspNetCore.SignalR.Internal; using Microsoft.AspNetCore.SignalR.Protocol; using Microsoft.AspNetCore.SignalR.StackExchangeRedis.Internal; @@ -176,7 +178,7 @@ public void WriteInvocation(string testName) // Actual invocation doesn't matter because we're using a dummy hub protocol. // But the dummy protocol will check that we gave it the test message to make sure everything flows through properly. var expected = testData.Decoded(); - var encoded = protocol.WriteInvocation(_testMessage.Target, _testMessage.Arguments, expected.ExcludedConnectionIds); + var encoded = protocol.WriteInvocation(_testMessage.Target, _testMessage.Arguments, excludedConnectionIds: expected.ExcludedConnectionIds); Assert.Equal(testData.Encoded, encoded); } @@ -192,7 +194,77 @@ public void WriteInvocationWithHubMessageSerializer(string testName) // Actual invocation doesn't matter because we're using a dummy hub protocol. // But the dummy protocol will check that we gave it the test message to make sure everything flows through properly. var expected = testData.Decoded(); - var encoded = protocol.WriteInvocation(_testMessage.Target, _testMessage.Arguments, expected.ExcludedConnectionIds); + var encoded = protocol.WriteInvocation(_testMessage.Target, _testMessage.Arguments, excludedConnectionIds: expected.ExcludedConnectionIds); + + Assert.Equal(testData.Encoded, encoded); + } + + private static readonly Dictionary> _completionMessageTestData = new[] + { + CreateTestData( + "JsonMessageForwarded", + new RedisCompletion("json", new ReadOnlySequence(Encoding.UTF8.GetBytes("{\"type\":3,\"invocationId\":\"1\",\"result\":1}"))), + 0x92, + 0xa4, + (byte)'j', + (byte)'s', + (byte)'o', + (byte)'n', + 0xc4, // 'bin' + 0x28, // length + 0x7b, 0x22, 0x74, 0x79, 0x70, 0x65, 0x22, 0x3a, 0x33, 0x2c, 0x22, 0x69, 0x6e, 0x76, 0x6f, 0x63, 0x61, 0x74, 0x69, + 0x6f, 0x6e, 0x49, 0x64, 0x22, 0x3a, 0x22, 0x31, 0x22, 0x2c, 0x22, 0x72, 0x65, 0x73, 0x75, 0x6c, 0x74, 0x22, 0x3a, + 0x31, 0x7d), + CreateTestData( + "MsgPackMessageForwarded", + new RedisCompletion("messagepack", new ReadOnlySequence(new byte[] { 0x95, 0x03, 0x80, 0xa3, (byte)'x', (byte)'y', (byte)'z', 0x03, 0x2a })), + 0x92, + 0xab, + (byte)'m', + (byte)'e', + (byte)'s', + (byte)'s', + (byte)'a', + (byte)'g', + (byte)'e', + (byte)'p', + (byte)'a', + (byte)'c', + (byte)'k', + 0xc4, // 'bin' + 0x09, // 'bin' length + 0x95, // 5 array elements + 0x03, // type: 3 + 0x80, // empty headers + 0xa3, // 'str' + (byte)'x', + (byte)'y', + (byte)'z', + 0x03, // has result + 0x2a), // 42 + }.ToDictionary(t => t.Name); + + public static IEnumerable CompletionMessageTestData = _completionMessageTestData.Keys.Select(k => new object[] { k }); + + [Theory] + [MemberData(nameof(CompletionMessageTestData))] + public void ParseCompletionMessage(string testName) + { + var testData = _completionMessageTestData[testName]; + + var completionMessage = RedisProtocol.ReadCompletion(testData.Encoded); + + Assert.Equal(testData.Decoded.ProtocolName, completionMessage.ProtocolName); + Assert.Equal(testData.Decoded.CompletionMessage.ToArray(), completionMessage.CompletionMessage.ToArray()); + } + + [Theory] + [MemberData(nameof(CompletionMessageTestData))] + public void WriteCompletionMessage(string testName) + { + var testData = _completionMessageTestData[testName]; + + var encoded = RedisProtocol.WriteCompletionMessage(testData.Decoded.CompletionMessage, testData.Decoded.ProtocolName); Assert.Equal(testData.Encoded, encoded); } From cdc25125c0b38033cabdbc57fdfdbce9cff1bb40 Mon Sep 17 00:00:00 2001 From: Brennan Date: Mon, 21 Mar 2022 09:32:03 -0700 Subject: [PATCH 02/11] some cleanup --- .../csharp/Client.Core/src/HubConnection.cs | 21 ++- .../src/HubConnectionExtensions.OnResult.cs | 151 +++++++++++------- .../src/HubConnectionExtensions.cs | 18 --- .../src/Protocol/JsonHubProtocol.cs | 1 + .../SignalR.Common/src/Protocol/RawResult.cs | 8 +- .../server/Core/src/ClientProxyExtensions.cs | 65 ++++---- .../server/Core/src/HubLifetimeManager.cs | 30 ++-- src/SignalR/server/Core/src/IHubClients.cs | 7 +- src/SignalR/server/Core/src/IHubClients`T.cs | 7 +- .../server/Core/src/ISingleClientProxy.cs | 17 +- .../Core/src/Internal/DefaultHubDispatcher.cs | 5 +- ...edisReturnResult.cs => RedisCompletion.cs} | 19 --- .../src/Internal/RedisProtocol.cs | 3 - .../src/RedisHubLifetimeManager.cs | 15 +- 14 files changed, 169 insertions(+), 198 deletions(-) rename src/SignalR/server/StackExchangeRedis/src/Internal/{RedisReturnResult.cs => RedisCompletion.cs} (51%) diff --git a/src/SignalR/clients/csharp/Client.Core/src/HubConnection.cs b/src/SignalR/clients/csharp/Client.Core/src/HubConnection.cs index 42d4707e434c..15bbf7648f38 100644 --- a/src/SignalR/clients/csharp/Client.Core/src/HubConnection.cs +++ b/src/SignalR/clients/csharp/Client.Core/src/HubConnection.cs @@ -313,13 +313,17 @@ public virtual async ValueTask DisposeAsync() } /// - /// + /// Registers a handler that will be invoked when the hub method with the specified method name is invoked. + /// Returns value returned by handler to server if the server requests a result. /// - /// - /// - /// - /// - /// + /// The name of the hub method to define. + /// The parameters types expected by the hub method. + /// The handler that will be raised when the hub method is invoked. + /// A state object that will be passed to the handler. + /// A subscription that can be disposed to unsubscribe from the hub method. + /// + /// This is a low level method for registering a handler. Using an On extension method is recommended. + /// public virtual IDisposable On(string methodName, Type[] parameterTypes, Func> handler, object state) { Log.RegisteringHandler(_logger, methodName); @@ -1077,6 +1081,10 @@ private async Task DispatchInvocationAsync(InvocationMessage invocation, Connect await SendWithLock(connectionState, CompletionMessage.WithResult(invocation.InvocationId!, result), cancellationToken: default).ConfigureAwait(false); } } + else if (hasResult) + { + // Log: result given but server didn't ask for one. + } } private async Task DispatchInvocationStreamItemAsync(StreamItemMessage streamItem, InvocationRequest irq) @@ -1739,7 +1747,6 @@ public InvocationHandler(Type[] parameterTypes, Func ca _callback = callback; ParameterTypes = parameterTypes; _state = state; - } public Task InvokeAsync(object?[] parameters) diff --git a/src/SignalR/clients/csharp/Client.Core/src/HubConnectionExtensions.OnResult.cs b/src/SignalR/clients/csharp/Client.Core/src/HubConnectionExtensions.OnResult.cs index 098fcd1bcbe3..878ac788681e 100644 --- a/src/SignalR/clients/csharp/Client.Core/src/HubConnectionExtensions.OnResult.cs +++ b/src/SignalR/clients/csharp/Client.Core/src/HubConnectionExtensions.OnResult.cs @@ -20,14 +20,33 @@ private static IDisposable On(this HubConnection hubConnection, string } /// - /// + /// Registers a handler that will be invoked when the hub method with the specified method name is invoked. + /// Returns value returned by handler to server if the server requests a result. + /// + /// The return type the handler returns. + /// The hub connection. + /// The name of the hub method to define. + /// The parameters types expected by the hub method. + /// The handler that will be raised when the hub method is invoked. + /// A subscription that can be disposed to unsubscribe from the hub method. + public static IDisposable On(this HubConnection hubConnection, string methodName, Type[] parameterTypes, Func> handler) + { + return hubConnection.On(methodName, parameterTypes, async (parameters, state) => + { + var currentHandler = (Func>)state; + return await currentHandler(parameters).ConfigureAwait(false); + }, handler); + } + + /// + /// Registers a handler that will be invoked when the hub method with the specified method name is invoked. + /// Returns value returned by handler to server if the server requests a result. /// - /// - /// - /// - /// - /// - /// + /// The return type the handler returns. + /// The hub connection. + /// The name of the hub method to define. + /// The handler that will be raised when the hub method is invoked. + /// A subscription that can be disposed to unsubscribe from the hub method. public static IDisposable On(this HubConnection hubConnection, string methodName, Func> handler) { if (hubConnection == null) @@ -39,14 +58,14 @@ public static IDisposable On(this HubConnection hubConnection, string m } /// - /// + /// Registers a handler that will be invoked when the hub method with the specified method name is invoked. + /// Returns value returned by handler to server if the server requests a result. /// - /// - /// - /// - /// - /// - /// + /// The return type the handler returns. + /// The hub connection. + /// The name of the hub method to define. + /// The handler that will be raised when the hub method is invoked. + /// A subscription that can be disposed to unsubscribe from the hub method. public static IDisposable On(this HubConnection hubConnection, string methodName, Func handler) { if (hubConnection == null) @@ -58,15 +77,15 @@ public static IDisposable On(this HubConnection hubConnection, string m } /// - /// + /// Registers a handler that will be invoked when the hub method with the specified method name is invoked. + /// Returns value returned by handler to server if the server requests a result. /// - /// - /// - /// - /// - /// - /// - /// + /// The first argument type. + /// The return type the handler returns. + /// The hub connection. + /// The name of the hub method to define. + /// The handler that will be raised when the hub method is invoked. + /// A subscription that can be disposed to unsubscribe from the hub method. public static IDisposable On(this HubConnection hubConnection, string methodName, Func handler) { if (hubConnection == null) @@ -80,16 +99,16 @@ public static IDisposable On(this HubConnection hubConnection, stri } /// - /// + /// Registers a handler that will be invoked when the hub method with the specified method name is invoked. + /// Returns value returned by handler to server if the server requests a result. /// - /// - /// /// - /// - /// - /// - /// - /// - /// + /// The first argument type. + /// The second argument type. + /// The return type the handler returns. + /// The hub connection. + /// The name of the hub method to define. + /// The handler that will be raised when the hub method is invoked. + /// A subscription that can be disposed to unsubscribe from the hub method. public static IDisposable On(this HubConnection hubConnection, string methodName, Func handler) { if (hubConnection == null) @@ -104,11 +123,12 @@ public static IDisposable On(this HubConnection hubConnection, /// /// Registers a handler that will be invoked when the hub method with the specified method name is invoked. + /// Returns value returned by handler to server if the server requests a result. /// /// The first argument type. /// The second argument type. /// The third argument type. - /// + /// The return type the handler returns. /// The hub connection. /// The name of the hub method to define. /// The handler that will be raised when the hub method is invoked. @@ -127,12 +147,13 @@ public static IDisposable On(this HubConnection hubConnecti /// /// Registers a handler that will be invoked when the hub method with the specified method name is invoked. + /// Returns value returned by handler to server if the server requests a result. /// /// The first argument type. /// The second argument type. /// The third argument type. /// The fourth argument type. - /// + /// The return type the handler returns. /// The hub connection. /// The name of the hub method to define. /// The handler that will be raised when the hub method is invoked. @@ -151,13 +172,14 @@ public static IDisposable On(this HubConnection hubConn /// /// Registers a handler that will be invoked when the hub method with the specified method name is invoked. + /// Returns value returned by handler to server if the server requests a result. /// /// The first argument type. /// The second argument type. /// The third argument type. /// The fourth argument type. /// The fifth argument type. - /// + /// The return type the handler returns. /// The hub connection. /// The name of the hub method to define. /// The handler that will be raised when the hub method is invoked. @@ -176,6 +198,7 @@ public static IDisposable On(this HubConnection hub /// /// Registers a handler that will be invoked when the hub method with the specified method name is invoked. + /// Returns value returned by handler to server if the server requests a result. /// /// The first argument type. /// The second argument type. @@ -183,7 +206,7 @@ public static IDisposable On(this HubConnection hub /// The fourth argument type. /// The fifth argument type. /// The sixth argument type. - /// + /// The return type the handler returns. /// The hub connection. /// The name of the hub method to define. /// The handler that will be raised when the hub method is invoked. @@ -202,6 +225,7 @@ public static IDisposable On(this HubConnection /// /// Registers a handler that will be invoked when the hub method with the specified method name is invoked. + /// Returns value returned by handler to server if the server requests a result. /// /// The first argument type. /// The second argument type. @@ -210,7 +234,7 @@ public static IDisposable On(this HubConnection /// The fifth argument type. /// The sixth argument type. /// The seventh argument type. - /// + /// The return type the handler returns. /// The hub connection. /// The name of the hub method to define. /// The handler that will be raised when the hub method is invoked. @@ -229,6 +253,7 @@ public static IDisposable On(this HubConnec /// /// Registers a handler that will be invoked when the hub method with the specified method name is invoked. + /// Returns value returned by handler to server if the server requests a result. /// /// The first argument type. /// The second argument type. @@ -238,7 +263,7 @@ public static IDisposable On(this HubConnec /// The sixth argument type. /// The seventh argument type. /// The eighth argument type. - /// + /// The return type the handler returns. /// The hub connection. /// The name of the hub method to define. /// The handler that will be raised when the hub method is invoked. @@ -256,15 +281,15 @@ public static IDisposable On(this HubCo } /// - /// + /// Registers a handler that will be invoked when the hub method with the specified method name is invoked. + /// Returns value returned by handler to server if the server requests a result. /// - /// - /// - /// - /// - /// - /// - /// + /// The first argument type. + /// The return type the handler returns. + /// The hub connection. + /// The name of the hub method to define. + /// The handler that will be raised when the hub method is invoked. + /// A subscription that can be disposed to unsubscribe from the hub method. public static IDisposable On(this HubConnection hubConnection, string methodName, Func> handler) { if (hubConnection == null) @@ -278,16 +303,16 @@ public static IDisposable On(this HubConnection hubConnection, stri } /// - /// + /// Registers a handler that will be invoked when the hub method with the specified method name is invoked. + /// Returns value returned by handler to server if the server requests a result. /// - /// - /// - /// - /// - /// - /// - /// - /// + /// The first argument type. + /// The second argument type. + /// The return type the handler returns. + /// The hub connection. + /// The name of the hub method to define. + /// The handler that will be raised when the hub method is invoked. + /// A subscription that can be disposed to unsubscribe from the hub method. public static IDisposable On(this HubConnection hubConnection, string methodName, Func> handler) { if (hubConnection == null) @@ -302,11 +327,12 @@ public static IDisposable On(this HubConnection hubConnection, /// /// Registers a handler that will be invoked when the hub method with the specified method name is invoked. + /// Returns value returned by handler to server if the server requests a result. /// /// The first argument type. /// The second argument type. /// The third argument type. - /// + /// The return type the handler returns. /// The hub connection. /// The name of the hub method to define. /// The handler that will be raised when the hub method is invoked. @@ -325,12 +351,13 @@ public static IDisposable On(this HubConnection hubConnecti /// /// Registers a handler that will be invoked when the hub method with the specified method name is invoked. + /// Returns value returned by handler to server if the server requests a result. /// /// The first argument type. /// The second argument type. /// The third argument type. /// The fourth argument type. - /// + /// The return type the handler returns. /// The hub connection. /// The name of the hub method to define. /// The handler that will be raised when the hub method is invoked. @@ -349,13 +376,14 @@ public static IDisposable On(this HubConnection hubConn /// /// Registers a handler that will be invoked when the hub method with the specified method name is invoked. + /// Returns value returned by handler to server if the server requests a result. /// /// The first argument type. /// The second argument type. /// The third argument type. /// The fourth argument type. /// The fifth argument type. - /// + /// The return type the handler returns. /// The hub connection. /// The name of the hub method to define. /// The handler that will be raised when the hub method is invoked. @@ -374,6 +402,7 @@ public static IDisposable On(this HubConnection hub /// /// Registers a handler that will be invoked when the hub method with the specified method name is invoked. + /// Returns value returned by handler to server if the server requests a result. /// /// The first argument type. /// The second argument type. @@ -381,7 +410,7 @@ public static IDisposable On(this HubConnection hub /// The fourth argument type. /// The fifth argument type. /// The sixth argument type. - /// + /// The return type the handler returns. /// The hub connection. /// The name of the hub method to define. /// The handler that will be raised when the hub method is invoked. @@ -400,6 +429,7 @@ public static IDisposable On(this HubConnection /// /// Registers a handler that will be invoked when the hub method with the specified method name is invoked. + /// Returns value returned by handler to server if the server requests a result. /// /// The first argument type. /// The second argument type. @@ -408,7 +438,7 @@ public static IDisposable On(this HubConnection /// The fifth argument type. /// The sixth argument type. /// The seventh argument type. - /// + /// The return type the handler returns. /// The hub connection. /// The name of the hub method to define. /// The handler that will be raised when the hub method is invoked. @@ -427,6 +457,7 @@ public static IDisposable On(this HubConnec /// /// Registers a handler that will be invoked when the hub method with the specified method name is invoked. + /// Returns value returned by handler to server if the server requests a result. /// /// The first argument type. /// The second argument type. @@ -436,7 +467,7 @@ public static IDisposable On(this HubConnec /// The sixth argument type. /// The seventh argument type. /// The eighth argument type. - /// + /// The return type the handler returns. /// The hub connection. /// The name of the hub method to define. /// The handler that will be raised when the hub method is invoked. diff --git a/src/SignalR/clients/csharp/Client.Core/src/HubConnectionExtensions.cs b/src/SignalR/clients/csharp/Client.Core/src/HubConnectionExtensions.cs index 1da69401294b..ea131eff6b35 100644 --- a/src/SignalR/clients/csharp/Client.Core/src/HubConnectionExtensions.cs +++ b/src/SignalR/clients/csharp/Client.Core/src/HubConnectionExtensions.cs @@ -243,24 +243,6 @@ public static IDisposable On(this HubConnection hubConnection, string methodName }, handler); } - /// - /// - /// - /// - /// - /// - /// - /// - /// - public static IDisposable On(this HubConnection hubConnection, string methodName, Type[] parameterTypes, Func> handler) - { - return hubConnection.On(methodName, parameterTypes, async (parameters, state) => - { - var currentHandler = (Func>)state; - return await currentHandler(parameters).ConfigureAwait(false); - }, handler); - } - /// /// Registers a handler that will be invoked when the hub method with the specified method name is invoked. /// diff --git a/src/SignalR/common/Protocols.Json/src/Protocol/JsonHubProtocol.cs b/src/SignalR/common/Protocols.Json/src/Protocol/JsonHubProtocol.cs index b21a008ffadc..33e9a74fcfb9 100644 --- a/src/SignalR/common/Protocols.Json/src/Protocol/JsonHubProtocol.cs +++ b/src/SignalR/common/Protocols.Json/src/Protocol/JsonHubProtocol.cs @@ -209,6 +209,7 @@ public ReadOnlyMemory GetMessageBytes(HubMessage message) } else { + // If we have an invocation id already we can parse the end result var returnType = binder.GetReturnType(invocationId); result = BindType(ref reader, input, returnType); } diff --git a/src/SignalR/common/SignalR.Common/src/Protocol/RawResult.cs b/src/SignalR/common/SignalR.Common/src/Protocol/RawResult.cs index 8aecffb30967..6a56ba0c5c98 100644 --- a/src/SignalR/common/SignalR.Common/src/Protocol/RawResult.cs +++ b/src/SignalR/common/SignalR.Common/src/Protocol/RawResult.cs @@ -15,21 +15,21 @@ namespace Microsoft.AspNetCore.SignalR.Protocol; /// stored as raw serialized bytes in the format of the protocol being used. /// /// -/// In Json that would mean storing the bytes of {"prop":10} as an example. +/// In Json that would mean storing the byte representation of ascii {"prop":10} as an example. /// public class RawResult { /// - /// + /// Stores the raw serialized bytes of a for forwarding to another server. /// - /// + /// The raw bytes from the client. public RawResult(ReadOnlySequence rawBytes) { RawSerializedData = rawBytes; } /// - /// + /// The raw serialized bytes from the client. /// public ReadOnlySequence RawSerializedData { get; private set; } } diff --git a/src/SignalR/server/Core/src/ClientProxyExtensions.cs b/src/SignalR/server/Core/src/ClientProxyExtensions.cs index fe01852358d3..7b2f0c8c8b8f 100644 --- a/src/SignalR/server/Core/src/ClientProxyExtensions.cs +++ b/src/SignalR/server/Core/src/ClientProxyExtensions.cs @@ -163,7 +163,7 @@ public static Task SendAsync(this IClientProxy clientProxy, string method, objec /// The fifth argument. /// The sixth argument. /// The seventh argument. - /// The eigth argument. + /// The eighth argument. /// The token to monitor for cancellation requests. The default value is . /// A that represents the asynchronous invoke. [SuppressMessage("ApiDesign", "RS0026:Do not add multiple overloads with optional parameters", Justification = "Required to maintain compatibility")] @@ -185,7 +185,7 @@ public static Task SendAsync(this IClientProxy clientProxy, string method, objec /// The fifth argument. /// The sixth argument. /// The seventh argument. - /// The eigth argument. + /// The eighth argument. /// The ninth argument. /// The token to monitor for cancellation requests. The default value is . /// A that represents the asynchronous invoke. @@ -208,7 +208,7 @@ public static Task SendAsync(this IClientProxy clientProxy, string method, objec /// The fifth argument. /// The sixth argument. /// The seventh argument. - /// The eigth argument. + /// The eighth argument. /// The ninth argument. /// The tenth argument. /// The token to monitor for cancellation requests. The default value is . @@ -220,10 +220,9 @@ public static Task SendAsync(this IClientProxy clientProxy, string method, objec } /// - /// Invokes a method on the connection(s) represented by the instance. - /// Does not wait for a response from the receiver. + /// Invokes a method on the connection represented by the instance and waits for a response. /// - /// The + /// The . /// The name of the method to invoke. /// The token to monitor for cancellation requests. The default value is . /// A that represents the asynchronous invoke. @@ -234,10 +233,9 @@ public static Task InvokeAsync(this ISingleClientProxy clientProxy, string } /// - /// Invokes a method on the connection(s) represented by the instance. - /// Does not wait for a response from the receiver. + /// Invokes a method on the connection represented by the instance and waits for a response. /// - /// The + /// The . /// The name of the method to invoke. /// The first argument. /// The token to monitor for cancellation requests. The default value is . @@ -249,10 +247,9 @@ public static Task InvokeAsync(this ISingleClientProxy clientProxy, string } /// - /// Invokes a method on the connection(s) represented by the instance. - /// Does not wait for a response from the receiver. + /// Invokes a method on the connection represented by the instance and waits for a response. /// - /// The + /// The . /// The name of the method to invoke. /// The first argument. /// The second argument. @@ -265,10 +262,9 @@ public static Task InvokeAsync(this ISingleClientProxy clientProxy, string } /// - /// Invokes a method on the connection(s) represented by the instance. - /// Does not wait for a response from the receiver. + /// Invokes a method on the connection represented by the instance and waits for a response. /// - /// The + /// The . /// The name of the method to invoke. /// The first argument. /// The second argument. @@ -282,10 +278,9 @@ public static Task InvokeAsync(this ISingleClientProxy clientProxy, string } /// - /// Invokes a method on the connection(s) represented by the instance. - /// Does not wait for a response from the receiver. + /// Invokes a method on the connection represented by the instance and waits for a response. /// - /// The + /// The . /// The name of the method to invoke. /// The first argument. /// The second argument. @@ -300,10 +295,9 @@ public static Task InvokeAsync(this ISingleClientProxy clientProxy, string } /// - /// Invokes a method on the connection(s) represented by the instance. - /// Does not wait for a response from the receiver. + /// Invokes a method on the connection represented by the instance and waits for a response. /// - /// The + /// The . /// The name of the method to invoke. /// The first argument. /// The second argument. @@ -319,10 +313,9 @@ public static Task InvokeAsync(this ISingleClientProxy clientProxy, string } /// - /// Invokes a method on the connection(s) represented by the instance. - /// Does not wait for a response from the receiver. + /// Invokes a method on the connection represented by the instance and waits for a response. /// - /// The + /// The . /// The name of the method to invoke. /// The first argument. /// The second argument. @@ -339,8 +332,7 @@ public static Task InvokeAsync(this ISingleClientProxy clientProxy, string } /// - /// Invokes a method on the connection(s) represented by the instance. - /// Does not wait for a response from the receiver. + /// Invokes a method on the connection represented by the instance and waits for a response. /// /// The /// The name of the method to invoke. @@ -360,10 +352,9 @@ public static Task InvokeAsync(this ISingleClientProxy clientProxy, string } /// - /// Invokes a method on the connection(s) represented by the instance. - /// Does not wait for a response from the receiver. + /// Invokes a method on the connection represented by the instance and waits for a response. /// - /// The + /// The . /// The name of the method to invoke. /// The first argument. /// The second argument. @@ -372,7 +363,7 @@ public static Task InvokeAsync(this ISingleClientProxy clientProxy, string /// The fifth argument. /// The sixth argument. /// The seventh argument. - /// The eigth argument. + /// The eighth argument. /// The token to monitor for cancellation requests. The default value is . /// A that represents the asynchronous invoke. [SuppressMessage("ApiDesign", "RS0026:Do not add multiple overloads with optional parameters", Justification = "Required to maintain compatibility")] @@ -382,10 +373,9 @@ public static Task InvokeAsync(this ISingleClientProxy clientProxy, string } /// - /// Invokes a method on the connection(s) represented by the instance. - /// Does not wait for a response from the receiver. + /// Invokes a method on the connection represented by the instance and waits for a response. /// - /// The + /// The . /// The name of the method to invoke. /// The first argument. /// The second argument. @@ -394,7 +384,7 @@ public static Task InvokeAsync(this ISingleClientProxy clientProxy, string /// The fifth argument. /// The sixth argument. /// The seventh argument. - /// The eigth argument. + /// The eighth argument. /// The ninth argument. /// The token to monitor for cancellation requests. The default value is . /// A that represents the asynchronous invoke. @@ -405,10 +395,9 @@ public static Task InvokeAsync(this ISingleClientProxy clientProxy, string } /// - /// Invokes a method on the connection(s) represented by the instance. - /// Does not wait for a response from the receiver. + /// Invokes a method on the connection represented by the instance and waits for a response. /// - /// The + /// The . /// The name of the method to invoke. /// The first argument. /// The second argument. @@ -417,7 +406,7 @@ public static Task InvokeAsync(this ISingleClientProxy clientProxy, string /// The fifth argument. /// The sixth argument. /// The seventh argument. - /// The eigth argument. + /// The eighth argument. /// The ninth argument. /// The tenth argument. /// The token to monitor for cancellation requests. The default value is . diff --git a/src/SignalR/server/Core/src/HubLifetimeManager.cs b/src/SignalR/server/Core/src/HubLifetimeManager.cs index c5050dbfde57..897993d3235e 100644 --- a/src/SignalR/server/Core/src/HubLifetimeManager.cs +++ b/src/SignalR/server/Core/src/HubLifetimeManager.cs @@ -136,37 +136,35 @@ public abstract class HubLifetimeManager where THub : Hub public abstract Task RemoveFromGroupAsync(string connectionId, string groupName, CancellationToken cancellationToken = default); /// - /// + /// Sends an invocation message to the specified connection and waits for a response. /// - /// - /// - /// - /// - /// - /// - /// + /// The type of the response expected. + /// The connection ID. + /// The invocation method name. + /// The invocation arguments. + /// The token to monitor for cancellation requests. The default value is . + /// The response from the connection. public virtual Task InvokeConnectionAsync(string connectionId, string methodName, object?[] args, CancellationToken cancellationToken = default) { throw new NotImplementedException(); } /// - /// + /// Sets the connection result for an in progress call. /// - /// - /// - /// - /// + /// The connection ID. + /// The result from the connection. + /// A that represents the result being set or being forwarded to another server. public virtual Task SetConnectionResultAsync(string connectionId, CompletionMessage result) { throw new NotImplementedException(); } /// - /// + /// Tells implementations what the expected type from a connection result is. /// - /// - /// + /// The ID of the in progress invocation. + /// The type the connection is expected to send. Or if the result is intended for another server. /// public virtual bool TryGetReturnType(string invocationId, [NotNullWhen(true)] out Type? type) { diff --git a/src/SignalR/server/Core/src/IHubClients.cs b/src/SignalR/server/Core/src/IHubClients.cs index 06ecf8a606cb..3646d4bc8258 100644 --- a/src/SignalR/server/Core/src/IHubClients.cs +++ b/src/SignalR/server/Core/src/IHubClients.cs @@ -9,10 +9,9 @@ namespace Microsoft.AspNetCore.SignalR; public interface IHubClients : IHubClients { /// - /// + /// Gets a proxy that can be used to invoke methods on a single client connected to the hub and receive results. /// - /// - /// - /// + /// The connection ID. + /// A client caller. new ISingleClientProxy Single(string connectionId) => throw new NotImplementedException(); } diff --git a/src/SignalR/server/Core/src/IHubClients`T.cs b/src/SignalR/server/Core/src/IHubClients`T.cs index 0479d9fb87e6..0dee6f33b19a 100644 --- a/src/SignalR/server/Core/src/IHubClients`T.cs +++ b/src/SignalR/server/Core/src/IHubClients`T.cs @@ -10,11 +10,10 @@ namespace Microsoft.AspNetCore.SignalR; public interface IHubClients { /// - /// + /// Gets a that can be used to invoke methods on a single client connected to the hub and receive results. /// - /// - /// - /// + /// The connection ID. + /// A client caller. T Single(string connectionId) => throw new NotImplementedException(); /// diff --git a/src/SignalR/server/Core/src/ISingleClientProxy.cs b/src/SignalR/server/Core/src/ISingleClientProxy.cs index f077baeb08a7..f400b13e6acc 100644 --- a/src/SignalR/server/Core/src/ISingleClientProxy.cs +++ b/src/SignalR/server/Core/src/ISingleClientProxy.cs @@ -4,18 +4,21 @@ namespace Microsoft.AspNetCore.SignalR; /// -/// +/// A proxy abstraction for invoking hub methods on the client and getting a result. /// public interface ISingleClientProxy : IClientProxy { + // client proxy method is called InvokeCoreAsync instead of InvokeAsync so that arrays of references + // like string[], e.g. InvokeAsync(string, string[]), do not choose InvokeAsync(string, object[]) + // over InvokeAsync(string, object) overload + /// - /// + /// Invokes a method on the connection represented by the instance and waits for a result. /// /// - /// - /// - /// - /// - /// + /// Name of the method to invoke. + /// A collection of arguments to pass to the client. + /// The token to monitor for cancellation requests. The default value is . + /// A that represents the asynchronous invoke and wait for a client result. Task InvokeCoreAsync(string method, object?[] args, CancellationToken cancellationToken = default); } diff --git a/src/SignalR/server/Core/src/Internal/DefaultHubDispatcher.cs b/src/SignalR/server/Core/src/Internal/DefaultHubDispatcher.cs index 15eb40a80aa0..bd173ee7079f 100644 --- a/src/SignalR/server/Core/src/Internal/DefaultHubDispatcher.cs +++ b/src/SignalR/server/Core/src/Internal/DefaultHubDispatcher.cs @@ -178,7 +178,7 @@ public override Task DispatchMessageAsync(HubConnectionContext connection, HubMe // InvocationId is always required on CompletionMessage, it's nullable because of the base type else if (_hubLifetimeManager.TryGetReturnType(completionMessage.InvocationId!, out _)) { - _hubLifetimeManager.SetConnectionResultAsync(connection.ConnectionId, completionMessage); + return _hubLifetimeManager.SetConnectionResultAsync(connection.ConnectionId, completionMessage); } else { @@ -187,9 +187,6 @@ public override Task DispatchMessageAsync(HubConnectionContext connection, HubMe } break; - //case ClientResultMessage clientResultMessage: - // return _hubLifetimeManager.SetClientResult(clientResultMessage); - // Other kind of message we weren't expecting default: Log.UnsupportedMessageReceived(_logger, hubMessage.GetType().FullName!); diff --git a/src/SignalR/server/StackExchangeRedis/src/Internal/RedisReturnResult.cs b/src/SignalR/server/StackExchangeRedis/src/Internal/RedisCompletion.cs similarity index 51% rename from src/SignalR/server/StackExchangeRedis/src/Internal/RedisReturnResult.cs rename to src/SignalR/server/StackExchangeRedis/src/Internal/RedisCompletion.cs index 85a718fc58a4..c950d2f2bbfd 100644 --- a/src/SignalR/server/StackExchangeRedis/src/Internal/RedisReturnResult.cs +++ b/src/SignalR/server/StackExchangeRedis/src/Internal/RedisCompletion.cs @@ -5,27 +5,8 @@ namespace Microsoft.AspNetCore.SignalR.StackExchangeRedis.Internal; -internal readonly struct RedisReturnResult -{ - /// - /// Gets the message serialization cache containing serialized payloads for the message. - /// - public object? Result { get; } - - public string InvocationId { get; } - - public RedisReturnResult(string invocationId, object? result) - { - InvocationId = invocationId; - Result = result; - } -} - internal readonly struct RedisCompletion { - /// - /// Gets the message serialization cache containing serialized payloads for the message. - /// public ReadOnlySequence CompletionMessage { get; } public string ProtocolName { get; } diff --git a/src/SignalR/server/StackExchangeRedis/src/Internal/RedisProtocol.cs b/src/SignalR/server/StackExchangeRedis/src/Internal/RedisProtocol.cs index 557be073b5b9..00d18b80aee7 100644 --- a/src/SignalR/server/StackExchangeRedis/src/Internal/RedisProtocol.cs +++ b/src/SignalR/server/StackExchangeRedis/src/Internal/RedisProtocol.cs @@ -32,9 +32,6 @@ public RedisProtocol(DefaultHubMessageSerializer messageSerializer) // * The "Length prefixed string" is the string format used by BinaryReader/BinaryWriter: // * A 7-bit variable length integer encodes the length in bytes, followed by the encoded string in UTF-8. - //public byte[] WriteInvocation(string methodName, object?[] args, string? invocationId = null) => - // WriteInvocation(methodName, args, invocationId, excludedConnectionIds: null); - public byte[] WriteInvocation(string methodName, object?[] args, string? invocationId = null, IReadOnlyList? excludedConnectionIds = null, string? returnChannel = null) { diff --git a/src/SignalR/server/StackExchangeRedis/src/RedisHubLifetimeManager.cs b/src/SignalR/server/StackExchangeRedis/src/RedisHubLifetimeManager.cs index eb916a810c7a..b46fb6e73f19 100644 --- a/src/SignalR/server/StackExchangeRedis/src/RedisHubLifetimeManager.cs +++ b/src/SignalR/server/StackExchangeRedis/src/RedisHubLifetimeManager.cs @@ -458,11 +458,7 @@ public override Task SetConnectionResultAsync(string connectionId, CompletionMes /// public override bool TryGetReturnType(string invocationId, [NotNullWhen(true)] out Type? type) { - if (_clientResultsManager.TryGetType(invocationId, out type)) - { - return true; - } - return false; + return _clientResultsManager.TryGetType(invocationId, out type); } private async Task SubscribeToAll() @@ -658,15 +654,6 @@ private async Task SubscribeToReturnResultsAsync() }); } - //private CompletionMessage GetCompletionMessage(ReadOnlyMemory rawMessage, string protocolName) - //{ - // var protocol = _hubProtocolResolver.AllProtocols.First(p => p.Name.Equals(protocolName)); - // var serialized = new System.Buffers.ReadOnlySequence(rawMessage); - // protocol.TryParseMessage(ref serialized, this, out var message); - // Debug.Assert(message is CompletionMessage); - // return (CompletionMessage)message; - //} - private async Task EnsureRedisServerConnection() { if (_redisServerConnection == null) From 58299a897d2947a487afe1aea24151dffe35d675 Mon Sep 17 00:00:00 2001 From: Brennan Date: Mon, 21 Mar 2022 16:36:53 -0700 Subject: [PATCH 03/11] cleanup --- .../java/com/microsoft/signalr/Action.java | 2 +- .../com/microsoft/signalr/ActionBase.java | 4 - .../com/microsoft/signalr/CallbackMap.java | 2 +- .../signalr/ClientResultMessage.java | 47 ------- .../java/com/microsoft/signalr/Function.java | 5 - .../com/microsoft/signalr/HubConnection.java | 77 +++-------- .../com/microsoft/signalr/HubMessageType.java | 1 - .../microsoft/signalr/InvocationHandler.java | 14 +- .../com/microsoft/signalr/sample/Chat.java | 4 - .../src/Protocol/JsonHubProtocol.cs | 7 +- .../common/Shared/ClientResultsManager.cs | 2 + .../SignalR.Common/src/Protocol/RawResult.cs | 4 +- src/SignalR/samples/ClientSample/HubSample.cs | 124 +++++------------- .../samples/SignalRSamples/Hubs/Chat.cs | 9 +- .../samples/SignalRSamples/Hubs/GameHub.cs | 24 ---- src/SignalR/samples/SignalRSamples/Startup.cs | 67 +--------- .../samples/SignalRSamples/wwwroot/hubs.html | 12 +- .../Core/src/DefaultHubLifetimeManager.cs | 2 +- .../server/Core/src/IHubCallerClients.cs | 7 +- 19 files changed, 78 insertions(+), 336 deletions(-) delete mode 100644 src/SignalR/clients/java/signalr/core/src/main/java/com/microsoft/signalr/ClientResultMessage.java delete mode 100644 src/SignalR/clients/java/signalr/core/src/main/java/com/microsoft/signalr/Function.java delete mode 100644 src/SignalR/samples/SignalRSamples/Hubs/GameHub.cs diff --git a/src/SignalR/clients/java/signalr/core/src/main/java/com/microsoft/signalr/Action.java b/src/SignalR/clients/java/signalr/core/src/main/java/com/microsoft/signalr/Action.java index 653ab10e69e1..fd1216e15431 100644 --- a/src/SignalR/clients/java/signalr/core/src/main/java/com/microsoft/signalr/Action.java +++ b/src/SignalR/clients/java/signalr/core/src/main/java/com/microsoft/signalr/Action.java @@ -10,4 +10,4 @@ public interface Action { // We can't use the @FunctionalInterface annotation because it's only // available on Android API Level 24 and above. void invoke(); -} \ No newline at end of file +} diff --git a/src/SignalR/clients/java/signalr/core/src/main/java/com/microsoft/signalr/ActionBase.java b/src/SignalR/clients/java/signalr/core/src/main/java/com/microsoft/signalr/ActionBase.java index 4e5fcf1b5d9f..e24630b53d23 100644 --- a/src/SignalR/clients/java/signalr/core/src/main/java/com/microsoft/signalr/ActionBase.java +++ b/src/SignalR/clients/java/signalr/core/src/main/java/com/microsoft/signalr/ActionBase.java @@ -8,7 +8,3 @@ interface ActionBase { // available on Android API Level 24 and above. void invoke(Object ... params); } - -interface FunctionBase { - Object invoke(Object ... params); -} \ No newline at end of file diff --git a/src/SignalR/clients/java/signalr/core/src/main/java/com/microsoft/signalr/CallbackMap.java b/src/SignalR/clients/java/signalr/core/src/main/java/com/microsoft/signalr/CallbackMap.java index 6b6fd69c467c..2a7013cc5dfb 100644 --- a/src/SignalR/clients/java/signalr/core/src/main/java/com/microsoft/signalr/CallbackMap.java +++ b/src/SignalR/clients/java/signalr/core/src/main/java/com/microsoft/signalr/CallbackMap.java @@ -14,7 +14,7 @@ class CallbackMap { private final Map> handlers = new HashMap<>(); private final ReentrantLock lock = new ReentrantLock(); - public InvocationHandler put(String target, FunctionBase action, Type... types) { + public InvocationHandler put(String target, ActionBase action, Type... types) { try { lock.lock(); InvocationHandler handler = new InvocationHandler(action, types); diff --git a/src/SignalR/clients/java/signalr/core/src/main/java/com/microsoft/signalr/ClientResultMessage.java b/src/SignalR/clients/java/signalr/core/src/main/java/com/microsoft/signalr/ClientResultMessage.java deleted file mode 100644 index 2046fb54863d..000000000000 --- a/src/SignalR/clients/java/signalr/core/src/main/java/com/microsoft/signalr/ClientResultMessage.java +++ /dev/null @@ -1,47 +0,0 @@ -// Licensed to the .NET Foundation under one or more agreements. -// The .NET Foundation licenses this file to you under the MIT license. - -package com.microsoft.signalr; - -import java.util.Map; - -public final class ClientResultMessage extends HubMessage { - private final int type = HubMessageType.CLIENT_RESULT.value; - private Map headers; - private final String invocationId; - private final Object result; - private final String error; - - public ClientResultMessage(Map headers, String invocationId, Object result, String error) { - if (headers != null && !headers.isEmpty()) { - this.headers = headers; - } - if (error != null && result != null) { - throw new IllegalArgumentException("Expected either 'error' or 'result' to be provided, but not both."); - } - this.invocationId = invocationId; - this.result = result; - this.error = error; - } - - public Map getHeaders() { - return headers; - } - - public Object getResult() { - return result; - } - - public String getError() { - return error; - } - - public String getInvocationId() { - return invocationId; - } - - @Override - public HubMessageType getMessageType() { - return HubMessageType.values()[type - 1]; - } -} diff --git a/src/SignalR/clients/java/signalr/core/src/main/java/com/microsoft/signalr/Function.java b/src/SignalR/clients/java/signalr/core/src/main/java/com/microsoft/signalr/Function.java deleted file mode 100644 index dd56f1e0f4c1..000000000000 --- a/src/SignalR/clients/java/signalr/core/src/main/java/com/microsoft/signalr/Function.java +++ /dev/null @@ -1,5 +0,0 @@ -package com.microsoft.signalr; - -public interface Function { - Object invoke(); -} diff --git a/src/SignalR/clients/java/signalr/core/src/main/java/com/microsoft/signalr/HubConnection.java b/src/SignalR/clients/java/signalr/core/src/main/java/com/microsoft/signalr/HubConnection.java index 4d31b0a0a496..1e2ce1bc5a5d 100644 --- a/src/SignalR/clients/java/signalr/core/src/main/java/com/microsoft/signalr/HubConnection.java +++ b/src/SignalR/clients/java/signalr/core/src/main/java/com/microsoft/signalr/HubConnection.java @@ -470,21 +470,14 @@ private void ReceiveLoop(ByteBuffer payload) InvocationBindingFailureMessage msg = (InvocationBindingFailureMessage)message; logger.error("Failed to bind arguments received in invocation '{}' of '{}'.", msg.getInvocationId(), msg.getTarget(), msg.getException()); break; - case STREAM_BINDING_FAILURE: - StreamBindingFailureMessage streamBindingFailure = (StreamBindingFailureMessage)message; - logger.error("Failed to bind arguments received in invocation '{}'.", streamBindingFailure.getInvocationId(), streamBindingFailure.getException()); - break; case INVOCATION: + InvocationMessage invocationMessage = (InvocationMessage) message; List handlers = this.handlers.get(invocationMessage.getTarget()); if (handlers != null) { for (InvocationHandler handler : handlers) { try { - Object result = handler.getAction().invoke(invocationMessage.getArguments()); - logger.error("{}", result); - if (result != null) { - this.sendHubMessageWithLock(new ClientResultMessage(null, invocationMessage.getInvocationId(), new CompletionMessage(null, invocationMessage.getInvocationId(), result, null), null)); - } + handler.getAction().invoke(invocationMessage.getArguments()); } catch (Exception e) { logger.error("Invoking client side method '{}' failed:", invocationMessage.getTarget(), e); } @@ -520,7 +513,6 @@ private void ReceiveLoop(ByteBuffer payload) streamInvocationRequest.addItem(streamItem); break; - case CLIENT_RESULT: case STREAM_INVOCATION: case CANCEL_INVOCATION: logger.error("This client does not support {} messages.", message.getMessageType()); @@ -876,17 +868,7 @@ public void onClosed(OnClosedCallback callback) { * @return A {@link Subscription} that can be disposed to unsubscribe from the hub method. */ public Subscription on(String target, Action callback) { - FunctionBase action = args -> { - callback.invoke(); - return null; - }; - return registerHandler(target, action); - } - - public Subscription on(String target, Function callback) { - FunctionBase action = args -> { - return callback.invoke(); - }; + ActionBase action = args -> callback.invoke(); return registerHandler(target, action); } @@ -901,11 +883,9 @@ public Subscription on(String target, Function callback) { * @return A {@link Subscription} that can be disposed to unsubscribe from the hub method. */ public Subscription on(String target, Action1 callback, Class param1) { - FunctionBase action = params -> { - callback.invoke(Utils.cast(param1, params[0])); - return null; - }; + ActionBase action = params -> callback.invoke(Utils.cast(param1, params[0])); return registerHandler(target, action, param1); + } /** @@ -921,9 +901,8 @@ public Subscription on(String target, Action1 callback, Class param * @return A {@link Subscription} that can be disposed to unsubscribe from the hub method. */ public Subscription on(String target, Action2 callback, Class param1, Class param2) { - FunctionBase action = params -> { + ActionBase action = params -> { callback.invoke(Utils.cast(param1, params[0]), Utils.cast(param2, params[1])); - return null; }; return registerHandler(target, action, param1, param2); } @@ -944,9 +923,8 @@ public Subscription on(String target, Action2 callback, Class Subscription on(String target, Action3 callback, Class param1, Class param2, Class param3) { - FunctionBase action = params -> { + ActionBase action = params -> { callback.invoke(Utils.cast(param1, params[0]), Utils.cast(param2, params[1]), Utils.cast(param3, params[2])); - return null; }; return registerHandler(target, action, param1, param2, param3); } @@ -969,10 +947,9 @@ public Subscription on(String target, Action3 callback, */ public Subscription on(String target, Action4 callback, Class param1, Class param2, Class param3, Class param4) { - FunctionBase action = params -> { + ActionBase action = params -> { callback.invoke(Utils.cast(param1, params[0]), Utils.cast(param2, params[1]), Utils.cast(param3, params[2]), Utils.cast(param4, params[3])); - return null; }; return registerHandler(target, action, param1, param2, param3, param4); } @@ -997,10 +974,9 @@ public Subscription on(String target, Action4 c */ public Subscription on(String target, Action5 callback, Class param1, Class param2, Class param3, Class param4, Class param5) { - FunctionBase action = params -> { + ActionBase action = params -> { callback.invoke(Utils.cast(param1, params[0]), Utils.cast(param2, params[1]), Utils.cast(param3, params[2]), Utils.cast(param4, params[3]), Utils.cast(param5, params[4])); - return null; }; return registerHandler(target, action, param1, param2, param3, param4, param5); } @@ -1027,10 +1003,9 @@ public Subscription on(String target, Action5 Subscription on(String target, Action6 callback, Class param1, Class param2, Class param3, Class param4, Class param5, Class param6) { - FunctionBase action = params -> { + ActionBase action = params -> { callback.invoke(Utils.cast(param1, params[0]), Utils.cast(param2, params[1]), Utils.cast(param3, params[2]), Utils.cast(param4, params[3]), Utils.cast(param5, params[4]), Utils.cast(param6, params[5])); - return null; }; return registerHandler(target, action, param1, param2, param3, param4, param5, param6); } @@ -1059,10 +1034,9 @@ public Subscription on(String target, Action6 Subscription on(String target, Action7 callback, Class param1, Class param2, Class param3, Class param4, Class param5, Class param6, Class param7) { - FunctionBase action = params -> { + ActionBase action = params -> { callback.invoke(Utils.cast(param1, params[0]), Utils.cast(param2, params[1]), Utils.cast(param3, params[2]), Utils.cast(param4, params[3]), Utils.cast(param5, params[4]), Utils.cast(param6, params[5]), Utils.cast(param7, params[6])); - return null; }; return registerHandler(target, action, param1, param2, param3, param4, param5, param6, param7); } @@ -1093,11 +1067,10 @@ public Subscription on(String target, Action7 Subscription on(String target, Action8 callback, Class param1, Class param2, Class param3, Class param4, Class param5, Class param6, Class param7, Class param8) { - FunctionBase action = params -> { + ActionBase action = params -> { callback.invoke(Utils.cast(param1, params[0]), Utils.cast(param2, params[1]), Utils.cast(param3, params[2]), Utils.cast(param4, params[3]), Utils.cast(param5, params[4]), Utils.cast(param6, params[5]), Utils.cast(param7, params[6]), Utils.cast(param8, params[7])); - return null; }; return registerHandler(target, action, param1, param2, param3, param4, param5, param6, param7, param8); } @@ -1114,9 +1087,8 @@ public Subscription on(String target, Action8 Subscription on(String target, Action1 callback, Type param1) { - FunctionBase action = params -> { + ActionBase action = params -> { callback.invoke(Utils.cast(param1, params[0])); - return null; }; return registerHandler(target, action, param1); } @@ -1135,9 +1107,8 @@ public Subscription on(String target, Action1 callback, Type param1) { * @return A {@link Subscription} that can be disposed to unsubscribe from the hub method. */ public Subscription on(String target, Action2 callback, Type param1, Type param2) { - FunctionBase action = params -> { + ActionBase action = params -> { callback.invoke(Utils.cast(param1, params[0]), Utils.cast(param2, params[1])); - return null; }; return registerHandler(target, action, param1, param2); } @@ -1159,9 +1130,8 @@ public Subscription on(String target, Action2 callback, Type pa */ public Subscription on(String target, Action3 callback, Type param1, Type param2, Type param3) { - FunctionBase action = params -> { + ActionBase action = params -> { callback.invoke(Utils.cast(param1, params[0]), Utils.cast(param2, params[1]), Utils.cast(param3, params[2])); - return null; }; return registerHandler(target, action, param1, param2, param3); } @@ -1185,10 +1155,9 @@ public Subscription on(String target, Action3 callback, */ public Subscription on(String target, Action4 callback, Type param1, Type param2, Type param3, Type param4) { - FunctionBase action = params -> { + ActionBase action = params -> { callback.invoke(Utils.cast(param1, params[0]), Utils.cast(param2, params[1]), Utils.cast(param3, params[2]), Utils.cast(param4, params[3])); - return null; }; return registerHandler(target, action, param1, param2, param3, param4); } @@ -1214,10 +1183,9 @@ public Subscription on(String target, Action4 c */ public Subscription on(String target, Action5 callback, Type param1, Type param2, Type param3, Type param4, Type param5) { - FunctionBase action = params -> { + ActionBase action = params -> { callback.invoke(Utils.cast(param1, params[0]), Utils.cast(param2, params[1]), Utils.cast(param3, params[2]), Utils.cast(param4, params[3]), Utils.cast(param5, params[4])); - return null; }; return registerHandler(target, action, param1, param2, param3, param4, param5); } @@ -1245,10 +1213,9 @@ public Subscription on(String target, Action5 Subscription on(String target, Action6 callback, Type param1, Type param2, Type param3, Type param4, Type param5, Type param6) { - FunctionBase action = params -> { + ActionBase action = params -> { callback.invoke(Utils.cast(param1, params[0]), Utils.cast(param2, params[1]), Utils.cast(param3, params[2]), Utils.cast(param4, params[3]), Utils.cast(param5, params[4]), Utils.cast(param6, params[5])); - return null; }; return registerHandler(target, action, param1, param2, param3, param4, param5, param6); } @@ -1278,10 +1245,9 @@ public Subscription on(String target, Action6 Subscription on(String target, Action7 callback, Type param1, Type param2, Type param3, Type param4, Type param5, Type param6, Type param7) { - FunctionBase action = params -> { + ActionBase action = params -> { callback.invoke(Utils.cast(param1, params[0]), Utils.cast(param2, params[1]), Utils.cast(param3, params[2]), Utils.cast(param4, params[3]), Utils.cast(param5, params[4]), Utils.cast(param6, params[5]), Utils.cast(param7, params[6])); - return null; }; return registerHandler(target, action, param1, param2, param3, param4, param5, param6, param7); } @@ -1314,16 +1280,15 @@ public Subscription on(String target, Action7 Subscription on(String target, Action8 callback, Type param1, Type param2, Type param3, Type param4, Type param5, Type param6, Type param7, Type param8) { - FunctionBase action = params -> { + ActionBase action = params -> { callback.invoke(Utils.cast(param1, params[0]), Utils.cast(param2, params[1]), Utils.cast(param3, params[2]), Utils.cast(param4, params[3]), Utils.cast(param5, params[4]), Utils.cast(param6, params[5]), Utils.cast(param7, params[6]), Utils.cast(param8, params[7])); - return null; }; return registerHandler(target, action, param1, param2, param3, param4, param5, param6, param7, param8); } - private Subscription registerHandler(String target, FunctionBase action, Type... types) { + private Subscription registerHandler(String target, ActionBase action, Type... types) { InvocationHandler handler = handlers.put(target, action, types); logger.debug("Registering handler for client method: '{}'.", target); return new Subscription(handlers, handler, target); diff --git a/src/SignalR/clients/java/signalr/core/src/main/java/com/microsoft/signalr/HubMessageType.java b/src/SignalR/clients/java/signalr/core/src/main/java/com/microsoft/signalr/HubMessageType.java index d191f9f7917c..23201c0c0d8a 100644 --- a/src/SignalR/clients/java/signalr/core/src/main/java/com/microsoft/signalr/HubMessageType.java +++ b/src/SignalR/clients/java/signalr/core/src/main/java/com/microsoft/signalr/HubMessageType.java @@ -11,7 +11,6 @@ public enum HubMessageType { CANCEL_INVOCATION(5), PING(6), CLOSE(7), - CLIENT_RESULT(8), INVOCATION_BINDING_FAILURE(-1), STREAM_BINDING_FAILURE(-2); diff --git a/src/SignalR/clients/java/signalr/core/src/main/java/com/microsoft/signalr/InvocationHandler.java b/src/SignalR/clients/java/signalr/core/src/main/java/com/microsoft/signalr/InvocationHandler.java index a4632c85cced..98fa53aaf365 100644 --- a/src/SignalR/clients/java/signalr/core/src/main/java/com/microsoft/signalr/InvocationHandler.java +++ b/src/SignalR/clients/java/signalr/core/src/main/java/com/microsoft/signalr/InvocationHandler.java @@ -9,26 +9,18 @@ class InvocationHandler { private final List types; - private final Object action; - private final Boolean hasResult; - - InvocationHandler(FunctionBase action, Type... types) { - this.action = action; - this.types = Arrays.asList(types); - this.hasResult = false; - } + private final ActionBase action; InvocationHandler(ActionBase action, Type... types) { this.action = action; this.types = Arrays.asList(types); - this.hasResult = true; } public List getTypes() { return types; } - public FunctionBase getAction() { - return (FunctionBase)action; + public ActionBase getAction() { + return action; } } \ No newline at end of file diff --git a/src/SignalR/clients/java/signalr/test/src/main/java/com/microsoft/signalr/sample/Chat.java b/src/SignalR/clients/java/signalr/test/src/main/java/com/microsoft/signalr/sample/Chat.java index cdb411427eb8..f4c0850c0fb9 100644 --- a/src/SignalR/clients/java/signalr/test/src/main/java/com/microsoft/signalr/sample/Chat.java +++ b/src/SignalR/clients/java/signalr/test/src/main/java/com/microsoft/signalr/sample/Chat.java @@ -15,10 +15,6 @@ public static void main(final String[] args) throws Exception { final String input = reader.nextLine(); try (HubConnection hubConnection = HubConnectionBuilder.create(input).build()) { - hubConnection.on("F", () -> { - return 2; - }); - hubConnection.on("Send", (message) -> { System.out.println(message); }, String.class); diff --git a/src/SignalR/common/Protocols.Json/src/Protocol/JsonHubProtocol.cs b/src/SignalR/common/Protocols.Json/src/Protocol/JsonHubProtocol.cs index 33e9a74fcfb9..087d912c914d 100644 --- a/src/SignalR/common/Protocols.Json/src/Protocol/JsonHubProtocol.cs +++ b/src/SignalR/common/Protocols.Json/src/Protocol/JsonHubProtocol.cs @@ -739,12 +739,9 @@ private static HubMessage BindInvocationMessage(string? invocationId, string tar reader.Skip(); var end = reader.BytesConsumed; var sequence = input.Slice(start, end - start); - // Technically we could pass the sequence without copying into a new array + // Review: Technically we could pass the sequence without copying into a new array // but in the future we could break this if we dispatched the CompletionMessage and the underlying Pipe read would be advanced - var arr = new byte[sequence.Length]; - sequence.CopyTo(arr); - // REVIEW: We can make this type do the copying which would allow us to rent from the ArrayPool - return new RawResult(new ReadOnlySequence(arr)); + return new RawResult(sequence); } return BindType(ref reader, type); } diff --git a/src/SignalR/common/Shared/ClientResultsManager.cs b/src/SignalR/common/Shared/ClientResultsManager.cs index 352acd308124..4ab975267ed9 100644 --- a/src/SignalR/common/Shared/ClientResultsManager.cs +++ b/src/SignalR/common/Shared/ClientResultsManager.cs @@ -9,6 +9,8 @@ namespace Microsoft.AspNetCore.SignalR.Internal; +// Common type used by our HubLifetimeManager implementations to manage client results. +// Handles cancellation, cleanup, and completion, so any bugs or improvements can be made in a single place internal class ClientResultsManager : IInvocationBinder { private readonly ConcurrentDictionary Completion)> _pendingInvocations = new(); diff --git a/src/SignalR/common/SignalR.Common/src/Protocol/RawResult.cs b/src/SignalR/common/SignalR.Common/src/Protocol/RawResult.cs index 6a56ba0c5c98..54acfa3909d4 100644 --- a/src/SignalR/common/SignalR.Common/src/Protocol/RawResult.cs +++ b/src/SignalR/common/SignalR.Common/src/Protocol/RawResult.cs @@ -21,11 +21,13 @@ public class RawResult { /// /// Stores the raw serialized bytes of a for forwarding to another server. + /// Will copy the passed in bytes to internal storage. /// /// The raw bytes from the client. public RawResult(ReadOnlySequence rawBytes) { - RawSerializedData = rawBytes; + // Review: If we want to use an ArrayPool we would need some sort of release mechanism + RawSerializedData = new ReadOnlySequence(rawBytes.ToArray()); } /// diff --git a/src/SignalR/samples/ClientSample/HubSample.cs b/src/SignalR/samples/ClientSample/HubSample.cs index 5b69699c2f64..f77b9c3a9a8d 100644 --- a/src/SignalR/samples/ClientSample/HubSample.cs +++ b/src/SignalR/samples/ClientSample/HubSample.cs @@ -37,10 +37,10 @@ public static async Task ExecuteAsync(string baseUrl) logging.AddConsole(); }); - //connectionBuilder.Services.Configure(options => - //{ - // options.MinLevel = LogLevel.Trace; - //}); + connectionBuilder.Services.Configure(options => + { + options.MinLevel = LogLevel.Trace; + }); if (uri.Scheme == "net.tcp") { @@ -66,51 +66,7 @@ public static async Task ExecuteAsync(string baseUrl) }; // Set up handler - connection.On("GetNumber", () => - { - Console.WriteLine("Provide an integer:"); - return Task.FromResult(int.Parse(Console.ReadLine(), System.Globalization.NumberFormatInfo.InvariantInfo)); - }); - - connection.On("g", (string s, int r) => - { - return Task.FromResult(1); - }); - - connection.On("g", () => - { - return 1; - }); - - connection.On("g", (string s) => - { - return 1; - }); - - connection.On("g", async (string s) => - { - await Task.CompletedTask; - return 1; - }); - - connection.On("g", async () => - { - await Task.CompletedTask; - return 1; - }); - - connection.On("g", async (string s, int r) => - { - await Task.CompletedTask; - return 1; - }); - - connection.On("g", (string s, int r) => - { - return Task.FromResult(1); - }); - - connection.On("Result", r => Console.WriteLine($"Result: {r}")); + connection.On("Send", Console.WriteLine); connection.Closed += e => { @@ -125,49 +81,39 @@ public static async Task ExecuteAsync(string baseUrl) return 0; } - await connection.SendAsync("AddPlayer"); - Console.WriteLine("Connected to {0}", uri); - Console.WriteLine(connection.ConnectionId); - - var wait = new TaskCompletionSource(); - closedTokenSource.Token.Register(() => - { - wait.SetResult(null); - }); - await wait.Task; // Handle the connected connection - //while (true) - //{ - // // If the underlying connection closes while waiting for user input, the user will not observe - // // the connection close aside from "Connection closed..." being printed to the console. That's - // // because cancelling Console.ReadLine() is a royal pain. - // var line = Console.ReadLine(); - - // if (line == null || closedTokenSource.Token.IsCancellationRequested) - // { - // Console.WriteLine("Exiting..."); - // break; - // } - - // try - // { - // await connection.InvokeAsync("Send", line); - // } - // catch when (closedTokenSource.IsCancellationRequested) - // { - // // We're shutting down the client - // Console.WriteLine("Failed to send '{0}' because the CancelKeyPress event fired first. Exiting...", line); - // break; - // } - // catch (Exception ex) - // { - // // Send could have failed because the connection closed - // // Continue to loop because we should be reconnecting. - // Console.WriteLine(ex); - // } - //} + while (true) + { + // If the underlying connection closes while waiting for user input, the user will not observe + // the connection close aside from "Connection closed..." being printed to the console. That's + // because cancelling Console.ReadLine() is a royal pain. + var line = Console.ReadLine(); + + if (line == null || closedTokenSource.Token.IsCancellationRequested) + { + Console.WriteLine("Exiting..."); + break; + } + + try + { + await connection.InvokeAsync("Send", line); + } + catch when (closedTokenSource.IsCancellationRequested) + { + // We're shutting down the client + Console.WriteLine("Failed to send '{0}' because the CancelKeyPress event fired first. Exiting...", line); + break; + } + catch (Exception ex) + { + // Send could have failed because the connection closed + // Continue to loop because we should be reconnecting. + Console.WriteLine(ex); + } + } } finally { diff --git a/src/SignalR/samples/SignalRSamples/Hubs/Chat.cs b/src/SignalR/samples/SignalRSamples/Hubs/Chat.cs index d1b7aa157075..47484e2cf3ef 100644 --- a/src/SignalR/samples/SignalRSamples/Hubs/Chat.cs +++ b/src/SignalR/samples/SignalRSamples/Hubs/Chat.cs @@ -19,14 +19,9 @@ public override Task OnDisconnectedAsync(Exception exception) return Clients.All.SendAsync("Send", $"{name} left the chat"); } - public async Task Send(string name, string message) + public Task Send(string name, string message) { - var c = Clients.Single(Context.ConnectionId); - _ = Task.Run(async () => - { - var i = await c.InvokeAsync("F"); - }); - await Clients.All.SendAsync("Send", $"{name}: {message}"); + return Clients.All.SendAsync("Send", $"{name}: {message}"); } public Task SendToOthers(string name, string message) diff --git a/src/SignalR/samples/SignalRSamples/Hubs/GameHub.cs b/src/SignalR/samples/SignalRSamples/Hubs/GameHub.cs deleted file mode 100644 index 9d6f100d0b0b..000000000000 --- a/src/SignalR/samples/SignalRSamples/Hubs/GameHub.cs +++ /dev/null @@ -1,24 +0,0 @@ -// Licensed to the .NET Foundation under one or more agreements. -// The .NET Foundation licenses this file to you under the MIT license. - -using Microsoft.AspNetCore.SignalR; - -namespace SignalRSamples.Hubs; - -public class GameHub : Hub -{ - private readonly Game _game; - - public GameHub(Game game) - { - _game = game; - } - - public Task AddPlayer() - { - //_ = await Clients.Caller.InvokeAsync("GetNumber"); - //Clients.Caller.InvokeClientAsync(); - _game.AddPlayer(Context.ConnectionId); - return Task.CompletedTask; - } -} diff --git a/src/SignalR/samples/SignalRSamples/Startup.cs b/src/SignalR/samples/SignalRSamples/Startup.cs index fa48bbce1251..5a3d67e481c3 100644 --- a/src/SignalR/samples/SignalRSamples/Startup.cs +++ b/src/SignalR/samples/SignalRSamples/Startup.cs @@ -3,7 +3,6 @@ using System.Reflection; using System.Text.Json; -using Microsoft.AspNetCore.SignalR; using SignalRSamples.ConnectionHandlers; using SignalRSamples.Hubs; @@ -19,11 +18,9 @@ public void ConfigureServices(IServiceCollection services) { services.AddConnections(); - services.AddSignalR(o => o.MaximumParallelInvocationsPerClient = 2) - .AddMessagePackProtocol() - .AddStackExchangeRedis(); - - services.AddSingleton(); + services.AddSignalR() + .AddMessagePackProtocol(); + //.AddStackExchangeRedis(); } // This method gets called by the runtime. Use this method to configure the HTTP request pipeline. @@ -42,17 +39,11 @@ public void Configure(IApplicationBuilder app, IWebHostEnvironment env) app.UseEndpoints(endpoints => { - endpoints.MapGet("/start", (Game game, string connection1, string connection2) => - { - _ = game.GameLoop(connection1, connection2); - }); - endpoints.MapHub("/dynamic"); endpoints.MapHub("/default"); endpoints.MapHub("/streaming"); endpoints.MapHub("/uploading"); endpoints.MapHub("/hubT"); - endpoints.MapHub("/game"); endpoints.MapConnectionHandler("/chat"); @@ -88,55 +79,3 @@ public void Configure(IApplicationBuilder app, IWebHostEnvironment env) }); } } - -// Nurdle -public class Game -{ - public string Player1Id { get; set; } - public string Player2Id { get; set; } - - private readonly IHubContext _hubContext; - - public Game(IHubContext hubContext) - { - _hubContext = hubContext; - } - - public void AddPlayer(string Id) - { - if (string.IsNullOrEmpty(Player1Id)) - { - Player1Id = Id; - } - else - { - Player2Id = Id; - } - } - - public async Task GameLoop(string connection1, string connection2) - { - var randomAnswer = Random.Shared.Next(2, 10); - var res = 0; - - do - { - await Task.Delay(1000); - var task1 = _hubContext.Clients.Single(connection1).InvokeAsync("GetNumber"); - var task2 = _hubContext.Clients.Single(connection2).InvokeAsync("GetNumber"); - res = (await task1) + (await task2); - - if (res < randomAnswer) - { - await _hubContext.Clients.Clients(connection1, connection2).SendAsync("Result", $"Guessed {res} which is too low"); - } - else if (res > randomAnswer) - { - await _hubContext.Clients.Clients(connection1, connection2).SendAsync("Result", $"Guessed {res} which is too high"); - } - } - while (res != randomAnswer); - - await _hubContext.Clients.Clients(connection1, connection2).SendAsync("Result", $"Guessed {res} which is correct!"); - } -} diff --git a/src/SignalR/samples/SignalRSamples/wwwroot/hubs.html b/src/SignalR/samples/SignalRSamples/wwwroot/hubs.html index efdca8d3d367..a7a18a02450c 100644 --- a/src/SignalR/samples/SignalRSamples/wwwroot/hubs.html +++ b/src/SignalR/samples/SignalRSamples/wwwroot/hubs.html @@ -146,7 +146,7 @@

Group Actions

return; } - let hubRoute = "game"; + let hubRoute = hubTypeDropdown.value || "default"; let protocol = protocolDropdown.value === "msgpack" ? new signalR.protocols.msgpack.MessagePackHubProtocol() : new signalR.JsonHubProtocol(); @@ -174,16 +174,6 @@

Group Actions

addLine('message-list', msg); }); - connection.on('F', function () { - return new Promise((resolve, reject) => { - setTimeout(() => resolve(2), 5000); - }); - }); - - connection.on('GetNumber', function () { - return 2; - }); - connection.onclose(function (e) { if (e) { addLine('message-list', 'Connection closed with error: ' + e, 'red'); diff --git a/src/SignalR/server/Core/src/DefaultHubLifetimeManager.cs b/src/SignalR/server/Core/src/DefaultHubLifetimeManager.cs index 53474856632c..de84fd623080 100644 --- a/src/SignalR/server/Core/src/DefaultHubLifetimeManager.cs +++ b/src/SignalR/server/Core/src/DefaultHubLifetimeManager.cs @@ -357,7 +357,7 @@ public override async Task InvokeConnectionAsync(string connectionId, stri // Write message directly to connection without caching it in memory var message = new InvocationMessage(invocationId, methodName, args); - await connection.WriteAsync(message, cancellationToken).AsTask(); + await connection.WriteAsync(message, cancellationToken); } catch { diff --git a/src/SignalR/server/Core/src/IHubCallerClients.cs b/src/SignalR/server/Core/src/IHubCallerClients.cs index 2fb87c207552..82968013d23f 100644 --- a/src/SignalR/server/Core/src/IHubCallerClients.cs +++ b/src/SignalR/server/Core/src/IHubCallerClients.cs @@ -9,10 +9,9 @@ namespace Microsoft.AspNetCore.SignalR; public interface IHubCallerClients : IHubCallerClients { /// - /// + /// Gets a proxy that can be used to invoke methods on a single client connected to the hub and receive results. /// - /// - /// - /// + /// The connection ID. + /// A client caller. new ISingleClientProxy Single(string connectionId) => throw new NotImplementedException(); } From ae4e42ef3c2b51efe6e31b438238af05b79f2b3b Mon Sep 17 00:00:00 2001 From: Brennan Date: Wed, 23 Mar 2022 16:42:17 -0700 Subject: [PATCH 04/11] some fb --- .../csharp/Client.Core/src/HubConnection.cs | 24 +------ ...soft.AspNetCore.SignalR.Client.Core.csproj | 1 + .../clients/ts/signalr/src/HubConnection.ts | 13 +++- .../src/Protocol/JsonHubProtocol.cs | 9 ++- .../Protocol/MessagePackHubProtocolWorker.cs | 4 +- .../src/Protocol/NewtonsoftJsonHubProtocol.cs | 8 ++- .../common/Shared/ClientResultsManager.cs | 51 +++++++------- .../common/Shared/CreateLinkedToken.cs | 33 +++++++++ .../SignalR.Common/src/Protocol/RawResult.cs | 2 +- .../Core/src/DefaultHubLifetimeManager.cs | 31 +++++---- .../Core/src/Internal/HubConnectionBinder.cs | 2 +- .../Microsoft.AspNetCore.SignalR.Core.csproj | 1 + .../src/HubLifetimeManagerTestBase.cs | 4 +- .../src/ScaleoutHubLifetimeManagerTests.cs | 4 +- ...pNetCore.SignalR.StackExchangeRedis.csproj | 1 + .../src/RedisHubLifetimeManager.cs | 68 +++++++++++++------ 16 files changed, 154 insertions(+), 102 deletions(-) create mode 100644 src/SignalR/common/Shared/CreateLinkedToken.cs diff --git a/src/SignalR/clients/csharp/Client.Core/src/HubConnection.cs b/src/SignalR/clients/csharp/Client.Core/src/HubConnection.cs index 15bbf7648f38..46d009e997d1 100644 --- a/src/SignalR/clients/csharp/Client.Core/src/HubConnection.cs +++ b/src/SignalR/clients/csharp/Client.Core/src/HubConnection.cs @@ -259,7 +259,7 @@ private async Task StartAsyncInner(CancellationToken cancellationToken = default throw new InvalidOperationException($"The {nameof(HubConnection)} cannot be started while {nameof(StopAsync)} is running."); } - using (CreateLinkedToken(cancellationToken, _state.StopCts.Token, out var linkedToken)) + using (CancellationTokenUtils.CreateLinkedToken(cancellationToken, _state.StopCts.Token, out var linkedToken)) { await StartAsyncCore(linkedToken).ConfigureAwait(false); } @@ -1148,7 +1148,7 @@ private async Task HandshakeAsync(ConnectionState startingConnectionState, Cance try { // cancellationToken already contains _state.StopCts.Token, so we don't have to link it again - using (CreateLinkedToken(cancellationToken, handshakeCts.Token, out var linkedToken)) + using (CancellationTokenUtils.CreateLinkedToken(cancellationToken, handshakeCts.Token, out var linkedToken)) { while (true) { @@ -1637,26 +1637,6 @@ async Task RunReconnectedEventAsync() } } - private static IDisposable? CreateLinkedToken(CancellationToken token1, CancellationToken token2, out CancellationToken linkedToken) - { - if (!token1.CanBeCanceled) - { - linkedToken = token2; - return null; - } - else if (!token2.CanBeCanceled) - { - linkedToken = token1; - return null; - } - else - { - var cts = CancellationTokenSource.CreateLinkedTokenSource(token1, token2); - linkedToken = cts.Token; - return cts; - } - } - // Debug.Assert plays havoc with Unit Tests. But I want something that I can "assert" only in Debug builds. [Conditional("DEBUG")] private static void SafeAssert(bool condition, string message, [CallerMemberName] string? memberName = null, [CallerFilePath] string? fileName = null, [CallerLineNumber] int lineNumber = 0) diff --git a/src/SignalR/clients/csharp/Client.Core/src/Microsoft.AspNetCore.SignalR.Client.Core.csproj b/src/SignalR/clients/csharp/Client.Core/src/Microsoft.AspNetCore.SignalR.Client.Core.csproj index 7c0548133757..0fe14fda5c80 100644 --- a/src/SignalR/clients/csharp/Client.Core/src/Microsoft.AspNetCore.SignalR.Client.Core.csproj +++ b/src/SignalR/clients/csharp/Client.Core/src/Microsoft.AspNetCore.SignalR.Client.Core.csproj @@ -13,6 +13,7 @@ + diff --git a/src/SignalR/clients/ts/signalr/src/HubConnection.ts b/src/SignalR/clients/ts/signalr/src/HubConnection.ts index 4dce394014ec..e03d369b62b6 100644 --- a/src/SignalR/clients/ts/signalr/src/HubConnection.ts +++ b/src/SignalR/clients/ts/signalr/src/HubConnection.ts @@ -677,29 +677,35 @@ export class HubConnection { private async _invokeClientMethod(invocationMessage: InvocationMessage) { const methods = this._methods[invocationMessage.target.toLowerCase()]; if (methods) { + // Avoid issues with handlers removing themselves thus modifying the list while iterating through it const methodsCopy = methods.slice(); + // Server expects a response if (invocationMessage.invocationId) { + // We preserve the last result or exception but still call all handlers let res; let exception; for (const m of methodsCopy) { try { if (res) { - this._logger.log(LogLevel.Warning, `Result already provided for '${invocationMessage.target.toLowerCase()}' only last one will be sent.`); + this._logger.log(LogLevel.Warning, `Result already provided for '${invocationMessage.target.toLowerCase()}' only the last one will be sent.`); } res = await m.apply(this, invocationMessage.arguments); + // Ignore exception if we got a result after, the exception will be logged exception = undefined; } catch (e) { exception = e; this._logger.log(LogLevel.Error, `A callback for the method '${invocationMessage.target.toLowerCase()}' threw error '${e}'.`); } } + // If there is an exception that means either no result was given or a handler after a result threw + // And since we prefer handlers registered later we'll use the exception to return to the server. if (exception) { await this._sendWithProtocol(this._createCompletionMessage(invocationMessage.invocationId, `${exception}`, null)); - } - else if (res !== undefined) { + } else if (res !== undefined) { await this._sendWithProtocol(this._createCompletionMessage(invocationMessage.invocationId, null, res)); } else { + // Client didn't provide a result or throw from a handler, server expects a response so we send an error await this._sendWithProtocol(this._createCompletionMessage(invocationMessage.invocationId, "Client didn't provide a result.", null)); } } else { @@ -712,6 +718,7 @@ export class HubConnection { } else { this._logger.log(LogLevel.Warning, `No client method with the name '${invocationMessage.target.toLowerCase()}' found.`); + // No handlers provided by client but the server is expecting a response still, so we send an error if (invocationMessage.invocationId) { await this._sendWithProtocol(this._createCompletionMessage(invocationMessage.invocationId, "Client didn't provide a result.", null)); } diff --git a/src/SignalR/common/Protocols.Json/src/Protocol/JsonHubProtocol.cs b/src/SignalR/common/Protocols.Json/src/Protocol/JsonHubProtocol.cs index 087d912c914d..4fa2f313945d 100644 --- a/src/SignalR/common/Protocols.Json/src/Protocol/JsonHubProtocol.cs +++ b/src/SignalR/common/Protocols.Json/src/Protocol/JsonHubProtocol.cs @@ -535,15 +535,14 @@ private void WriteCompletionMessage(CompletionMessage message, Utf8JsonWriter wr } else { - var resultType = message.Result.GetType(); - if (resultType == typeof(RawResult)) + if (message.Result is RawResult result) { - Debug.Assert(((RawResult)message.Result).RawSerializedData.IsSingleSegment); - writer.WriteRawValue(((RawResult)message.Result).RawSerializedData.First.Span, skipInputValidation: true); + Debug.Assert(result.RawSerializedData.IsSingleSegment); + writer.WriteRawValue(result.RawSerializedData.First.Span, skipInputValidation: true); } else { - JsonSerializer.Serialize(writer, message.Result, resultType, _payloadSerializerOptions); + JsonSerializer.Serialize(writer, message.Result, message.Result.GetType(), _payloadSerializerOptions); } } } diff --git a/src/SignalR/common/Protocols.MessagePack/src/Protocol/MessagePackHubProtocolWorker.cs b/src/SignalR/common/Protocols.MessagePack/src/Protocol/MessagePackHubProtocolWorker.cs index a38cf1851611..e665b67aa629 100644 --- a/src/SignalR/common/Protocols.MessagePack/src/Protocol/MessagePackHubProtocolWorker.cs +++ b/src/SignalR/common/Protocols.MessagePack/src/Protocol/MessagePackHubProtocolWorker.cs @@ -441,9 +441,9 @@ private void WriteArgument(object? argument, ref MessagePackWriter writer) { writer.WriteNil(); } - else if (argument.GetType() == typeof(RawResult)) + else if (argument is RawResult result) { - writer.WriteRaw(((RawResult)argument).RawSerializedData); + writer.WriteRaw(result.RawSerializedData); } else { diff --git a/src/SignalR/common/Protocols.NewtonsoftJson/src/Protocol/NewtonsoftJsonHubProtocol.cs b/src/SignalR/common/Protocols.NewtonsoftJson/src/Protocol/NewtonsoftJsonHubProtocol.cs index c0516dba9cb4..d3b2c6196d07 100644 --- a/src/SignalR/common/Protocols.NewtonsoftJson/src/Protocol/NewtonsoftJsonHubProtocol.cs +++ b/src/SignalR/common/Protocols.NewtonsoftJson/src/Protocol/NewtonsoftJsonHubProtocol.cs @@ -549,9 +549,13 @@ private void WriteCompletionMessage(CompletionMessage message, JsonTextWriter wr else if (message.HasResult) { writer.WritePropertyName(ResultPropertyName); - if (message.Result?.GetType() == typeof(RawResult)) + if (message.Result is RawResult result) { - writer.WriteRawValue(Encoding.UTF8.GetString(((RawResult)message.Result).RawSerializedData.ToArray())); +#if NETCOREAPP2_1_OR_GREATER + writer.WriteRawValue(Encoding.UTF8.GetString(result.RawSerializedData)); +#else + writer.WriteRawValue(Encoding.UTF8.GetString(result.RawSerializedData.ToArray())); +#endif } else { diff --git a/src/SignalR/common/Shared/ClientResultsManager.cs b/src/SignalR/common/Shared/ClientResultsManager.cs index 4ab975267ed9..00d15af93f38 100644 --- a/src/SignalR/common/Shared/ClientResultsManager.cs +++ b/src/SignalR/common/Shared/ClientResultsManager.cs @@ -13,13 +13,14 @@ namespace Microsoft.AspNetCore.SignalR.Internal; // Handles cancellation, cleanup, and completion, so any bugs or improvements can be made in a single place internal class ClientResultsManager : IInvocationBinder { - private readonly ConcurrentDictionary Completion)> _pendingInvocations = new(); + private readonly ConcurrentDictionary Completion)> _pendingInvocations = new(); public Task AddInvocation(string connectionId, string invocationId, CancellationToken cancellationToken) { var tcs = new TaskCompletionSourceWithCancellation(this, connectionId, invocationId, cancellationToken); - _pendingInvocations.TryAdd(invocationId, (typeof(T), connectionId, completionMessage => + _pendingInvocations.TryAdd(invocationId, (typeof(T), connectionId, tcs, static (state, completionMessage) => { + var tcs = (TaskCompletionSourceWithCancellation)state; if (completionMessage.HasResult) { tcs.SetResult((T)completionMessage.Result); @@ -32,10 +33,12 @@ public Task AddInvocation(string connectionId, string invocationId, Cancel } )); + tcs.RegisterCancellation(); + return tcs.Task; } - public void AddInvocation(string invocationId, (Type Type, string ConnectionId, Func Completion) invocationInfo) + public void AddInvocation(string invocationId, (Type Type, string ConnectionId, object Tcs, Func Completion) invocationInfo) { _pendingInvocations.TryAdd(invocationId, invocationInfo); } @@ -54,7 +57,7 @@ public Task TryCompleteResult(string connectionId, CompletionMessage message) // we'll ignore both cases if (_pendingInvocations.Remove(message.InvocationId!, out _)) { - return item.Completion(message); + return item.Completion(item.Tcs, message); } } else @@ -64,32 +67,12 @@ public Task TryCompleteResult(string connectionId, CompletionMessage message) return Task.CompletedTask; } - public (Type Type, string ConnectionId, Func Completion) RemoveInvocation(string invocationId) + public (Type Type, string ConnectionId, object Tcs, Func Completion)? RemoveInvocation(string invocationId) { _pendingInvocations.Remove(invocationId, out var item); return item; } - public void CleanupConnection(string connectionId, List? pendingTasks) - { - var invocationIds = _pendingInvocations.Where(x => x.Value.ConnectionId == connectionId).Select(x => x.Key); - foreach (var id in invocationIds) - { - if (_pendingInvocations.Remove(id, out var item)) - { - var task = item.Completion(CompletionMessage.WithError(id, "Connection disconnected")); - if (!task.IsCompletedSuccessfully) - { - if (pendingTasks is null) - { - pendingTasks = new List(); - } - pendingTasks.Add(task); - } - } - } - } - public bool TryGetType(string invocationId, [NotNullWhen(true)] out Type? type) { if (_pendingInvocations.TryGetValue(invocationId, out var item)) @@ -110,22 +93,28 @@ public Type GetReturnType(string invocationId) throw new InvalidOperationException(); } + // Unused, here to honor the IInvocationBinder interface but should never be called public IReadOnlyList GetParameterTypes(string methodName) { throw new NotImplementedException(); } + // Unused, here to honor the IInvocationBinder interface but should never be called public Type GetStreamItemType(string streamId) { throw new NotImplementedException(); } + // Custom TCS type to avoid the extra allocation that would be introduced if we managed the cancellation separately + // Also makes it easier to keep track of the CancellationTokenRegistration for disposal private sealed class TaskCompletionSourceWithCancellation : TaskCompletionSource { private readonly ClientResultsManager _clientResultsManager; private readonly string _connectionId; private readonly string _invocationId; - private readonly CancellationTokenRegistration _tokenRegistration; + private readonly CancellationToken _token; + + private CancellationTokenRegistration _tokenRegistration; public TaskCompletionSourceWithCancellation(ClientResultsManager clientResultsManager, string connectionId, string invocationId, CancellationToken cancellationToken) @@ -134,10 +123,16 @@ public TaskCompletionSourceWithCancellation(ClientResultsManager clientResultsMa _clientResultsManager = clientResultsManager; _connectionId = connectionId; _invocationId = invocationId; + _token = cancellationToken; + } - if (cancellationToken.CanBeCanceled) + // Needs to be called after adding the completion to the dictionary in order to avoid synchronous completions of the token registration + // not canceling when the dictionary hasn't been updated yet. + public void RegisterCancellation() + { + if (_token.CanBeCanceled) { - _tokenRegistration = cancellationToken.UnsafeRegister(static o => + _tokenRegistration = _token.UnsafeRegister(static o => { var tcs = (TaskCompletionSourceWithCancellation)o!; tcs.SetCanceled(); diff --git a/src/SignalR/common/Shared/CreateLinkedToken.cs b/src/SignalR/common/Shared/CreateLinkedToken.cs new file mode 100644 index 000000000000..a5e4170f975c --- /dev/null +++ b/src/SignalR/common/Shared/CreateLinkedToken.cs @@ -0,0 +1,33 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. + +using System; +using System.Threading; + +namespace Microsoft.AspNetCore.SignalR.Internal; + +internal static class CancellationTokenUtils +{ + // Similar to CreateLinkedTokenSource except it will not allocate a new internal LinkedCancellationTokenSource in the case where + // one of the tokens passed in isn't cancellable. + // Returns a disposable only when an actual LinkkedTokenSource is created. + internal static IDisposable? CreateLinkedToken(CancellationToken token1, CancellationToken token2, out CancellationToken linkedToken) + { + if (!token1.CanBeCanceled) + { + linkedToken = token2; + return null; + } + else if (!token2.CanBeCanceled) + { + linkedToken = token1; + return null; + } + else + { + var cts = CancellationTokenSource.CreateLinkedTokenSource(token1, token2); + linkedToken = cts.Token; + return cts; + } + } +} diff --git a/src/SignalR/common/SignalR.Common/src/Protocol/RawResult.cs b/src/SignalR/common/SignalR.Common/src/Protocol/RawResult.cs index 54acfa3909d4..7431df26cef7 100644 --- a/src/SignalR/common/SignalR.Common/src/Protocol/RawResult.cs +++ b/src/SignalR/common/SignalR.Common/src/Protocol/RawResult.cs @@ -17,7 +17,7 @@ namespace Microsoft.AspNetCore.SignalR.Protocol; /// /// In Json that would mean storing the byte representation of ascii {"prop":10} as an example. /// -public class RawResult +public sealed class RawResult { /// /// Stores the raw serialized bytes of a for forwarding to another server. diff --git a/src/SignalR/server/Core/src/DefaultHubLifetimeManager.cs b/src/SignalR/server/Core/src/DefaultHubLifetimeManager.cs index de84fd623080..d2990fad4b11 100644 --- a/src/SignalR/server/Core/src/DefaultHubLifetimeManager.cs +++ b/src/SignalR/server/Core/src/DefaultHubLifetimeManager.cs @@ -2,7 +2,6 @@ // The .NET Foundation licenses this file to you under the MIT license. using System.Collections.Concurrent; -using System.Diagnostics; using System.Diagnostics.CodeAnalysis; using System.Globalization; using System.Linq; @@ -300,11 +299,6 @@ public override Task OnDisconnectedAsync(HubConnectionContext connection) _connections.Remove(connection); _groups.RemoveDisconnectedConnection(connection.ConnectionId); - List? pendingTasks = null; - _clientResultsManager.CleanupConnection(connection.ConnectionId, pendingTasks); - // Completions should be synchronous for DefaultHubLifetimeManager - Debug.Assert(pendingTasks is null); - return Task.CompletedTask; } @@ -342,14 +336,9 @@ public override async Task InvokeConnectionAsync(string connectionId, stri } var invocationId = Interlocked.Increment(ref _lastInvocationId).ToString(NumberFormatInfo.InvariantInfo); - var task = _clientResultsManager.AddInvocation(connectionId, invocationId, cancellationToken); - // Connection disconnected while adding invocation - // we need to try to remove it here to avoid the task hanging if the add happened after the connection cleanup - if (connection.ConnectionAborted.IsCancellationRequested) - { - await _clientResultsManager.TryCompleteResult(connectionId, CompletionMessage.WithError(invocationId, "Connection disconnected")); - return await task; - } + using var _ = CancellationTokenUtils.CreateLinkedToken(cancellationToken, + connection.ConnectionAborted, out var linkedToken); + var task = _clientResultsManager.AddInvocation(connectionId, invocationId, linkedToken); try { @@ -365,7 +354,19 @@ public override async Task InvokeConnectionAsync(string connectionId, stri throw; } - return await task; + try + { + return await task; + } + catch + { + // ConnectionAborted will trigger a generic "Canceled" exception from the task, let's convert it into a more specific message. + if (connection.ConnectionAborted.IsCancellationRequested) + { + throw new Exception("Connection disconnected."); + } + throw; + } } /// diff --git a/src/SignalR/server/Core/src/Internal/HubConnectionBinder.cs b/src/SignalR/server/Core/src/Internal/HubConnectionBinder.cs index dfde6a11d070..086624a81d23 100644 --- a/src/SignalR/server/Core/src/Internal/HubConnectionBinder.cs +++ b/src/SignalR/server/Core/src/Internal/HubConnectionBinder.cs @@ -27,7 +27,7 @@ public Type GetReturnType(string invocationId) { return type; } - throw new InvalidOperationException("Unknown invocation ID."); + throw new InvalidOperationException($"Unknown invocation ID '{invocationId}'."); } public Type GetStreamItemType(string streamId) diff --git a/src/SignalR/server/Core/src/Microsoft.AspNetCore.SignalR.Core.csproj b/src/SignalR/server/Core/src/Microsoft.AspNetCore.SignalR.Core.csproj index cf269845f355..fd607f3d832b 100644 --- a/src/SignalR/server/Core/src/Microsoft.AspNetCore.SignalR.Core.csproj +++ b/src/SignalR/server/Core/src/Microsoft.AspNetCore.SignalR.Core.csproj @@ -18,6 +18,7 @@ + diff --git a/src/SignalR/server/Specification.Tests/src/HubLifetimeManagerTestBase.cs b/src/SignalR/server/Specification.Tests/src/HubLifetimeManagerTestBase.cs index b66595b6aa60..331015d90268 100644 --- a/src/SignalR/server/Specification.Tests/src/HubLifetimeManagerTestBase.cs +++ b/src/SignalR/server/Specification.Tests/src/HubLifetimeManagerTestBase.cs @@ -331,9 +331,11 @@ public async Task ClientDisconnectsWithoutCompletingClientResult() var invoke1 = manager1.InvokeConnectionAsync(connection1.ConnectionId, "Result", new object[] { "test" }); + connection1.Abort(); await manager1.OnDisconnectedAsync(connection1).DefaultTimeout(); - await Assert.ThrowsAsync(() => invoke1).DefaultTimeout(); + var ex = await Assert.ThrowsAsync(() => invoke1).DefaultTimeout(); + Assert.Equal("Connection disconnected.", ex.Message); } } diff --git a/src/SignalR/server/Specification.Tests/src/ScaleoutHubLifetimeManagerTests.cs b/src/SignalR/server/Specification.Tests/src/ScaleoutHubLifetimeManagerTests.cs index d7dae9b73ad1..10bfaae1b307 100644 --- a/src/SignalR/server/Specification.Tests/src/ScaleoutHubLifetimeManagerTests.cs +++ b/src/SignalR/server/Specification.Tests/src/ScaleoutHubLifetimeManagerTests.cs @@ -568,10 +568,12 @@ public async Task ClientDisconnectsWithoutCompletingClientResultOnSecondServer() var invoke1 = manager1.InvokeConnectionAsync(connection1.ConnectionId, "Result", new object[] { "test" }); var invocation = Assert.IsType(await client1.ReadAsync().DefaultTimeout()); + connection1.Abort(); await manager2.OnDisconnectedAsync(connection1).DefaultTimeout(); // Server should propogate connection closure so task isn't blocked - await Assert.ThrowsAsync(() => invoke1).DefaultTimeout(); + var ex = await Assert.ThrowsAsync(() => invoke1).DefaultTimeout(); + Assert.Equal("Connection disconnected.", ex.Message); } } } diff --git a/src/SignalR/server/StackExchangeRedis/src/Microsoft.AspNetCore.SignalR.StackExchangeRedis.csproj b/src/SignalR/server/StackExchangeRedis/src/Microsoft.AspNetCore.SignalR.StackExchangeRedis.csproj index a9c070a090fc..e3ed270e72ef 100644 --- a/src/SignalR/server/StackExchangeRedis/src/Microsoft.AspNetCore.SignalR.StackExchangeRedis.csproj +++ b/src/SignalR/server/StackExchangeRedis/src/Microsoft.AspNetCore.SignalR.StackExchangeRedis.csproj @@ -10,6 +10,7 @@ + diff --git a/src/SignalR/server/StackExchangeRedis/src/RedisHubLifetimeManager.cs b/src/SignalR/server/StackExchangeRedis/src/RedisHubLifetimeManager.cs index b46fb6e73f19..5e494375d924 100644 --- a/src/SignalR/server/StackExchangeRedis/src/RedisHubLifetimeManager.cs +++ b/src/SignalR/server/StackExchangeRedis/src/RedisHubLifetimeManager.cs @@ -2,6 +2,7 @@ // The .NET Foundation licenses this file to you under the MIT license. using System.Buffers; +using System.Diagnostics; using System.Diagnostics.CodeAnalysis; using System.Globalization; using System.Linq; @@ -138,8 +139,6 @@ public override Task OnDisconnectedAsync(HubConnectionContext connection) tasks.Add(RemoveUserAsync(connection)); } - _clientResultsManager.CleanupConnection(connection.ConnectionId, tasks); - return Task.WhenAll(tasks); } @@ -409,7 +408,9 @@ public override async Task InvokeConnectionAsync(string connectionId, stri var connection = _connections[connectionId]; var invocationId = Interlocked.Increment(ref _lastInvocationId).ToString(NumberFormatInfo.InvariantInfo); - var task = _clientResultsManager.AddInvocation(connectionId, invocationId, cancellationToken); + using var _ = CancellationTokenUtils.CreateLinkedToken(cancellationToken, + connection?.ConnectionAborted ?? default, out var linkedToken); + var task = _clientResultsManager.AddInvocation(connectionId, invocationId, linkedToken); try { @@ -425,14 +426,6 @@ public override async Task InvokeConnectionAsync(string connectionId, stri } else { - // Connection disconnected while adding invocation - // we need to try to remove it here to avoid the task hanging if the add happened after the connection cleanup - if (connection.ConnectionAborted.IsCancellationRequested) - { - await _clientResultsManager.TryCompleteResult(connectionId, CompletionMessage.WithError(invocationId, "Connection disconnected")); - return await task; - } - // We're sending to a single connection // Write message directly to connection without caching it in memory var message = new InvocationMessage(invocationId, methodName, args); @@ -446,7 +439,19 @@ public override async Task InvokeConnectionAsync(string connectionId, stri throw; } - return await task; + try + { + return await task; + } + catch + { + // ConnectionAborted will trigger a generic "Canceled" exception from the task, let's convert it into a more specific message. + if (connection?.ConnectionAborted.IsCancellationRequested == true) + { + throw new Exception("Connection disconnected."); + } + throw; + } } /// @@ -549,11 +554,19 @@ private async Task SubscribeToConnection(HubConnectionContext connection) channel.OnMessage(channelMessage => { var invocation = RedisProtocol.ReadInvocation((byte[])channelMessage.Message); + // This is a Client result we need to setup state for the completion and send the message to the client if (!string.IsNullOrEmpty(invocation.InvocationId)) { - _clientResultsManager.AddInvocation(invocation.InvocationId, (typeof(RawResult), connection.ConnectionId, completionMessage => + object? tokenRegistration = null; + _clientResultsManager.AddInvocation(invocation.InvocationId, + (typeof(RawResult), connection.ConnectionId, null!, (_, completionMessage) => { var protocolName = connection.Protocol.Name; + if (tokenRegistration is not null) + { + ((CancellationTokenRegistration)tokenRegistration).Dispose(); + } + // TODO: acquiring this and then calling RedisProtocol.WriteCompletionMessage will allocate a new MemoryBufferWriter, we can avoid this var memoryBufferWriter = AspNetCore.Internal.MemoryBufferWriter.Get(); try { @@ -568,13 +581,16 @@ private async Task SubscribeToConnection(HubConnectionContext connection) } } )); - // Connection disconnected while adding invocation - // we need to try to remove it here to avoid the task hanging if the add happened after the connection cleanup - if (connection.ConnectionAborted.IsCancellationRequested) + + // TODO: this isn't great + tokenRegistration = connection.ConnectionAborted.UnsafeRegister(_ => { - return _clientResultsManager.TryCompleteResult(connection.ConnectionId, CompletionMessage.WithError(invocation.InvocationId, "Connection disconnected")); - } + var invocationInfo = _clientResultsManager.RemoveInvocation(invocation.InvocationId); + invocationInfo?.Completion(null!, CompletionMessage.WithError(invocation.InvocationId, "Connection disconnected.")); + }, null); } + + // Normal client method invokes and client result invokes use the same message return connection.WriteAsync(invocation.Message).AsTask(); }); } @@ -645,12 +661,22 @@ private async Task SubscribeToReturnResultsAsync() channel.OnMessage((channelMessage) => { var completion = RedisProtocol.ReadCompletion(channelMessage.Message); - var protocol = _hubProtocolResolver.AllProtocols.Where(p => p.Name.Equals(completion.ProtocolName)).First(); + IHubProtocol? protocol = null; + foreach (var hubProtocol in _hubProtocolResolver.AllProtocols) + { + if (hubProtocol.Name.Equals(completion.ProtocolName)) + { + protocol = hubProtocol; + break; + } + } + Debug.Assert(protocol is not null); var ros = completion.CompletionMessage; - protocol.TryParseMessage(ref ros, _clientResultsManager, out var hubMessage); + var parseSuccess = protocol.TryParseMessage(ref ros, _clientResultsManager, out var hubMessage); + Debug.Assert(parseSuccess); var invocationInfo = _clientResultsManager.RemoveInvocation(((CompletionMessage)hubMessage!).InvocationId!); - invocationInfo.Completion((CompletionMessage)hubMessage); + invocationInfo?.Completion(invocationInfo?.Tcs!, (CompletionMessage)hubMessage!); }); } From 476d28e51dbc11408bfd05078f93d49783688e90 Mon Sep 17 00:00:00 2001 From: Brennan Date: Mon, 28 Mar 2022 09:25:47 -0700 Subject: [PATCH 05/11] small fix --- .../src/Protocol/NewtonsoftJsonHubProtocol.cs | 36 ++++++++++++++++-- .../src/ScaleoutHubLifetimeManagerTests.cs | 38 +++++++++++++++++++ .../src/RedisHubLifetimeManager.cs | 3 +- 3 files changed, 72 insertions(+), 5 deletions(-) diff --git a/src/SignalR/common/Protocols.NewtonsoftJson/src/Protocol/NewtonsoftJsonHubProtocol.cs b/src/SignalR/common/Protocols.NewtonsoftJson/src/Protocol/NewtonsoftJsonHubProtocol.cs index d3b2c6196d07..89e4b03edaa4 100644 --- a/src/SignalR/common/Protocols.NewtonsoftJson/src/Protocol/NewtonsoftJsonHubProtocol.cs +++ b/src/SignalR/common/Protocols.NewtonsoftJson/src/Protocol/NewtonsoftJsonHubProtocol.cs @@ -218,8 +218,22 @@ public ReadOnlyMemory GetMessageBytes(HubMessage message) if (returnType == typeof(RawResult)) { var token = JToken.Load(reader); - var str = token.ToString(Formatting.None); - result = new RawResult(new ReadOnlySequence(Encoding.UTF8.GetBytes(str))); + using var strm = new MemoryStream(); + using var writer = new StreamWriter(strm); + using var jsonTextWriter = new JsonTextWriter(writer); + token.WriteTo(jsonTextWriter); + jsonTextWriter.Flush(); + writer.Flush(); + Memory buf; + if (strm.TryGetBuffer(out var segment)) + { + buf = segment.Array.AsMemory(segment.Offset, segment.Count); + } + else + { + buf = strm.ToArray(); + } + result = new RawResult(new ReadOnlySequence(buf)); } else { @@ -400,8 +414,22 @@ public ReadOnlyMemory GetMessageBytes(HubMessage message) var returnType = binder.GetReturnType(invocationId); if (returnType == typeof(RawResult)) { - var str = resultToken.ToString(Formatting.None); - result = new RawResult(new ReadOnlySequence(Encoding.UTF8.GetBytes(str))); + using var strm = new MemoryStream(); + using var writer = new StreamWriter(strm); + using var jsonTextWriter = new JsonTextWriter(writer); + resultToken.WriteTo(jsonTextWriter); + jsonTextWriter.Flush(); + writer.Flush(); + Memory buf; + if (strm.TryGetBuffer(out var segment)) + { + buf = segment.Array.AsMemory(segment.Offset, segment.Count); + } + else + { + buf = strm.ToArray(); + } + result = new RawResult(new ReadOnlySequence(buf)); } else { diff --git a/src/SignalR/server/Specification.Tests/src/ScaleoutHubLifetimeManagerTests.cs b/src/SignalR/server/Specification.Tests/src/ScaleoutHubLifetimeManagerTests.cs index 10bfaae1b307..4d031a9e7f65 100644 --- a/src/SignalR/server/Specification.Tests/src/ScaleoutHubLifetimeManagerTests.cs +++ b/src/SignalR/server/Specification.Tests/src/ScaleoutHubLifetimeManagerTests.cs @@ -576,4 +576,42 @@ public async Task ClientDisconnectsWithoutCompletingClientResultOnSecondServer() Assert.Equal("Connection disconnected.", ex.Message); } } + + /// + /// Specification test for SignalR HubLifetimeManager. + /// + /// A representing the asynchronous completion of the test. + [Fact] + public async Task InvocationsFromDifferentServersUseUniqueIDs() + { + var backplane = CreateBackplane(); + var manager1 = CreateNewHubLifetimeManager(backplane); + var manager2 = CreateNewHubLifetimeManager(backplane); + + using (var client1 = new TestClient()) + using (var client2 = new TestClient()) + { + var connection1 = HubConnectionContextUtils.Create(client1.Connection); + var connection2 = HubConnectionContextUtils.Create(client2.Connection); + + await manager1.OnConnectedAsync(connection1).DefaultTimeout(); + await manager2.OnConnectedAsync(connection2).DefaultTimeout(); + + var invoke1 = manager1.InvokeConnectionAsync(connection2.ConnectionId, "Result", new object[] { "test" }); + var invocation2 = Assert.IsType(await client2.ReadAsync().DefaultTimeout()); + + var invoke2 = manager2.InvokeConnectionAsync(connection1.ConnectionId, "Result", new object[] { "test" }); + var invocation1 = Assert.IsType(await client1.ReadAsync().DefaultTimeout()); + + Assert.NotEqual(invocation1.InvocationId, invocation2.InvocationId); + + await manager1.SetConnectionResultAsync(connection2.ConnectionId, CompletionMessage.WithResult(invocation2.InvocationId, 2)).DefaultTimeout(); + await manager2.SetConnectionResultAsync(connection1.ConnectionId, CompletionMessage.WithResult(invocation1.InvocationId, 5)).DefaultTimeout(); + + var res = await invoke1.DefaultTimeout(); + Assert.Equal(2, res); + res = await invoke2.DefaultTimeout(); + Assert.Equal(5, res); + } + } } diff --git a/src/SignalR/server/StackExchangeRedis/src/RedisHubLifetimeManager.cs b/src/SignalR/server/StackExchangeRedis/src/RedisHubLifetimeManager.cs index 5e494375d924..b864dabdd485 100644 --- a/src/SignalR/server/StackExchangeRedis/src/RedisHubLifetimeManager.cs +++ b/src/SignalR/server/StackExchangeRedis/src/RedisHubLifetimeManager.cs @@ -407,7 +407,8 @@ public override async Task InvokeConnectionAsync(string connectionId, stri var connection = _connections[connectionId]; - var invocationId = Interlocked.Increment(ref _lastInvocationId).ToString(NumberFormatInfo.InvariantInfo); + // Needs to be unique across servers, easiest way to do that is prefix with connection ID. + var invocationId = $"{connectionId}{Interlocked.Increment(ref _lastInvocationId)}"; using var _ = CancellationTokenUtils.CreateLinkedToken(cancellationToken, connection?.ConnectionAborted ?? default, out var linkedToken); var task = _clientResultsManager.AddInvocation(connectionId, invocationId, linkedToken); From cfe089bd13fe0ac86f4d71ca1aac55eaddbfc2da Mon Sep 17 00:00:00 2001 From: Brennan Date: Mon, 28 Mar 2022 13:14:00 -0700 Subject: [PATCH 06/11] fb --- .../common/Shared/ClientResultsManager.cs | 4 +- .../server/Core/src/HubLifetimeManager.cs | 4 +- .../server/SignalR/test/ClientProxyTests.cs | 79 +++++++++++++++++++ .../src/HubLifetimeManagerTestBase.cs | 4 +- .../src/RedisHubLifetimeManager.cs | 1 - 5 files changed, 85 insertions(+), 7 deletions(-) diff --git a/src/SignalR/common/Shared/ClientResultsManager.cs b/src/SignalR/common/Shared/ClientResultsManager.cs index 00d15af93f38..e23d9771591a 100644 --- a/src/SignalR/common/Shared/ClientResultsManager.cs +++ b/src/SignalR/common/Shared/ClientResultsManager.cs @@ -49,7 +49,7 @@ public Task TryCompleteResult(string connectionId, CompletionMessage message) { if (item.ConnectionId != connectionId) { - throw new Exception("wrong ID"); + throw new InvalidOperationException($"Connection ID '{connectionId}' is not valid for invocation ID '{message.InvocationId}'."); } // if false the connection disconnected right after the above TryGetValue @@ -90,7 +90,7 @@ public Type GetReturnType(string invocationId) { return type; } - throw new InvalidOperationException(); + throw new InvalidOperationException($"Invocation ID '{invocationId}' is not associated with a pending client result."); } // Unused, here to honor the IInvocationBinder interface but should never be called diff --git a/src/SignalR/server/Core/src/HubLifetimeManager.cs b/src/SignalR/server/Core/src/HubLifetimeManager.cs index 897993d3235e..14a294190876 100644 --- a/src/SignalR/server/Core/src/HubLifetimeManager.cs +++ b/src/SignalR/server/Core/src/HubLifetimeManager.cs @@ -146,7 +146,7 @@ public abstract class HubLifetimeManager where THub : Hub /// The response from the connection. public virtual Task InvokeConnectionAsync(string connectionId, string methodName, object?[] args, CancellationToken cancellationToken = default) { - throw new NotImplementedException(); + throw new NotImplementedException($"{GetType().Name} does not support client return values."); } /// @@ -157,7 +157,7 @@ public virtual Task InvokeConnectionAsync(string connectionId, string meth /// A that represents the result being set or being forwarded to another server. public virtual Task SetConnectionResultAsync(string connectionId, CompletionMessage result) { - throw new NotImplementedException(); + throw new NotImplementedException($"{GetType().Name} does not support client return values."); } /// diff --git a/src/SignalR/server/SignalR/test/ClientProxyTests.cs b/src/SignalR/server/SignalR/test/ClientProxyTests.cs index eb0d2613cc6c..ede98b1c505a 100644 --- a/src/SignalR/server/SignalR/test/ClientProxyTests.cs +++ b/src/SignalR/server/SignalR/test/ClientProxyTests.cs @@ -6,6 +6,7 @@ using System.Threading; using System.Threading.Tasks; using Microsoft.AspNetCore.SignalR.Internal; +using Microsoft.AspNetCore.Testing; using Moq; using Xunit; @@ -205,4 +206,82 @@ public async Task MultipleClientProxy_SendAsync_ArrayArgumentNotExpanded() Assert.Same(data, arg); } + + [Fact] + public async Task SingleClientProxyWithInvoke_ThrowsNotSupported() + { + var hubLifetimeManager = new EmptyHubLifetimeManager(); + + var proxy = new SingleClientProxyWithInvoke(hubLifetimeManager, ""); + var ex = await Assert.ThrowsAsync(async () => await proxy.InvokeAsync("method")).DefaultTimeout(); + Assert.Equal("EmptyHubLifetimeManager`1 does not support client return values.", ex.Message); + } + + internal class EmptyHubLifetimeManager : HubLifetimeManager where THub : Hub + { + public override Task AddToGroupAsync(string connectionId, string groupName, CancellationToken cancellationToken = default) + { + throw new NotImplementedException(); + } + + public override Task OnConnectedAsync(HubConnectionContext connection) + { + throw new NotImplementedException(); + } + + public override Task OnDisconnectedAsync(HubConnectionContext connection) + { + throw new NotImplementedException(); + } + + public override Task RemoveFromGroupAsync(string connectionId, string groupName, CancellationToken cancellationToken = default) + { + throw new NotImplementedException(); + } + + public override Task SendAllAsync(string methodName, object[] args, CancellationToken cancellationToken = default) + { + throw new NotImplementedException(); + } + + public override Task SendAllExceptAsync(string methodName, object[] args, IReadOnlyList excludedConnectionIds, CancellationToken cancellationToken = default) + { + throw new NotImplementedException(); + } + + public override Task SendConnectionAsync(string connectionId, string methodName, object[] args, CancellationToken cancellationToken = default) + { + throw new NotImplementedException(); + } + + public override Task SendConnectionsAsync(IReadOnlyList connectionIds, string methodName, object[] args, CancellationToken cancellationToken = default) + { + throw new NotImplementedException(); + } + + public override Task SendGroupAsync(string groupName, string methodName, object[] args, CancellationToken cancellationToken = default) + { + throw new NotImplementedException(); + } + + public override Task SendGroupExceptAsync(string groupName, string methodName, object[] args, IReadOnlyList excludedConnectionIds, CancellationToken cancellationToken = default) + { + throw new NotImplementedException(); + } + + public override Task SendGroupsAsync(IReadOnlyList groupNames, string methodName, object[] args, CancellationToken cancellationToken = default) + { + throw new NotImplementedException(); + } + + public override Task SendUserAsync(string userId, string methodName, object[] args, CancellationToken cancellationToken = default) + { + throw new NotImplementedException(); + } + + public override Task SendUsersAsync(IReadOnlyList userIds, string methodName, object[] args, CancellationToken cancellationToken = default) + { + throw new NotImplementedException(); + } + } } diff --git a/src/SignalR/server/Specification.Tests/src/HubLifetimeManagerTestBase.cs b/src/SignalR/server/Specification.Tests/src/HubLifetimeManagerTestBase.cs index 331015d90268..ac2acc1f56f4 100644 --- a/src/SignalR/server/Specification.Tests/src/HubLifetimeManagerTestBase.cs +++ b/src/SignalR/server/Specification.Tests/src/HubLifetimeManagerTestBase.cs @@ -248,10 +248,10 @@ public async Task ExceptionWhenIncorrectClientCompletesClientResult() Assert.NotNull(invocation.InvocationId); Assert.Equal("test", invocation.Arguments[0]); - var ex = await Assert.ThrowsAsync(() => + var ex = await Assert.ThrowsAsync(() => manager.SetConnectionResultAsync(connection2.ConnectionId, CompletionMessage.WithError(invocation.InvocationId, "Error from client"))).DefaultTimeout(); - Assert.Equal("wrong ID", ex.Message); + Assert.Equal($"Connection ID '{connection2.ConnectionId}' is not valid for invocation ID '{invocation.InvocationId}'.", ex.Message); // Internal state for invocation isn't affected by wrong client, check that we can still complete the invocation await manager.SetConnectionResultAsync(connection1.ConnectionId, CompletionMessage.WithResult(invocation.InvocationId, 10)).DefaultTimeout(); diff --git a/src/SignalR/server/StackExchangeRedis/src/RedisHubLifetimeManager.cs b/src/SignalR/server/StackExchangeRedis/src/RedisHubLifetimeManager.cs index b864dabdd485..2d3269705707 100644 --- a/src/SignalR/server/StackExchangeRedis/src/RedisHubLifetimeManager.cs +++ b/src/SignalR/server/StackExchangeRedis/src/RedisHubLifetimeManager.cs @@ -4,7 +4,6 @@ using System.Buffers; using System.Diagnostics; using System.Diagnostics.CodeAnalysis; -using System.Globalization; using System.Linq; using System.Text; using Microsoft.AspNetCore.Http.Features; From f4ccc84ed79c664568d15df5634e67d5806cbd2e Mon Sep 17 00:00:00 2001 From: Brennan Date: Tue, 29 Mar 2022 09:14:47 -0700 Subject: [PATCH 07/11] logs --- .../Client.Core/src/HubConnection.Log.cs | 9 +++ .../csharp/Client.Core/src/HubConnection.cs | 8 +- .../clients/ts/signalr/src/HubConnection.ts | 77 ++++++++++--------- .../ts/signalr/tests/HubConnection.test.ts | 6 +- .../common/Shared/ClientResultsManager.cs | 12 +-- .../common/Shared/CreateLinkedToken.cs | 2 +- .../Core/src/Internal/DefaultHubDispatcher.cs | 3 +- .../src/Internal/DefaultHubDispatcherLog.cs | 6 +- .../SignalR/test/HubConnectionHandlerTests.cs | 2 +- .../src/ScaleoutHubLifetimeManagerTests.cs | 15 ++++ .../src/Internal/RedisProtocol.cs | 23 ++++-- .../src/RedisHubLifetimeManager.cs | 10 +-- 12 files changed, 110 insertions(+), 63 deletions(-) diff --git a/src/SignalR/clients/csharp/Client.Core/src/HubConnection.Log.cs b/src/SignalR/clients/csharp/Client.Core/src/HubConnection.Log.cs index 43f77ae3c86a..bfeb73ca401e 100644 --- a/src/SignalR/clients/csharp/Client.Core/src/HubConnection.Log.cs +++ b/src/SignalR/clients/csharp/Client.Core/src/HubConnection.Log.cs @@ -310,5 +310,14 @@ public static void ErrorHandshakeTimedOut(ILogger logger, TimeSpan handshakeTime [LoggerMessage(84, LogLevel.Trace, "Client threw an error for stream '{StreamId}'.", EventName = "ErroredStream")] public static partial void ErroredStream(ILogger logger, string streamId, Exception exception); + + [LoggerMessage(85, LogLevel.Warning, "No result given for '{Target}' method and invocation ID '{InvocationId}'.", EventName = "NoResultGiven")] + public static partial void NoResultGiven(ILogger logger, string target, string invocationId); + + [LoggerMessage(86, LogLevel.Warning, "Result given for '{Target}' method but server is not expecting a result.", EventName = "ResultNotExpected")] + public static partial void ResultNotExpected(ILogger logger, string target); + + [LoggerMessage(87, LogLevel.Warning, "Result already provided for '{Target}' only the last one will be sent.", EventName = "IgnoringPreviousResult")] + public static partial void IgnoringPreviousResult(ILogger logger, string target); } } diff --git a/src/SignalR/clients/csharp/Client.Core/src/HubConnection.cs b/src/SignalR/clients/csharp/Client.Core/src/HubConnection.cs index 46d009e997d1..f5dd7032d073 100644 --- a/src/SignalR/clients/csharp/Client.Core/src/HubConnection.cs +++ b/src/SignalR/clients/csharp/Client.Core/src/HubConnection.cs @@ -1030,6 +1030,7 @@ private async Task DispatchInvocationAsync(InvocationMessage invocation, Connect Log.MissingHandler(_logger, invocation.Target); if (expectsResult) { + Log.NoResultGiven(_logger, invocation.Target, invocation.InvocationId!); await SendWithLock(connectionState, CompletionMessage.WithError(invocation.InvocationId!, "Client didn't provide a result."), cancellationToken: default).ConfigureAwait(false); } return; @@ -1047,6 +1048,10 @@ private async Task DispatchInvocationAsync(InvocationMessage invocation, Connect var task = handler.InvokeAsync(invocation.Arguments); if (handler.HasResult && task is Task resultTask) { + if (hasResult) + { + Log.IgnoringPreviousResult(_logger, invocation.Target); + } hasResult = true; result = await resultTask.ConfigureAwait(false); // ignore previous results' exception, we prefer last .On handler for results @@ -1074,6 +1079,7 @@ private async Task DispatchInvocationAsync(InvocationMessage invocation, Connect } else if (!hasResult) { + Log.NoResultGiven(_logger, invocation.Target, invocation.InvocationId!); await SendWithLock(connectionState, CompletionMessage.WithError(invocation.InvocationId!, "Client didn't provide a result."), cancellationToken: default).ConfigureAwait(false); } else @@ -1083,7 +1089,7 @@ private async Task DispatchInvocationAsync(InvocationMessage invocation, Connect } else if (hasResult) { - // Log: result given but server didn't ask for one. + Log.ResultNotExpected(_logger, invocation.Target); } } diff --git a/src/SignalR/clients/ts/signalr/src/HubConnection.ts b/src/SignalR/clients/ts/signalr/src/HubConnection.ts index e03d369b62b6..145aa91dda95 100644 --- a/src/SignalR/clients/ts/signalr/src/HubConnection.ts +++ b/src/SignalR/clients/ts/signalr/src/HubConnection.ts @@ -676,52 +676,55 @@ export class HubConnection { private async _invokeClientMethod(invocationMessage: InvocationMessage) { const methods = this._methods[invocationMessage.target.toLowerCase()]; - if (methods) { - // Avoid issues with handlers removing themselves thus modifying the list while iterating through it - const methodsCopy = methods.slice(); + if (!methods) { + this._logger.log(LogLevel.Warning, `No client method with the name '${invocationMessage.target.toLowerCase()}' found.`); - // Server expects a response + // No handlers provided by client but the server is expecting a response still, so we send an error if (invocationMessage.invocationId) { - // We preserve the last result or exception but still call all handlers - let res; - let exception; - for (const m of methodsCopy) { - try { - if (res) { - this._logger.log(LogLevel.Warning, `Result already provided for '${invocationMessage.target.toLowerCase()}' only the last one will be sent.`); - } - res = await m.apply(this, invocationMessage.arguments); - // Ignore exception if we got a result after, the exception will be logged - exception = undefined; - } catch (e) { - exception = e; - this._logger.log(LogLevel.Error, `A callback for the method '${invocationMessage.target.toLowerCase()}' threw error '${e}'.`); - } - } - // If there is an exception that means either no result was given or a handler after a result threw - // And since we prefer handlers registered later we'll use the exception to return to the server. - if (exception) { - await this._sendWithProtocol(this._createCompletionMessage(invocationMessage.invocationId, `${exception}`, null)); - } else if (res !== undefined) { - await this._sendWithProtocol(this._createCompletionMessage(invocationMessage.invocationId, null, res)); - } else { - // Client didn't provide a result or throw from a handler, server expects a response so we send an error - await this._sendWithProtocol(this._createCompletionMessage(invocationMessage.invocationId, "Client didn't provide a result.", null)); - } - } else { + this._logger.log(LogLevel.Warning, `No result given for '${invocationMessage.target.toLowerCase()}' method and invocation ID '${invocationMessage.invocationId}'.`); + await this._sendWithProtocol(this._createCompletionMessage(invocationMessage.invocationId, "Client didn't provide a result.", null)); + } + return; + } + + // Avoid issues with handlers removing themselves thus modifying the list while iterating through it + const methodsCopy = methods.slice(); + + // Server expects a response + if (invocationMessage.invocationId) { + // We preserve the last result or exception but still call all handlers + let res; + let exception; + for (const m of methodsCopy) { try { - methodsCopy.forEach((m) => m.apply(this, invocationMessage.arguments)); + if (res || exception) { + this._logger.log(LogLevel.Warning, `Result already provided for '${invocationMessage.target.toLowerCase()}' only the last one will be sent.`); + } + res = await m.apply(this, invocationMessage.arguments); + // Ignore exception if we got a result after, the exception will be logged + exception = undefined; } catch (e) { + exception = e; this._logger.log(LogLevel.Error, `A callback for the method '${invocationMessage.target.toLowerCase()}' threw error '${e}'.`); } } - } else { - this._logger.log(LogLevel.Warning, `No client method with the name '${invocationMessage.target.toLowerCase()}' found.`); - - // No handlers provided by client but the server is expecting a response still, so we send an error - if (invocationMessage.invocationId) { + // If there is an exception that means either no result was given or a handler after a result threw + // And since we prefer handlers registered later we'll use the exception to return to the server. + if (exception) { + await this._sendWithProtocol(this._createCompletionMessage(invocationMessage.invocationId, `${exception}`, null)); + } else if (res !== undefined) { + await this._sendWithProtocol(this._createCompletionMessage(invocationMessage.invocationId, null, res)); + } else { + this._logger.log(LogLevel.Warning, `No result given for '${invocationMessage.target.toLowerCase()}' method and invocation ID '${invocationMessage.invocationId}'.`); + // Client didn't provide a result or throw from a handler, server expects a response so we send an error await this._sendWithProtocol(this._createCompletionMessage(invocationMessage.invocationId, "Client didn't provide a result.", null)); } + } else { + try { + methodsCopy.forEach((m) => m.apply(this, invocationMessage.arguments)); + } catch (e) { + this._logger.log(LogLevel.Error, `A callback for the method '${invocationMessage.target.toLowerCase()}' threw error '${e}'.`); + } } } diff --git a/src/SignalR/clients/ts/signalr/tests/HubConnection.test.ts b/src/SignalR/clients/ts/signalr/tests/HubConnection.test.ts index bc8b54aed788..bb54a3c13249 100644 --- a/src/SignalR/clients/ts/signalr/tests/HubConnection.test.ts +++ b/src/SignalR/clients/ts/signalr/tests/HubConnection.test.ts @@ -650,7 +650,8 @@ describe("HubConnection", () => { type: MessageType.Invocation, }); - expect(warnings).toEqual(["No client method with the name 'message' found."]); + expect(warnings).toEqual(["No client method with the name 'message' found.", + "No result given for 'message' method and invocation ID '0'."]); } finally { await hubConnection.stop(); } @@ -1035,7 +1036,8 @@ describe("HubConnection", () => { type: MessageType.Invocation, }); - expect(warnings).toEqual(["No client method with the name 'message' found."]); + expect(warnings).toEqual(["No client method with the name 'message' found.", + "No result given for 'message' method and invocation ID '0'."]); hubConnection.off(null!, undefined!); hubConnection.off(undefined!, null!); diff --git a/src/SignalR/common/Shared/ClientResultsManager.cs b/src/SignalR/common/Shared/ClientResultsManager.cs index e23d9771591a..76f153110def 100644 --- a/src/SignalR/common/Shared/ClientResultsManager.cs +++ b/src/SignalR/common/Shared/ClientResultsManager.cs @@ -18,7 +18,7 @@ internal class ClientResultsManager : IInvocationBinder public Task AddInvocation(string connectionId, string invocationId, CancellationToken cancellationToken) { var tcs = new TaskCompletionSourceWithCancellation(this, connectionId, invocationId, cancellationToken); - _pendingInvocations.TryAdd(invocationId, (typeof(T), connectionId, tcs, static (state, completionMessage) => + var result = _pendingInvocations.TryAdd(invocationId, (typeof(T), connectionId, tcs, static (state, completionMessage) => { var tcs = (TaskCompletionSourceWithCancellation)state; if (completionMessage.HasResult) @@ -32,6 +32,7 @@ public Task AddInvocation(string connectionId, string invocationId, Cancel return Task.CompletedTask; } )); + Debug.Assert(result); tcs.RegisterCancellation(); @@ -40,7 +41,8 @@ public Task AddInvocation(string connectionId, string invocationId, Cancel public void AddInvocation(string invocationId, (Type Type, string ConnectionId, object Tcs, Func Completion) invocationInfo) { - _pendingInvocations.TryAdd(invocationId, invocationInfo); + var result = _pendingInvocations.TryAdd(invocationId, invocationInfo); + Debug.Assert(result); } public Task TryCompleteResult(string connectionId, CompletionMessage message) @@ -107,7 +109,7 @@ public Type GetStreamItemType(string streamId) // Custom TCS type to avoid the extra allocation that would be introduced if we managed the cancellation separately // Also makes it easier to keep track of the CancellationTokenRegistration for disposal - private sealed class TaskCompletionSourceWithCancellation : TaskCompletionSource + internal sealed class TaskCompletionSourceWithCancellation : TaskCompletionSource { private readonly ClientResultsManager _clientResultsManager; private readonly string _connectionId; @@ -149,14 +151,14 @@ public void RegisterCancellation() public new void SetResult(T result) { - base.SetResult(result); _tokenRegistration.Dispose(); + base.SetResult(result); } public new void SetException(Exception exception) { - base.SetException(exception); _tokenRegistration.Dispose(); + base.SetException(exception); } #pragma warning disable IDE0060 // Remove unused parameter diff --git a/src/SignalR/common/Shared/CreateLinkedToken.cs b/src/SignalR/common/Shared/CreateLinkedToken.cs index a5e4170f975c..198bde588696 100644 --- a/src/SignalR/common/Shared/CreateLinkedToken.cs +++ b/src/SignalR/common/Shared/CreateLinkedToken.cs @@ -10,7 +10,7 @@ internal static class CancellationTokenUtils { // Similar to CreateLinkedTokenSource except it will not allocate a new internal LinkedCancellationTokenSource in the case where // one of the tokens passed in isn't cancellable. - // Returns a disposable only when an actual LinkkedTokenSource is created. + // Returns a disposable only when an actual LinkedTokenSource is created. internal static IDisposable? CreateLinkedToken(CancellationToken token1, CancellationToken token2, out CancellationToken linkedToken) { if (!token1.CanBeCanceled) diff --git a/src/SignalR/server/Core/src/Internal/DefaultHubDispatcher.cs b/src/SignalR/server/Core/src/Internal/DefaultHubDispatcher.cs index bd173ee7079f..c61c6687ccfe 100644 --- a/src/SignalR/server/Core/src/Internal/DefaultHubDispatcher.cs +++ b/src/SignalR/server/Core/src/Internal/DefaultHubDispatcher.cs @@ -182,8 +182,7 @@ public override Task DispatchMessageAsync(HubConnectionContext connection, HubMe } else { - // TODO: Retire this log and replace with a more generic one - Log.UnexpectedStreamCompletion(_logger); + Log.UnexpectedCompletion(_logger, completionMessage.InvocationId!); } break; diff --git a/src/SignalR/server/Core/src/Internal/DefaultHubDispatcherLog.cs b/src/SignalR/server/Core/src/Internal/DefaultHubDispatcherLog.cs index 4905d4e59f49..f80a970935a3 100644 --- a/src/SignalR/server/Core/src/Internal/DefaultHubDispatcherLog.cs +++ b/src/SignalR/server/Core/src/Internal/DefaultHubDispatcherLog.cs @@ -92,8 +92,7 @@ public static void ClosingStreamWithBindingError(ILogger logger, CompletionMessa [LoggerMessage(19, LogLevel.Debug, "Stream '{StreamId}' closed with error '{Error}'.", EventName = "ClosingStreamWithBindingError")] private static partial void ClosingStreamWithBindingError(ILogger logger, string? streamId, string? error); - [LoggerMessage(20, LogLevel.Debug, "StreamCompletionMessage received unexpectedly.", EventName = "UnexpectedStreamCompletion")] - public static partial void UnexpectedStreamCompletion(ILogger logger); + // Retired [20]UnexpectedStreamCompletion, replaced with more generic [24]UnexpectedCompletion [LoggerMessage(21, LogLevel.Debug, "StreamItemMessage received unexpectedly.", EventName = "UnexpectedStreamItem")] public static partial void UnexpectedStreamItem(ILogger logger); @@ -103,4 +102,7 @@ public static void ClosingStreamWithBindingError(ILogger logger, CompletionMessa [LoggerMessage(23, LogLevel.Debug, "Invocation ID '{InvocationId}' is already in use.", EventName = "InvocationIdInUse")] public static partial void InvocationIdInUse(ILogger logger, string InvocationId); + + [LoggerMessage(24, LogLevel.Debug, "CompletionMessage for invocation ID '{InvocationId}' received unexpectedly.", EventName = "UnexpectedCompletion")] + public static partial void UnexpectedCompletion(ILogger logger, string invocationId); } diff --git a/src/SignalR/server/SignalR/test/HubConnectionHandlerTests.cs b/src/SignalR/server/SignalR/test/HubConnectionHandlerTests.cs index f09d5cb7d416..bce1522e597d 100644 --- a/src/SignalR/server/SignalR/test/HubConnectionHandlerTests.cs +++ b/src/SignalR/server/SignalR/test/HubConnectionHandlerTests.cs @@ -3925,7 +3925,7 @@ public async Task UploadStreamCompleteInvalidId() } Assert.Single(TestSink.Writes.Where(w => w.LoggerName == "Microsoft.AspNetCore.SignalR.Internal.DefaultHubDispatcher" && - w.EventId.Name == "UnexpectedStreamCompletion")); + w.EventId.Name == "UnexpectedCompletion")); } public static string CustomErrorMessage = "custom error for testing ::::)"; diff --git a/src/SignalR/server/Specification.Tests/src/ScaleoutHubLifetimeManagerTests.cs b/src/SignalR/server/Specification.Tests/src/ScaleoutHubLifetimeManagerTests.cs index 4d031a9e7f65..0c25ce9698ff 100644 --- a/src/SignalR/server/Specification.Tests/src/ScaleoutHubLifetimeManagerTests.cs +++ b/src/SignalR/server/Specification.Tests/src/ScaleoutHubLifetimeManagerTests.cs @@ -614,4 +614,19 @@ public async Task InvocationsFromDifferentServersUseUniqueIDs() Assert.Equal(5, res); } } + + /// + /// Specification test for SignalR HubLifetimeManager. + /// + /// A representing the asynchronous completion of the test. + [Fact] + public async Task ConnectionDoesNotExist_FailsInvokeConnectionAsync() + { + var backplane = CreateBackplane(); + var manager1 = CreateNewHubLifetimeManager(backplane); + var manager2 = CreateNewHubLifetimeManager(backplane); + + var ex = await Assert.ThrowsAsync(() => manager1.InvokeConnectionAsync("1234", "Result", new object[] { "test" })).DefaultTimeout(); + Assert.Equal("Connection does not exist.", ex.Message); + } } diff --git a/src/SignalR/server/StackExchangeRedis/src/Internal/RedisProtocol.cs b/src/SignalR/server/StackExchangeRedis/src/Internal/RedisProtocol.cs index 00d18b80aee7..51f3512e0d43 100644 --- a/src/SignalR/server/StackExchangeRedis/src/Internal/RedisProtocol.cs +++ b/src/SignalR/server/StackExchangeRedis/src/Internal/RedisProtocol.cs @@ -51,8 +51,6 @@ public byte[] WriteInvocation(string methodName, object?[] args, string? invocat if (!string.IsNullOrEmpty(returnChannel)) { writer.WriteArrayHeader(4); - writer.Write(invocationId); - writer.Write(returnChannel); } else { @@ -72,6 +70,15 @@ public byte[] WriteInvocation(string methodName, object?[] args, string? invocat } WriteHubMessage(ref writer, new InvocationMessage(invocationId, methodName, args)); + + // Write last in order to preserve original order for cases where one server is updated and the other isn't. + // Not really a supported scenario, but why not be nice + if (!string.IsNullOrEmpty(returnChannel)) + { + writer.Write(invocationId); + writer.Write(returnChannel); + } + writer.Flush(); return memoryBufferWriter.ToArray(); @@ -170,11 +177,6 @@ public static RedisInvocation ReadInvocation(ReadOnlyMemory data) string? returnChannel = null; string? invocationId = null; - if (length > 3) - { - invocationId = reader.ReadString(); - returnChannel = reader.ReadString(); - } // Read excluded Ids IReadOnlyList? excludedConnectionIds = null; @@ -192,6 +194,13 @@ public static RedisInvocation ReadInvocation(ReadOnlyMemory data) // Read payload var message = ReadSerializedHubMessage(ref reader); + + if (length > 3) + { + invocationId = reader.ReadString(); + returnChannel = reader.ReadString(); + } + return new RedisInvocation(message, excludedConnectionIds, invocationId, returnChannel); } diff --git a/src/SignalR/server/StackExchangeRedis/src/RedisHubLifetimeManager.cs b/src/SignalR/server/StackExchangeRedis/src/RedisHubLifetimeManager.cs index 2d3269705707..d0621031ef88 100644 --- a/src/SignalR/server/StackExchangeRedis/src/RedisHubLifetimeManager.cs +++ b/src/SignalR/server/StackExchangeRedis/src/RedisHubLifetimeManager.cs @@ -408,6 +408,7 @@ public override async Task InvokeConnectionAsync(string connectionId, stri // Needs to be unique across servers, easiest way to do that is prefix with connection ID. var invocationId = $"{connectionId}{Interlocked.Increment(ref _lastInvocationId)}"; + using var _ = CancellationTokenUtils.CreateLinkedToken(cancellationToken, connection?.ConnectionAborted ?? default, out var linkedToken); var task = _clientResultsManager.AddInvocation(connectionId, invocationId, linkedToken); @@ -554,18 +555,16 @@ private async Task SubscribeToConnection(HubConnectionContext connection) channel.OnMessage(channelMessage => { var invocation = RedisProtocol.ReadInvocation((byte[])channelMessage.Message); + // This is a Client result we need to setup state for the completion and send the message to the client if (!string.IsNullOrEmpty(invocation.InvocationId)) { - object? tokenRegistration = null; + CancellationTokenRegistration? tokenRegistration = null; _clientResultsManager.AddInvocation(invocation.InvocationId, (typeof(RawResult), connection.ConnectionId, null!, (_, completionMessage) => { var protocolName = connection.Protocol.Name; - if (tokenRegistration is not null) - { - ((CancellationTokenRegistration)tokenRegistration).Dispose(); - } + tokenRegistration?.Dispose(); // TODO: acquiring this and then calling RedisProtocol.WriteCompletionMessage will allocate a new MemoryBufferWriter, we can avoid this var memoryBufferWriter = AspNetCore.Internal.MemoryBufferWriter.Get(); try @@ -590,6 +589,7 @@ private async Task SubscribeToConnection(HubConnectionContext connection) }, null); } + // Forward message from other server to client // Normal client method invokes and client result invokes use the same message return connection.WriteAsync(invocation.Message).AsTask(); }); From ccb01d4e5f1b8dc0798d11c0d659a37c109ba310 Mon Sep 17 00:00:00 2001 From: Brennan Date: Mon, 11 Apr 2022 13:57:46 -0700 Subject: [PATCH 08/11] fb --- .../Client.Core/src/HubConnection.Log.cs | 7 +- .../csharp/Client.Core/src/HubConnection.cs | 40 ++++++---- .../UnitTests/HubConnectionTests.Protocol.cs | 16 ++-- .../clients/ts/signalr/src/HubConnection.ts | 59 +++++++------- .../ts/signalr/tests/HubConnection.test.ts | 76 ++++++++++++++++--- .../common/Shared/ClientResultsManager.cs | 14 ++-- .../Core/src/DefaultHubLifetimeManager.cs | 7 +- .../Core/src/Internal/TypedClientBuilder.cs | 15 ++-- .../src/HubLifetimeManagerTestBase.cs | 6 +- .../src/ScaleoutHubLifetimeManagerTests.cs | 6 +- .../src/Internal/RedisLog.cs | 6 ++ .../src/RedisHubLifetimeManager.cs | 29 ++++--- 12 files changed, 179 insertions(+), 102 deletions(-) diff --git a/src/SignalR/clients/csharp/Client.Core/src/HubConnection.Log.cs b/src/SignalR/clients/csharp/Client.Core/src/HubConnection.Log.cs index bfeb73ca401e..3f66fd3af762 100644 --- a/src/SignalR/clients/csharp/Client.Core/src/HubConnection.Log.cs +++ b/src/SignalR/clients/csharp/Client.Core/src/HubConnection.Log.cs @@ -311,13 +311,10 @@ public static void ErrorHandshakeTimedOut(ILogger logger, TimeSpan handshakeTime [LoggerMessage(84, LogLevel.Trace, "Client threw an error for stream '{StreamId}'.", EventName = "ErroredStream")] public static partial void ErroredStream(ILogger logger, string streamId, Exception exception); - [LoggerMessage(85, LogLevel.Warning, "No result given for '{Target}' method and invocation ID '{InvocationId}'.", EventName = "NoResultGiven")] - public static partial void NoResultGiven(ILogger logger, string target, string invocationId); + [LoggerMessage(85, LogLevel.Warning, "Failed to find a value returning handler for '{Target}' method. Sending error to server.", EventName = "MissingResultHandler")] + public static partial void MissingResultHandler(ILogger logger, string target); [LoggerMessage(86, LogLevel.Warning, "Result given for '{Target}' method but server is not expecting a result.", EventName = "ResultNotExpected")] public static partial void ResultNotExpected(ILogger logger, string target); - - [LoggerMessage(87, LogLevel.Warning, "Result already provided for '{Target}' only the last one will be sent.", EventName = "IgnoringPreviousResult")] - public static partial void IgnoringPreviousResult(ILogger logger, string target); } } diff --git a/src/SignalR/clients/csharp/Client.Core/src/HubConnection.cs b/src/SignalR/clients/csharp/Client.Core/src/HubConnection.cs index f5dd7032d073..8e90b23fa48a 100644 --- a/src/SignalR/clients/csharp/Client.Core/src/HubConnection.cs +++ b/src/SignalR/clients/csharp/Client.Core/src/HubConnection.cs @@ -337,7 +337,7 @@ public virtual IDisposable On(string methodName, Type[] parameterTypes, Func resultTask) { - if (hasResult) - { - Log.IgnoringPreviousResult(_logger, invocation.Target); - } - hasResult = true; result = await resultTask.ConfigureAwait(false); - // ignore previous results' exception, we prefer last .On handler for results - resultException = null; + hasResult = true; } else { @@ -1071,20 +1068,21 @@ private async Task DispatchInvocationAsync(InvocationMessage invocation, Connect } } } + if (expectsResult) { if (resultException is not null) { await SendWithLock(connectionState, CompletionMessage.WithError(invocation.InvocationId!, resultException.Message), cancellationToken: default).ConfigureAwait(false); } - else if (!hasResult) + else if (hasResult) { - Log.NoResultGiven(_logger, invocation.Target, invocation.InvocationId!); - await SendWithLock(connectionState, CompletionMessage.WithError(invocation.InvocationId!, "Client didn't provide a result."), cancellationToken: default).ConfigureAwait(false); + await SendWithLock(connectionState, CompletionMessage.WithResult(invocation.InvocationId!, result), cancellationToken: default).ConfigureAwait(false); } else { - await SendWithLock(connectionState, CompletionMessage.WithResult(invocation.InvocationId!, result), cancellationToken: default).ConfigureAwait(false); + Log.MissingResultHandler(_logger, invocation.Target); + await SendWithLock(connectionState, CompletionMessage.WithError(invocation.InvocationId!, "Client didn't provide a result."), cancellationToken: default).ConfigureAwait(false); } } else if (hasResult) @@ -1700,10 +1698,20 @@ internal InvocationHandler[] GetHandlers() return handlers; } - internal void Add(InvocationHandler handler) + internal void Add(string methodName, InvocationHandler handler) { lock (_invocationHandlers) { + if (handler.HasResult) + { + foreach (var m in _invocationHandlers) + { + if (m.HasResult) + { + throw new InvalidOperationException($"'{methodName}' already has a value returning handler. Multiple return values are not supported."); + } + } + } _invocationHandlers.Add(handler); _copiedHandlers = null; } diff --git a/src/SignalR/clients/csharp/Client/test/UnitTests/HubConnectionTests.Protocol.cs b/src/SignalR/clients/csharp/Client/test/UnitTests/HubConnectionTests.Protocol.cs index eaed0dab2778..da3621b9c34e 100644 --- a/src/SignalR/clients/csharp/Client/test/UnitTests/HubConnectionTests.Protocol.cs +++ b/src/SignalR/clients/csharp/Client/test/UnitTests/HubConnectionTests.Protocol.cs @@ -702,7 +702,7 @@ public async Task ClientCanReturnResult() } [Fact] - public async Task ClientReturnResultUsesLastResult() + public async Task ThrowsWhenMultipleReturningHandlersRegistered() { var connection = new TestConnection(); var hubConnection = CreateHubConnection(connection); @@ -711,15 +711,9 @@ public async Task ClientReturnResultUsesLastResult() await hubConnection.StartAsync().DefaultTimeout(); hubConnection.On("Result", () => 10); - hubConnection.On("Result", () => 11); - hubConnection.On("Result", () => 14); - hubConnection.On("Result", () => 3); - - await connection.ReceiveTextAsync("{\"type\":1,\"invocationId\":\"1\",\"target\":\"Result\",\"arguments\":[]}\u001e").DefaultTimeout(); - - var invokeMessage = await connection.ReadSentTextMessageAsync().DefaultTimeout(); - - Assert.Equal("{\"type\":3,\"invocationId\":\"1\",\"result\":3}", invokeMessage); + var ex = Assert.Throws( + () => hubConnection.On("Result", () => 11)); + Assert.Equal("'Result' already has a value returning handler. Multiple return values are not supported.", ex.Message); } finally { @@ -791,7 +785,7 @@ public async Task ClientResultIgnoresErrorWhenLastHandlerSuccessful() { await hubConnection.StartAsync().DefaultTimeout(); - hubConnection.On("Result", int () => + hubConnection.On("Result", () => { throw new Exception("error from client"); }); diff --git a/src/SignalR/clients/ts/signalr/src/HubConnection.ts b/src/SignalR/clients/ts/signalr/src/HubConnection.ts index 145aa91dda95..606745c4845f 100644 --- a/src/SignalR/clients/ts/signalr/src/HubConnection.ts +++ b/src/SignalR/clients/ts/signalr/src/HubConnection.ts @@ -675,13 +675,14 @@ export class HubConnection { } private async _invokeClientMethod(invocationMessage: InvocationMessage) { - const methods = this._methods[invocationMessage.target.toLowerCase()]; + const methodName = invocationMessage.target.toLowerCase(); + const methods = this._methods[methodName]; if (!methods) { - this._logger.log(LogLevel.Warning, `No client method with the name '${invocationMessage.target.toLowerCase()}' found.`); + this._logger.log(LogLevel.Warning, `No client method with the name '${methodName}' found.`); // No handlers provided by client but the server is expecting a response still, so we send an error if (invocationMessage.invocationId) { - this._logger.log(LogLevel.Warning, `No result given for '${invocationMessage.target.toLowerCase()}' method and invocation ID '${invocationMessage.invocationId}'.`); + this._logger.log(LogLevel.Warning, `No result given for '${methodName}' method and invocation ID '${invocationMessage.invocationId}'.`); await this._sendWithProtocol(this._createCompletionMessage(invocationMessage.invocationId, "Client didn't provide a result.", null)); } return; @@ -691,39 +692,43 @@ export class HubConnection { const methodsCopy = methods.slice(); // Server expects a response - if (invocationMessage.invocationId) { - // We preserve the last result or exception but still call all handlers - let res; - let exception; - for (const m of methodsCopy) { - try { - if (res || exception) { - this._logger.log(LogLevel.Warning, `Result already provided for '${invocationMessage.target.toLowerCase()}' only the last one will be sent.`); - } - res = await m.apply(this, invocationMessage.arguments); - // Ignore exception if we got a result after, the exception will be logged - exception = undefined; - } catch (e) { - exception = e; - this._logger.log(LogLevel.Error, `A callback for the method '${invocationMessage.target.toLowerCase()}' threw error '${e}'.`); + const expects_response = invocationMessage.invocationId ? true : false; + // We preserve the last result or exception but still call all handlers + let res; + let exception; + let completion_message; + for (const m of methodsCopy) { + try { + const prev_res = res; + res = await m.apply(this, invocationMessage.arguments); + if (expects_response && res && prev_res) { + this._logger.log(LogLevel.Error, `Multiple results provided for '${methodName}'. Sending error to server.`); + completion_message = this._createCompletionMessage(invocationMessage.invocationId!, `Client provided multiple results.`, null); } + // Ignore exception if we got a result after, the exception will be logged + exception = undefined; + } catch (e) { + exception = e; + this._logger.log(LogLevel.Error, `A callback for the method '${methodName}' threw error '${e}'.`); } + } + if (completion_message) { + await this._sendWithProtocol(completion_message); + } else if (expects_response) { // If there is an exception that means either no result was given or a handler after a result threw - // And since we prefer handlers registered later we'll use the exception to return to the server. if (exception) { - await this._sendWithProtocol(this._createCompletionMessage(invocationMessage.invocationId, `${exception}`, null)); + completion_message = this._createCompletionMessage(invocationMessage.invocationId!, `${exception}`, null); } else if (res !== undefined) { - await this._sendWithProtocol(this._createCompletionMessage(invocationMessage.invocationId, null, res)); + completion_message = this._createCompletionMessage(invocationMessage.invocationId!, null, res); } else { - this._logger.log(LogLevel.Warning, `No result given for '${invocationMessage.target.toLowerCase()}' method and invocation ID '${invocationMessage.invocationId}'.`); + this._logger.log(LogLevel.Warning, `No result given for '${methodName}' method and invocation ID '${invocationMessage.invocationId}'.`); // Client didn't provide a result or throw from a handler, server expects a response so we send an error - await this._sendWithProtocol(this._createCompletionMessage(invocationMessage.invocationId, "Client didn't provide a result.", null)); + completion_message = this._createCompletionMessage(invocationMessage.invocationId!, "Client didn't provide a result.", null); } + await this._sendWithProtocol(completion_message); } else { - try { - methodsCopy.forEach((m) => m.apply(this, invocationMessage.arguments)); - } catch (e) { - this._logger.log(LogLevel.Error, `A callback for the method '${invocationMessage.target.toLowerCase()}' threw error '${e}'.`); + if (res) { + this._logger.log(LogLevel.Error, `Result given for '${methodName}' method but server is not expecting a result.`); } } } diff --git a/src/SignalR/clients/ts/signalr/tests/HubConnection.test.ts b/src/SignalR/clients/ts/signalr/tests/HubConnection.test.ts index bb54a3c13249..4b8614b5ed0c 100644 --- a/src/SignalR/clients/ts/signalr/tests/HubConnection.test.ts +++ b/src/SignalR/clients/ts/signalr/tests/HubConnection.test.ts @@ -666,8 +666,12 @@ describe("HubConnection", () => { await hubConnection.start(); let count = 0; + const p = new PromiseSource(); const handler = () => { count++; }; - const secondHandler = () => { count++; }; + const secondHandler = () => { + count++; + p.resolve(); + }; hubConnection.on("inc", handler); hubConnection.on("inc", secondHandler); @@ -678,6 +682,7 @@ describe("HubConnection", () => { type: MessageType.Invocation, }); + await p; hubConnection.off("inc"); connection.receive({ @@ -702,8 +707,12 @@ describe("HubConnection", () => { await hubConnection.start(); let count = 0; + let p = new PromiseSource(); const handler = () => { count++; }; - const secondHandler = () => { count++; }; + const secondHandler = () => { + count++; + p.resolve(); + }; hubConnection.on("inc", handler); hubConnection.on("inc", secondHandler); @@ -714,6 +723,8 @@ describe("HubConnection", () => { type: MessageType.Invocation, }); + await p; + p = new PromiseSource(); hubConnection.off("inc", handler); connection.receive({ @@ -723,6 +734,8 @@ describe("HubConnection", () => { type: MessageType.Invocation, }); + await p; + expect(count).toBe(3); } finally { await hubConnection.stop(); @@ -756,7 +769,7 @@ describe("HubConnection", () => { }); }); - it("callback invoked when servers invokes a method on the client", async () => { + it("callback invoked when server invokes a method on the client", async () => { await VerifyLogger.run(async (logger) => { const connection = new TestConnection(); const hubConnection = createHubConnection(connection, logger); @@ -764,7 +777,11 @@ describe("HubConnection", () => { await hubConnection.start(); let value = ""; - hubConnection.on("message", (v) => value = v); + const p = new PromiseSource(); + hubConnection.on("message", (v) => { + value = v; + p.resolve(); + }); connection.receive({ arguments: ["test"], @@ -773,6 +790,7 @@ describe("HubConnection", () => { type: MessageType.Invocation, }); + await p; expect(value).toBe("test"); } finally { await hubConnection.stop(); @@ -891,8 +909,12 @@ describe("HubConnection", () => { let numInvocations1 = 0; let numInvocations2 = 0; + const p = new PromiseSource(); hubConnection.on("message", () => numInvocations1++); - hubConnection.on("message", () => numInvocations2++); + hubConnection.on("message", () => { + numInvocations2++; + p.resolve(); + }); connection.receive({ arguments: [], @@ -901,6 +923,7 @@ describe("HubConnection", () => { type: MessageType.Invocation, }); + await p; expect(numInvocations1).toBe(1); expect(numInvocations2).toBe(1); } finally { @@ -958,7 +981,11 @@ describe("HubConnection", () => { hubConnection.off(eventToTrack, callback1); numInvocations1++; } - const callback2 = () => numInvocations2++; + let p = new PromiseSource(); + const callback2 = () => { + numInvocations2++; + p.resolve(); + }; hubConnection.on(eventToTrack, callback1); hubConnection.on(eventToTrack, callback2); @@ -970,6 +997,8 @@ describe("HubConnection", () => { type: MessageType.Invocation, }); + await p; + p = new PromiseSource(); expect(numInvocations1).toBe(1); expect(numInvocations2).toBe(1); @@ -980,6 +1009,7 @@ describe("HubConnection", () => { type: MessageType.Invocation, }); + await p; expect(numInvocations1).toBe(1); expect(numInvocations2).toBe(2); } @@ -1177,7 +1207,7 @@ describe("HubConnection", () => { }, "A callback for the method 'message' threw error 'Error: from callback'."); }); - it("multiple results only sends last one", async () => { + it("multiple results sends error", async () => { await VerifyLogger.run(async (logger) => { const connection = new TestConnection(); const hubConnection = createHubConnection(connection, logger); @@ -1201,12 +1231,12 @@ describe("HubConnection", () => { expect(connection.parsedSentData.length).toEqual(2); expect(connection.parsedSentData[1].type).toEqual(3); - expect(connection.parsedSentData[1].result).toEqual(4); + expect(connection.parsedSentData[1].error).toEqual('Client provided multiple results.'); expect(connection.parsedSentData[1].invocationId).toEqual("1"); } finally { await hubConnection.stop(); } - }); + }, "Multiple results provided for 'message'. Sending error to server."); }); it("multiple result handlers error from last one sent", async () => { @@ -1334,6 +1364,34 @@ describe("HubConnection", () => { } }); }); + + it("logs error if return result not expected", async () => { + await VerifyLogger.run(async (logger) => { + const connection = new TestConnection(); + const hubConnection = createHubConnection(connection, logger); + try { + await hubConnection.start(); + + hubConnection.on("message", () => 13); + + connection.receive({ + arguments: [], + invocationId: undefined, + nonblocking: true, + target: "message", + type: MessageType.Invocation, + }); + + // nothing to wait on and the code is all synchronous, but because of how JS and async works we need to trigger + // async here to guarantee the sent message is written + await delayUntil(1); + + expect(connection.parsedSentData.length).toEqual(1); + } finally { + await hubConnection.stop(); + } + }, "Result given for 'message' method but server is not expecting a result."); + }); }); describe("stream", () => { diff --git a/src/SignalR/common/Shared/ClientResultsManager.cs b/src/SignalR/common/Shared/ClientResultsManager.cs index 76f153110def..c6273311d802 100644 --- a/src/SignalR/common/Shared/ClientResultsManager.cs +++ b/src/SignalR/common/Shared/ClientResultsManager.cs @@ -13,7 +13,7 @@ namespace Microsoft.AspNetCore.SignalR.Internal; // Handles cancellation, cleanup, and completion, so any bugs or improvements can be made in a single place internal class ClientResultsManager : IInvocationBinder { - private readonly ConcurrentDictionary Completion)> _pendingInvocations = new(); + private readonly ConcurrentDictionary Complete)> _pendingInvocations = new(); public Task AddInvocation(string connectionId, string invocationId, CancellationToken cancellationToken) { @@ -29,7 +29,6 @@ public Task AddInvocation(string connectionId, string invocationId, Cancel { tcs.SetException(new Exception(completionMessage.Error)); } - return Task.CompletedTask; } )); Debug.Assert(result); @@ -39,13 +38,13 @@ public Task AddInvocation(string connectionId, string invocationId, Cancel return tcs.Task; } - public void AddInvocation(string invocationId, (Type Type, string ConnectionId, object Tcs, Func Completion) invocationInfo) + public void AddInvocation(string invocationId, (Type Type, string ConnectionId, object Tcs, Action Complete) invocationInfo) { var result = _pendingInvocations.TryAdd(invocationId, invocationInfo); Debug.Assert(result); } - public Task TryCompleteResult(string connectionId, CompletionMessage message) + public void TryCompleteResult(string connectionId, CompletionMessage message) { if (_pendingInvocations.TryGetValue(message.InvocationId!, out var item)) { @@ -59,17 +58,16 @@ public Task TryCompleteResult(string connectionId, CompletionMessage message) // we'll ignore both cases if (_pendingInvocations.Remove(message.InvocationId!, out _)) { - return item.Completion(item.Tcs, message); + item.Complete(item.Tcs, message); } } else { // connection was disconnected or someone else completed the invocation } - return Task.CompletedTask; } - public (Type Type, string ConnectionId, object Tcs, Func Completion)? RemoveInvocation(string invocationId) + public (Type Type, string ConnectionId, object Tcs, Action Completion)? RemoveInvocation(string invocationId) { _pendingInvocations.Remove(invocationId, out var item); return item; @@ -146,7 +144,7 @@ public void RegisterCancellation() { // TODO: RedisHubLifetimeManager will want to notify the other server (if there is one) about the cancellation // so it can clean up state and potentially forward that info to the connection - _ = _clientResultsManager.TryCompleteResult(_connectionId, CompletionMessage.WithError(_invocationId, "Canceled")); + _clientResultsManager.TryCompleteResult(_connectionId, CompletionMessage.WithError(_invocationId, "Canceled")); } public new void SetResult(T result) diff --git a/src/SignalR/server/Core/src/DefaultHubLifetimeManager.cs b/src/SignalR/server/Core/src/DefaultHubLifetimeManager.cs index d2990fad4b11..08b4d0f632e8 100644 --- a/src/SignalR/server/Core/src/DefaultHubLifetimeManager.cs +++ b/src/SignalR/server/Core/src/DefaultHubLifetimeManager.cs @@ -332,7 +332,7 @@ public override async Task InvokeConnectionAsync(string connectionId, stri if (connection == null) { - throw new InvalidOperationException("Connection does not exist."); + throw new IOException($"Connection '{connectionId}' does not exist."); } var invocationId = Interlocked.Increment(ref _lastInvocationId).ToString(NumberFormatInfo.InvariantInfo); @@ -363,7 +363,7 @@ public override async Task InvokeConnectionAsync(string connectionId, stri // ConnectionAborted will trigger a generic "Canceled" exception from the task, let's convert it into a more specific message. if (connection.ConnectionAborted.IsCancellationRequested) { - throw new Exception("Connection disconnected."); + throw new IOException($"Connection '{connectionId}' disconnected."); } throw; } @@ -372,7 +372,8 @@ public override async Task InvokeConnectionAsync(string connectionId, stri /// public override Task SetConnectionResultAsync(string connectionId, CompletionMessage result) { - return _clientResultsManager.TryCompleteResult(connectionId, result); + _clientResultsManager.TryCompleteResult(connectionId, result); + return Task.CompletedTask; } /// diff --git a/src/SignalR/server/Core/src/Internal/TypedClientBuilder.cs b/src/SignalR/server/Core/src/Internal/TypedClientBuilder.cs index ec8971c4134d..103f8dad8937 100644 --- a/src/SignalR/server/Core/src/Internal/TypedClientBuilder.cs +++ b/src/SignalR/server/Core/src/Internal/TypedClientBuilder.cs @@ -166,7 +166,7 @@ private static void BuildMethod(TypeBuilder type, MethodInfo interfaceMethodInfo generator.Emit(OpCodes.Ldarg_0); generator.Emit(OpCodes.Ldfld, proxyField); - var notTypeLabel = generator.DefineLabel(); + var isTypeLabel = generator.DefineLabel(); if (isInvoke) { var singleClientProxyType = typeof(ISingleClientProxy); @@ -178,8 +178,13 @@ private static void BuildMethod(TypeBuilder type, MethodInfo interfaceMethodInfo throw new InvalidOperationException("InvokeAsync only works with Single clients."); */ generator.Emit(OpCodes.Isinst, singleClientProxyType); - generator.Emit(OpCodes.Brfalse_S, notTypeLabel); + generator.Emit(OpCodes.Brtrue_S, isTypeLabel); + generator.Emit(OpCodes.Ldstr, "InvokeAsync only works with Single clients."); + generator.Emit(OpCodes.Newobj, typeof(InvalidOperationException).GetConstructor(new Type[] { typeof(string) })!); + generator.Emit(OpCodes.Throw); + + generator.MarkLabel(isTypeLabel); generator.Emit(OpCodes.Ldarg_0); generator.Emit(OpCodes.Ldfld, proxyField); generator.Emit(OpCodes.Castclass, singleClientProxyType); @@ -221,12 +226,6 @@ private static void BuildMethod(TypeBuilder type, MethodInfo interfaceMethodInfo generator.Emit(OpCodes.Callvirt, invokeMethod); generator.Emit(OpCodes.Ret); // Return the Task returned by 'invokeMethod' - - // Used by InvokeAsync to check if it's being called with ISingleClientProxy otherwise throws - generator.MarkLabel(notTypeLabel); - generator.Emit(OpCodes.Ldstr, "InvokeAsync only works with Single clients."); - generator.Emit(OpCodes.Newobj, typeof(InvalidOperationException).GetConstructor(new Type[] { typeof(string) })!); - generator.Emit(OpCodes.Throw); } private static void BuildFactoryMethod(TypeBuilder type, ConstructorInfo ctor) diff --git a/src/SignalR/server/Specification.Tests/src/HubLifetimeManagerTestBase.cs b/src/SignalR/server/Specification.Tests/src/HubLifetimeManagerTestBase.cs index ac2acc1f56f4..19b061d1c6bd 100644 --- a/src/SignalR/server/Specification.Tests/src/HubLifetimeManagerTestBase.cs +++ b/src/SignalR/server/Specification.Tests/src/HubLifetimeManagerTestBase.cs @@ -277,7 +277,7 @@ public async Task ConnectionIDNotPresentWhenInvokingClientResult() await manager1.OnConnectedAsync(connection1).DefaultTimeout(); // No client with this ID - await Assert.ThrowsAsync(() => manager1.InvokeConnectionAsync("none", "Result", new object[] { "test" })).DefaultTimeout(); + await Assert.ThrowsAsync(() => manager1.InvokeConnectionAsync("none", "Result", new object[] { "test" })).DefaultTimeout(); } } @@ -334,8 +334,8 @@ public async Task ClientDisconnectsWithoutCompletingClientResult() connection1.Abort(); await manager1.OnDisconnectedAsync(connection1).DefaultTimeout(); - var ex = await Assert.ThrowsAsync(() => invoke1).DefaultTimeout(); - Assert.Equal("Connection disconnected.", ex.Message); + var ex = await Assert.ThrowsAsync(() => invoke1).DefaultTimeout(); + Assert.Equal($"Connection '{connection1.ConnectionId}' disconnected.", ex.Message); } } diff --git a/src/SignalR/server/Specification.Tests/src/ScaleoutHubLifetimeManagerTests.cs b/src/SignalR/server/Specification.Tests/src/ScaleoutHubLifetimeManagerTests.cs index 0c25ce9698ff..5fc1c9637c76 100644 --- a/src/SignalR/server/Specification.Tests/src/ScaleoutHubLifetimeManagerTests.cs +++ b/src/SignalR/server/Specification.Tests/src/ScaleoutHubLifetimeManagerTests.cs @@ -544,7 +544,7 @@ public async Task ConnectionIDNotPresentMultiServerWhenInvokingClientResult() await manager1.OnConnectedAsync(connection1).DefaultTimeout(); // No client on any backplanes with this ID - await Assert.ThrowsAsync(() => manager1.InvokeConnectionAsync("none", "Result", new object[] { "test" })).DefaultTimeout(); + await Assert.ThrowsAsync(() => manager1.InvokeConnectionAsync("none", "Result", new object[] { "test" })).DefaultTimeout(); } } @@ -626,7 +626,7 @@ public async Task ConnectionDoesNotExist_FailsInvokeConnectionAsync() var manager1 = CreateNewHubLifetimeManager(backplane); var manager2 = CreateNewHubLifetimeManager(backplane); - var ex = await Assert.ThrowsAsync(() => manager1.InvokeConnectionAsync("1234", "Result", new object[] { "test" })).DefaultTimeout(); - Assert.Equal("Connection does not exist.", ex.Message); + var ex = await Assert.ThrowsAsync(() => manager1.InvokeConnectionAsync("1234", "Result", new object[] { "test" })).DefaultTimeout(); + Assert.Equal("Connection '1234' does not exist.", ex.Message); } } diff --git a/src/SignalR/server/StackExchangeRedis/src/Internal/RedisLog.cs b/src/SignalR/server/StackExchangeRedis/src/Internal/RedisLog.cs index ca9ead0a5778..ba7217748af2 100644 --- a/src/SignalR/server/StackExchangeRedis/src/Internal/RedisLog.cs +++ b/src/SignalR/server/StackExchangeRedis/src/Internal/RedisLog.cs @@ -52,6 +52,12 @@ public static void ConnectingToEndpoints(ILogger logger, EndPointCollection endp [LoggerMessage(11, LogLevel.Warning, "Error processing message for internal server message.", EventName = "InternalMessageFailed")] public static partial void InternalMessageFailed(ILogger logger, Exception exception); + [LoggerMessage(12, LogLevel.Error, "Received a client result for protocol {HubProtocol} which is not supported by this server. This likely means you have different versions of your server deployed.", EventName = "MismatchedServers")] + public static partial void MismatchedServers(ILogger logger, string hubProtocol); + + [LoggerMessage(13, LogLevel.Error, "Error forwarding client result with ID '{InvocationID}' to server.", EventName = "ErrorForwardingResult")] + public static partial void ErrorForwardingResult(ILogger logger, string invocationId, Exception ex); + // This isn't DefineMessage-based because it's just the simple TextWriter logging from ConnectionMultiplexer public static void ConnectionMultiplexerMessage(ILogger logger, string? message) { diff --git a/src/SignalR/server/StackExchangeRedis/src/RedisHubLifetimeManager.cs b/src/SignalR/server/StackExchangeRedis/src/RedisHubLifetimeManager.cs index d0621031ef88..be33c79ab116 100644 --- a/src/SignalR/server/StackExchangeRedis/src/RedisHubLifetimeManager.cs +++ b/src/SignalR/server/StackExchangeRedis/src/RedisHubLifetimeManager.cs @@ -422,7 +422,7 @@ public override async Task InvokeConnectionAsync(string connectionId, stri var received = await PublishAsync(_channels.Connection(connectionId), m); if (received < 1) { - throw new InvalidOperationException("Connection does not exist."); + throw new IOException($"Connection '{connectionId}' does not exist."); } } else @@ -431,7 +431,7 @@ public override async Task InvokeConnectionAsync(string connectionId, stri // Write message directly to connection without caching it in memory var message = new InvocationMessage(invocationId, methodName, args); - await connection.WriteAsync(message, cancellationToken).AsTask(); + await connection.WriteAsync(message, cancellationToken); } } catch @@ -449,7 +449,7 @@ public override async Task InvokeConnectionAsync(string connectionId, stri // ConnectionAborted will trigger a generic "Canceled" exception from the task, let's convert it into a more specific message. if (connection?.ConnectionAborted.IsCancellationRequested == true) { - throw new Exception("Connection disconnected."); + throw new IOException($"Connection '{connectionId}' disconnected."); } throw; } @@ -458,7 +458,8 @@ public override async Task InvokeConnectionAsync(string connectionId, stri /// public override Task SetConnectionResultAsync(string connectionId, CompletionMessage result) { - return _clientResultsManager.TryCompleteResult(connectionId, result); + _clientResultsManager.TryCompleteResult(connectionId, result); + return Task.CompletedTask; } /// @@ -561,7 +562,7 @@ private async Task SubscribeToConnection(HubConnectionContext connection) { CancellationTokenRegistration? tokenRegistration = null; _clientResultsManager.AddInvocation(invocation.InvocationId, - (typeof(RawResult), connection.ConnectionId, null!, (_, completionMessage) => + (typeof(RawResult), connection.ConnectionId, null!, async (_, completionMessage) => { var protocolName = connection.Protocol.Name; tokenRegistration?.Dispose(); @@ -572,14 +573,17 @@ private async Task SubscribeToConnection(HubConnectionContext connection) connection.Protocol.WriteMessage(completionMessage, memoryBufferWriter); // TODO: we can avoid this ToArray call var message = RedisProtocol.WriteCompletionMessage(new ReadOnlySequence(memoryBufferWriter.ToArray()), protocolName); - return PublishAsync(invocation.ReturnChannel!, message); + await PublishAsync(invocation.ReturnChannel!, message); + } + catch (Exception ex) + { + RedisLog.ErrorForwardingResult(_logger, completionMessage.InvocationId, ex); } finally { memoryBufferWriter.Dispose(); } - } - )); + })); // TODO: this isn't great tokenRegistration = connection.ConnectionAborted.UnsafeRegister(_ => @@ -670,7 +674,14 @@ private async Task SubscribeToReturnResultsAsync() break; } } - Debug.Assert(protocol is not null); + + // Should only happen if you have different versions of servers and don't have the same protocols registered on both + if (protocol is null) + { + RedisLog.MismatchedServers(_logger, completion.ProtocolName); + return; + } + var ros = completion.CompletionMessage; var parseSuccess = protocol.TryParseMessage(ref ros, _clientResultsManager, out var hubMessage); Debug.Assert(parseSuccess); From d077749b109662b71d02a9dcdeb93b07107f4321 Mon Sep 17 00:00:00 2001 From: Brennan Conroy Date: Wed, 13 Apr 2022 09:05:02 -0700 Subject: [PATCH 09/11] Alloc --- .../clients/ts/signalr/src/HubConnection.ts | 24 ++++----- ...re.SignalR.Protocols.NewtonsoftJson.csproj | 1 + .../src/Protocol/NewtonsoftJsonHubProtocol.cs | 22 ++++---- .../common/Shared/ClientResultsManager.cs | 2 +- .../common/Shared/MemoryBufferWriter.cs | 53 ++++++++++++++++++- .../src/Internal/RedisProtocol.cs | 26 +++++---- .../src/RedisHubLifetimeManager.cs | 21 ++++---- .../test/RedisProtocolTests.cs | 7 ++- 8 files changed, 106 insertions(+), 50 deletions(-) diff --git a/src/SignalR/clients/ts/signalr/src/HubConnection.ts b/src/SignalR/clients/ts/signalr/src/HubConnection.ts index 606745c4845f..67c32aef1fa1 100644 --- a/src/SignalR/clients/ts/signalr/src/HubConnection.ts +++ b/src/SignalR/clients/ts/signalr/src/HubConnection.ts @@ -692,18 +692,18 @@ export class HubConnection { const methodsCopy = methods.slice(); // Server expects a response - const expects_response = invocationMessage.invocationId ? true : false; + const expectsResponse = invocationMessage.invocationId ? true : false; // We preserve the last result or exception but still call all handlers let res; let exception; - let completion_message; + let completionMessage; for (const m of methodsCopy) { try { - const prev_res = res; + const prevRes = res; res = await m.apply(this, invocationMessage.arguments); - if (expects_response && res && prev_res) { + if (expectsResponse && res && prevRes) { this._logger.log(LogLevel.Error, `Multiple results provided for '${methodName}'. Sending error to server.`); - completion_message = this._createCompletionMessage(invocationMessage.invocationId!, `Client provided multiple results.`, null); + completionMessage = this._createCompletionMessage(invocationMessage.invocationId!, `Client provided multiple results.`, null); } // Ignore exception if we got a result after, the exception will be logged exception = undefined; @@ -712,20 +712,20 @@ export class HubConnection { this._logger.log(LogLevel.Error, `A callback for the method '${methodName}' threw error '${e}'.`); } } - if (completion_message) { - await this._sendWithProtocol(completion_message); - } else if (expects_response) { + if (completionMessage) { + await this._sendWithProtocol(completionMessage); + } else if (expectsResponse) { // If there is an exception that means either no result was given or a handler after a result threw if (exception) { - completion_message = this._createCompletionMessage(invocationMessage.invocationId!, `${exception}`, null); + completionMessage = this._createCompletionMessage(invocationMessage.invocationId!, `${exception}`, null); } else if (res !== undefined) { - completion_message = this._createCompletionMessage(invocationMessage.invocationId!, null, res); + completionMessage = this._createCompletionMessage(invocationMessage.invocationId!, null, res); } else { this._logger.log(LogLevel.Warning, `No result given for '${methodName}' method and invocation ID '${invocationMessage.invocationId}'.`); // Client didn't provide a result or throw from a handler, server expects a response so we send an error - completion_message = this._createCompletionMessage(invocationMessage.invocationId!, "Client didn't provide a result.", null); + completionMessage = this._createCompletionMessage(invocationMessage.invocationId!, "Client didn't provide a result.", null); } - await this._sendWithProtocol(completion_message); + await this._sendWithProtocol(completionMessage); } else { if (res) { this._logger.log(LogLevel.Error, `Result given for '${methodName}' method but server is not expecting a result.`); diff --git a/src/SignalR/common/Protocols.NewtonsoftJson/src/Microsoft.AspNetCore.SignalR.Protocols.NewtonsoftJson.csproj b/src/SignalR/common/Protocols.NewtonsoftJson/src/Microsoft.AspNetCore.SignalR.Protocols.NewtonsoftJson.csproj index 51d392697c8d..2167aa456d47 100644 --- a/src/SignalR/common/Protocols.NewtonsoftJson/src/Microsoft.AspNetCore.SignalR.Protocols.NewtonsoftJson.csproj +++ b/src/SignalR/common/Protocols.NewtonsoftJson/src/Microsoft.AspNetCore.SignalR.Protocols.NewtonsoftJson.csproj @@ -14,6 +14,7 @@ + diff --git a/src/SignalR/common/Protocols.NewtonsoftJson/src/Protocol/NewtonsoftJsonHubProtocol.cs b/src/SignalR/common/Protocols.NewtonsoftJson/src/Protocol/NewtonsoftJsonHubProtocol.cs index 89e4b03edaa4..555db0092e5b 100644 --- a/src/SignalR/common/Protocols.NewtonsoftJson/src/Protocol/NewtonsoftJsonHubProtocol.cs +++ b/src/SignalR/common/Protocols.NewtonsoftJson/src/Protocol/NewtonsoftJsonHubProtocol.cs @@ -218,22 +218,20 @@ public ReadOnlyMemory GetMessageBytes(HubMessage message) if (returnType == typeof(RawResult)) { var token = JToken.Load(reader); - using var strm = new MemoryStream(); - using var writer = new StreamWriter(strm); - using var jsonTextWriter = new JsonTextWriter(writer); - token.WriteTo(jsonTextWriter); - jsonTextWriter.Flush(); - writer.Flush(); - Memory buf; - if (strm.TryGetBuffer(out var segment)) + var strm = MemoryBufferWriter.Get(); + try { - buf = segment.Array.AsMemory(segment.Offset, segment.Count); + using var writer = new StreamWriter(strm); + using var jsonTextWriter = new JsonTextWriter(writer); + token.WriteTo(jsonTextWriter); + jsonTextWriter.Flush(); + writer.Flush(); + result = new RawResult(new ReadOnlySequence(strm.ToArray())); } - else + finally { - buf = strm.ToArray(); + MemoryBufferWriter.Return(strm); } - result = new RawResult(new ReadOnlySequence(buf)); } else { diff --git a/src/SignalR/common/Shared/ClientResultsManager.cs b/src/SignalR/common/Shared/ClientResultsManager.cs index c6273311d802..68be58269ebc 100644 --- a/src/SignalR/common/Shared/ClientResultsManager.cs +++ b/src/SignalR/common/Shared/ClientResultsManager.cs @@ -69,7 +69,7 @@ public void TryCompleteResult(string connectionId, CompletionMessage message) public (Type Type, string ConnectionId, object Tcs, Action Completion)? RemoveInvocation(string invocationId) { - _pendingInvocations.Remove(invocationId, out var item); + _pendingInvocations.TryRemove(invocationId, out var item); return item; } diff --git a/src/SignalR/common/Shared/MemoryBufferWriter.cs b/src/SignalR/common/Shared/MemoryBufferWriter.cs index ca4dc6407ad6..f1c198ba1cf6 100644 --- a/src/SignalR/common/Shared/MemoryBufferWriter.cs +++ b/src/SignalR/common/Shared/MemoryBufferWriter.cs @@ -326,6 +326,30 @@ public override void Write(ReadOnlySpan span) } #endif + public WrittenBuffers DetachAndReset() + { + WrittenBuffers written; + + if (_currentSegment != null) + { + if (_completedSegments is null) + { + _completedSegments = new List(); + } + _completedSegments.Add(new CompletedBuffer(_currentSegment, _position)); + } + + _completedSegments ??= new List(); + written = new WrittenBuffers(_completedSegments, _bytesWritten); + + _currentSegment = null; + _completedSegments = null; + _bytesWritten = 0; + _position = 0; + + return written; + } + protected override void Dispose(bool disposing) { if (disposing) @@ -334,10 +358,37 @@ protected override void Dispose(bool disposing) } } + /// + /// Holds the written segments from a MemoryBufferWriter and is no longer attached to a MemoryBufferWriter. + /// You are now responsible for calling Dispose on this type to return the memory to the pool. + /// + internal readonly ref struct WrittenBuffers + { + public readonly List Segments; + private readonly int _bytesWritten; + + public WrittenBuffers(List segments, int bytesWritten) + { + Segments = segments; + _bytesWritten = bytesWritten; + } + + public int ByteLength => _bytesWritten; + + public void Dispose() + { + for (var i = 0; i < Segments.Count; i++) + { + Segments[i].Return(); + } + Segments.Clear(); + } + } + /// /// Holds a byte[] from the pool and a size value. Basically a Memory but guaranteed to be backed by an ArrayPool byte[], so that we know we can return it. /// - private readonly struct CompletedBuffer + internal readonly struct CompletedBuffer { public byte[] Buffer { get; } public int Length { get; } diff --git a/src/SignalR/server/StackExchangeRedis/src/Internal/RedisProtocol.cs b/src/SignalR/server/StackExchangeRedis/src/Internal/RedisProtocol.cs index 51f3512e0d43..7a60143c97d4 100644 --- a/src/SignalR/server/StackExchangeRedis/src/Internal/RedisProtocol.cs +++ b/src/SignalR/server/StackExchangeRedis/src/Internal/RedisProtocol.cs @@ -143,30 +143,28 @@ public static byte[] WriteAck(int messageId) } } - public static byte[] WriteCompletionMessage(ReadOnlySequence completionMessage, string protocolName) + public static byte[] WriteCompletionMessage(MemoryBufferWriter writer, string protocolName) { // Written as a MessagePack 'arr' containing at least these items: // * A 'str': The name of the HubProtocol used for the serialization of the Completion Message // * [A serialized Completion Message which is a 'bin'] // Any additional items are discarded. - var memoryBufferWriter = MemoryBufferWriter.Get(); - try - { - var writer = new MessagePackWriter(memoryBufferWriter); - - writer.WriteArrayHeader(2); - writer.Write(protocolName); - writer.Write(completionMessage); + var completionMessage = writer.DetachAndReset(); + var msgPackWriter = new MessagePackWriter(writer); - writer.Flush(); + msgPackWriter.WriteArrayHeader(2); + msgPackWriter.Write(protocolName); - return memoryBufferWriter.ToArray(); - } - finally + msgPackWriter.WriteBinHeader(completionMessage.ByteLength); + foreach (var segment in completionMessage.Segments) { - MemoryBufferWriter.Return(memoryBufferWriter); + msgPackWriter.WriteRaw(segment.Span); } + completionMessage.Dispose(); + + msgPackWriter.Flush(); + return writer.ToArray(); } public static RedisInvocation ReadInvocation(ReadOnlyMemory data) diff --git a/src/SignalR/server/StackExchangeRedis/src/RedisHubLifetimeManager.cs b/src/SignalR/server/StackExchangeRedis/src/RedisHubLifetimeManager.cs index be33c79ab116..6af795083e08 100644 --- a/src/SignalR/server/StackExchangeRedis/src/RedisHubLifetimeManager.cs +++ b/src/SignalR/server/StackExchangeRedis/src/RedisHubLifetimeManager.cs @@ -557,7 +557,7 @@ private async Task SubscribeToConnection(HubConnectionContext connection) { var invocation = RedisProtocol.ReadInvocation((byte[])channelMessage.Message); - // This is a Client result we need to setup state for the completion and send the message to the client + // This is a Client result we need to setup state for the completion and forward the message to the client if (!string.IsNullOrEmpty(invocation.InvocationId)) { CancellationTokenRegistration? tokenRegistration = null; @@ -566,23 +566,26 @@ private async Task SubscribeToConnection(HubConnectionContext connection) { var protocolName = connection.Protocol.Name; tokenRegistration?.Dispose(); - // TODO: acquiring this and then calling RedisProtocol.WriteCompletionMessage will allocate a new MemoryBufferWriter, we can avoid this + var memoryBufferWriter = AspNetCore.Internal.MemoryBufferWriter.Get(); + byte[] message; try { - connection.Protocol.WriteMessage(completionMessage, memoryBufferWriter); - // TODO: we can avoid this ToArray call - var message = RedisProtocol.WriteCompletionMessage(new ReadOnlySequence(memoryBufferWriter.ToArray()), protocolName); + try + { + connection.Protocol.WriteMessage(completionMessage, memoryBufferWriter); + message = RedisProtocol.WriteCompletionMessage(memoryBufferWriter, protocolName); + } + finally + { + memoryBufferWriter.Dispose(); + } await PublishAsync(invocation.ReturnChannel!, message); } catch (Exception ex) { RedisLog.ErrorForwardingResult(_logger, completionMessage.InvocationId, ex); } - finally - { - memoryBufferWriter.Dispose(); - } })); // TODO: this isn't great diff --git a/src/SignalR/server/StackExchangeRedis/test/RedisProtocolTests.cs b/src/SignalR/server/StackExchangeRedis/test/RedisProtocolTests.cs index b87ee91e564d..0a9d56a233f5 100644 --- a/src/SignalR/server/StackExchangeRedis/test/RedisProtocolTests.cs +++ b/src/SignalR/server/StackExchangeRedis/test/RedisProtocolTests.cs @@ -6,6 +6,7 @@ using System.Collections.Generic; using System.Linq; using System.Text; +using Microsoft.AspNetCore.Internal; using Microsoft.AspNetCore.SignalR.Internal; using Microsoft.AspNetCore.SignalR.Protocol; using Microsoft.AspNetCore.SignalR.StackExchangeRedis.Internal; @@ -264,7 +265,11 @@ public void WriteCompletionMessage(string testName) { var testData = _completionMessageTestData[testName]; - var encoded = RedisProtocol.WriteCompletionMessage(testData.Decoded.CompletionMessage, testData.Decoded.ProtocolName); + var writer = MemoryBufferWriter.Get(); + writer.Write(testData.Decoded.CompletionMessage.ToArray()); + + var encoded = RedisProtocol.WriteCompletionMessage(writer, testData.Decoded.ProtocolName); + MemoryBufferWriter.Return(writer); Assert.Equal(testData.Encoded, encoded); } From 5951fa2e4c570882c921c7b408b32504ad7794d1 Mon Sep 17 00:00:00 2001 From: Brennan Date: Mon, 18 Apr 2022 20:40:16 -0700 Subject: [PATCH 10/11] fb --- .../src/HubConnectionExtensions.OnResult.cs | 2 +- .../UnitTests/HubConnectionTests.Protocol.cs | 25 ++++++++++ .../src/Protocol/JsonHubProtocol.cs | 3 +- .../src/Protocol/NewtonsoftJsonHubProtocol.cs | 49 +++++++------------ .../common/Shared/MemoryBufferWriter.cs | 11 ++--- .../Core/src/Internal/DefaultHubDispatcher.cs | 1 - .../Core/src/Internal/TypedClientBuilder.cs | 2 +- 7 files changed, 51 insertions(+), 42 deletions(-) diff --git a/src/SignalR/clients/csharp/Client.Core/src/HubConnectionExtensions.OnResult.cs b/src/SignalR/clients/csharp/Client.Core/src/HubConnectionExtensions.OnResult.cs index 878ac788681e..d8167d9707ff 100644 --- a/src/SignalR/clients/csharp/Client.Core/src/HubConnectionExtensions.OnResult.cs +++ b/src/SignalR/clients/csharp/Client.Core/src/HubConnectionExtensions.OnResult.cs @@ -31,7 +31,7 @@ private static IDisposable On(this HubConnection hubConnection, string /// A subscription that can be disposed to unsubscribe from the hub method. public static IDisposable On(this HubConnection hubConnection, string methodName, Type[] parameterTypes, Func> handler) { - return hubConnection.On(methodName, parameterTypes, async (parameters, state) => + return hubConnection.On(methodName, parameterTypes, static async (parameters, state) => { var currentHandler = (Func>)state; return await currentHandler(parameters).ConfigureAwait(false); diff --git a/src/SignalR/clients/csharp/Client/test/UnitTests/HubConnectionTests.Protocol.cs b/src/SignalR/clients/csharp/Client/test/UnitTests/HubConnectionTests.Protocol.cs index da3621b9c34e..caf666292686 100644 --- a/src/SignalR/clients/csharp/Client/test/UnitTests/HubConnectionTests.Protocol.cs +++ b/src/SignalR/clients/csharp/Client/test/UnitTests/HubConnectionTests.Protocol.cs @@ -851,5 +851,30 @@ public async Task ClientResultReturnsErrorIfNoResultFromClient() await connection.DisposeAsync().DefaultTimeout(); } } + + [Fact] + public async Task ClientResultCanReturnNullResult() + { + var connection = new TestConnection(); + var hubConnection = CreateHubConnection(connection); + try + { + await hubConnection.StartAsync().DefaultTimeout(); + + // No result provided + hubConnection.On("Result", object () => null); + + await connection.ReceiveTextAsync("{\"type\":1,\"invocationId\":\"1\",\"target\":\"Result\",\"arguments\":[]}\u001e").DefaultTimeout(); + + var invokeMessage = await connection.ReadSentTextMessageAsync().DefaultTimeout(); + + Assert.Equal("{\"type\":3,\"invocationId\":\"1\",\"result\":null}", invokeMessage); + } + finally + { + await hubConnection.DisposeAsync().DefaultTimeout(); + await connection.DisposeAsync().DefaultTimeout(); + } + } } } diff --git a/src/SignalR/common/Protocols.Json/src/Protocol/JsonHubProtocol.cs b/src/SignalR/common/Protocols.Json/src/Protocol/JsonHubProtocol.cs index 4fa2f313945d..4f288b1fb5e7 100644 --- a/src/SignalR/common/Protocols.Json/src/Protocol/JsonHubProtocol.cs +++ b/src/SignalR/common/Protocols.Json/src/Protocol/JsonHubProtocol.cs @@ -738,8 +738,9 @@ private static HubMessage BindInvocationMessage(string? invocationId, string tar reader.Skip(); var end = reader.BytesConsumed; var sequence = input.Slice(start, end - start); - // Review: Technically we could pass the sequence without copying into a new array + // Review: Technically the sequence doesn't need to be copied to a new array in RawResult // but in the future we could break this if we dispatched the CompletionMessage and the underlying Pipe read would be advanced + // instead we could try pooling in RawResult, but it would need release/dispose semantics return new RawResult(sequence); } return BindType(ref reader, type); diff --git a/src/SignalR/common/Protocols.NewtonsoftJson/src/Protocol/NewtonsoftJsonHubProtocol.cs b/src/SignalR/common/Protocols.NewtonsoftJson/src/Protocol/NewtonsoftJsonHubProtocol.cs index 555db0092e5b..51eb9310dd99 100644 --- a/src/SignalR/common/Protocols.NewtonsoftJson/src/Protocol/NewtonsoftJsonHubProtocol.cs +++ b/src/SignalR/common/Protocols.NewtonsoftJson/src/Protocol/NewtonsoftJsonHubProtocol.cs @@ -218,20 +218,7 @@ public ReadOnlyMemory GetMessageBytes(HubMessage message) if (returnType == typeof(RawResult)) { var token = JToken.Load(reader); - var strm = MemoryBufferWriter.Get(); - try - { - using var writer = new StreamWriter(strm); - using var jsonTextWriter = new JsonTextWriter(writer); - token.WriteTo(jsonTextWriter); - jsonTextWriter.Flush(); - writer.Flush(); - result = new RawResult(new ReadOnlySequence(strm.ToArray())); - } - finally - { - MemoryBufferWriter.Return(strm); - } + result = GetRawResult(token); } else { @@ -412,22 +399,7 @@ public ReadOnlyMemory GetMessageBytes(HubMessage message) var returnType = binder.GetReturnType(invocationId); if (returnType == typeof(RawResult)) { - using var strm = new MemoryStream(); - using var writer = new StreamWriter(strm); - using var jsonTextWriter = new JsonTextWriter(writer); - resultToken.WriteTo(jsonTextWriter); - jsonTextWriter.Flush(); - writer.Flush(); - Memory buf; - if (strm.TryGetBuffer(out var segment)) - { - buf = segment.Array.AsMemory(segment.Offset, segment.Count); - } - else - { - buf = strm.ToArray(); - } - result = new RawResult(new ReadOnlySequence(buf)); + result = GetRawResult(resultToken); } else { @@ -878,6 +850,23 @@ private static HubMessage ApplyHeaders(HubMessage message, Dictionary(strm.ToArray())); + } + finally + { + MemoryBufferWriter.Return(strm); + } + } internal static JsonSerializerSettings CreateDefaultSerializerSettings() { return new JsonSerializerSettings { ContractResolver = new CamelCasePropertyNamesContractResolver() }; diff --git a/src/SignalR/common/Shared/MemoryBufferWriter.cs b/src/SignalR/common/Shared/MemoryBufferWriter.cs index f1c198ba1cf6..999c3143b4e6 100644 --- a/src/SignalR/common/Shared/MemoryBufferWriter.cs +++ b/src/SignalR/common/Shared/MemoryBufferWriter.cs @@ -328,19 +328,14 @@ public override void Write(ReadOnlySpan span) public WrittenBuffers DetachAndReset() { - WrittenBuffers written; + _completedSegments ??= new List(); - if (_currentSegment != null) + if (_currentSegment is not null) { - if (_completedSegments is null) - { - _completedSegments = new List(); - } _completedSegments.Add(new CompletedBuffer(_currentSegment, _position)); } - _completedSegments ??= new List(); - written = new WrittenBuffers(_completedSegments, _bytesWritten); + var written = new WrittenBuffers(_completedSegments, _bytesWritten); _currentSegment = null; _completedSegments = null; diff --git a/src/SignalR/server/Core/src/Internal/DefaultHubDispatcher.cs b/src/SignalR/server/Core/src/Internal/DefaultHubDispatcher.cs index c61c6687ccfe..f4f299cd80b6 100644 --- a/src/SignalR/server/Core/src/Internal/DefaultHubDispatcher.cs +++ b/src/SignalR/server/Core/src/Internal/DefaultHubDispatcher.cs @@ -174,7 +174,6 @@ public override Task DispatchMessageAsync(HubConnectionContext connection, HubMe { Log.CompletingStream(_logger, completionMessage); } - // TODO: this relies on the lifetime manager keeping state for the return type after deserializing the message, is that ok? // InvocationId is always required on CompletionMessage, it's nullable because of the base type else if (_hubLifetimeManager.TryGetReturnType(completionMessage.InvocationId!, out _)) { diff --git a/src/SignalR/server/Core/src/Internal/TypedClientBuilder.cs b/src/SignalR/server/Core/src/Internal/TypedClientBuilder.cs index 103f8dad8937..69be30e9c761 100644 --- a/src/SignalR/server/Core/src/Internal/TypedClientBuilder.cs +++ b/src/SignalR/server/Core/src/Internal/TypedClientBuilder.cs @@ -269,7 +269,7 @@ private static void VerifyInterface(Type interfaceType) private static void VerifyMethod(MethodInfo interfaceMethod) { - if (interfaceMethod.ReturnType != typeof(Task) && interfaceMethod.ReturnType?.BaseType != typeof(Task)) + if (!typeof(Task).IsAssignableFrom(interfaceMethod.ReturnType)) { throw new InvalidOperationException( $"Cannot generate proxy implementation for '{typeof(T).FullName}.{interfaceMethod.Name}'. All client proxy methods must return '{typeof(Task).FullName}' or '{typeof(Task).FullName}'."); From 9eda15593fe7f7ce7c9abc079e577de6028c0375 Mon Sep 17 00:00:00 2001 From: Brennan Conroy Date: Tue, 19 Apr 2022 09:20:35 -0700 Subject: [PATCH 11/11] minor update --- .../Protocols.Json/src/Protocol/JsonHubProtocol.cs | 11 +++++++++-- 1 file changed, 9 insertions(+), 2 deletions(-) diff --git a/src/SignalR/common/Protocols.Json/src/Protocol/JsonHubProtocol.cs b/src/SignalR/common/Protocols.Json/src/Protocol/JsonHubProtocol.cs index 4f288b1fb5e7..bed3c046ffc5 100644 --- a/src/SignalR/common/Protocols.Json/src/Protocol/JsonHubProtocol.cs +++ b/src/SignalR/common/Protocols.Json/src/Protocol/JsonHubProtocol.cs @@ -537,8 +537,15 @@ private void WriteCompletionMessage(CompletionMessage message, Utf8JsonWriter wr { if (message.Result is RawResult result) { - Debug.Assert(result.RawSerializedData.IsSingleSegment); - writer.WriteRawValue(result.RawSerializedData.First.Span, skipInputValidation: true); + if (result.RawSerializedData.IsSingleSegment) + { + writer.WriteRawValue(result.RawSerializedData.First.Span, skipInputValidation: true); + } + else + { + // https://github.com/dotnet/runtime/issues/68223 + writer.WriteRawValue(result.RawSerializedData.ToArray(), skipInputValidation: true); + } } else {