diff --git a/src/SignalR/clients/csharp/Client.Core/src/HubConnection.Log.cs b/src/SignalR/clients/csharp/Client.Core/src/HubConnection.Log.cs index 43f77ae3c86a..3f66fd3af762 100644 --- a/src/SignalR/clients/csharp/Client.Core/src/HubConnection.Log.cs +++ b/src/SignalR/clients/csharp/Client.Core/src/HubConnection.Log.cs @@ -310,5 +310,11 @@ public static void ErrorHandshakeTimedOut(ILogger logger, TimeSpan handshakeTime [LoggerMessage(84, LogLevel.Trace, "Client threw an error for stream '{StreamId}'.", EventName = "ErroredStream")] public static partial void ErroredStream(ILogger logger, string streamId, Exception exception); + + [LoggerMessage(85, LogLevel.Warning, "Failed to find a value returning handler for '{Target}' method. Sending error to server.", EventName = "MissingResultHandler")] + public static partial void MissingResultHandler(ILogger logger, string target); + + [LoggerMessage(86, LogLevel.Warning, "Result given for '{Target}' method but server is not expecting a result.", EventName = "ResultNotExpected")] + public static partial void ResultNotExpected(ILogger logger, string target); } } diff --git a/src/SignalR/clients/csharp/Client.Core/src/HubConnection.cs b/src/SignalR/clients/csharp/Client.Core/src/HubConnection.cs index 46101e410213..8e90b23fa48a 100644 --- a/src/SignalR/clients/csharp/Client.Core/src/HubConnection.cs +++ b/src/SignalR/clients/csharp/Client.Core/src/HubConnection.cs @@ -259,7 +259,7 @@ private async Task StartAsyncInner(CancellationToken cancellationToken = default throw new InvalidOperationException($"The {nameof(HubConnection)} cannot be started while {nameof(StopAsync)} is running."); } - using (CreateLinkedToken(cancellationToken, _state.StopCts.Token, out var linkedToken)) + using (CancellationTokenUtils.CreateLinkedToken(cancellationToken, _state.StopCts.Token, out var linkedToken)) { await StartAsyncCore(linkedToken).ConfigureAwait(false); } @@ -312,6 +312,39 @@ public virtual async ValueTask DisposeAsync() } } + /// + /// Registers a handler that will be invoked when the hub method with the specified method name is invoked. + /// Returns value returned by handler to server if the server requests a result. + /// + /// The name of the hub method to define. + /// The parameters types expected by the hub method. + /// The handler that will be raised when the hub method is invoked. + /// A state object that will be passed to the handler. + /// A subscription that can be disposed to unsubscribe from the hub method. + /// + /// This is a low level method for registering a handler. Using an On extension method is recommended. + /// + public virtual IDisposable On(string methodName, Type[] parameterTypes, Func> handler, object state) + { + Log.RegisteringHandler(_logger, methodName); + + CheckDisposed(); + + // It's OK to be disposed while registering a callback, we'll just never call the callback anyway (as with all the callbacks registered before disposal). + var invocationHandler = new InvocationHandler(parameterTypes, handler, state); + var invocationList = _handlers.AddOrUpdate(methodName, _ => new InvocationHandlerList(invocationHandler), + (_, invocations) => + { + lock (invocations) + { + invocations.Add(methodName, invocationHandler); + } + return invocations; + }); + + return new Subscription(invocationHandler, invocationList); + } + // If the registered callback blocks it can cause the client to stop receiving messages. If you need to block, get off the current thread first. /// /// Registers a handler that will be invoked when the hub method with the specified method name is invoked. @@ -337,7 +370,7 @@ public virtual IDisposable On(string methodName, Type[] parameterTypes, Func resultTask) + { + result = await resultTask.ConfigureAwait(false); + hasResult = true; + } + else + { + await task.ConfigureAwait(false); + } } catch (Exception ex) { Log.ErrorInvokingClientSideMethod(_logger, invocation.Target, ex); + if (handler.HasResult) + { + resultException = ex; + } + } + } + + if (expectsResult) + { + if (resultException is not null) + { + await SendWithLock(connectionState, CompletionMessage.WithError(invocation.InvocationId!, resultException.Message), cancellationToken: default).ConfigureAwait(false); } + else if (hasResult) + { + await SendWithLock(connectionState, CompletionMessage.WithResult(invocation.InvocationId!, result), cancellationToken: default).ConfigureAwait(false); + } + else + { + Log.MissingResultHandler(_logger, invocation.Target); + await SendWithLock(connectionState, CompletionMessage.WithError(invocation.InvocationId!, "Client didn't provide a result."), cancellationToken: default).ConfigureAwait(false); + } + } + else if (hasResult) + { + Log.ResultNotExpected(_logger, invocation.Target); } } @@ -1073,7 +1152,7 @@ private async Task HandshakeAsync(ConnectionState startingConnectionState, Cance try { // cancellationToken already contains _state.StopCts.Token, so we don't have to link it again - using (CreateLinkedToken(cancellationToken, handshakeCts.Token, out var linkedToken)) + using (CancellationTokenUtils.CreateLinkedToken(cancellationToken, handshakeCts.Token, out var linkedToken)) { while (true) { @@ -1178,7 +1257,7 @@ async Task StartProcessingInvocationMessages(ChannelReader in { while (invocationMessageChannelReader.TryRead(out var invocationMessage)) { - await DispatchInvocationAsync(invocationMessage).ConfigureAwait(false); + await DispatchInvocationAsync(invocationMessage, connectionState).ConfigureAwait(false); } } } @@ -1562,26 +1641,6 @@ async Task RunReconnectedEventAsync() } } - private static IDisposable? CreateLinkedToken(CancellationToken token1, CancellationToken token2, out CancellationToken linkedToken) - { - if (!token1.CanBeCanceled) - { - linkedToken = token2; - return null; - } - else if (!token2.CanBeCanceled) - { - linkedToken = token1; - return null; - } - else - { - var cts = CancellationTokenSource.CreateLinkedTokenSource(token1, token2); - linkedToken = cts.Token; - return cts; - } - } - // Debug.Assert plays havoc with Unit Tests. But I want something that I can "assert" only in Debug builds. [Conditional("DEBUG")] private static void SafeAssert(bool condition, string message, [CallerMemberName] string? memberName = null, [CallerFilePath] string? fileName = null, [CallerLineNumber] int lineNumber = 0) @@ -1639,10 +1698,20 @@ internal InvocationHandler[] GetHandlers() return handlers; } - internal void Add(InvocationHandler handler) + internal void Add(string methodName, InvocationHandler handler) { lock (_invocationHandlers) { + if (handler.HasResult) + { + foreach (var m in _invocationHandlers) + { + if (m.HasResult) + { + throw new InvalidOperationException($"'{methodName}' already has a value returning handler. Multiple return values are not supported."); + } + } + } _invocationHandlers.Add(handler); _copiedHandlers = null; } @@ -1663,6 +1732,7 @@ internal void Remove(InvocationHandler handler) private readonly struct InvocationHandler { public Type[] ParameterTypes { get; } + public bool HasResult => _callback.Method.ReturnType == typeof(Task); private readonly Func _callback; private readonly object _state; diff --git a/src/SignalR/clients/csharp/Client.Core/src/HubConnectionExtensions.OnResult.cs b/src/SignalR/clients/csharp/Client.Core/src/HubConnectionExtensions.OnResult.cs new file mode 100644 index 000000000000..d8167d9707ff --- /dev/null +++ b/src/SignalR/clients/csharp/Client.Core/src/HubConnectionExtensions.OnResult.cs @@ -0,0 +1,486 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. + +using System; +using System.Collections.Generic; +using System.Linq; +using System.Text; +using System.Threading.Tasks; + +namespace Microsoft.AspNetCore.SignalR.Client; +public static partial class HubConnectionExtensions +{ + private static IDisposable On(this HubConnection hubConnection, string methodName, Type[] parameterTypes, Func handler) + { + return hubConnection.On(methodName, parameterTypes, static (parameters, state) => + { + var currentHandler = (Func)state; + return Task.FromResult(currentHandler(parameters)); + }, handler); + } + + /// + /// Registers a handler that will be invoked when the hub method with the specified method name is invoked. + /// Returns value returned by handler to server if the server requests a result. + /// + /// The return type the handler returns. + /// The hub connection. + /// The name of the hub method to define. + /// The parameters types expected by the hub method. + /// The handler that will be raised when the hub method is invoked. + /// A subscription that can be disposed to unsubscribe from the hub method. + public static IDisposable On(this HubConnection hubConnection, string methodName, Type[] parameterTypes, Func> handler) + { + return hubConnection.On(methodName, parameterTypes, static async (parameters, state) => + { + var currentHandler = (Func>)state; + return await currentHandler(parameters).ConfigureAwait(false); + }, handler); + } + + /// + /// Registers a handler that will be invoked when the hub method with the specified method name is invoked. + /// Returns value returned by handler to server if the server requests a result. + /// + /// The return type the handler returns. + /// The hub connection. + /// The name of the hub method to define. + /// The handler that will be raised when the hub method is invoked. + /// A subscription that can be disposed to unsubscribe from the hub method. + public static IDisposable On(this HubConnection hubConnection, string methodName, Func> handler) + { + if (hubConnection == null) + { + throw new ArgumentNullException(nameof(hubConnection)); + } + + return hubConnection.On(methodName, Type.EmptyTypes, args => handler()); + } + + /// + /// Registers a handler that will be invoked when the hub method with the specified method name is invoked. + /// Returns value returned by handler to server if the server requests a result. + /// + /// The return type the handler returns. + /// The hub connection. + /// The name of the hub method to define. + /// The handler that will be raised when the hub method is invoked. + /// A subscription that can be disposed to unsubscribe from the hub method. + public static IDisposable On(this HubConnection hubConnection, string methodName, Func handler) + { + if (hubConnection == null) + { + throw new ArgumentNullException(nameof(hubConnection)); + } + + return hubConnection.On(methodName, Type.EmptyTypes, args => handler()); + } + + /// + /// Registers a handler that will be invoked when the hub method with the specified method name is invoked. + /// Returns value returned by handler to server if the server requests a result. + /// + /// The first argument type. + /// The return type the handler returns. + /// The hub connection. + /// The name of the hub method to define. + /// The handler that will be raised when the hub method is invoked. + /// A subscription that can be disposed to unsubscribe from the hub method. + public static IDisposable On(this HubConnection hubConnection, string methodName, Func handler) + { + if (hubConnection == null) + { + throw new ArgumentNullException(nameof(hubConnection)); + } + + return hubConnection.On(methodName, + new[] { typeof(T1) }, + args => handler((T1)args[0]!)); + } + + /// + /// Registers a handler that will be invoked when the hub method with the specified method name is invoked. + /// Returns value returned by handler to server if the server requests a result. + /// + /// The first argument type. + /// The second argument type. + /// The return type the handler returns. + /// The hub connection. + /// The name of the hub method to define. + /// The handler that will be raised when the hub method is invoked. + /// A subscription that can be disposed to unsubscribe from the hub method. + public static IDisposable On(this HubConnection hubConnection, string methodName, Func handler) + { + if (hubConnection == null) + { + throw new ArgumentNullException(nameof(hubConnection)); + } + + return hubConnection.On(methodName, + new[] { typeof(T1), typeof(T2) }, + args => handler((T1)args[0]!, (T2)args[1]!)); + } + + /// + /// Registers a handler that will be invoked when the hub method with the specified method name is invoked. + /// Returns value returned by handler to server if the server requests a result. + /// + /// The first argument type. + /// The second argument type. + /// The third argument type. + /// The return type the handler returns. + /// The hub connection. + /// The name of the hub method to define. + /// The handler that will be raised when the hub method is invoked. + /// A subscription that can be disposed to unsubscribe from the hub method. + public static IDisposable On(this HubConnection hubConnection, string methodName, Func handler) + { + if (hubConnection == null) + { + throw new ArgumentNullException(nameof(hubConnection)); + } + + return hubConnection.On(methodName, + new[] { typeof(T1), typeof(T2), typeof(T3) }, + args => handler((T1)args[0]!, (T2)args[1]!, (T3)args[2]!)); + } + + /// + /// Registers a handler that will be invoked when the hub method with the specified method name is invoked. + /// Returns value returned by handler to server if the server requests a result. + /// + /// The first argument type. + /// The second argument type. + /// The third argument type. + /// The fourth argument type. + /// The return type the handler returns. + /// The hub connection. + /// The name of the hub method to define. + /// The handler that will be raised when the hub method is invoked. + /// A subscription that can be disposed to unsubscribe from the hub method. + public static IDisposable On(this HubConnection hubConnection, string methodName, Func handler) + { + if (hubConnection == null) + { + throw new ArgumentNullException(nameof(hubConnection)); + } + + return hubConnection.On(methodName, + new[] { typeof(T1), typeof(T2), typeof(T3), typeof(T4) }, + args => handler((T1)args[0]!, (T2)args[1]!, (T3)args[2]!, (T4)args[3]!)); + } + + /// + /// Registers a handler that will be invoked when the hub method with the specified method name is invoked. + /// Returns value returned by handler to server if the server requests a result. + /// + /// The first argument type. + /// The second argument type. + /// The third argument type. + /// The fourth argument type. + /// The fifth argument type. + /// The return type the handler returns. + /// The hub connection. + /// The name of the hub method to define. + /// The handler that will be raised when the hub method is invoked. + /// A subscription that can be disposed to unsubscribe from the hub method. + public static IDisposable On(this HubConnection hubConnection, string methodName, Func handler) + { + if (hubConnection == null) + { + throw new ArgumentNullException(nameof(hubConnection)); + } + + return hubConnection.On(methodName, + new[] { typeof(T1), typeof(T2), typeof(T3), typeof(T4), typeof(T5) }, + args => handler((T1)args[0]!, (T2)args[1]!, (T3)args[2]!, (T4)args[3]!, (T5)args[4]!)); + } + + /// + /// Registers a handler that will be invoked when the hub method with the specified method name is invoked. + /// Returns value returned by handler to server if the server requests a result. + /// + /// The first argument type. + /// The second argument type. + /// The third argument type. + /// The fourth argument type. + /// The fifth argument type. + /// The sixth argument type. + /// The return type the handler returns. + /// The hub connection. + /// The name of the hub method to define. + /// The handler that will be raised when the hub method is invoked. + /// A subscription that can be disposed to unsubscribe from the hub method. + public static IDisposable On(this HubConnection hubConnection, string methodName, Func handler) + { + if (hubConnection == null) + { + throw new ArgumentNullException(nameof(hubConnection)); + } + + return hubConnection.On(methodName, + new[] { typeof(T1), typeof(T2), typeof(T3), typeof(T4), typeof(T5), typeof(T6) }, + args => handler((T1)args[0]!, (T2)args[1]!, (T3)args[2]!, (T4)args[3]!, (T5)args[4]!, (T6)args[5]!)); + } + + /// + /// Registers a handler that will be invoked when the hub method with the specified method name is invoked. + /// Returns value returned by handler to server if the server requests a result. + /// + /// The first argument type. + /// The second argument type. + /// The third argument type. + /// The fourth argument type. + /// The fifth argument type. + /// The sixth argument type. + /// The seventh argument type. + /// The return type the handler returns. + /// The hub connection. + /// The name of the hub method to define. + /// The handler that will be raised when the hub method is invoked. + /// A subscription that can be disposed to unsubscribe from the hub method. + public static IDisposable On(this HubConnection hubConnection, string methodName, Func handler) + { + if (hubConnection == null) + { + throw new ArgumentNullException(nameof(hubConnection)); + } + + return hubConnection.On(methodName, + new[] { typeof(T1), typeof(T2), typeof(T3), typeof(T4), typeof(T5), typeof(T6), typeof(T7) }, + args => handler((T1)args[0]!, (T2)args[1]!, (T3)args[2]!, (T4)args[3]!, (T5)args[4]!, (T6)args[5]!, (T7)args[6]!)); + } + + /// + /// Registers a handler that will be invoked when the hub method with the specified method name is invoked. + /// Returns value returned by handler to server if the server requests a result. + /// + /// The first argument type. + /// The second argument type. + /// The third argument type. + /// The fourth argument type. + /// The fifth argument type. + /// The sixth argument type. + /// The seventh argument type. + /// The eighth argument type. + /// The return type the handler returns. + /// The hub connection. + /// The name of the hub method to define. + /// The handler that will be raised when the hub method is invoked. + /// A subscription that can be disposed to unsubscribe from the hub method. + public static IDisposable On(this HubConnection hubConnection, string methodName, Func handler) + { + if (hubConnection == null) + { + throw new ArgumentNullException(nameof(hubConnection)); + } + + return hubConnection.On(methodName, + new[] { typeof(T1), typeof(T2), typeof(T3), typeof(T4), typeof(T5), typeof(T6), typeof(T7), typeof(T8) }, + args => handler((T1)args[0]!, (T2)args[1]!, (T3)args[2]!, (T4)args[3]!, (T5)args[4]!, (T6)args[5]!, (T7)args[6]!, (T8)args[7]!)); + } + + /// + /// Registers a handler that will be invoked when the hub method with the specified method name is invoked. + /// Returns value returned by handler to server if the server requests a result. + /// + /// The first argument type. + /// The return type the handler returns. + /// The hub connection. + /// The name of the hub method to define. + /// The handler that will be raised when the hub method is invoked. + /// A subscription that can be disposed to unsubscribe from the hub method. + public static IDisposable On(this HubConnection hubConnection, string methodName, Func> handler) + { + if (hubConnection == null) + { + throw new ArgumentNullException(nameof(hubConnection)); + } + + return hubConnection.On(methodName, + new[] { typeof(T1) }, + args => handler((T1)args[0]!)); + } + + /// + /// Registers a handler that will be invoked when the hub method with the specified method name is invoked. + /// Returns value returned by handler to server if the server requests a result. + /// + /// The first argument type. + /// The second argument type. + /// The return type the handler returns. + /// The hub connection. + /// The name of the hub method to define. + /// The handler that will be raised when the hub method is invoked. + /// A subscription that can be disposed to unsubscribe from the hub method. + public static IDisposable On(this HubConnection hubConnection, string methodName, Func> handler) + { + if (hubConnection == null) + { + throw new ArgumentNullException(nameof(hubConnection)); + } + + return hubConnection.On(methodName, + new[] { typeof(T1), typeof(T2) }, + args => handler((T1)args[0]!, (T2)args[1]!)); + } + + /// + /// Registers a handler that will be invoked when the hub method with the specified method name is invoked. + /// Returns value returned by handler to server if the server requests a result. + /// + /// The first argument type. + /// The second argument type. + /// The third argument type. + /// The return type the handler returns. + /// The hub connection. + /// The name of the hub method to define. + /// The handler that will be raised when the hub method is invoked. + /// A subscription that can be disposed to unsubscribe from the hub method. + public static IDisposable On(this HubConnection hubConnection, string methodName, Func> handler) + { + if (hubConnection == null) + { + throw new ArgumentNullException(nameof(hubConnection)); + } + + return hubConnection.On(methodName, + new[] { typeof(T1), typeof(T2), typeof(T3) }, + args => handler((T1)args[0]!, (T2)args[1]!, (T3)args[2]!)); + } + + /// + /// Registers a handler that will be invoked when the hub method with the specified method name is invoked. + /// Returns value returned by handler to server if the server requests a result. + /// + /// The first argument type. + /// The second argument type. + /// The third argument type. + /// The fourth argument type. + /// The return type the handler returns. + /// The hub connection. + /// The name of the hub method to define. + /// The handler that will be raised when the hub method is invoked. + /// A subscription that can be disposed to unsubscribe from the hub method. + public static IDisposable On(this HubConnection hubConnection, string methodName, Func> handler) + { + if (hubConnection == null) + { + throw new ArgumentNullException(nameof(hubConnection)); + } + + return hubConnection.On(methodName, + new[] { typeof(T1), typeof(T2), typeof(T3), typeof(T4) }, + args => handler((T1)args[0]!, (T2)args[1]!, (T3)args[2]!, (T4)args[3]!)); + } + + /// + /// Registers a handler that will be invoked when the hub method with the specified method name is invoked. + /// Returns value returned by handler to server if the server requests a result. + /// + /// The first argument type. + /// The second argument type. + /// The third argument type. + /// The fourth argument type. + /// The fifth argument type. + /// The return type the handler returns. + /// The hub connection. + /// The name of the hub method to define. + /// The handler that will be raised when the hub method is invoked. + /// A subscription that can be disposed to unsubscribe from the hub method. + public static IDisposable On(this HubConnection hubConnection, string methodName, Func> handler) + { + if (hubConnection == null) + { + throw new ArgumentNullException(nameof(hubConnection)); + } + + return hubConnection.On(methodName, + new[] { typeof(T1), typeof(T2), typeof(T3), typeof(T4), typeof(T5) }, + args => handler((T1)args[0]!, (T2)args[1]!, (T3)args[2]!, (T4)args[3]!, (T5)args[4]!)); + } + + /// + /// Registers a handler that will be invoked when the hub method with the specified method name is invoked. + /// Returns value returned by handler to server if the server requests a result. + /// + /// The first argument type. + /// The second argument type. + /// The third argument type. + /// The fourth argument type. + /// The fifth argument type. + /// The sixth argument type. + /// The return type the handler returns. + /// The hub connection. + /// The name of the hub method to define. + /// The handler that will be raised when the hub method is invoked. + /// A subscription that can be disposed to unsubscribe from the hub method. + public static IDisposable On(this HubConnection hubConnection, string methodName, Func> handler) + { + if (hubConnection == null) + { + throw new ArgumentNullException(nameof(hubConnection)); + } + + return hubConnection.On(methodName, + new[] { typeof(T1), typeof(T2), typeof(T3), typeof(T4), typeof(T5), typeof(T6) }, + args => handler((T1)args[0]!, (T2)args[1]!, (T3)args[2]!, (T4)args[3]!, (T5)args[4]!, (T6)args[5]!)); + } + + /// + /// Registers a handler that will be invoked when the hub method with the specified method name is invoked. + /// Returns value returned by handler to server if the server requests a result. + /// + /// The first argument type. + /// The second argument type. + /// The third argument type. + /// The fourth argument type. + /// The fifth argument type. + /// The sixth argument type. + /// The seventh argument type. + /// The return type the handler returns. + /// The hub connection. + /// The name of the hub method to define. + /// The handler that will be raised when the hub method is invoked. + /// A subscription that can be disposed to unsubscribe from the hub method. + public static IDisposable On(this HubConnection hubConnection, string methodName, Func> handler) + { + if (hubConnection == null) + { + throw new ArgumentNullException(nameof(hubConnection)); + } + + return hubConnection.On(methodName, + new[] { typeof(T1), typeof(T2), typeof(T3), typeof(T4), typeof(T5), typeof(T6), typeof(T7) }, + args => handler((T1)args[0]!, (T2)args[1]!, (T3)args[2]!, (T4)args[3]!, (T5)args[4]!, (T6)args[5]!, (T7)args[6]!)); + } + + /// + /// Registers a handler that will be invoked when the hub method with the specified method name is invoked. + /// Returns value returned by handler to server if the server requests a result. + /// + /// The first argument type. + /// The second argument type. + /// The third argument type. + /// The fourth argument type. + /// The fifth argument type. + /// The sixth argument type. + /// The seventh argument type. + /// The eighth argument type. + /// The return type the handler returns. + /// The hub connection. + /// The name of the hub method to define. + /// The handler that will be raised when the hub method is invoked. + /// A subscription that can be disposed to unsubscribe from the hub method. + public static IDisposable On(this HubConnection hubConnection, string methodName, Func> handler) + { + if (hubConnection == null) + { + throw new ArgumentNullException(nameof(hubConnection)); + } + + return hubConnection.On(methodName, + new[] { typeof(T1), typeof(T2), typeof(T3), typeof(T4), typeof(T5), typeof(T6), typeof(T7), typeof(T8) }, + args => handler((T1)args[0]!, (T2)args[1]!, (T3)args[2]!, (T4)args[3]!, (T5)args[4]!, (T6)args[5]!, (T7)args[6]!, (T8)args[7]!)); + } +} diff --git a/src/SignalR/clients/csharp/Client.Core/src/HubConnectionExtensions.cs b/src/SignalR/clients/csharp/Client.Core/src/HubConnectionExtensions.cs index e8afe03949ef..ea131eff6b35 100644 --- a/src/SignalR/clients/csharp/Client.Core/src/HubConnectionExtensions.cs +++ b/src/SignalR/clients/csharp/Client.Core/src/HubConnectionExtensions.cs @@ -13,7 +13,7 @@ public static partial class HubConnectionExtensions { private static IDisposable On(this HubConnection hubConnection, string methodName, Type[] parameterTypes, Action handler) { - return hubConnection.On(methodName, parameterTypes, (parameters, state) => + return hubConnection.On(methodName, parameterTypes, static (parameters, state) => { var currentHandler = (Action)state; currentHandler(parameters); diff --git a/src/SignalR/clients/csharp/Client.Core/src/Microsoft.AspNetCore.SignalR.Client.Core.csproj b/src/SignalR/clients/csharp/Client.Core/src/Microsoft.AspNetCore.SignalR.Client.Core.csproj index 7c0548133757..0fe14fda5c80 100644 --- a/src/SignalR/clients/csharp/Client.Core/src/Microsoft.AspNetCore.SignalR.Client.Core.csproj +++ b/src/SignalR/clients/csharp/Client.Core/src/Microsoft.AspNetCore.SignalR.Client.Core.csproj @@ -13,6 +13,7 @@ + diff --git a/src/SignalR/clients/csharp/Client.Core/src/PublicAPI.Unshipped.txt b/src/SignalR/clients/csharp/Client.Core/src/PublicAPI.Unshipped.txt index 7dc5c58110bf..40077fde5c02 100644 --- a/src/SignalR/clients/csharp/Client.Core/src/PublicAPI.Unshipped.txt +++ b/src/SignalR/clients/csharp/Client.Core/src/PublicAPI.Unshipped.txt @@ -1 +1,21 @@ #nullable enable +static Microsoft.AspNetCore.SignalR.Client.HubConnectionExtensions.On(this Microsoft.AspNetCore.SignalR.Client.HubConnection! hubConnection, string! methodName, System.Func!>! handler) -> System.IDisposable! +static Microsoft.AspNetCore.SignalR.Client.HubConnectionExtensions.On(this Microsoft.AspNetCore.SignalR.Client.HubConnection! hubConnection, string! methodName, System.Func! handler) -> System.IDisposable! +static Microsoft.AspNetCore.SignalR.Client.HubConnectionExtensions.On(this Microsoft.AspNetCore.SignalR.Client.HubConnection! hubConnection, string! methodName, System.Func!>! handler) -> System.IDisposable! +static Microsoft.AspNetCore.SignalR.Client.HubConnectionExtensions.On(this Microsoft.AspNetCore.SignalR.Client.HubConnection! hubConnection, string! methodName, System.Func! handler) -> System.IDisposable! +static Microsoft.AspNetCore.SignalR.Client.HubConnectionExtensions.On(this Microsoft.AspNetCore.SignalR.Client.HubConnection! hubConnection, string! methodName, System.Func!>! handler) -> System.IDisposable! +static Microsoft.AspNetCore.SignalR.Client.HubConnectionExtensions.On(this Microsoft.AspNetCore.SignalR.Client.HubConnection! hubConnection, string! methodName, System.Func! handler) -> System.IDisposable! +static Microsoft.AspNetCore.SignalR.Client.HubConnectionExtensions.On(this Microsoft.AspNetCore.SignalR.Client.HubConnection! hubConnection, string! methodName, System.Func!>! handler) -> System.IDisposable! +static Microsoft.AspNetCore.SignalR.Client.HubConnectionExtensions.On(this Microsoft.AspNetCore.SignalR.Client.HubConnection! hubConnection, string! methodName, System.Func! handler) -> System.IDisposable! +static Microsoft.AspNetCore.SignalR.Client.HubConnectionExtensions.On(this Microsoft.AspNetCore.SignalR.Client.HubConnection! hubConnection, string! methodName, System.Func!>! handler) -> System.IDisposable! +static Microsoft.AspNetCore.SignalR.Client.HubConnectionExtensions.On(this Microsoft.AspNetCore.SignalR.Client.HubConnection! hubConnection, string! methodName, System.Func! handler) -> System.IDisposable! +static Microsoft.AspNetCore.SignalR.Client.HubConnectionExtensions.On(this Microsoft.AspNetCore.SignalR.Client.HubConnection! hubConnection, string! methodName, System.Func!>! handler) -> System.IDisposable! +static Microsoft.AspNetCore.SignalR.Client.HubConnectionExtensions.On(this Microsoft.AspNetCore.SignalR.Client.HubConnection! hubConnection, string! methodName, System.Func! handler) -> System.IDisposable! +static Microsoft.AspNetCore.SignalR.Client.HubConnectionExtensions.On(this Microsoft.AspNetCore.SignalR.Client.HubConnection! hubConnection, string! methodName, System.Func!>! handler) -> System.IDisposable! +static Microsoft.AspNetCore.SignalR.Client.HubConnectionExtensions.On(this Microsoft.AspNetCore.SignalR.Client.HubConnection! hubConnection, string! methodName, System.Func! handler) -> System.IDisposable! +static Microsoft.AspNetCore.SignalR.Client.HubConnectionExtensions.On(this Microsoft.AspNetCore.SignalR.Client.HubConnection! hubConnection, string! methodName, System.Func!>! handler) -> System.IDisposable! +static Microsoft.AspNetCore.SignalR.Client.HubConnectionExtensions.On(this Microsoft.AspNetCore.SignalR.Client.HubConnection! hubConnection, string! methodName, System.Func! handler) -> System.IDisposable! +static Microsoft.AspNetCore.SignalR.Client.HubConnectionExtensions.On(this Microsoft.AspNetCore.SignalR.Client.HubConnection! hubConnection, string! methodName, System.Func!>! handler) -> System.IDisposable! +static Microsoft.AspNetCore.SignalR.Client.HubConnectionExtensions.On(this Microsoft.AspNetCore.SignalR.Client.HubConnection! hubConnection, string! methodName, System.Func! handler) -> System.IDisposable! +static Microsoft.AspNetCore.SignalR.Client.HubConnectionExtensions.On(this Microsoft.AspNetCore.SignalR.Client.HubConnection! hubConnection, string! methodName, System.Type![]! parameterTypes, System.Func!>! handler) -> System.IDisposable! +virtual Microsoft.AspNetCore.SignalR.Client.HubConnection.On(string! methodName, System.Type![]! parameterTypes, System.Func!>! handler, object! state) -> System.IDisposable! diff --git a/src/SignalR/clients/csharp/Client/test/UnitTests/HubConnectionTests.Extensions.cs b/src/SignalR/clients/csharp/Client/test/UnitTests/HubConnectionTests.Extensions.cs index 7f60b7791161..1ae2bf44df54 100644 --- a/src/SignalR/clients/csharp/Client/test/UnitTests/HubConnectionTests.Extensions.cs +++ b/src/SignalR/clients/csharp/Client/test/UnitTests/HubConnectionTests.Extensions.cs @@ -4,6 +4,7 @@ using System; using System.Threading.Tasks; using Microsoft.AspNetCore.Testing; +using Newtonsoft.Json.Linq; using Xunit; namespace Microsoft.AspNetCore.SignalR.Client.Tests; @@ -396,5 +397,308 @@ await connection.ReceiveJsonMessage( await hubConnection.DisposeAsync().DefaultTimeout(); } } + + [Fact] + public async Task OnWithResult() + { + var returnValue = 46; + var result = await InvokeOnWithResult( + (hubConnection, tcs) => hubConnection.On("Foo", + () => + { + tcs.SetResult(new object[0]); + return returnValue; + }), + new object[0]); + Assert.Equal(returnValue, result); + } + + [Fact] + public async Task OnAsyncWithResult() + { + var returnValue = 1220; + var result = await InvokeOnWithResult( + (hubConnection, tcs) => hubConnection.On("Foo", + async () => + { + tcs.SetResult(new object[0]); + await Task.CompletedTask; + return returnValue; + }), + new object[0]); + Assert.Equal(returnValue, result); + } + + [Fact] + public async Task OnT1WithResult() + { + var returnValue = "buffalo"; + var result = await InvokeOnWithResult( + (hubConnection, tcs) => hubConnection.On("Foo", + r => + { + tcs.SetResult(new object[] { r }); + return returnValue; + }), + new object[] { 42 }); + Assert.Equal(returnValue, result); + } + + [Fact] + public async Task OnT1AsyncWithResult() + { + var returnValue = 2; + var result = await InvokeOnWithResult( + (hubConnection, tcs) => hubConnection.On("Foo", + async r => + { + tcs.SetResult(new object[] { r }); + await Task.CompletedTask; + return returnValue; + }), + new object[] { 42 }); + + Assert.Equal(returnValue, result); + } + + [Fact] + public async Task OnT2WithResult() + { + var returnValue = "ret"; + var result = await InvokeOnWithResult( + (hubConnection, tcs) => hubConnection.On("Foo", + (r1, r2) => + { + tcs.SetResult(new object[] { r1, r2 }); + return returnValue; + }), + new object[] { 42, "abc" }); + Assert.Equal(returnValue, result); + } + + [Fact] + public async Task OnT2AsyncWithResult() + { + var returnResult = 928; + var result = await InvokeOnWithResult( + (hubConnection, tcs) => hubConnection.On("Foo", + (r1, r2) => + { + tcs.SetResult(new object[] { r1, r2 }); + return Task.FromResult(returnResult); + }), + new object[] { 42, "abc" }); + Assert.Equal(returnResult, result); + } + + [Fact] + public async Task OnT3WithResult() + { + var returnValue = "bob"; + var result = await InvokeOnWithResult( + (hubConnection, tcs) => hubConnection.On("Foo", + (r1, r2, r3) => + { + tcs.SetResult(new object[] { r1, r2, r3 }); + return returnValue; + }), + new object[] { 42, "abc", 24.0f }); + Assert.Equal(returnValue, result); + } + + [Fact] + public async Task OnT3AsyncWithResult() + { + var returnResult = "random"; + var result = await InvokeOnWithResult( + (hubConnection, tcs) => hubConnection.On("Foo", + (r1, r2, r3) => + { + tcs.SetResult(new object[] { r1, r2, r3 }); + return Task.FromResult(returnResult); + }), + new object[] { 42, "abc", 24.0f }); + Assert.Equal(returnResult, result); + } + + [Fact] + public async Task OnT4WithResult() + { + var returnResult = 233; + var result = await InvokeOnWithResult( + (hubConnection, tcs) => hubConnection.On("Foo", + (r1, r2, r3, r4) => + { + tcs.SetResult(new object[] { r1, r2, r3, r4 }); + return returnResult; + }), + new object[] { 42, "abc", 24.0f, 10d }); + Assert.Equal(returnResult, result); + } + + [Fact] + public async Task OnT4AsyncWithResult() + { + var returnResult = "alphabet"; + var result = await InvokeOnWithResult( + (hubConnection, tcs) => hubConnection.On("Foo", + (r1, r2, r3, r4) => + { + tcs.SetResult(new object[] { r1, r2, r3, r4 }); + return returnResult; + }), + new object[] { 42, "abc", 24.0f, 10d }); + Assert.Equal(returnResult, result); + } + + [Fact] + public async Task OnT5WithResult() + { + var returnResult = 3004; + var result = await InvokeOnWithResult( + (hubConnection, tcs) => hubConnection.On("Foo", + (r1, r2, r3, r4, r5) => + { + tcs.SetResult(new object[] { r1, r2, r3, r4, r5 }); + return returnResult; + }), + new object[] { 42, "abc", 24.0f, 10d, "123" }); + Assert.Equal(returnResult, result); + } + + [Fact] + public async Task OnT5AsyncWithResult() + { + var returnResult = "alphabet"; + var result = await InvokeOnWithResult( + (hubConnection, tcs) => hubConnection.On("Foo", + (r1, r2, r3, r4, r5) => + { + tcs.SetResult(new object[] { r1, r2, r3, r4, r5 }); + return returnResult; + }), + new object[] { 42, "abc", 24.0f, 10d, "123" }); + Assert.Equal(returnResult, result); + } + + [Fact] + public async Task OnT6WithResult() + { + var returnResult = "alphabet"; + var result = await InvokeOnWithResult( + (hubConnection, tcs) => hubConnection.On("Foo", + (r1, r2, r3, r4, r5, r6) => + { + tcs.SetResult(new object[] { r1, r2, r3, r4, r5, r6 }); + return returnResult; + }), + new object[] { 42, "abc", 24.0f, 10d, "123", 24 }); + Assert.Equal(returnResult, result); + } + + [Fact] + public async Task OnT6AsyncWithResult() + { + var returnResult = "alphabet"; + var result = await InvokeOnWithResult( + (hubConnection, tcs) => hubConnection.On("Foo", + (r1, r2, r3, r4, r5, r6) => + { + tcs.SetResult(new object[] { r1, r2, r3, r4, r5, r6 }); + return returnResult; + }), + new object[] { 42, "abc", 24.0f, 10d, "123", 24 }); + Assert.Equal(returnResult, result); + } + + [Fact] + public async Task OnT7WithResult() + { + var returnResult = 100; + var result = await InvokeOnWithResult( + (hubConnection, tcs) => hubConnection.On("Foo", + (r1, r2, r3, r4, r5, r6, r7) => + { + tcs.SetResult(new object[] { r1, r2, r3, r4, r5, r6, r7 }); + return returnResult; + }), + new object[] { 42, "abc", 24.0f, 10d, "123", 24, 'c' }); + Assert.Equal(returnResult, result); + } + + [Fact] + public async Task OnT7AsyncWithResult() + { + var returnResult = "alphabet"; + var result = await InvokeOnWithResult( + (hubConnection, tcs) => hubConnection.On("Foo", + (r1, r2, r3, r4, r5, r6, r7) => + { + tcs.SetResult(new object[] { r1, r2, r3, r4, r5, r6, r7 }); + return returnResult; + }), + new object[] { 42, "abc", 24.0f, 10d, "123", 24, 'c' }); + Assert.Equal(returnResult, result); + } + + [Fact] + public async Task OnT8WithResult() + { + var returnResult = 102; + var result = await InvokeOnWithResult( + (hubConnection, tcs) => hubConnection.On("Foo", + (r1, r2, r3, r4, r5, r6, r7, r8) => + { + tcs.SetResult(new object[] { r1, r2, r3, r4, r5, r6, r7, r8 }); + return returnResult; + }), + new object[] { 42, "abc", 24.0f, 10d, "123", 24, 'c', "XYZ" }); + Assert.Equal(returnResult, result); + } + + [Fact] + public async Task OnT8AsyncWithResult() + { + var returnResult = "alphabet"; + var result = await InvokeOnWithResult( + (hubConnection, tcs) => hubConnection.On("Foo", + (r1, r2, r3, r4, r5, r6, r7, r8) => + { + tcs.SetResult(new object[] { r1, r2, r3, r4, r5, r6, r7, r8 }); + return returnResult; + }), + new object[] { 42, "abc", 24.0f, 10d, "123", 24, 'c', "XYZ" }); + Assert.Equal(returnResult, result); + } + + private async Task InvokeOnWithResult(Action> onAction, object[] args) + { + var connection = new TestConnection(); + var hubConnection = CreateHubConnection(connection); + var handlerTcs = new TaskCompletionSource(); + try + { + onAction(hubConnection, handlerTcs); + await hubConnection.StartAsync(); + + await connection.ReceiveJsonMessage( + new + { + invocationId = "1", + type = 1, + target = "Foo", + arguments = args + }).DefaultTimeout(); + + await handlerTcs.Task.DefaultTimeout(); + var json = await connection.ReadSentJsonAsync(); + var result = json["result"]; + return result; + } + finally + { + await hubConnection.DisposeAsync().DefaultTimeout(); + } + } } } diff --git a/src/SignalR/clients/csharp/Client/test/UnitTests/HubConnectionTests.Protocol.cs b/src/SignalR/clients/csharp/Client/test/UnitTests/HubConnectionTests.Protocol.cs index 98078fb198ef..caf666292686 100644 --- a/src/SignalR/clients/csharp/Client/test/UnitTests/HubConnectionTests.Protocol.cs +++ b/src/SignalR/clients/csharp/Client/test/UnitTests/HubConnectionTests.Protocol.cs @@ -676,5 +676,205 @@ public async Task ClientWithInherentKeepAliveDoesNotPing() await connection.DisposeAsync().DefaultTimeout(); } } + + [Fact] + public async Task ClientCanReturnResult() + { + var connection = new TestConnection(); + var hubConnection = CreateHubConnection(connection); + try + { + await hubConnection.StartAsync().DefaultTimeout(); + + hubConnection.On("Result", () => 10); + + await connection.ReceiveTextAsync("{\"type\":1,\"invocationId\":\"1\",\"target\":\"Result\",\"arguments\":[]}\u001e").DefaultTimeout(); + + var invokeMessage = await connection.ReadSentTextMessageAsync().DefaultTimeout(); + + Assert.Equal("{\"type\":3,\"invocationId\":\"1\",\"result\":10}", invokeMessage); + } + finally + { + await hubConnection.DisposeAsync().DefaultTimeout(); + await connection.DisposeAsync().DefaultTimeout(); + } + } + + [Fact] + public async Task ThrowsWhenMultipleReturningHandlersRegistered() + { + var connection = new TestConnection(); + var hubConnection = CreateHubConnection(connection); + try + { + await hubConnection.StartAsync().DefaultTimeout(); + + hubConnection.On("Result", () => 10); + var ex = Assert.Throws( + () => hubConnection.On("Result", () => 11)); + Assert.Equal("'Result' already has a value returning handler. Multiple return values are not supported.", ex.Message); + } + finally + { + await hubConnection.DisposeAsync().DefaultTimeout(); + await connection.DisposeAsync().DefaultTimeout(); + } + } + + [Fact] + public async Task ClientReturnHandlerCanMixWithNonReturnHandler() + { + var connection = new TestConnection(); + var hubConnection = CreateHubConnection(connection); + try + { + await hubConnection.StartAsync().DefaultTimeout(); + + var tcs = new TaskCompletionSource(TaskCreationOptions.RunContinuationsAsynchronously); + hubConnection.On("Result", () => 40); + hubConnection.On("Result", tcs.SetResult); + + await connection.ReceiveTextAsync("{\"type\":1,\"invocationId\":\"1\",\"target\":\"Result\",\"arguments\":[]}\u001e").DefaultTimeout(); + + var invokeMessage = await connection.ReadSentTextMessageAsync().DefaultTimeout(); + + Assert.Equal("{\"type\":3,\"invocationId\":\"1\",\"result\":40}", invokeMessage); + await tcs.Task.DefaultTimeout(); + } + finally + { + await hubConnection.DisposeAsync().DefaultTimeout(); + await connection.DisposeAsync().DefaultTimeout(); + } + } + + [Fact] + public async Task ClientCanThrowErrorResult() + { + var connection = new TestConnection(); + var hubConnection = CreateHubConnection(connection); + try + { + await hubConnection.StartAsync().DefaultTimeout(); + + hubConnection.On("Result", int () => + { + throw new Exception("error from client"); + }); + + await connection.ReceiveTextAsync("{\"type\":1,\"invocationId\":\"1\",\"target\":\"Result\",\"arguments\":[]}\u001e").DefaultTimeout(); + + var invokeMessage = await connection.ReadSentTextMessageAsync().DefaultTimeout(); + + Assert.Equal("{\"type\":3,\"invocationId\":\"1\",\"error\":\"error from client\"}", invokeMessage); + } + finally + { + await hubConnection.DisposeAsync().DefaultTimeout(); + await connection.DisposeAsync().DefaultTimeout(); + } + } + + [Fact] + public async Task ClientResultIgnoresErrorWhenLastHandlerSuccessful() + { + var connection = new TestConnection(); + var hubConnection = CreateHubConnection(connection); + try + { + await hubConnection.StartAsync().DefaultTimeout(); + + hubConnection.On("Result", () => + { + throw new Exception("error from client"); + }); + + hubConnection.On("Result", () => 20); + + await connection.ReceiveTextAsync("{\"type\":1,\"invocationId\":\"1\",\"target\":\"Result\",\"arguments\":[]}\u001e").DefaultTimeout(); + + var invokeMessage = await connection.ReadSentTextMessageAsync().DefaultTimeout(); + + Assert.Equal("{\"type\":3,\"invocationId\":\"1\",\"result\":20}", invokeMessage); + } + finally + { + await hubConnection.DisposeAsync().DefaultTimeout(); + await connection.DisposeAsync().DefaultTimeout(); + } + } + + [Fact] + public async Task ClientResultReturnsErrorIfNoHandlerFromClient() + { + var connection = new TestConnection(); + var hubConnection = CreateHubConnection(connection); + try + { + await hubConnection.StartAsync().DefaultTimeout(); + + await connection.ReceiveTextAsync("{\"type\":1,\"invocationId\":\"1\",\"target\":\"Result\",\"arguments\":[]}\u001e").DefaultTimeout(); + + var invokeMessage = await connection.ReadSentTextMessageAsync().DefaultTimeout(); + + Assert.Equal("{\"type\":3,\"invocationId\":\"1\",\"error\":\"Client didn't provide a result.\"}", invokeMessage); + } + finally + { + await hubConnection.DisposeAsync().DefaultTimeout(); + await connection.DisposeAsync().DefaultTimeout(); + } + } + + [Fact] + public async Task ClientResultReturnsErrorIfNoResultFromClient() + { + var connection = new TestConnection(); + var hubConnection = CreateHubConnection(connection); + try + { + await hubConnection.StartAsync().DefaultTimeout(); + + // No result provided + hubConnection.On("Result", () => { }); + + await connection.ReceiveTextAsync("{\"type\":1,\"invocationId\":\"1\",\"target\":\"Result\",\"arguments\":[]}\u001e").DefaultTimeout(); + + var invokeMessage = await connection.ReadSentTextMessageAsync().DefaultTimeout(); + + Assert.Equal("{\"type\":3,\"invocationId\":\"1\",\"error\":\"Client didn't provide a result.\"}", invokeMessage); + } + finally + { + await hubConnection.DisposeAsync().DefaultTimeout(); + await connection.DisposeAsync().DefaultTimeout(); + } + } + + [Fact] + public async Task ClientResultCanReturnNullResult() + { + var connection = new TestConnection(); + var hubConnection = CreateHubConnection(connection); + try + { + await hubConnection.StartAsync().DefaultTimeout(); + + // No result provided + hubConnection.On("Result", object () => null); + + await connection.ReceiveTextAsync("{\"type\":1,\"invocationId\":\"1\",\"target\":\"Result\",\"arguments\":[]}\u001e").DefaultTimeout(); + + var invokeMessage = await connection.ReadSentTextMessageAsync().DefaultTimeout(); + + Assert.Equal("{\"type\":3,\"invocationId\":\"1\",\"result\":null}", invokeMessage); + } + finally + { + await hubConnection.DisposeAsync().DefaultTimeout(); + await connection.DisposeAsync().DefaultTimeout(); + } + } } } diff --git a/src/SignalR/clients/ts/FunctionalTests/Startup.cs b/src/SignalR/clients/ts/FunctionalTests/Startup.cs index 509c7378da60..ef67be4b91a6 100644 --- a/src/SignalR/clients/ts/FunctionalTests/Startup.cs +++ b/src/SignalR/clients/ts/FunctionalTests/Startup.cs @@ -1,6 +1,7 @@ // Licensed to the .NET Foundation under one or more agreements. // The .NET Foundation licenses this file to you under the MIT license. +using System.Globalization; using System.IdentityModel.Tokens.Jwt; using System.Reflection; using System.Security.Claims; @@ -223,6 +224,19 @@ public void Configure(IApplicationBuilder app, IWebHostEnvironment env, ILogger< return context.Response.WriteAsync(GenerateJwtToken()); }); + endpoints.MapGet("/clientresult/{id}", async (IHubContext hubContext, string id) => + { + try + { + var result = await hubContext.Clients.Single(id).InvokeAsync("Result"); + return result.ToString(CultureInfo.InvariantCulture); + } + catch (Exception ex) + { + return ex.Message; + } + }); + endpoints.MapGet("/deployment", context => { var attributes = Assembly.GetAssembly(typeof(Startup)).GetCustomAttributes(); diff --git a/src/SignalR/clients/ts/FunctionalTests/ts/HubConnectionTests.ts b/src/SignalR/clients/ts/FunctionalTests/ts/HubConnectionTests.ts index c463fb7096b6..2b9d6efacbd9 100644 --- a/src/SignalR/clients/ts/FunctionalTests/ts/HubConnectionTests.ts +++ b/src/SignalR/clients/ts/FunctionalTests/ts/HubConnectionTests.ts @@ -513,6 +513,60 @@ describe("hubConnection", () => { await hubConnection.stop(); } }); + + it("can return result to server", async () => { + const hubConnection = getConnectionBuilder(transportType, undefined, { httpClient }) + .withHubProtocol(protocol) + .build(); + + hubConnection.on("Result", () => { + return 10; + }); + + await hubConnection.start(); + + const response = await httpClient.get(ENDPOINT_BASE_URL + `/clientresult/${hubConnection.connectionId}`); + + expect(response.content).toEqual("10"); + + await hubConnection.stop(); + }); + + it("can throw result to server", async () => { + const hubConnection = getConnectionBuilder(transportType, undefined, { httpClient }) + .withHubProtocol(protocol) + .build(); + + hubConnection.on("Result", () => { + throw new Error("from callback"); + }); + + try { + await hubConnection.start(); + + const response = await httpClient.get(ENDPOINT_BASE_URL + `/clientresult/${hubConnection.connectionId}`); + + expect(response.content).toEqual("Error: from callback"); + } finally { + await hubConnection.stop(); + } + }); + + it("returns result error to server when no result given", async () => { + const hubConnection = getConnectionBuilder(transportType, undefined, { httpClient }) + .withHubProtocol(protocol) + .build(); + + try { + await hubConnection.start(); + + const response = await httpClient.get(ENDPOINT_BASE_URL + `/clientresult/${hubConnection.connectionId}`); + + expect(response.content).toEqual("Client didn't provide a result."); + } finally { + await hubConnection.stop(); + } + }); }); }); diff --git a/src/SignalR/clients/ts/signalr/src/HubConnection.ts b/src/SignalR/clients/ts/signalr/src/HubConnection.ts index 0338521956c5..67c32aef1fa1 100644 --- a/src/SignalR/clients/ts/signalr/src/HubConnection.ts +++ b/src/SignalR/clients/ts/signalr/src/HubConnection.ts @@ -39,7 +39,7 @@ export class HubConnection { private _protocol: IHubProtocol; private _handshakeProtocol: HandshakeProtocol; private _callbacks: { [invocationId: string]: (invocationEvent: StreamItemMessage | CompletionMessage | null, error?: Error) => void }; - private _methods: { [name: string]: ((...args: any[]) => void)[] }; + private _methods: { [name: string]: (((...args: any[]) => void) | ((...args: any[]) => any))[] }; private _invocationId: number; private _closedCallbacks: ((error?: Error) => void)[]; @@ -443,6 +443,7 @@ export class HubConnection { * @param {string} methodName The name of the hub method to define. * @param {Function} newMethod The handler that will be raised when the hub method is invoked. */ + public on(methodName: string, newMethod: (...args: any[]) => any): void public on(methodName: string, newMethod: (...args: any[]) => void): void { if (!methodName || !newMethod) { return; @@ -546,6 +547,7 @@ export class HubConnection { for (const message of messages) { switch (message.type) { case MessageType.Invocation: + // eslint-disable-next-line @typescript-eslint/no-floating-promises this._invokeClientMethod(message); break; case MessageType.StreamItem: @@ -672,26 +674,62 @@ export class HubConnection { this.connection.stop(new Error("Server timeout elapsed without receiving a message from the server.")); } - private _invokeClientMethod(invocationMessage: InvocationMessage) { - const methods = this._methods[invocationMessage.target.toLowerCase()]; - if (methods) { - const methodsCopy = methods.slice(); - try { - methodsCopy.forEach((m) => m.apply(this, invocationMessage.arguments)); - } catch (e) { - this._logger.log(LogLevel.Error, `A callback for the method ${invocationMessage.target.toLowerCase()} threw error '${e}'.`); - } + private async _invokeClientMethod(invocationMessage: InvocationMessage) { + const methodName = invocationMessage.target.toLowerCase(); + const methods = this._methods[methodName]; + if (!methods) { + this._logger.log(LogLevel.Warning, `No client method with the name '${methodName}' found.`); + // No handlers provided by client but the server is expecting a response still, so we send an error if (invocationMessage.invocationId) { - // This is not supported in v1. So we return an error to avoid blocking the server waiting for the response. - const message = "Server requested a response, which is not supported in this version of the client."; - this._logger.log(LogLevel.Error, message); + this._logger.log(LogLevel.Warning, `No result given for '${methodName}' method and invocation ID '${invocationMessage.invocationId}'.`); + await this._sendWithProtocol(this._createCompletionMessage(invocationMessage.invocationId, "Client didn't provide a result.", null)); + } + return; + } + + // Avoid issues with handlers removing themselves thus modifying the list while iterating through it + const methodsCopy = methods.slice(); - // We don't want to wait on the stop itself. - this._stopPromise = this._stopInternal(new Error(message)); + // Server expects a response + const expectsResponse = invocationMessage.invocationId ? true : false; + // We preserve the last result or exception but still call all handlers + let res; + let exception; + let completionMessage; + for (const m of methodsCopy) { + try { + const prevRes = res; + res = await m.apply(this, invocationMessage.arguments); + if (expectsResponse && res && prevRes) { + this._logger.log(LogLevel.Error, `Multiple results provided for '${methodName}'. Sending error to server.`); + completionMessage = this._createCompletionMessage(invocationMessage.invocationId!, `Client provided multiple results.`, null); + } + // Ignore exception if we got a result after, the exception will be logged + exception = undefined; + } catch (e) { + exception = e; + this._logger.log(LogLevel.Error, `A callback for the method '${methodName}' threw error '${e}'.`); } + } + if (completionMessage) { + await this._sendWithProtocol(completionMessage); + } else if (expectsResponse) { + // If there is an exception that means either no result was given or a handler after a result threw + if (exception) { + completionMessage = this._createCompletionMessage(invocationMessage.invocationId!, `${exception}`, null); + } else if (res !== undefined) { + completionMessage = this._createCompletionMessage(invocationMessage.invocationId!, null, res); + } else { + this._logger.log(LogLevel.Warning, `No result given for '${methodName}' method and invocation ID '${invocationMessage.invocationId}'.`); + // Client didn't provide a result or throw from a handler, server expects a response so we send an error + completionMessage = this._createCompletionMessage(invocationMessage.invocationId!, "Client didn't provide a result.", null); + } + await this._sendWithProtocol(completionMessage); } else { - this._logger.log(LogLevel.Warning, `No client method with the name '${invocationMessage.target}' found.`); + if (res) { + this._logger.log(LogLevel.Error, `Result given for '${methodName}' method but server is not expecting a result.`); + } } } diff --git a/src/SignalR/clients/ts/signalr/tests/HubConnection.test.ts b/src/SignalR/clients/ts/signalr/tests/HubConnection.test.ts index 18b5a6de3c09..4b8614b5ed0c 100644 --- a/src/SignalR/clients/ts/signalr/tests/HubConnection.test.ts +++ b/src/SignalR/clients/ts/signalr/tests/HubConnection.test.ts @@ -650,7 +650,8 @@ describe("HubConnection", () => { type: MessageType.Invocation, }); - expect(warnings).toEqual(["No client method with the name 'message' found."]); + expect(warnings).toEqual(["No client method with the name 'message' found.", + "No result given for 'message' method and invocation ID '0'."]); } finally { await hubConnection.stop(); } @@ -665,8 +666,12 @@ describe("HubConnection", () => { await hubConnection.start(); let count = 0; + const p = new PromiseSource(); const handler = () => { count++; }; - const secondHandler = () => { count++; }; + const secondHandler = () => { + count++; + p.resolve(); + }; hubConnection.on("inc", handler); hubConnection.on("inc", secondHandler); @@ -677,6 +682,7 @@ describe("HubConnection", () => { type: MessageType.Invocation, }); + await p; hubConnection.off("inc"); connection.receive({ @@ -701,8 +707,12 @@ describe("HubConnection", () => { await hubConnection.start(); let count = 0; + let p = new PromiseSource(); const handler = () => { count++; }; - const secondHandler = () => { count++; }; + const secondHandler = () => { + count++; + p.resolve(); + }; hubConnection.on("inc", handler); hubConnection.on("inc", secondHandler); @@ -713,6 +723,8 @@ describe("HubConnection", () => { type: MessageType.Invocation, }); + await p; + p = new PromiseSource(); hubConnection.off("inc", handler); connection.receive({ @@ -722,6 +734,8 @@ describe("HubConnection", () => { type: MessageType.Invocation, }); + await p; + expect(count).toBe(3); } finally { await hubConnection.stop(); @@ -755,7 +769,7 @@ describe("HubConnection", () => { }); }); - it("callback invoked when servers invokes a method on the client", async () => { + it("callback invoked when server invokes a method on the client", async () => { await VerifyLogger.run(async (logger) => { const connection = new TestConnection(); const hubConnection = createHubConnection(connection, logger); @@ -763,7 +777,11 @@ describe("HubConnection", () => { await hubConnection.start(); let value = ""; - hubConnection.on("message", (v) => value = v); + const p = new PromiseSource(); + hubConnection.on("message", (v) => { + value = v; + p.resolve(); + }); connection.receive({ arguments: ["test"], @@ -772,6 +790,7 @@ describe("HubConnection", () => { type: MessageType.Invocation, }); + await p; expect(value).toBe("test"); } finally { await hubConnection.stop(); @@ -890,8 +909,12 @@ describe("HubConnection", () => { let numInvocations1 = 0; let numInvocations2 = 0; + const p = new PromiseSource(); hubConnection.on("message", () => numInvocations1++); - hubConnection.on("message", () => numInvocations2++); + hubConnection.on("message", () => { + numInvocations2++; + p.resolve(); + }); connection.receive({ arguments: [], @@ -900,6 +923,7 @@ describe("HubConnection", () => { type: MessageType.Invocation, }); + await p; expect(numInvocations1).toBe(1); expect(numInvocations2).toBe(1); } finally { @@ -957,7 +981,11 @@ describe("HubConnection", () => { hubConnection.off(eventToTrack, callback1); numInvocations1++; } - const callback2 = () => numInvocations2++; + let p = new PromiseSource(); + const callback2 = () => { + numInvocations2++; + p.resolve(); + }; hubConnection.on(eventToTrack, callback1); hubConnection.on(eventToTrack, callback2); @@ -969,6 +997,8 @@ describe("HubConnection", () => { type: MessageType.Invocation, }); + await p; + p = new PromiseSource(); expect(numInvocations1).toBe(1); expect(numInvocations2).toBe(1); @@ -979,6 +1009,7 @@ describe("HubConnection", () => { type: MessageType.Invocation, }); + await p; expect(numInvocations1).toBe(1); expect(numInvocations2).toBe(2); } @@ -1035,7 +1066,8 @@ describe("HubConnection", () => { type: MessageType.Invocation, }); - expect(warnings).toEqual(["No client method with the name 'message' found."]); + expect(warnings).toEqual(["No client method with the name 'message' found.", + "No result given for 'message' method and invocation ID '0'."]); hubConnection.off(null!, undefined!); hubConnection.off(undefined!, null!); @@ -1048,6 +1080,318 @@ describe("HubConnection", () => { } }); }); + + it("can return result from callback", async () => { + await VerifyLogger.run(async (logger) => { + const connection = new TestConnection(); + const hubConnection = createHubConnection(connection, logger); + try { + await hubConnection.start(); + + hubConnection.on("message", () => 10); + + connection.receive({ + arguments: [], + invocationId: "1", + nonblocking: true, + target: "message", + type: MessageType.Invocation, + }); + + // nothing to wait on and the code is all synchronous, but because of how JS and async works we need to trigger + // async here to guarantee the sent message is written + await delayUntil(1); + + expect(connection.parsedSentData.length).toEqual(2); + expect(connection.parsedSentData[1].type).toEqual(3); + expect(connection.parsedSentData[1].result).toEqual(10); + expect(connection.parsedSentData[1].invocationId).toEqual("1"); + } finally { + await hubConnection.stop(); + } + }); + }); + + it("can return null result from callback", async () => { + await VerifyLogger.run(async (logger) => { + const connection = new TestConnection(); + const hubConnection = createHubConnection(connection, logger); + try { + await hubConnection.start(); + + hubConnection.on("message", () => null); + + connection.receive({ + arguments: [], + invocationId: "1", + nonblocking: true, + target: "message", + type: MessageType.Invocation, + }); + + // nothing to wait on and the code is all synchronous, but because of how JS and async works we need to trigger + // async here to guarantee the sent message is written + await delayUntil(1); + + expect(connection.parsedSentData.length).toEqual(2); + expect(connection.parsedSentData[1].type).toEqual(3); + expect(connection.parsedSentData[1].result).toBeNull(); + expect(connection.parsedSentData[1].invocationId).toEqual("1"); + } finally { + await hubConnection.stop(); + } + }); + }); + + it("can return task result from callback", async () => { + await VerifyLogger.run(async (logger) => { + const connection = new TestConnection(); + const hubConnection = createHubConnection(connection, logger); + try { + await hubConnection.start(); + + const p = new PromiseSource(); + hubConnection.on("message", () => p); + + connection.receive({ + arguments: [], + invocationId: "1", + nonblocking: true, + target: "message", + type: MessageType.Invocation, + }); + p.resolve(13); + + // nothing to wait on and the code is all synchronous, but because of how JS and async works we need to trigger + // async here to guarantee the sent message is written + await delayUntil(1); + + expect(connection.parsedSentData.length).toEqual(2); + expect(connection.parsedSentData[1].type).toEqual(3); + expect(connection.parsedSentData[1].result).toEqual(13); + expect(connection.parsedSentData[1].invocationId).toEqual("1"); + } finally { + await hubConnection.stop(); + } + }); + }); + + it("can throw from callback when expecting result", async () => { + await VerifyLogger.run(async (logger) => { + const connection = new TestConnection(); + const hubConnection = createHubConnection(connection, logger); + try { + await hubConnection.start(); + + hubConnection.on("message", () => { throw new Error("from callback"); }); + + connection.receive({ + arguments: [], + invocationId: "1", + nonblocking: true, + target: "message", + type: MessageType.Invocation, + }); + + // nothing to wait on and the code is all synchronous, but because of how JS and async works we need to trigger + // async here to guarantee the sent message is written + await delayUntil(1); + + expect(connection.parsedSentData.length).toEqual(2); + expect(connection.parsedSentData[1].type).toEqual(3); + expect(connection.parsedSentData[1].error).toEqual("Error: from callback"); + expect(connection.parsedSentData[1].invocationId).toEqual("1"); + } finally { + await hubConnection.stop(); + } + }, "A callback for the method 'message' threw error 'Error: from callback'."); + }); + + it("multiple results sends error", async () => { + await VerifyLogger.run(async (logger) => { + const connection = new TestConnection(); + const hubConnection = createHubConnection(connection, logger); + try { + await hubConnection.start(); + + hubConnection.on("message", () => 3); + hubConnection.on("message", () => 4); + + connection.receive({ + arguments: [], + invocationId: "1", + nonblocking: true, + target: "message", + type: MessageType.Invocation, + }); + + // nothing to wait on and the code is all synchronous, but because of how JS and async works we need to trigger + // async here to guarantee the sent message is written + await delayUntil(1); + + expect(connection.parsedSentData.length).toEqual(2); + expect(connection.parsedSentData[1].type).toEqual(3); + expect(connection.parsedSentData[1].error).toEqual('Client provided multiple results.'); + expect(connection.parsedSentData[1].invocationId).toEqual("1"); + } finally { + await hubConnection.stop(); + } + }, "Multiple results provided for 'message'. Sending error to server."); + }); + + it("multiple result handlers error from last one sent", async () => { + await VerifyLogger.run(async (logger) => { + const connection = new TestConnection(); + const hubConnection = createHubConnection(connection, logger); + try { + await hubConnection.start(); + + hubConnection.on("message", () => 3); + hubConnection.on("message", () => { throw new Error("from callback"); }); + + connection.receive({ + arguments: [], + invocationId: "1", + nonblocking: true, + target: "message", + type: MessageType.Invocation, + }); + + // nothing to wait on and the code is all synchronous, but because of how JS and async works we need to trigger + // async here to guarantee the sent message is written + await delayUntil(1); + + expect(connection.parsedSentData.length).toEqual(2); + expect(connection.parsedSentData[1].type).toEqual(3); + expect(connection.parsedSentData[1].error).toEqual("Error: from callback"); + expect(connection.parsedSentData[1].result).toBeUndefined(); + expect(connection.parsedSentData[1].invocationId).toEqual("1"); + } finally { + await hubConnection.stop(); + } + }, "A callback for the method 'message' threw error 'Error: from callback'."); + }); + + it("multiple result handlers ignore error if last one has result", async () => { + await VerifyLogger.run(async (logger) => { + const connection = new TestConnection(); + const hubConnection = createHubConnection(connection, logger); + try { + await hubConnection.start(); + + hubConnection.on("message", () => { throw new Error("from callback"); }); + hubConnection.on("message", () => 3); + + connection.receive({ + arguments: [], + invocationId: "1", + nonblocking: true, + target: "message", + type: MessageType.Invocation, + }); + + // nothing to wait on and the code is all synchronous, but because of how JS and async works we need to trigger + // async here to guarantee the sent message is written + await delayUntil(1); + + expect(connection.parsedSentData.length).toEqual(2); + expect(connection.parsedSentData[1].type).toEqual(3); + expect(connection.parsedSentData[1].result).toEqual(3); + expect(connection.parsedSentData[1].error).toBeUndefined(); + expect(connection.parsedSentData[1].invocationId).toEqual("1"); + } finally { + await hubConnection.stop(); + } + }, "A callback for the method 'message' threw error 'Error: from callback'."); + }); + + it("sends completion error if return result expected but not returned", async () => { + await VerifyLogger.run(async (logger) => { + const connection = new TestConnection(); + const hubConnection = createHubConnection(connection, logger); + try { + await hubConnection.start(); + + hubConnection.on("message", () => {}); + + connection.receive({ + arguments: [], + invocationId: "1", + nonblocking: true, + target: "message", + type: MessageType.Invocation, + }); + + // nothing to wait on and the code is all synchronous, but because of how JS and async works we need to trigger + // async here to guarantee the sent message is written + await delayUntil(1); + + expect(connection.parsedSentData.length).toEqual(2); + expect(connection.parsedSentData[1].type).toEqual(3); + expect(connection.parsedSentData[1].error).toEqual("Client didn't provide a result."); + expect(connection.parsedSentData[1].invocationId).toEqual("1"); + } finally { + await hubConnection.stop(); + } + }); + }); + + it("sends completion error if return result expected but no handlers", async () => { + await VerifyLogger.run(async (logger) => { + const connection = new TestConnection(); + const hubConnection = createHubConnection(connection, logger); + try { + await hubConnection.start(); + + connection.receive({ + arguments: [], + invocationId: "1", + nonblocking: true, + target: "message", + type: MessageType.Invocation, + }); + + // nothing to wait on and the code is all synchronous, but because of how JS and async works we need to trigger + // async here to guarantee the sent message is written + await delayUntil(1); + + expect(connection.parsedSentData.length).toEqual(2); + expect(connection.parsedSentData[1].type).toEqual(3); + expect(connection.parsedSentData[1].error).toEqual("Client didn't provide a result."); + expect(connection.parsedSentData[1].invocationId).toEqual("1"); + } finally { + await hubConnection.stop(); + } + }); + }); + + it("logs error if return result not expected", async () => { + await VerifyLogger.run(async (logger) => { + const connection = new TestConnection(); + const hubConnection = createHubConnection(connection, logger); + try { + await hubConnection.start(); + + hubConnection.on("message", () => 13); + + connection.receive({ + arguments: [], + invocationId: undefined, + nonblocking: true, + target: "message", + type: MessageType.Invocation, + }); + + // nothing to wait on and the code is all synchronous, but because of how JS and async works we need to trigger + // async here to guarantee the sent message is written + await delayUntil(1); + + expect(connection.parsedSentData.length).toEqual(1); + } finally { + await hubConnection.stop(); + } + }, "Result given for 'message' method but server is not expecting a result."); + }); }); describe("stream", () => { diff --git a/src/SignalR/common/Protocols.Json/src/Protocol/JsonHubProtocol.cs b/src/SignalR/common/Protocols.Json/src/Protocol/JsonHubProtocol.cs index f35cba60c611..bed3c046ffc5 100644 --- a/src/SignalR/common/Protocols.Json/src/Protocol/JsonHubProtocol.cs +++ b/src/SignalR/common/Protocols.Json/src/Protocol/JsonHubProtocol.cs @@ -200,8 +200,6 @@ public ReadOnlyMemory GetMessageBytes(HubMessage message) { hasResult = true; - reader.CheckRead(); - if (string.IsNullOrEmpty(invocationId)) { // If we don't have an invocation id then we need to value copy the reader so we can parse it later @@ -213,7 +211,7 @@ public ReadOnlyMemory GetMessageBytes(HubMessage message) { // If we have an invocation id already we can parse the end result var returnType = binder.GetReturnType(invocationId); - result = BindType(ref reader, returnType); + result = BindType(ref reader, input, returnType); } } else if (reader.ValueTextEquals(ItemPropertyNameBytes.EncodedUtf8Bytes)) @@ -391,7 +389,7 @@ public ReadOnlyMemory GetMessageBytes(HubMessage message) if (hasResultToken) { var returnType = binder.GetReturnType(invocationId); - result = BindType(ref resultToken, returnType); + result = BindType(ref resultToken, input, returnType); } message = BindCompletionMessage(invocationId, error, result, hasResult); @@ -537,7 +535,22 @@ private void WriteCompletionMessage(CompletionMessage message, Utf8JsonWriter wr } else { - JsonSerializer.Serialize(writer, message.Result, message.Result.GetType(), _payloadSerializerOptions); + if (message.Result is RawResult result) + { + if (result.RawSerializedData.IsSingleSegment) + { + writer.WriteRawValue(result.RawSerializedData.First.Span, skipInputValidation: true); + } + else + { + // https://github.com/dotnet/runtime/issues/68223 + writer.WriteRawValue(result.RawSerializedData.ToArray(), skipInputValidation: true); + } + } + else + { + JsonSerializer.Serialize(writer, message.Result, message.Result.GetType(), _payloadSerializerOptions); + } } } } @@ -724,6 +737,22 @@ private static HubMessage BindInvocationMessage(string? invocationId, string tar return new InvocationMessage(invocationId, target, arguments, streamIds); } + private object? BindType(ref Utf8JsonReader reader, ReadOnlySequence input, Type type) + { + if (type == typeof(RawResult)) + { + var start = reader.BytesConsumed; + reader.Skip(); + var end = reader.BytesConsumed; + var sequence = input.Slice(start, end - start); + // Review: Technically the sequence doesn't need to be copied to a new array in RawResult + // but in the future we could break this if we dispatched the CompletionMessage and the underlying Pipe read would be advanced + // instead we could try pooling in RawResult, but it would need release/dispose semantics + return new RawResult(sequence); + } + return BindType(ref reader, type); + } + private object? BindType(ref Utf8JsonReader reader, Type type) { return JsonSerializer.Deserialize(ref reader, type, _payloadSerializerOptions); diff --git a/src/SignalR/common/Protocols.MessagePack/src/Protocol/MessagePackHubProtocolWorker.cs b/src/SignalR/common/Protocols.MessagePack/src/Protocol/MessagePackHubProtocolWorker.cs index b12baeb3b14c..e665b67aa629 100644 --- a/src/SignalR/common/Protocols.MessagePack/src/Protocol/MessagePackHubProtocolWorker.cs +++ b/src/SignalR/common/Protocols.MessagePack/src/Protocol/MessagePackHubProtocolWorker.cs @@ -162,7 +162,14 @@ private CompletionMessage CreateCompletionMessage(ref MessagePackReader reader, break; case NonVoidResult: var itemType = binder.GetReturnType(invocationId); - result = DeserializeObject(ref reader, itemType, "argument"); + if (itemType == typeof(RawResult)) + { + result = new RawResult(reader.ReadRaw()); + } + else + { + result = DeserializeObject(ref reader, itemType, "argument"); + } hasResult = true; break; case VoidResult: @@ -434,6 +441,10 @@ private void WriteArgument(object? argument, ref MessagePackWriter writer) { writer.WriteNil(); } + else if (argument is RawResult result) + { + writer.WriteRaw(result.RawSerializedData); + } else { Serialize(ref writer, argument.GetType(), argument); diff --git a/src/SignalR/common/Protocols.NewtonsoftJson/src/Microsoft.AspNetCore.SignalR.Protocols.NewtonsoftJson.csproj b/src/SignalR/common/Protocols.NewtonsoftJson/src/Microsoft.AspNetCore.SignalR.Protocols.NewtonsoftJson.csproj index 51d392697c8d..2167aa456d47 100644 --- a/src/SignalR/common/Protocols.NewtonsoftJson/src/Microsoft.AspNetCore.SignalR.Protocols.NewtonsoftJson.csproj +++ b/src/SignalR/common/Protocols.NewtonsoftJson/src/Microsoft.AspNetCore.SignalR.Protocols.NewtonsoftJson.csproj @@ -14,6 +14,7 @@ + diff --git a/src/SignalR/common/Protocols.NewtonsoftJson/src/Protocol/NewtonsoftJsonHubProtocol.cs b/src/SignalR/common/Protocols.NewtonsoftJson/src/Protocol/NewtonsoftJsonHubProtocol.cs index 694831ae5e8c..51eb9310dd99 100644 --- a/src/SignalR/common/Protocols.NewtonsoftJson/src/Protocol/NewtonsoftJsonHubProtocol.cs +++ b/src/SignalR/common/Protocols.NewtonsoftJson/src/Protocol/NewtonsoftJsonHubProtocol.cs @@ -8,6 +8,7 @@ using System.Diagnostics.CodeAnalysis; using System.IO; using System.Runtime.ExceptionServices; +using System.Text; using Microsoft.AspNetCore.Connections; using Microsoft.AspNetCore.Internal; using Microsoft.AspNetCore.SignalR.Internal; @@ -214,7 +215,15 @@ public ReadOnlyMemory GetMessageBytes(HubMessage message) throw new JsonReaderException("Unexpected end when reading JSON"); } - result = PayloadSerializer.Deserialize(reader, returnType); + if (returnType == typeof(RawResult)) + { + var token = JToken.Load(reader); + result = GetRawResult(token); + } + else + { + result = PayloadSerializer.Deserialize(reader, returnType); + } } break; case ItemPropertyName: @@ -388,7 +397,14 @@ public ReadOnlyMemory GetMessageBytes(HubMessage message) if (resultToken != null) { var returnType = binder.GetReturnType(invocationId); - result = resultToken.ToObject(returnType, PayloadSerializer); + if (returnType == typeof(RawResult)) + { + result = GetRawResult(resultToken); + } + else + { + result = resultToken.ToObject(returnType, PayloadSerializer); + } } message = BindCompletionMessage(invocationId, error, result, hasResult); @@ -531,7 +547,18 @@ private void WriteCompletionMessage(CompletionMessage message, JsonTextWriter wr else if (message.HasResult) { writer.WritePropertyName(ResultPropertyName); - PayloadSerializer.Serialize(writer, message.Result); + if (message.Result is RawResult result) + { +#if NETCOREAPP2_1_OR_GREATER + writer.WriteRawValue(Encoding.UTF8.GetString(result.RawSerializedData)); +#else + writer.WriteRawValue(Encoding.UTF8.GetString(result.RawSerializedData.ToArray())); +#endif + } + else + { + PayloadSerializer.Serialize(writer, message.Result); + } } } @@ -823,6 +850,23 @@ private static HubMessage ApplyHeaders(HubMessage message, Dictionary(strm.ToArray())); + } + finally + { + MemoryBufferWriter.Return(strm); + } + } internal static JsonSerializerSettings CreateDefaultSerializerSettings() { return new JsonSerializerSettings { ContractResolver = new CamelCasePropertyNamesContractResolver() }; diff --git a/src/SignalR/common/Shared/ClientResultsManager.cs b/src/SignalR/common/Shared/ClientResultsManager.cs new file mode 100644 index 000000000000..68be58269ebc --- /dev/null +++ b/src/SignalR/common/Shared/ClientResultsManager.cs @@ -0,0 +1,193 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. + +using System.Collections.Concurrent; +using System.Diagnostics; +using System.Diagnostics.CodeAnalysis; +using System.Linq; +using Microsoft.AspNetCore.SignalR.Protocol; + +namespace Microsoft.AspNetCore.SignalR.Internal; + +// Common type used by our HubLifetimeManager implementations to manage client results. +// Handles cancellation, cleanup, and completion, so any bugs or improvements can be made in a single place +internal class ClientResultsManager : IInvocationBinder +{ + private readonly ConcurrentDictionary Complete)> _pendingInvocations = new(); + + public Task AddInvocation(string connectionId, string invocationId, CancellationToken cancellationToken) + { + var tcs = new TaskCompletionSourceWithCancellation(this, connectionId, invocationId, cancellationToken); + var result = _pendingInvocations.TryAdd(invocationId, (typeof(T), connectionId, tcs, static (state, completionMessage) => + { + var tcs = (TaskCompletionSourceWithCancellation)state; + if (completionMessage.HasResult) + { + tcs.SetResult((T)completionMessage.Result); + } + else + { + tcs.SetException(new Exception(completionMessage.Error)); + } + } + )); + Debug.Assert(result); + + tcs.RegisterCancellation(); + + return tcs.Task; + } + + public void AddInvocation(string invocationId, (Type Type, string ConnectionId, object Tcs, Action Complete) invocationInfo) + { + var result = _pendingInvocations.TryAdd(invocationId, invocationInfo); + Debug.Assert(result); + } + + public void TryCompleteResult(string connectionId, CompletionMessage message) + { + if (_pendingInvocations.TryGetValue(message.InvocationId!, out var item)) + { + if (item.ConnectionId != connectionId) + { + throw new InvalidOperationException($"Connection ID '{connectionId}' is not valid for invocation ID '{message.InvocationId}'."); + } + + // if false the connection disconnected right after the above TryGetValue + // or someone else completed the invocation (likely a bad client) + // we'll ignore both cases + if (_pendingInvocations.Remove(message.InvocationId!, out _)) + { + item.Complete(item.Tcs, message); + } + } + else + { + // connection was disconnected or someone else completed the invocation + } + } + + public (Type Type, string ConnectionId, object Tcs, Action Completion)? RemoveInvocation(string invocationId) + { + _pendingInvocations.TryRemove(invocationId, out var item); + return item; + } + + public bool TryGetType(string invocationId, [NotNullWhen(true)] out Type? type) + { + if (_pendingInvocations.TryGetValue(invocationId, out var item)) + { + type = item.Type; + return true; + } + type = null; + return false; + } + + public Type GetReturnType(string invocationId) + { + if (TryGetType(invocationId, out var type)) + { + return type; + } + throw new InvalidOperationException($"Invocation ID '{invocationId}' is not associated with a pending client result."); + } + + // Unused, here to honor the IInvocationBinder interface but should never be called + public IReadOnlyList GetParameterTypes(string methodName) + { + throw new NotImplementedException(); + } + + // Unused, here to honor the IInvocationBinder interface but should never be called + public Type GetStreamItemType(string streamId) + { + throw new NotImplementedException(); + } + + // Custom TCS type to avoid the extra allocation that would be introduced if we managed the cancellation separately + // Also makes it easier to keep track of the CancellationTokenRegistration for disposal + internal sealed class TaskCompletionSourceWithCancellation : TaskCompletionSource + { + private readonly ClientResultsManager _clientResultsManager; + private readonly string _connectionId; + private readonly string _invocationId; + private readonly CancellationToken _token; + + private CancellationTokenRegistration _tokenRegistration; + + public TaskCompletionSourceWithCancellation(ClientResultsManager clientResultsManager, string connectionId, string invocationId, + CancellationToken cancellationToken) + : base(TaskCreationOptions.RunContinuationsAsynchronously) + { + _clientResultsManager = clientResultsManager; + _connectionId = connectionId; + _invocationId = invocationId; + _token = cancellationToken; + } + + // Needs to be called after adding the completion to the dictionary in order to avoid synchronous completions of the token registration + // not canceling when the dictionary hasn't been updated yet. + public void RegisterCancellation() + { + if (_token.CanBeCanceled) + { + _tokenRegistration = _token.UnsafeRegister(static o => + { + var tcs = (TaskCompletionSourceWithCancellation)o!; + tcs.SetCanceled(); + }, this); + } + } + + public new void SetCanceled() + { + // TODO: RedisHubLifetimeManager will want to notify the other server (if there is one) about the cancellation + // so it can clean up state and potentially forward that info to the connection + _clientResultsManager.TryCompleteResult(_connectionId, CompletionMessage.WithError(_invocationId, "Canceled")); + } + + public new void SetResult(T result) + { + _tokenRegistration.Dispose(); + base.SetResult(result); + } + + public new void SetException(Exception exception) + { + _tokenRegistration.Dispose(); + base.SetException(exception); + } + +#pragma warning disable IDE0060 // Remove unused parameter + // Just making sure we don't accidentally call one of these without knowing + public static new void SetCanceled(CancellationToken cancellationToken) => Debug.Assert(false); + public static new void SetException(IEnumerable exceptions) => Debug.Assert(false); + public static new bool TrySetCanceled() + { + Debug.Assert(false); + return false; + } + public static new bool TrySetCanceled(CancellationToken cancellationToken) + { + Debug.Assert(false); + return false; + } + public static new bool TrySetException(IEnumerable exceptions) + { + Debug.Assert(false); + return false; + } + public static new bool TrySetException(Exception exception) + { + Debug.Assert(false); + return false; + } + public static new bool TrySetResult(T result) + { + Debug.Assert(false); + return false; + } +#pragma warning restore IDE0060 // Remove unused parameter + } +} diff --git a/src/SignalR/common/Shared/CreateLinkedToken.cs b/src/SignalR/common/Shared/CreateLinkedToken.cs new file mode 100644 index 000000000000..198bde588696 --- /dev/null +++ b/src/SignalR/common/Shared/CreateLinkedToken.cs @@ -0,0 +1,33 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. + +using System; +using System.Threading; + +namespace Microsoft.AspNetCore.SignalR.Internal; + +internal static class CancellationTokenUtils +{ + // Similar to CreateLinkedTokenSource except it will not allocate a new internal LinkedCancellationTokenSource in the case where + // one of the tokens passed in isn't cancellable. + // Returns a disposable only when an actual LinkedTokenSource is created. + internal static IDisposable? CreateLinkedToken(CancellationToken token1, CancellationToken token2, out CancellationToken linkedToken) + { + if (!token1.CanBeCanceled) + { + linkedToken = token2; + return null; + } + else if (!token2.CanBeCanceled) + { + linkedToken = token1; + return null; + } + else + { + var cts = CancellationTokenSource.CreateLinkedTokenSource(token1, token2); + linkedToken = cts.Token; + return cts; + } + } +} diff --git a/src/SignalR/common/Shared/MemoryBufferWriter.cs b/src/SignalR/common/Shared/MemoryBufferWriter.cs index ca4dc6407ad6..999c3143b4e6 100644 --- a/src/SignalR/common/Shared/MemoryBufferWriter.cs +++ b/src/SignalR/common/Shared/MemoryBufferWriter.cs @@ -326,6 +326,25 @@ public override void Write(ReadOnlySpan span) } #endif + public WrittenBuffers DetachAndReset() + { + _completedSegments ??= new List(); + + if (_currentSegment is not null) + { + _completedSegments.Add(new CompletedBuffer(_currentSegment, _position)); + } + + var written = new WrittenBuffers(_completedSegments, _bytesWritten); + + _currentSegment = null; + _completedSegments = null; + _bytesWritten = 0; + _position = 0; + + return written; + } + protected override void Dispose(bool disposing) { if (disposing) @@ -334,10 +353,37 @@ protected override void Dispose(bool disposing) } } + /// + /// Holds the written segments from a MemoryBufferWriter and is no longer attached to a MemoryBufferWriter. + /// You are now responsible for calling Dispose on this type to return the memory to the pool. + /// + internal readonly ref struct WrittenBuffers + { + public readonly List Segments; + private readonly int _bytesWritten; + + public WrittenBuffers(List segments, int bytesWritten) + { + Segments = segments; + _bytesWritten = bytesWritten; + } + + public int ByteLength => _bytesWritten; + + public void Dispose() + { + for (var i = 0; i < Segments.Count; i++) + { + Segments[i].Return(); + } + Segments.Clear(); + } + } + /// /// Holds a byte[] from the pool and a size value. Basically a Memory but guaranteed to be backed by an ArrayPool byte[], so that we know we can return it. /// - private readonly struct CompletedBuffer + internal readonly struct CompletedBuffer { public byte[] Buffer { get; } public int Length { get; } diff --git a/src/SignalR/common/SignalR.Common/src/Protocol/CompletionMessage.cs b/src/SignalR/common/SignalR.Common/src/Protocol/CompletionMessage.cs index 1645e7f3ed5e..440e431c6bb8 100644 --- a/src/SignalR/common/SignalR.Common/src/Protocol/CompletionMessage.cs +++ b/src/SignalR/common/SignalR.Common/src/Protocol/CompletionMessage.cs @@ -36,7 +36,7 @@ public class CompletionMessage : HubInvocationMessage public CompletionMessage(string invocationId, string? error, object? result, bool hasResult) : base(invocationId) { - if (error != null && result != null) + if (error is not null && hasResult) { throw new ArgumentException($"Expected either '{nameof(error)}' or '{nameof(result)}' to be provided, but not both"); } diff --git a/src/SignalR/common/SignalR.Common/src/Protocol/RawResult.cs b/src/SignalR/common/SignalR.Common/src/Protocol/RawResult.cs new file mode 100644 index 000000000000..7431df26cef7 --- /dev/null +++ b/src/SignalR/common/SignalR.Common/src/Protocol/RawResult.cs @@ -0,0 +1,37 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. + +using System; +using System.Buffers; +using System.Collections.Generic; +using System.Linq; +using System.Text; +using System.Threading.Tasks; + +namespace Microsoft.AspNetCore.SignalR.Protocol; + +/// +/// Type returned to implementations to let them know the object being deserialized should be +/// stored as raw serialized bytes in the format of the protocol being used. +/// +/// +/// In Json that would mean storing the byte representation of ascii {"prop":10} as an example. +/// +public sealed class RawResult +{ + /// + /// Stores the raw serialized bytes of a for forwarding to another server. + /// Will copy the passed in bytes to internal storage. + /// + /// The raw bytes from the client. + public RawResult(ReadOnlySequence rawBytes) + { + // Review: If we want to use an ArrayPool we would need some sort of release mechanism + RawSerializedData = new ReadOnlySequence(rawBytes.ToArray()); + } + + /// + /// The raw serialized bytes from the client. + /// + public ReadOnlySequence RawSerializedData { get; private set; } +} diff --git a/src/SignalR/common/SignalR.Common/src/PublicAPI.Unshipped.txt b/src/SignalR/common/SignalR.Common/src/PublicAPI.Unshipped.txt index 7dc5c58110bf..76fd06c813c2 100644 --- a/src/SignalR/common/SignalR.Common/src/PublicAPI.Unshipped.txt +++ b/src/SignalR/common/SignalR.Common/src/PublicAPI.Unshipped.txt @@ -1 +1,4 @@ #nullable enable +Microsoft.AspNetCore.SignalR.Protocol.RawResult +Microsoft.AspNetCore.SignalR.Protocol.RawResult.RawResult(System.Buffers.ReadOnlySequence rawBytes) -> void +Microsoft.AspNetCore.SignalR.Protocol.RawResult.RawSerializedData.get -> System.Buffers.ReadOnlySequence diff --git a/src/SignalR/common/SignalR.Common/test/Internal/Protocol/JsonHubProtocolTestsBase.cs b/src/SignalR/common/SignalR.Common/test/Internal/Protocol/JsonHubProtocolTestsBase.cs index 66fd2de266e1..deccbb188e14 100644 --- a/src/SignalR/common/SignalR.Common/test/Internal/Protocol/JsonHubProtocolTestsBase.cs +++ b/src/SignalR/common/SignalR.Common/test/Internal/Protocol/JsonHubProtocolTestsBase.cs @@ -387,6 +387,67 @@ public void VerifyMessageSize(string testDataName) } } + public static IDictionary ClientResultData => new[] + { + new ClientResultTestData("SimpleResult", "{\"type\":3,\"invocationId\":\"1\",\"result\":45}", typeof(int), 45), + new ClientResultTestData("SimpleResult_InvocationIdLast", "{\"type\":3,\"result\":45,\"invocationId\":\"1\"}", typeof(int), 45), + new ClientResultTestData("MissingResult", "{\"type\":3,\"invocationId\":\"1\"}", typeof(int), null), + + new ClientResultTestData("ComplexResult", "{\"type\":3,\"invocationId\":\"1\",\"result\":{\"stringProp\":\"test\",\"doubleProp\":1.1,\"intProp\":0,\"dateTimeProp\":\"0001-01-01T00:00:00\",\"nullProp\":null,\"byteArrProp\":\"AgQG\"}}", typeof(CustomObject), + new CustomObject() + { + ByteArrProp = new byte[] { 2, 4, 6 }, + IntProp = default, + DoubleProp = 1.1, + StringProp = "test", + DateTimeProp = default + }), + new ClientResultTestData("ComplexResult_InvocationIdLast", "{\"type\":3,\"result\":{\"stringProp\":\"test\",\"doubleProp\":1.1,\"intProp\":0,\"dateTimeProp\":\"0001-01-01T00:00:00\",\"nullProp\":null,\"byteArrProp\":\"AgQG\"},\"invocationId\":\"1\"}", typeof(CustomObject), + new CustomObject() + { + ByteArrProp = new byte[] { 2, 4, 6 }, + IntProp = default, + DoubleProp = 1.1, + StringProp = "test", + DateTimeProp = default + }), + }.ToDictionary(t => t.Name); + + public static IEnumerable ClientResultDataNames => ClientResultData.Keys.Select(name => new object[] { name }); + + [Theory] + [MemberData(nameof(ClientResultDataNames))] + public void RawResultRoundTripsProperly(string testDataName) + { + var testData = ClientResultData[testDataName]; + + var binder = new TestBinder(null, typeof(RawResult)); + var input = Frame(testData.Message); + var data = new ReadOnlySequence(Encoding.UTF8.GetBytes(input)); + Assert.True(JsonHubProtocol.TryParseMessage(ref data, binder, out var message)); + var completion = Assert.IsType(message); + + var writer = MemoryBufferWriter.Get(); + try + { + // WriteMessage should handle RawResult as Raw Json and write it properly + JsonHubProtocol.WriteMessage(completion, writer); + + // Now we check if the Raw Json was written properly and can be read using the expected type + binder = new TestBinder(null, testData.ResultType); + var written = writer.ToArray(); + data = new ReadOnlySequence(written); + Assert.True(JsonHubProtocol.TryParseMessage(ref data, binder, out message)); + + completion = Assert.IsType(message); + Assert.Equal(testData.Result, completion.Result); + } + finally + { + MemoryBufferWriter.Return(writer); + } + } + public static string Frame(string input) { var data = Encoding.UTF8.GetBytes(input); @@ -436,4 +497,22 @@ public MessageSizeTestData(string name, HubMessage message, int size) public override string ToString() => Name; } + + public class ClientResultTestData + { + public string Name { get; } + public string Message { get; } + public Type ResultType { get; } + public object Result { get; } + + public ClientResultTestData(string name, string message, Type resultType, object result) + { + Name = name; + Message = message; + ResultType = resultType; + Result = result; + } + + public override string ToString() => Name; + } } diff --git a/src/SignalR/common/SignalR.Common/test/Internal/Protocol/MessagePackHubProtocolTests.cs b/src/SignalR/common/SignalR.Common/test/Internal/Protocol/MessagePackHubProtocolTests.cs index b1f6acab656b..978d729cc4d5 100644 --- a/src/SignalR/common/SignalR.Common/test/Internal/Protocol/MessagePackHubProtocolTests.cs +++ b/src/SignalR/common/SignalR.Common/test/Internal/Protocol/MessagePackHubProtocolTests.cs @@ -203,4 +203,67 @@ public void WriteMessages(string testDataName) TestWriteMessages(testData); } + + public static IDictionary ClientResultData => new[] + { + new ClientResultTestData("SimpleResult", "lQOAo3h5egMq", typeof(int), 42), + new ClientResultTestData("NullResult", "lQOAo3h5egPA", typeof(CustomObject), null), + + new ClientResultTestData("ComplexResult", "lQOAo3h5egOGqlN0cmluZ1Byb3CoU2lnbmFsUiGqRG91YmxlUHJvcMtAGSH7VELPEqdJbnRQcm9wKqxEYXRlVGltZVByb3DW/1jsHICoTnVsbFByb3DAq0J5dGVBcnJQcm9wxAMBAgM=", typeof(CustomObject), + new CustomObject()), + }.ToDictionary(t => t.Name); + + public static IEnumerable ClientResultDataNames => ClientResultData.Keys.Select(name => new object[] { name }); + + [Theory] + [MemberData(nameof(ClientResultDataNames))] + public void RawResultRoundTripsProperly(string testDataName) + { + var testData = ClientResultData[testDataName]; + var bytes = Convert.FromBase64String(testData.Message); + + var binder = new TestBinder(null, typeof(RawResult)); + var input = Frame(bytes); + var data = new ReadOnlySequence(input); + Assert.True(HubProtocol.TryParseMessage(ref data, binder, out var message)); + var completion = Assert.IsType(message); + + var writer = MemoryBufferWriter.Get(); + try + { + // WriteMessage should handle RawResult as Raw Json and write it properly + HubProtocol.WriteMessage(completion, writer); + + // Now we check if the Raw Json was written properly and can be read using the expected type + binder = new TestBinder(null, testData.ResultType); + var written = writer.ToArray(); + data = new ReadOnlySequence(written); + Assert.True(HubProtocol.TryParseMessage(ref data, binder, out message)); + + completion = Assert.IsType(message); + Assert.Equal(testData.Result, completion.Result); + } + finally + { + MemoryBufferWriter.Return(writer); + } + } + + public class ClientResultTestData + { + public string Name { get; } + public string Message { get; } + public Type ResultType { get; } + public object Result { get; } + + public ClientResultTestData(string name, string message, Type resultType, object result) + { + Name = name; + Message = message; + ResultType = resultType; + Result = result; + } + + public override string ToString() => Name; + } } diff --git a/src/SignalR/perf/Microbenchmarks/DefaultHubDispatcherBenchmark.cs b/src/SignalR/perf/Microbenchmarks/DefaultHubDispatcherBenchmark.cs index 896088fe63c5..bdbf8443d752 100644 --- a/src/SignalR/perf/Microbenchmarks/DefaultHubDispatcherBenchmark.cs +++ b/src/SignalR/perf/Microbenchmarks/DefaultHubDispatcherBenchmark.cs @@ -29,13 +29,15 @@ public void GlobalSetup() var serviceScopeFactory = provider.GetService(); + var hubLifetimeManager = new DefaultHubLifetimeManager(NullLogger>.Instance); _dispatcher = new DefaultHubDispatcher( serviceScopeFactory, - new HubContext(new DefaultHubLifetimeManager(NullLogger>.Instance)), + new HubContext(hubLifetimeManager), enableDetailedErrors: false, disableImplicitFromServiceParameters: true, new Logger>(NullLoggerFactory.Instance), - hubFilters: null); + hubFilters: null, + hubLifetimeManager); var pair = DuplexPipe.CreateConnectionPair(PipeOptions.Default, PipeOptions.Default); var connection = new DefaultConnectionContext(Guid.NewGuid().ToString(), pair.Application, pair.Transport); diff --git a/src/SignalR/perf/Microbenchmarks/RedisProtocolBenchmark.cs b/src/SignalR/perf/Microbenchmarks/RedisProtocolBenchmark.cs index 5f92a6dddb53..98c38dbd390b 100644 --- a/src/SignalR/perf/Microbenchmarks/RedisProtocolBenchmark.cs +++ b/src/SignalR/perf/Microbenchmarks/RedisProtocolBenchmark.cs @@ -45,8 +45,8 @@ public void GlobalSetup() _writtenAck = RedisProtocol.WriteAck(42); _writtenGroupCommand = RedisProtocol.WriteGroupCommand(_groupCommand); _writtenInvocationNoExclusions = _protocol.WriteInvocation(_methodName, _args, null); - _writtenInvocationSmallExclusions = _protocol.WriteInvocation(_methodName, _args, _excludedConnectionIdsSmall); - _writtenInvocationLargeExclusions = _protocol.WriteInvocation(_methodName, _args, _excludedConnectionIdsLarge); + _writtenInvocationSmallExclusions = _protocol.WriteInvocation(_methodName, _args, excludedConnectionIds: _excludedConnectionIdsSmall); + _writtenInvocationLargeExclusions = _protocol.WriteInvocation(_methodName, _args, excludedConnectionIds: _excludedConnectionIdsLarge); } [Benchmark] @@ -70,13 +70,13 @@ public void WriteInvocationNoExclusions() [Benchmark] public void WriteInvocationSmallExclusions() { - _protocol.WriteInvocation(_methodName, _args, _excludedConnectionIdsSmall); + _protocol.WriteInvocation(_methodName, _args, excludedConnectionIds: _excludedConnectionIdsSmall); } [Benchmark] public void WriteInvocationLargeExclusions() { - _protocol.WriteInvocation(_methodName, _args, _excludedConnectionIdsLarge); + _protocol.WriteInvocation(_methodName, _args, excludedConnectionIds: _excludedConnectionIdsLarge); } [Benchmark] diff --git a/src/SignalR/samples/SignalRSamples/Program.cs b/src/SignalR/samples/SignalRSamples/Program.cs index 92e56f306369..3675e5c32b37 100644 --- a/src/SignalR/samples/SignalRSamples/Program.cs +++ b/src/SignalR/samples/SignalRSamples/Program.cs @@ -25,17 +25,18 @@ public static Task Main(string[] args) { factory.AddConfiguration(c.Configuration.GetSection("Logging")); factory.AddConsole(); + //factory.SetMinimumLevel(LogLevel.Trace); }) .UseKestrel(options => { // Default port - options.ListenLocalhost(5000); + options.ListenAnyIP(0); // Hub bound to TCP end point - options.Listen(IPAddress.Any, 9001, builder => - { - builder.UseHub(); - }); + //options.Listen(IPAddress.Any, 9001, builder => + //{ + // builder.UseHub(); + //}); }) .UseContentRoot(Directory.GetCurrentDirectory()) .UseIISIntegration() diff --git a/src/SignalR/server/Core/src/ClientProxyExtensions.cs b/src/SignalR/server/Core/src/ClientProxyExtensions.cs index 6b23cf59b3c2..7b2f0c8c8b8f 100644 --- a/src/SignalR/server/Core/src/ClientProxyExtensions.cs +++ b/src/SignalR/server/Core/src/ClientProxyExtensions.cs @@ -163,7 +163,7 @@ public static Task SendAsync(this IClientProxy clientProxy, string method, objec /// The fifth argument. /// The sixth argument. /// The seventh argument. - /// The eigth argument. + /// The eighth argument. /// The token to monitor for cancellation requests. The default value is . /// A that represents the asynchronous invoke. [SuppressMessage("ApiDesign", "RS0026:Do not add multiple overloads with optional parameters", Justification = "Required to maintain compatibility")] @@ -185,7 +185,7 @@ public static Task SendAsync(this IClientProxy clientProxy, string method, objec /// The fifth argument. /// The sixth argument. /// The seventh argument. - /// The eigth argument. + /// The eighth argument. /// The ninth argument. /// The token to monitor for cancellation requests. The default value is . /// A that represents the asynchronous invoke. @@ -208,7 +208,7 @@ public static Task SendAsync(this IClientProxy clientProxy, string method, objec /// The fifth argument. /// The sixth argument. /// The seventh argument. - /// The eigth argument. + /// The eighth argument. /// The ninth argument. /// The tenth argument. /// The token to monitor for cancellation requests. The default value is . @@ -218,4 +218,202 @@ public static Task SendAsync(this IClientProxy clientProxy, string method, objec { return clientProxy.SendCoreAsync(method, new[] { arg1, arg2, arg3, arg4, arg5, arg6, arg7, arg8, arg9, arg10 }, cancellationToken); } + + /// + /// Invokes a method on the connection represented by the instance and waits for a response. + /// + /// The . + /// The name of the method to invoke. + /// The token to monitor for cancellation requests. The default value is . + /// A that represents the asynchronous invoke. + [SuppressMessage("ApiDesign", "RS0026:Do not add multiple overloads with optional parameters", Justification = "Required to maintain compatibility")] + public static Task InvokeAsync(this ISingleClientProxy clientProxy, string method, CancellationToken cancellationToken = default) + { + return clientProxy.InvokeCoreAsync(method, Array.Empty(), cancellationToken); + } + + /// + /// Invokes a method on the connection represented by the instance and waits for a response. + /// + /// The . + /// The name of the method to invoke. + /// The first argument. + /// The token to monitor for cancellation requests. The default value is . + /// A that represents the asynchronous invoke. + [SuppressMessage("ApiDesign", "RS0026:Do not add multiple overloads with optional parameters", Justification = "Required to maintain compatibility")] + public static Task InvokeAsync(this ISingleClientProxy clientProxy, string method, object? arg1, CancellationToken cancellationToken = default) + { + return clientProxy.InvokeCoreAsync(method, new[] { arg1 }, cancellationToken); + } + + /// + /// Invokes a method on the connection represented by the instance and waits for a response. + /// + /// The . + /// The name of the method to invoke. + /// The first argument. + /// The second argument. + /// The token to monitor for cancellation requests. The default value is . + /// A that represents the asynchronous invoke. + [SuppressMessage("ApiDesign", "RS0026:Do not add multiple overloads with optional parameters", Justification = "Required to maintain compatibility")] + public static Task InvokeAsync(this ISingleClientProxy clientProxy, string method, object? arg1, object? arg2, CancellationToken cancellationToken = default) + { + return clientProxy.InvokeCoreAsync(method, new[] { arg1, arg2 }, cancellationToken); + } + + /// + /// Invokes a method on the connection represented by the instance and waits for a response. + /// + /// The . + /// The name of the method to invoke. + /// The first argument. + /// The second argument. + /// The third argument. + /// The token to monitor for cancellation requests. The default value is . + /// A that represents the asynchronous invoke. + [SuppressMessage("ApiDesign", "RS0026:Do not add multiple overloads with optional parameters", Justification = "Required to maintain compatibility")] + public static Task InvokeAsync(this ISingleClientProxy clientProxy, string method, object? arg1, object? arg2, object? arg3, CancellationToken cancellationToken = default) + { + return clientProxy.InvokeCoreAsync(method, new[] { arg1, arg2, arg3 }, cancellationToken); + } + + /// + /// Invokes a method on the connection represented by the instance and waits for a response. + /// + /// The . + /// The name of the method to invoke. + /// The first argument. + /// The second argument. + /// The third argument. + /// The fourth argument. + /// The token to monitor for cancellation requests. The default value is . + /// A that represents the asynchronous invoke. + [SuppressMessage("ApiDesign", "RS0026:Do not add multiple overloads with optional parameters", Justification = "Required to maintain compatibility")] + public static Task InvokeAsync(this ISingleClientProxy clientProxy, string method, object? arg1, object? arg2, object? arg3, object? arg4, CancellationToken cancellationToken = default) + { + return clientProxy.InvokeCoreAsync(method, new[] { arg1, arg2, arg3, arg4 }, cancellationToken); + } + + /// + /// Invokes a method on the connection represented by the instance and waits for a response. + /// + /// The . + /// The name of the method to invoke. + /// The first argument. + /// The second argument. + /// The third argument. + /// The fourth argument. + /// The fifth argument. + /// The token to monitor for cancellation requests. The default value is . + /// A that represents the asynchronous invoke. + [SuppressMessage("ApiDesign", "RS0026:Do not add multiple overloads with optional parameters", Justification = "Required to maintain compatibility")] + public static Task InvokeAsync(this ISingleClientProxy clientProxy, string method, object? arg1, object? arg2, object? arg3, object? arg4, object? arg5, CancellationToken cancellationToken = default) + { + return clientProxy.InvokeCoreAsync(method, new[] { arg1, arg2, arg3, arg4, arg5 }, cancellationToken); + } + + /// + /// Invokes a method on the connection represented by the instance and waits for a response. + /// + /// The . + /// The name of the method to invoke. + /// The first argument. + /// The second argument. + /// The third argument. + /// The fourth argument. + /// The fifth argument. + /// The sixth argument. + /// The token to monitor for cancellation requests. The default value is . + /// A that represents the asynchronous invoke. + [SuppressMessage("ApiDesign", "RS0026:Do not add multiple overloads with optional parameters", Justification = "Required to maintain compatibility")] + public static Task InvokeAsync(this ISingleClientProxy clientProxy, string method, object? arg1, object? arg2, object? arg3, object? arg4, object? arg5, object? arg6, CancellationToken cancellationToken = default) + { + return clientProxy.InvokeCoreAsync(method, new[] { arg1, arg2, arg3, arg4, arg5, arg6 }, cancellationToken); + } + + /// + /// Invokes a method on the connection represented by the instance and waits for a response. + /// + /// The + /// The name of the method to invoke. + /// The first argument. + /// The second argument. + /// The third argument. + /// The fourth argument. + /// The fifth argument. + /// The sixth argument. + /// The seventh argument. + /// The token to monitor for cancellation requests. The default value is . + /// A that represents the asynchronous invoke. + [SuppressMessage("ApiDesign", "RS0026:Do not add multiple overloads with optional parameters", Justification = "Required to maintain compatibility")] + public static Task InvokeAsync(this ISingleClientProxy clientProxy, string method, object? arg1, object? arg2, object? arg3, object? arg4, object? arg5, object? arg6, object? arg7, CancellationToken cancellationToken = default) + { + return clientProxy.InvokeCoreAsync(method, new[] { arg1, arg2, arg3, arg4, arg5, arg6, arg7 }, cancellationToken); + } + + /// + /// Invokes a method on the connection represented by the instance and waits for a response. + /// + /// The . + /// The name of the method to invoke. + /// The first argument. + /// The second argument. + /// The third argument. + /// The fourth argument. + /// The fifth argument. + /// The sixth argument. + /// The seventh argument. + /// The eighth argument. + /// The token to monitor for cancellation requests. The default value is . + /// A that represents the asynchronous invoke. + [SuppressMessage("ApiDesign", "RS0026:Do not add multiple overloads with optional parameters", Justification = "Required to maintain compatibility")] + public static Task InvokeAsync(this ISingleClientProxy clientProxy, string method, object? arg1, object? arg2, object? arg3, object? arg4, object? arg5, object? arg6, object? arg7, object? arg8, CancellationToken cancellationToken = default) + { + return clientProxy.InvokeCoreAsync(method, new[] { arg1, arg2, arg3, arg4, arg5, arg6, arg7, arg8 }, cancellationToken); + } + + /// + /// Invokes a method on the connection represented by the instance and waits for a response. + /// + /// The . + /// The name of the method to invoke. + /// The first argument. + /// The second argument. + /// The third argument. + /// The fourth argument. + /// The fifth argument. + /// The sixth argument. + /// The seventh argument. + /// The eighth argument. + /// The ninth argument. + /// The token to monitor for cancellation requests. The default value is . + /// A that represents the asynchronous invoke. + [SuppressMessage("ApiDesign", "RS0026:Do not add multiple overloads with optional parameters", Justification = "Required to maintain compatibility")] + public static Task InvokeAsync(this ISingleClientProxy clientProxy, string method, object? arg1, object? arg2, object? arg3, object? arg4, object? arg5, object? arg6, object? arg7, object? arg8, object? arg9, CancellationToken cancellationToken = default) + { + return clientProxy.InvokeCoreAsync(method, new[] { arg1, arg2, arg3, arg4, arg5, arg6, arg7, arg8, arg9 }, cancellationToken); + } + + /// + /// Invokes a method on the connection represented by the instance and waits for a response. + /// + /// The . + /// The name of the method to invoke. + /// The first argument. + /// The second argument. + /// The third argument. + /// The fourth argument. + /// The fifth argument. + /// The sixth argument. + /// The seventh argument. + /// The eighth argument. + /// The ninth argument. + /// The tenth argument. + /// The token to monitor for cancellation requests. The default value is . + /// A that represents the asynchronous invoke. + [SuppressMessage("ApiDesign", "RS0026:Do not add multiple overloads with optional parameters", Justification = "Required to maintain compatibility")] + public static Task InvokeAsync(this ISingleClientProxy clientProxy, string method, object? arg1, object? arg2, object? arg3, object? arg4, object? arg5, object? arg6, object? arg7, object? arg8, object? arg9, object? arg10, CancellationToken cancellationToken = default) + { + return clientProxy.InvokeCoreAsync(method, new[] { arg1, arg2, arg3, arg4, arg5, arg6, arg7, arg8, arg9, arg10 }, cancellationToken); + } } diff --git a/src/SignalR/server/Core/src/DefaultHubLifetimeManager.cs b/src/SignalR/server/Core/src/DefaultHubLifetimeManager.cs index 18e4891a16b9..08b4d0f632e8 100644 --- a/src/SignalR/server/Core/src/DefaultHubLifetimeManager.cs +++ b/src/SignalR/server/Core/src/DefaultHubLifetimeManager.cs @@ -2,6 +2,8 @@ // The .NET Foundation licenses this file to you under the MIT license. using System.Collections.Concurrent; +using System.Diagnostics.CodeAnalysis; +using System.Globalization; using System.Linq; using Microsoft.AspNetCore.SignalR.Internal; using Microsoft.AspNetCore.SignalR.Protocol; @@ -17,6 +19,8 @@ public class DefaultHubLifetimeManager : HubLifetimeManager where TH private readonly HubConnectionStore _connections = new HubConnectionStore(); private readonly HubGroupList _groups = new HubGroupList(); private readonly ILogger _logger; + private readonly ClientResultsManager _clientResultsManager = new(); + private ulong _lastInvocationId; /// /// Initializes a new instance of the class. @@ -294,6 +298,7 @@ public override Task OnDisconnectedAsync(HubConnectionContext connection) { _connections.Remove(connection); _groups.RemoveDisconnectedConnection(connection.ConnectionId); + return Task.CompletedTask; } @@ -314,4 +319,71 @@ public override Task SendUsersAsync(IReadOnlyList userIds, string method { return SendToAllConnections(methodName, args, (connection, state) => ((IReadOnlyList)state!).Contains(connection.UserIdentifier), userIds, cancellationToken); } + + /// + public override async Task InvokeConnectionAsync(string connectionId, string methodName, object?[] args, CancellationToken cancellationToken = default) + { + if (connectionId == null) + { + throw new ArgumentNullException(nameof(connectionId)); + } + + var connection = _connections[connectionId]; + + if (connection == null) + { + throw new IOException($"Connection '{connectionId}' does not exist."); + } + + var invocationId = Interlocked.Increment(ref _lastInvocationId).ToString(NumberFormatInfo.InvariantInfo); + using var _ = CancellationTokenUtils.CreateLinkedToken(cancellationToken, + connection.ConnectionAborted, out var linkedToken); + var task = _clientResultsManager.AddInvocation(connectionId, invocationId, linkedToken); + + try + { + // We're sending to a single connection + // Write message directly to connection without caching it in memory + var message = new InvocationMessage(invocationId, methodName, args); + + await connection.WriteAsync(message, cancellationToken); + } + catch + { + _clientResultsManager.RemoveInvocation(invocationId); + throw; + } + + try + { + return await task; + } + catch + { + // ConnectionAborted will trigger a generic "Canceled" exception from the task, let's convert it into a more specific message. + if (connection.ConnectionAborted.IsCancellationRequested) + { + throw new IOException($"Connection '{connectionId}' disconnected."); + } + throw; + } + } + + /// + public override Task SetConnectionResultAsync(string connectionId, CompletionMessage result) + { + _clientResultsManager.TryCompleteResult(connectionId, result); + return Task.CompletedTask; + } + + /// + public override bool TryGetReturnType(string invocationId, [NotNullWhen(true)] out Type? type) + { + if (_clientResultsManager.TryGetType(invocationId, out type)) + { + return true; + } + type = null; + return false; + } } diff --git a/src/SignalR/server/Core/src/HubConnectionHandler.cs b/src/SignalR/server/Core/src/HubConnectionHandler.cs index dfdcbe3ce822..3bb5566e6a7a 100644 --- a/src/SignalR/server/Core/src/HubConnectionHandler.cs +++ b/src/SignalR/server/Core/src/HubConnectionHandler.cs @@ -96,7 +96,8 @@ IServiceScopeFactory serviceScopeFactory _enableDetailedErrors, disableImplicitFromServiceParameters, new Logger>(loggerFactory), - hubFilters); + hubFilters, + lifetimeManager); } /// @@ -240,7 +241,7 @@ private async Task DispatchMessagesAsync(HubConnectionContext connection) var protocol = connection.Protocol; connection.BeginClientTimeout(); - var binder = new HubConnectionBinder(_dispatcher, connection); + var binder = new HubConnectionBinder(_dispatcher, _lifetimeManager, connection); while (true) { diff --git a/src/SignalR/server/Core/src/HubLifetimeManager.cs b/src/SignalR/server/Core/src/HubLifetimeManager.cs index 140e0aea1672..14a294190876 100644 --- a/src/SignalR/server/Core/src/HubLifetimeManager.cs +++ b/src/SignalR/server/Core/src/HubLifetimeManager.cs @@ -1,6 +1,9 @@ // Licensed to the .NET Foundation under one or more agreements. // The .NET Foundation licenses this file to you under the MIT license. +using System.Diagnostics.CodeAnalysis; +using Microsoft.AspNetCore.SignalR.Protocol; + namespace Microsoft.AspNetCore.SignalR; /// @@ -131,4 +134,41 @@ public abstract class HubLifetimeManager where THub : Hub /// The token to monitor for cancellation requests. The default value is . /// A that represents the asynchronous remove. public abstract Task RemoveFromGroupAsync(string connectionId, string groupName, CancellationToken cancellationToken = default); + + /// + /// Sends an invocation message to the specified connection and waits for a response. + /// + /// The type of the response expected. + /// The connection ID. + /// The invocation method name. + /// The invocation arguments. + /// The token to monitor for cancellation requests. The default value is . + /// The response from the connection. + public virtual Task InvokeConnectionAsync(string connectionId, string methodName, object?[] args, CancellationToken cancellationToken = default) + { + throw new NotImplementedException($"{GetType().Name} does not support client return values."); + } + + /// + /// Sets the connection result for an in progress call. + /// + /// The connection ID. + /// The result from the connection. + /// A that represents the result being set or being forwarded to another server. + public virtual Task SetConnectionResultAsync(string connectionId, CompletionMessage result) + { + throw new NotImplementedException($"{GetType().Name} does not support client return values."); + } + + /// + /// Tells implementations what the expected type from a connection result is. + /// + /// The ID of the in progress invocation. + /// The type the connection is expected to send. Or if the result is intended for another server. + /// + public virtual bool TryGetReturnType(string invocationId, [NotNullWhen(true)] out Type? type) + { + type = null; + return false; + } } diff --git a/src/SignalR/server/Core/src/IHubCallerClients.cs b/src/SignalR/server/Core/src/IHubCallerClients.cs index 0a9fdaa6b404..82968013d23f 100644 --- a/src/SignalR/server/Core/src/IHubCallerClients.cs +++ b/src/SignalR/server/Core/src/IHubCallerClients.cs @@ -1,4 +1,4 @@ -// Licensed to the .NET Foundation under one or more agreements. +// Licensed to the .NET Foundation under one or more agreements. // The .NET Foundation licenses this file to you under the MIT license. namespace Microsoft.AspNetCore.SignalR; @@ -6,4 +6,12 @@ namespace Microsoft.AspNetCore.SignalR; /// /// A clients caller abstraction for a hub. /// -public interface IHubCallerClients : IHubCallerClients { } +public interface IHubCallerClients : IHubCallerClients +{ + /// + /// Gets a proxy that can be used to invoke methods on a single client connected to the hub and receive results. + /// + /// The connection ID. + /// A client caller. + new ISingleClientProxy Single(string connectionId) => throw new NotImplementedException(); +} diff --git a/src/SignalR/server/Core/src/IHubClients.cs b/src/SignalR/server/Core/src/IHubClients.cs index 1f4299a83d2e..3646d4bc8258 100644 --- a/src/SignalR/server/Core/src/IHubClients.cs +++ b/src/SignalR/server/Core/src/IHubClients.cs @@ -6,4 +6,12 @@ namespace Microsoft.AspNetCore.SignalR; /// /// An abstraction that provides access to client connections. /// -public interface IHubClients : IHubClients { } +public interface IHubClients : IHubClients +{ + /// + /// Gets a proxy that can be used to invoke methods on a single client connected to the hub and receive results. + /// + /// The connection ID. + /// A client caller. + new ISingleClientProxy Single(string connectionId) => throw new NotImplementedException(); +} diff --git a/src/SignalR/server/Core/src/IHubClients`T.cs b/src/SignalR/server/Core/src/IHubClients`T.cs index 9f792d51a224..0dee6f33b19a 100644 --- a/src/SignalR/server/Core/src/IHubClients`T.cs +++ b/src/SignalR/server/Core/src/IHubClients`T.cs @@ -1,4 +1,4 @@ -// Licensed to the .NET Foundation under one or more agreements. +// Licensed to the .NET Foundation under one or more agreements. // The .NET Foundation licenses this file to you under the MIT license. namespace Microsoft.AspNetCore.SignalR; @@ -9,6 +9,13 @@ namespace Microsoft.AspNetCore.SignalR; /// The client invoker type. public interface IHubClients { + /// + /// Gets a that can be used to invoke methods on a single client connected to the hub and receive results. + /// + /// The connection ID. + /// A client caller. + T Single(string connectionId) => throw new NotImplementedException(); + /// /// Gets a that can be used to invoke methods on all clients connected to the hub. /// @@ -72,4 +79,3 @@ public interface IHubClients /// A client caller. T Users(IReadOnlyList userIds); } - diff --git a/src/SignalR/server/Core/src/ISingleClientProxy.cs b/src/SignalR/server/Core/src/ISingleClientProxy.cs new file mode 100644 index 000000000000..f400b13e6acc --- /dev/null +++ b/src/SignalR/server/Core/src/ISingleClientProxy.cs @@ -0,0 +1,24 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. + +namespace Microsoft.AspNetCore.SignalR; + +/// +/// A proxy abstraction for invoking hub methods on the client and getting a result. +/// +public interface ISingleClientProxy : IClientProxy +{ + // client proxy method is called InvokeCoreAsync instead of InvokeAsync so that arrays of references + // like string[], e.g. InvokeAsync(string, string[]), do not choose InvokeAsync(string, object[]) + // over InvokeAsync(string, object) overload + + /// + /// Invokes a method on the connection represented by the instance and waits for a result. + /// + /// + /// Name of the method to invoke. + /// A collection of arguments to pass to the client. + /// The token to monitor for cancellation requests. The default value is . + /// A that represents the asynchronous invoke and wait for a client result. + Task InvokeCoreAsync(string method, object?[] args, CancellationToken cancellationToken = default); +} diff --git a/src/SignalR/server/Core/src/Internal/DefaultHubDispatcher.cs b/src/SignalR/server/Core/src/Internal/DefaultHubDispatcher.cs index b3f51f7fe4bb..f4f299cd80b6 100644 --- a/src/SignalR/server/Core/src/Internal/DefaultHubDispatcher.cs +++ b/src/SignalR/server/Core/src/Internal/DefaultHubDispatcher.cs @@ -26,14 +26,16 @@ internal partial class DefaultHubDispatcher : HubDispatcher where TH private readonly Func>? _invokeMiddleware; private readonly Func? _onConnectedMiddleware; private readonly Func? _onDisconnectedMiddleware; + private readonly HubLifetimeManager _hubLifetimeManager; public DefaultHubDispatcher(IServiceScopeFactory serviceScopeFactory, IHubContext hubContext, bool enableDetailedErrors, - bool disableImplicitFromServiceParameters, ILogger> logger, List? hubFilters) + bool disableImplicitFromServiceParameters, ILogger> logger, List? hubFilters, HubLifetimeManager lifetimeManager) { _serviceScopeFactory = serviceScopeFactory; _hubContext = hubContext; _enableDetailedErrors = enableDetailedErrors; _logger = logger; + _hubLifetimeManager = lifetimeManager; DiscoverHubMethods(disableImplicitFromServiceParameters); var count = hubFilters?.Count ?? 0; @@ -70,7 +72,7 @@ public DefaultHubDispatcher(IServiceScopeFactory serviceScopeFactory, IHubContex public override async Task OnConnectedAsync(HubConnectionContext connection) { await using var scope = _serviceScopeFactory.CreateAsyncScope(); - connection.HubCallerClients = new HubCallerClients(_hubContext.Clients, connection.ConnectionId); + connection.HubCallerClients = new HubCallerClients(_hubContext.Clients, connection.ConnectionId, connection.ActiveInvocationLimit is not null); var hubActivator = scope.ServiceProvider.GetRequiredService>(); var hub = hubActivator.Create(); @@ -165,16 +167,21 @@ public override Task DispatchMessageAsync(HubConnectionContext connection, HubMe case StreamItemMessage streamItem: return ProcessStreamItem(connection, streamItem); - case CompletionMessage streamCompleteMessage: + case CompletionMessage completionMessage: // closes channels, removes from Lookup dict // user's method can see the channel is complete and begin wrapping up - if (connection.StreamTracker.TryComplete(streamCompleteMessage)) + if (connection.StreamTracker.TryComplete(completionMessage)) { - Log.CompletingStream(_logger, streamCompleteMessage); + Log.CompletingStream(_logger, completionMessage); + } + // InvocationId is always required on CompletionMessage, it's nullable because of the base type + else if (_hubLifetimeManager.TryGetReturnType(completionMessage.InvocationId!, out _)) + { + return _hubLifetimeManager.SetConnectionResultAsync(connection.ConnectionId, completionMessage); } else { - Log.UnexpectedStreamCompletion(_logger); + Log.UnexpectedCompletion(_logger, completionMessage.InvocationId!); } break; @@ -247,7 +254,7 @@ private Task ProcessInvocation(HubConnectionContext connection, bool isStreamCall = descriptor.StreamingParameters != null; if (connection.ActiveInvocationLimit != null && !isStreamCall && !isStreamResponse) { - return connection.ActiveInvocationLimit.RunAsync(state => + return connection.ActiveInvocationLimit.RunAsync(static state => { var (dispatcher, descriptor, connection, invocationMessage) = state; return dispatcher.Invoke(descriptor, connection, invocationMessage, isStreamResponse: false, isStreamCall: false); diff --git a/src/SignalR/server/Core/src/Internal/DefaultHubDispatcherLog.cs b/src/SignalR/server/Core/src/Internal/DefaultHubDispatcherLog.cs index 4905d4e59f49..f80a970935a3 100644 --- a/src/SignalR/server/Core/src/Internal/DefaultHubDispatcherLog.cs +++ b/src/SignalR/server/Core/src/Internal/DefaultHubDispatcherLog.cs @@ -92,8 +92,7 @@ public static void ClosingStreamWithBindingError(ILogger logger, CompletionMessa [LoggerMessage(19, LogLevel.Debug, "Stream '{StreamId}' closed with error '{Error}'.", EventName = "ClosingStreamWithBindingError")] private static partial void ClosingStreamWithBindingError(ILogger logger, string? streamId, string? error); - [LoggerMessage(20, LogLevel.Debug, "StreamCompletionMessage received unexpectedly.", EventName = "UnexpectedStreamCompletion")] - public static partial void UnexpectedStreamCompletion(ILogger logger); + // Retired [20]UnexpectedStreamCompletion, replaced with more generic [24]UnexpectedCompletion [LoggerMessage(21, LogLevel.Debug, "StreamItemMessage received unexpectedly.", EventName = "UnexpectedStreamItem")] public static partial void UnexpectedStreamItem(ILogger logger); @@ -103,4 +102,7 @@ public static void ClosingStreamWithBindingError(ILogger logger, CompletionMessa [LoggerMessage(23, LogLevel.Debug, "Invocation ID '{InvocationId}' is already in use.", EventName = "InvocationIdInUse")] public static partial void InvocationIdInUse(ILogger logger, string InvocationId); + + [LoggerMessage(24, LogLevel.Debug, "CompletionMessage for invocation ID '{InvocationId}' received unexpectedly.", EventName = "UnexpectedCompletion")] + public static partial void UnexpectedCompletion(ILogger logger, string invocationId); } diff --git a/src/SignalR/server/Core/src/Internal/HubCallerClients.cs b/src/SignalR/server/Core/src/Internal/HubCallerClients.cs index a07c83c373af..6d3a31332f30 100644 --- a/src/SignalR/server/Core/src/Internal/HubCallerClients.cs +++ b/src/SignalR/server/Core/src/Internal/HubCallerClients.cs @@ -8,12 +8,43 @@ internal class HubCallerClients : IHubCallerClients private readonly string _connectionId; private readonly IHubClients _hubClients; private readonly string[] _currentConnectionId; + private readonly bool _parallelEnabled; - public HubCallerClients(IHubClients hubClients, string connectionId) + public HubCallerClients(IHubClients hubClients, string connectionId, bool parallelEnabled) { _connectionId = connectionId; _hubClients = hubClients; _currentConnectionId = new[] { _connectionId }; + _parallelEnabled = parallelEnabled; + } + + private class NotParallelSingleClientProxy : ISingleClientProxy + { + private readonly ISingleClientProxy _proxy; + + public NotParallelSingleClientProxy(ISingleClientProxy hubClients) + { + _proxy = hubClients; + } + + public Task InvokeCoreAsync(string method, object?[] args, CancellationToken cancellationToken = default) + { + throw new InvalidOperationException("Client results inside a Hub method requires HubOptions.MaximumParallelInvocationsPerClient to be greater than 1."); + } + + public Task SendCoreAsync(string method, object?[] args, CancellationToken cancellationToken = default) + { + return _proxy.SendCoreAsync(method, args, cancellationToken); + } + } + + public ISingleClientProxy Single(string connectionId) + { + if (!_parallelEnabled) + { + return new NotParallelSingleClientProxy(_hubClients.Single(connectionId)); + } + return _hubClients.Single(connectionId); } public IClientProxy Caller => _hubClients.Client(_connectionId); diff --git a/src/SignalR/server/Core/src/Internal/HubClients.cs b/src/SignalR/server/Core/src/Internal/HubClients.cs index 0920e40ad6cf..203b3fe4ff92 100644 --- a/src/SignalR/server/Core/src/Internal/HubClients.cs +++ b/src/SignalR/server/Core/src/Internal/HubClients.cs @@ -1,4 +1,4 @@ -// Licensed to the .NET Foundation under one or more agreements. +// Licensed to the .NET Foundation under one or more agreements. // The .NET Foundation licenses this file to you under the MIT license. namespace Microsoft.AspNetCore.SignalR.Internal; @@ -13,6 +13,11 @@ public HubClients(HubLifetimeManager lifetimeManager) All = new AllClientProxy(_lifetimeManager); } + public ISingleClientProxy Single(string connectionId) + { + return new SingleClientProxyWithInvoke(_lifetimeManager, connectionId); + } + public IClientProxy All { get; } public IClientProxy AllExcept(IReadOnlyList excludedConnectionIds) diff --git a/src/SignalR/server/Core/src/Internal/HubConnectionBinder.cs b/src/SignalR/server/Core/src/Internal/HubConnectionBinder.cs index 2f69c7827329..086624a81d23 100644 --- a/src/SignalR/server/Core/src/Internal/HubConnectionBinder.cs +++ b/src/SignalR/server/Core/src/Internal/HubConnectionBinder.cs @@ -7,11 +7,13 @@ internal class HubConnectionBinder : IInvocationBinder where THub : Hub { private readonly HubDispatcher _dispatcher; private readonly HubConnectionContext _connection; + private readonly HubLifetimeManager _hubLifetimeManager; - public HubConnectionBinder(HubDispatcher dispatcher, HubConnectionContext connection) + public HubConnectionBinder(HubDispatcher dispatcher, HubLifetimeManager lifetimeManager, HubConnectionContext connection) { _dispatcher = dispatcher; _connection = connection; + _hubLifetimeManager = lifetimeManager; } public IReadOnlyList GetParameterTypes(string methodName) @@ -21,7 +23,11 @@ public IReadOnlyList GetParameterTypes(string methodName) public Type GetReturnType(string invocationId) { - return typeof(object); + if (_hubLifetimeManager.TryGetReturnType(invocationId, out var type)) + { + return type; + } + throw new InvalidOperationException($"Unknown invocation ID '{invocationId}'."); } public Type GetStreamItemType(string streamId) diff --git a/src/SignalR/server/Core/src/Internal/Proxies.cs b/src/SignalR/server/Core/src/Internal/Proxies.cs index 951b8c7c8941..7dead0caf95a 100644 --- a/src/SignalR/server/Core/src/Internal/Proxies.cs +++ b/src/SignalR/server/Core/src/Internal/Proxies.cs @@ -155,3 +155,25 @@ public Task SendCoreAsync(string method, object?[] args, CancellationToken cance return _lifetimeManager.SendConnectionsAsync(_connectionIds, method, args, cancellationToken); } } + +internal class SingleClientProxyWithInvoke : ISingleClientProxy where THub : Hub +{ + private readonly string _connectionId; + private readonly HubLifetimeManager _lifetimeManager; + + public SingleClientProxyWithInvoke(HubLifetimeManager lifetimeManager, string connectionId) + { + _lifetimeManager = lifetimeManager; + _connectionId = connectionId; + } + + public Task SendCoreAsync(string method, object?[] args, CancellationToken cancellationToken = default) + { + return _lifetimeManager.SendConnectionAsync(_connectionId, method, args, cancellationToken); + } + + public Task InvokeCoreAsync(string method, object?[] args, CancellationToken cancellationToken = default) + { + return _lifetimeManager.InvokeConnectionAsync(_connectionId, method, args ?? Array.Empty(), cancellationToken); + } +} diff --git a/src/SignalR/server/Core/src/Internal/TypedClientBuilder.cs b/src/SignalR/server/Core/src/Internal/TypedClientBuilder.cs index 182153f31f47..69be30e9c761 100644 --- a/src/SignalR/server/Core/src/Internal/TypedClientBuilder.cs +++ b/src/SignalR/server/Core/src/Internal/TypedClientBuilder.cs @@ -113,12 +113,25 @@ private static void BuildMethod(TypeBuilder type, MethodInfo interfaceMethodInfo var parameters = interfaceMethodInfo.GetParameters(); var paramTypes = parameters.Select(param => param.ParameterType).ToArray(); + var returnType = interfaceMethodInfo.ReturnType; + bool isInvoke = returnType != typeof(Task); var methodBuilder = type.DefineMethod(interfaceMethodInfo.Name, methodAttributes); - var invokeMethod = typeof(IClientProxy).GetMethod( - nameof(IClientProxy.SendCoreAsync), BindingFlags.Instance | BindingFlags.Public | BindingFlags.NonPublic, null, - new[] { typeof(string), typeof(object[]), typeof(CancellationToken) }, null)!; + MethodInfo invokeMethod; + if (isInvoke) + { + invokeMethod = typeof(ISingleClientProxy).GetMethod( + nameof(ISingleClientProxy.InvokeCoreAsync), BindingFlags.Instance | BindingFlags.Public | BindingFlags.NonPublic, null, + new[] { typeof(string), typeof(object[]), typeof(CancellationToken) }, null)! + .MakeGenericMethod(returnType.GenericTypeArguments); + } + else + { + invokeMethod = typeof(IClientProxy).GetMethod( + nameof(IClientProxy.SendCoreAsync), BindingFlags.Instance | BindingFlags.Public | BindingFlags.NonPublic, null, + new[] { typeof(string), typeof(object[]), typeof(CancellationToken) }, null)!; + } methodBuilder.SetReturnType(interfaceMethodInfo.ReturnType); methodBuilder.SetParameters(paramTypes); @@ -153,6 +166,30 @@ private static void BuildMethod(TypeBuilder type, MethodInfo interfaceMethodInfo generator.Emit(OpCodes.Ldarg_0); generator.Emit(OpCodes.Ldfld, proxyField); + var isTypeLabel = generator.DefineLabel(); + if (isInvoke) + { + var singleClientProxyType = typeof(ISingleClientProxy); + /* + if (_proxy is ISingleClientProxy singleClientProxy) + { + return ((ISingleClientProxy)_proxy).InvokeAsync(methodName, args, cancellationToken); + } + throw new InvalidOperationException("InvokeAsync only works with Single clients."); + */ + generator.Emit(OpCodes.Isinst, singleClientProxyType); + generator.Emit(OpCodes.Brtrue_S, isTypeLabel); + + generator.Emit(OpCodes.Ldstr, "InvokeAsync only works with Single clients."); + generator.Emit(OpCodes.Newobj, typeof(InvalidOperationException).GetConstructor(new Type[] { typeof(string) })!); + generator.Emit(OpCodes.Throw); + + generator.MarkLabel(isTypeLabel); + generator.Emit(OpCodes.Ldarg_0); + generator.Emit(OpCodes.Ldfld, proxyField); + generator.Emit(OpCodes.Castclass, singleClientProxyType); + } + // The first argument to IClientProxy.SendCoreAsync is this method's name generator.Emit(OpCodes.Ldstr, methodName); @@ -232,10 +269,10 @@ private static void VerifyInterface(Type interfaceType) private static void VerifyMethod(MethodInfo interfaceMethod) { - if (interfaceMethod.ReturnType != typeof(Task)) + if (!typeof(Task).IsAssignableFrom(interfaceMethod.ReturnType)) { throw new InvalidOperationException( - $"Cannot generate proxy implementation for '{typeof(T).FullName}.{interfaceMethod.Name}'. All client proxy methods must return '{typeof(Task).FullName}'."); + $"Cannot generate proxy implementation for '{typeof(T).FullName}.{interfaceMethod.Name}'. All client proxy methods must return '{typeof(Task).FullName}' or '{typeof(Task).FullName}'."); } foreach (var parameter in interfaceMethod.GetParameters()) diff --git a/src/SignalR/server/Core/src/Internal/TypedHubClients.cs b/src/SignalR/server/Core/src/Internal/TypedHubClients.cs index 5bc8ce693ff7..0bffb3d83abb 100644 --- a/src/SignalR/server/Core/src/Internal/TypedHubClients.cs +++ b/src/SignalR/server/Core/src/Internal/TypedHubClients.cs @@ -1,4 +1,4 @@ -// Licensed to the .NET Foundation under one or more agreements. +// Licensed to the .NET Foundation under one or more agreements. // The .NET Foundation licenses this file to you under the MIT license. namespace Microsoft.AspNetCore.SignalR.Internal; @@ -12,6 +12,8 @@ public TypedHubClients(IHubCallerClients dynamicContext) _hubClients = dynamicContext; } + public T Single(string connectionId) => TypedClientBuilder.Build(_hubClients.Single(connectionId)); + public T All => TypedClientBuilder.Build(_hubClients.All); public T Caller => TypedClientBuilder.Build(_hubClients.Caller); diff --git a/src/SignalR/server/Core/src/Microsoft.AspNetCore.SignalR.Core.csproj b/src/SignalR/server/Core/src/Microsoft.AspNetCore.SignalR.Core.csproj index c5ae1cb3718f..fd607f3d832b 100644 --- a/src/SignalR/server/Core/src/Microsoft.AspNetCore.SignalR.Core.csproj +++ b/src/SignalR/server/Core/src/Microsoft.AspNetCore.SignalR.Core.csproj @@ -17,6 +17,8 @@ + + diff --git a/src/SignalR/server/Core/src/PublicAPI.Unshipped.txt b/src/SignalR/server/Core/src/PublicAPI.Unshipped.txt index eb9071344d2a..a3686554ff72 100644 --- a/src/SignalR/server/Core/src/PublicAPI.Unshipped.txt +++ b/src/SignalR/server/Core/src/PublicAPI.Unshipped.txt @@ -5,3 +5,25 @@ Microsoft.AspNetCore.SignalR.HubOptions.DisableImplicitFromServicesParameters.se Microsoft.AspNetCore.SignalR.HubConnectionHandler.HubConnectionHandler(Microsoft.AspNetCore.SignalR.HubLifetimeManager! lifetimeManager, Microsoft.AspNetCore.SignalR.IHubProtocolResolver! protocolResolver, Microsoft.Extensions.Options.IOptions! globalHubOptions, Microsoft.Extensions.Options.IOptions!>! hubOptions, Microsoft.Extensions.Logging.ILoggerFactory! loggerFactory, Microsoft.AspNetCore.SignalR.IUserIdProvider! userIdProvider, Microsoft.Extensions.DependencyInjection.IServiceScopeFactory! serviceScopeFactory) -> void *REMOVED*~Microsoft.AspNetCore.SignalR.HubOptionsSetup.HubOptionsSetup(Microsoft.Extensions.Options.IOptions! options) -> void Microsoft.AspNetCore.SignalR.HubOptionsSetup.HubOptionsSetup(Microsoft.Extensions.Options.IOptions! options) -> void +Microsoft.AspNetCore.SignalR.IHubCallerClients.Single(string! connectionId) -> Microsoft.AspNetCore.SignalR.ISingleClientProxy! +Microsoft.AspNetCore.SignalR.IHubClients.Single(string! connectionId) -> Microsoft.AspNetCore.SignalR.ISingleClientProxy! +Microsoft.AspNetCore.SignalR.IHubClients.Single(string! connectionId) -> T +Microsoft.AspNetCore.SignalR.ISingleClientProxy +Microsoft.AspNetCore.SignalR.ISingleClientProxy.InvokeCoreAsync(string! method, object?[]! args, System.Threading.CancellationToken cancellationToken = default(System.Threading.CancellationToken)) -> System.Threading.Tasks.Task! +override Microsoft.AspNetCore.SignalR.DefaultHubLifetimeManager.InvokeConnectionAsync(string! connectionId, string! methodName, object?[]! args, System.Threading.CancellationToken cancellationToken = default(System.Threading.CancellationToken)) -> System.Threading.Tasks.Task! +override Microsoft.AspNetCore.SignalR.DefaultHubLifetimeManager.SetConnectionResultAsync(string! connectionId, Microsoft.AspNetCore.SignalR.Protocol.CompletionMessage! result) -> System.Threading.Tasks.Task! +override Microsoft.AspNetCore.SignalR.DefaultHubLifetimeManager.TryGetReturnType(string! invocationId, out System.Type? type) -> bool +static Microsoft.AspNetCore.SignalR.ClientProxyExtensions.InvokeAsync(this Microsoft.AspNetCore.SignalR.ISingleClientProxy! clientProxy, string! method, System.Threading.CancellationToken cancellationToken = default(System.Threading.CancellationToken)) -> System.Threading.Tasks.Task! +static Microsoft.AspNetCore.SignalR.ClientProxyExtensions.InvokeAsync(this Microsoft.AspNetCore.SignalR.ISingleClientProxy! clientProxy, string! method, object? arg1, System.Threading.CancellationToken cancellationToken = default(System.Threading.CancellationToken)) -> System.Threading.Tasks.Task! +static Microsoft.AspNetCore.SignalR.ClientProxyExtensions.InvokeAsync(this Microsoft.AspNetCore.SignalR.ISingleClientProxy! clientProxy, string! method, object? arg1, object? arg2, System.Threading.CancellationToken cancellationToken = default(System.Threading.CancellationToken)) -> System.Threading.Tasks.Task! +static Microsoft.AspNetCore.SignalR.ClientProxyExtensions.InvokeAsync(this Microsoft.AspNetCore.SignalR.ISingleClientProxy! clientProxy, string! method, object? arg1, object? arg2, object? arg3, System.Threading.CancellationToken cancellationToken = default(System.Threading.CancellationToken)) -> System.Threading.Tasks.Task! +static Microsoft.AspNetCore.SignalR.ClientProxyExtensions.InvokeAsync(this Microsoft.AspNetCore.SignalR.ISingleClientProxy! clientProxy, string! method, object? arg1, object? arg2, object? arg3, object? arg4, System.Threading.CancellationToken cancellationToken = default(System.Threading.CancellationToken)) -> System.Threading.Tasks.Task! +static Microsoft.AspNetCore.SignalR.ClientProxyExtensions.InvokeAsync(this Microsoft.AspNetCore.SignalR.ISingleClientProxy! clientProxy, string! method, object? arg1, object? arg2, object? arg3, object? arg4, object? arg5, System.Threading.CancellationToken cancellationToken = default(System.Threading.CancellationToken)) -> System.Threading.Tasks.Task! +static Microsoft.AspNetCore.SignalR.ClientProxyExtensions.InvokeAsync(this Microsoft.AspNetCore.SignalR.ISingleClientProxy! clientProxy, string! method, object? arg1, object? arg2, object? arg3, object? arg4, object? arg5, object? arg6, System.Threading.CancellationToken cancellationToken = default(System.Threading.CancellationToken)) -> System.Threading.Tasks.Task! +static Microsoft.AspNetCore.SignalR.ClientProxyExtensions.InvokeAsync(this Microsoft.AspNetCore.SignalR.ISingleClientProxy! clientProxy, string! method, object? arg1, object? arg2, object? arg3, object? arg4, object? arg5, object? arg6, object? arg7, System.Threading.CancellationToken cancellationToken = default(System.Threading.CancellationToken)) -> System.Threading.Tasks.Task! +static Microsoft.AspNetCore.SignalR.ClientProxyExtensions.InvokeAsync(this Microsoft.AspNetCore.SignalR.ISingleClientProxy! clientProxy, string! method, object? arg1, object? arg2, object? arg3, object? arg4, object? arg5, object? arg6, object? arg7, object? arg8, System.Threading.CancellationToken cancellationToken = default(System.Threading.CancellationToken)) -> System.Threading.Tasks.Task! +static Microsoft.AspNetCore.SignalR.ClientProxyExtensions.InvokeAsync(this Microsoft.AspNetCore.SignalR.ISingleClientProxy! clientProxy, string! method, object? arg1, object? arg2, object? arg3, object? arg4, object? arg5, object? arg6, object? arg7, object? arg8, object? arg9, System.Threading.CancellationToken cancellationToken = default(System.Threading.CancellationToken)) -> System.Threading.Tasks.Task! +static Microsoft.AspNetCore.SignalR.ClientProxyExtensions.InvokeAsync(this Microsoft.AspNetCore.SignalR.ISingleClientProxy! clientProxy, string! method, object? arg1, object? arg2, object? arg3, object? arg4, object? arg5, object? arg6, object? arg7, object? arg8, object? arg9, object? arg10, System.Threading.CancellationToken cancellationToken = default(System.Threading.CancellationToken)) -> System.Threading.Tasks.Task! +virtual Microsoft.AspNetCore.SignalR.HubLifetimeManager.InvokeConnectionAsync(string! connectionId, string! methodName, object?[]! args, System.Threading.CancellationToken cancellationToken = default(System.Threading.CancellationToken)) -> System.Threading.Tasks.Task! +virtual Microsoft.AspNetCore.SignalR.HubLifetimeManager.SetConnectionResultAsync(string! connectionId, Microsoft.AspNetCore.SignalR.Protocol.CompletionMessage! result) -> System.Threading.Tasks.Task! +virtual Microsoft.AspNetCore.SignalR.HubLifetimeManager.TryGetReturnType(string! invocationId, out System.Type? type) -> bool diff --git a/src/SignalR/server/SignalR/test/ClientProxyTests.cs b/src/SignalR/server/SignalR/test/ClientProxyTests.cs index eb0d2613cc6c..ede98b1c505a 100644 --- a/src/SignalR/server/SignalR/test/ClientProxyTests.cs +++ b/src/SignalR/server/SignalR/test/ClientProxyTests.cs @@ -6,6 +6,7 @@ using System.Threading; using System.Threading.Tasks; using Microsoft.AspNetCore.SignalR.Internal; +using Microsoft.AspNetCore.Testing; using Moq; using Xunit; @@ -205,4 +206,82 @@ public async Task MultipleClientProxy_SendAsync_ArrayArgumentNotExpanded() Assert.Same(data, arg); } + + [Fact] + public async Task SingleClientProxyWithInvoke_ThrowsNotSupported() + { + var hubLifetimeManager = new EmptyHubLifetimeManager(); + + var proxy = new SingleClientProxyWithInvoke(hubLifetimeManager, ""); + var ex = await Assert.ThrowsAsync(async () => await proxy.InvokeAsync("method")).DefaultTimeout(); + Assert.Equal("EmptyHubLifetimeManager`1 does not support client return values.", ex.Message); + } + + internal class EmptyHubLifetimeManager : HubLifetimeManager where THub : Hub + { + public override Task AddToGroupAsync(string connectionId, string groupName, CancellationToken cancellationToken = default) + { + throw new NotImplementedException(); + } + + public override Task OnConnectedAsync(HubConnectionContext connection) + { + throw new NotImplementedException(); + } + + public override Task OnDisconnectedAsync(HubConnectionContext connection) + { + throw new NotImplementedException(); + } + + public override Task RemoveFromGroupAsync(string connectionId, string groupName, CancellationToken cancellationToken = default) + { + throw new NotImplementedException(); + } + + public override Task SendAllAsync(string methodName, object[] args, CancellationToken cancellationToken = default) + { + throw new NotImplementedException(); + } + + public override Task SendAllExceptAsync(string methodName, object[] args, IReadOnlyList excludedConnectionIds, CancellationToken cancellationToken = default) + { + throw new NotImplementedException(); + } + + public override Task SendConnectionAsync(string connectionId, string methodName, object[] args, CancellationToken cancellationToken = default) + { + throw new NotImplementedException(); + } + + public override Task SendConnectionsAsync(IReadOnlyList connectionIds, string methodName, object[] args, CancellationToken cancellationToken = default) + { + throw new NotImplementedException(); + } + + public override Task SendGroupAsync(string groupName, string methodName, object[] args, CancellationToken cancellationToken = default) + { + throw new NotImplementedException(); + } + + public override Task SendGroupExceptAsync(string groupName, string methodName, object[] args, IReadOnlyList excludedConnectionIds, CancellationToken cancellationToken = default) + { + throw new NotImplementedException(); + } + + public override Task SendGroupsAsync(IReadOnlyList groupNames, string methodName, object[] args, CancellationToken cancellationToken = default) + { + throw new NotImplementedException(); + } + + public override Task SendUserAsync(string userId, string methodName, object[] args, CancellationToken cancellationToken = default) + { + throw new NotImplementedException(); + } + + public override Task SendUsersAsync(IReadOnlyList userIds, string methodName, object[] args, CancellationToken cancellationToken = default) + { + throw new NotImplementedException(); + } + } } diff --git a/src/SignalR/server/SignalR/test/HubConnectionHandlerTestUtils/Hubs.cs b/src/SignalR/server/SignalR/test/HubConnectionHandlerTestUtils/Hubs.cs index d4b6a7df5f35..db7359370a33 100644 --- a/src/SignalR/server/SignalR/test/HubConnectionHandlerTestUtils/Hubs.cs +++ b/src/SignalR/server/SignalR/test/HubConnectionHandlerTestUtils/Hubs.cs @@ -332,6 +332,12 @@ public async Task BlockingMethod() await tcs.Task; } + + public async Task GetClientResult(int num) + { + var sum = await Clients.Single(Context.ConnectionId).InvokeAsync("Sum", num); + return sum; + } } internal class SelfRef diff --git a/src/SignalR/server/SignalR/test/HubConnectionHandlerTests.ClientResult.cs b/src/SignalR/server/SignalR/test/HubConnectionHandlerTests.ClientResult.cs new file mode 100644 index 000000000000..b179f92acbc5 --- /dev/null +++ b/src/SignalR/server/SignalR/test/HubConnectionHandlerTests.ClientResult.cs @@ -0,0 +1,106 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. + +using Microsoft.AspNetCore.SignalR.Protocol; +using Microsoft.AspNetCore.Testing; +using Microsoft.Extensions.DependencyInjection; + +namespace Microsoft.AspNetCore.SignalR.Tests; + +public partial class HubConnectionHandlerTests +{ + [Fact] + public async Task CanReturnClientResultToHub() + { + using (StartVerifiableLog()) + { + var serviceProvider = HubConnectionHandlerTestUtils.CreateServiceProvider(builder => + { + // Waiting for a client result blocks the hub dispatcher pipeline, need to allow multiple invocations + builder.AddSignalR(o => o.MaximumParallelInvocationsPerClient = 2); + }, LoggerFactory); + var connectionHandler = serviceProvider.GetService>(); + + using (var client = new TestClient()) + { + var connectionHandlerTask = await client.ConnectAsync(connectionHandler).DefaultTimeout(); + + var invocationId = await client.SendHubMessageAsync(new InvocationMessage("1", nameof(MethodHub.GetClientResult), new object[] { 5 })).DefaultTimeout(); + + // Hub asks client for a result, this is an invocation message with an ID + var invocationMessage = Assert.IsType(await client.ReadAsync().DefaultTimeout()); + Assert.NotNull(invocationMessage.InvocationId); + var res = 4 + ((long)invocationMessage.Arguments[0]); + await client.SendHubMessageAsync(CompletionMessage.WithResult(invocationMessage.InvocationId, res)).DefaultTimeout(); + + var completion = Assert.IsType(await client.ReadAsync().DefaultTimeout()); + Assert.Equal(9L, completion.Result); + Assert.Equal(invocationId, completion.InvocationId); + } + } + } + + [Fact] + public async Task CanReturnClientResultErrorToHub() + { + using (StartVerifiableLog(write => write.EventId.Name == "FailedInvokingHubMethod")) + { + var serviceProvider = HubConnectionHandlerTestUtils.CreateServiceProvider(builder => + { + // Waiting for a client result blocks the hub dispatcher pipeline, need to allow multiple invocations + builder.AddSignalR(o => + { + o.MaximumParallelInvocationsPerClient = 2; + o.EnableDetailedErrors = true; + }); + }, LoggerFactory); + var connectionHandler = serviceProvider.GetService>(); + + using (var client = new TestClient()) + { + var connectionHandlerTask = await client.ConnectAsync(connectionHandler).DefaultTimeout(); + + var invocationId = await client.SendHubMessageAsync(new InvocationMessage("1", nameof(MethodHub.GetClientResult), new object[] { 5 })).DefaultTimeout(); + + // Hub asks client for a result, this is an invocation message with an ID + var invocationMessage = Assert.IsType(await client.ReadAsync().DefaultTimeout()); + Assert.NotNull(invocationMessage.InvocationId); + await client.SendHubMessageAsync(CompletionMessage.WithError(invocationMessage.InvocationId, "Client error")).DefaultTimeout(); + + var completion = Assert.IsType(await client.ReadAsync().DefaultTimeout()); + Assert.Equal("An unexpected error occurred invoking 'GetClientResult' on the server. Exception: Client error", completion.Error); + Assert.Equal(invocationId, completion.InvocationId); + } + } + } + + [Fact] + public async Task ThrowsWhenParallelHubInvokesNotEnabled() + { + using (StartVerifiableLog(write => write.EventId.Name == "FailedInvokingHubMethod")) + { + var serviceProvider = HubConnectionHandlerTestUtils.CreateServiceProvider(builder => + { + builder.AddSignalR(o => + { + o.MaximumParallelInvocationsPerClient = 1; + o.EnableDetailedErrors = true; + }); + }, LoggerFactory); + var connectionHandler = serviceProvider.GetService>(); + + using (var client = new TestClient()) + { + var connectionHandlerTask = await client.ConnectAsync(connectionHandler).DefaultTimeout(); + + var invocationId = await client.SendHubMessageAsync(new InvocationMessage("1", nameof(MethodHub.GetClientResult), new object[] { 5 })).DefaultTimeout(); + + // Hub asks client for a result, this is an invocation message with an ID + var completionMessage = Assert.IsType(await client.ReadAsync().DefaultTimeout()); + Assert.Equal(invocationId, completionMessage.InvocationId); + Assert.Equal("An unexpected error occurred invoking 'GetClientResult' on the server. InvalidOperationException: Client results inside a Hub method requires HubOptions.MaximumParallelInvocationsPerClient to be greater than 1.", + completionMessage.Error); + } + } + } +} diff --git a/src/SignalR/server/SignalR/test/HubConnectionHandlerTests.cs b/src/SignalR/server/SignalR/test/HubConnectionHandlerTests.cs index a8b675bdb3dc..bce1522e597d 100644 --- a/src/SignalR/server/SignalR/test/HubConnectionHandlerTests.cs +++ b/src/SignalR/server/SignalR/test/HubConnectionHandlerTests.cs @@ -30,7 +30,7 @@ namespace Microsoft.AspNetCore.SignalR.Tests; -public class HubConnectionHandlerTests : VerifiableLoggedTest +public partial class HubConnectionHandlerTests : VerifiableLoggedTest { [Fact] [LogLevel(LogLevel.Trace)] @@ -240,7 +240,7 @@ public void FailsToLoadInvalidTypedHubClient() var serviceProvider = HubConnectionHandlerTestUtils.CreateServiceProvider(null, LoggerFactory); var ex = Assert.Throws(() => serviceProvider.GetRequiredService>()); - Assert.Equal($"Cannot generate proxy implementation for '{typeof(IVoidReturningTypedHubClient).FullName}.{nameof(IVoidReturningTypedHubClient.Send)}'. All client proxy methods must return '{typeof(Task).FullName}'.", ex.Message); + Assert.Equal($"Cannot generate proxy implementation for '{typeof(IVoidReturningTypedHubClient).FullName}.{nameof(IVoidReturningTypedHubClient.Send)}'. All client proxy methods must return '{typeof(Task).FullName}' or 'System.Threading.Tasks.Task'.", ex.Message); } } @@ -3925,7 +3925,7 @@ public async Task UploadStreamCompleteInvalidId() } Assert.Single(TestSink.Writes.Where(w => w.LoggerName == "Microsoft.AspNetCore.SignalR.Internal.DefaultHubDispatcher" && - w.EventId.Name == "UnexpectedStreamCompletion")); + w.EventId.Name == "UnexpectedCompletion")); } public static string CustomErrorMessage = "custom error for testing ::::)"; diff --git a/src/SignalR/server/SignalR/test/Internal/TypedClientBuilderTests.cs b/src/SignalR/server/SignalR/test/Internal/TypedClientBuilderTests.cs index 95688baadb57..93c730c8ffa7 100644 --- a/src/SignalR/server/SignalR/test/Internal/TypedClientBuilderTests.cs +++ b/src/SignalR/server/SignalR/test/Internal/TypedClientBuilderTests.cs @@ -1,13 +1,8 @@ // Licensed to the .NET Foundation under one or more agreements. // The .NET Foundation licenses this file to you under the MIT license. -using System; -using System.Collections.Generic; -using System.Threading; -using System.Threading.Tasks; using Microsoft.AspNetCore.SignalR.Internal; using Microsoft.AspNetCore.Testing; -using Xunit; namespace Microsoft.AspNetCore.SignalR.Tests.Internal; @@ -162,7 +157,7 @@ public void ThrowsIfInterfaceHasVoidReturningMethod() { var clientProxy = new MockProxy(); var ex = Assert.Throws(() => TypedClientBuilder.Build(clientProxy)); - Assert.Equal($"Cannot generate proxy implementation for '{typeof(IVoidMethodClient).FullName}.{nameof(IVoidMethodClient.Method)}'. All client proxy methods must return '{typeof(Task).FullName}'.", ex.Message); + Assert.Equal($"Cannot generate proxy implementation for '{typeof(IVoidMethodClient).FullName}.{nameof(IVoidMethodClient.Method)}'. All client proxy methods must return '{typeof(Task).FullName}' or '{typeof(Task).FullName}'.", ex.Message); } [Fact] @@ -170,7 +165,7 @@ public void ThrowsIfInterfaceHasNonTaskReturns() { var clientProxy = new MockProxy(); var ex = Assert.Throws(() => TypedClientBuilder.Build(clientProxy)); - Assert.Equal($"Cannot generate proxy implementation for '{typeof(IStringMethodClient).FullName}.{nameof(IStringMethodClient.Method)}'. All client proxy methods must return '{typeof(Task).FullName}'.", ex.Message); + Assert.Equal($"Cannot generate proxy implementation for '{typeof(IStringMethodClient).FullName}.{nameof(IStringMethodClient.Method)}'. All client proxy methods must return '{typeof(Task).FullName}' or '{typeof(Task).FullName}'.", ex.Message); } [Fact] @@ -207,9 +202,86 @@ public void ThrowsIfInterfaceHasEvents() Assert.Equal("Type must not contain events.", ex.Message); } + [Fact] + public async Task ProducesImplementationThatProxiesMethodsToISingleClientProxyAsync() + { + var clientProxy = new MockSingleClientProxy(); + var typedProxy = TypedClientBuilder.Build(clientProxy); + + var objArg = new object(); + var task = typedProxy.GetValue(1008, objArg, "test"); + Assert.False(task.IsCompleted); + + Assert.Collection(clientProxy.Sends, + send => + { + Assert.Equal("GetValue", send.Method); + Assert.Collection(send.Arguments, + arg1 => Assert.Equal(1008, arg1), + arg2 => Assert.Same(objArg, arg2), + arg3 => Assert.Same("test", arg3)); + Assert.Equal(CancellationToken.None, send.CancellationToken); + send.Complete(); + }); + + var result = await task.DefaultTimeout(); + Assert.Equal(default(int), result); + } + + [Fact] + public async Task ThrowsIfReturnMethodUsedWithoutSingleClientProxy() + { + var clientProxy = new MockProxy(); + var typedProxy = TypedClientBuilder.Build(clientProxy); + + var objArg = new object(); + var ex = await Assert.ThrowsAsync(() => typedProxy.GetValue(102, objArg, "test")).DefaultTimeout(); + Assert.Equal("InvokeAsync only works with Single clients.", ex.Message); + + Assert.Empty(clientProxy.Sends); + } + + [Fact] + public async Task ResultMethodSupportsCancellationToken() + { + var clientProxy = new MockSingleClientProxy(); + var typedProxy = TypedClientBuilder.Build(clientProxy); + CancellationTokenSource cts1 = new CancellationTokenSource(); + var task1 = typedProxy.MethodReturning("foo", cts1.Token); + Assert.False(task1.IsCompleted); + + CancellationTokenSource cts2 = new CancellationTokenSource(); + var task2 = typedProxy.NoArgumentMethodReturning(cts2.Token); + Assert.False(task2.IsCompleted); + + Assert.Collection(clientProxy.Sends, + send1 => + { + Assert.Equal("MethodReturning", send1.Method); + Assert.Single(send1.Arguments); + Assert.Collection(send1.Arguments, + arg1 => Assert.Equal("foo", arg1)); + Assert.Equal(cts1.Token, send1.CancellationToken); + send1.Complete(); + }, + send2 => + { + Assert.Equal("NoArgumentMethodReturning", send2.Method); + Assert.Empty(send2.Arguments); + Assert.Equal(cts2.Token, send2.CancellationToken); + send2.Complete(); + }); + + var result = await task1.DefaultTimeout(); + Assert.Equal(default(string), result); + var result2 = await task2.DefaultTimeout(); + Assert.Equal(default(int), result2); + } + public interface ITestClient { Task Method(string arg1, int arg2, object arg3); + Task GetValue(int arg1, object arg2, string arg3); } public interface IRenamedTestClient @@ -247,6 +319,9 @@ public interface ICancellationTokenMethod { Task Method(string foo, CancellationToken cancellationToken); Task NoArgumentMethod(CancellationToken cancellationToken); + + Task NoArgumentMethodReturning(CancellationToken cancellationToken); + Task MethodReturning(string foo, CancellationToken cancellationToken); } public interface IPropertiesClient @@ -259,6 +334,11 @@ public interface IEventsClient event EventHandler Event; } + public interface ITestReturnValueClient + { + Task GetValue(); + } + private class MockProxy : IClientProxy { public IList Sends { get; } = new List(); @@ -273,6 +353,30 @@ public Task SendCoreAsync(string method, object[] args, CancellationToken cancel } } + private class MockSingleClientProxy : ISingleClientProxy + { + public IList Sends { get; } = new List(); + + public Task SendCoreAsync(string method, object[] args, CancellationToken cancellationToken) + { + var tcs = new TaskCompletionSource(); + + Sends.Add(new SendContext(method, args, cancellationToken, tcs)); + + return tcs.Task; + } + + public async Task InvokeCoreAsync(string method, object[] args, CancellationToken cancellationToken = default) + { + var tcs = new TaskCompletionSource(); + + Sends.Add(new SendContext(method, args, cancellationToken, tcs)); + + await tcs.Task; + return default(T); + } + } + private struct SendContext { private readonly TaskCompletionSource _tcs; diff --git a/src/SignalR/server/Specification.Tests/src/HubLifetimeManagerTestBase.cs b/src/SignalR/server/Specification.Tests/src/HubLifetimeManagerTestBase.cs index 5a64beb0e792..19b061d1c6bd 100644 --- a/src/SignalR/server/Specification.Tests/src/HubLifetimeManagerTestBase.cs +++ b/src/SignalR/server/Specification.Tests/src/HubLifetimeManagerTestBase.cs @@ -170,4 +170,199 @@ public async Task SendConnectionAsyncWritesToConnectionOutput() Assert.Equal("World", (string)message.Arguments[0]); } } + + /// + /// Specification test for SignalR HubLifetimeManager. + /// + /// A representing the asynchronous completion of the test. + [Fact] + public async Task CanProcessClientReturnResult() + { + var manager = CreateNewHubLifetimeManager(); + + using (var client1 = new TestClient()) + { + var connection1 = HubConnectionContextUtils.Create(client1.Connection); + + await manager.OnConnectedAsync(connection1).DefaultTimeout(); + + var resultTask = manager.InvokeConnectionAsync(connection1.ConnectionId, "Result", new object[] { "test" }); + var invocation = Assert.IsType(await client1.ReadAsync().DefaultTimeout()); + Assert.NotNull(invocation.InvocationId); + Assert.Equal("test", invocation.Arguments[0]); + + await manager.SetConnectionResultAsync(connection1.ConnectionId, CompletionMessage.WithResult(invocation.InvocationId, 10)).DefaultTimeout(); + + var res = await resultTask.DefaultTimeout(); + Assert.Equal(10L, res); + } + } + + /// + /// Specification test for SignalR HubLifetimeManager. + /// + /// A representing the asynchronous completion of the test. + [Fact] + public async Task CanProcessClientReturnErrorResult() + { + var manager = CreateNewHubLifetimeManager(); + + using (var client1 = new TestClient()) + { + var connection1 = HubConnectionContextUtils.Create(client1.Connection); + + await manager.OnConnectedAsync(connection1).DefaultTimeout(); + + var resultTask = manager.InvokeConnectionAsync(connection1.ConnectionId, "Result", new object[] { "test" }); + var invocation = Assert.IsType(await client1.ReadAsync().DefaultTimeout()); + Assert.NotNull(invocation.InvocationId); + Assert.Equal("test", invocation.Arguments[0]); + + await manager.SetConnectionResultAsync(connection1.ConnectionId, CompletionMessage.WithError(invocation.InvocationId, "Error from client")).DefaultTimeout(); + + var ex = await Assert.ThrowsAsync(() => resultTask).DefaultTimeout(); + Assert.Equal("Error from client", ex.Message); + } + } + + /// + /// Specification test for SignalR HubLifetimeManager. + /// + /// A representing the asynchronous completion of the test. + [Fact] + public async Task ExceptionWhenIncorrectClientCompletesClientResult() + { + var manager = CreateNewHubLifetimeManager(); + + using (var client1 = new TestClient()) + using (var client2 = new TestClient()) + { + var connection1 = HubConnectionContextUtils.Create(client1.Connection); + var connection2 = HubConnectionContextUtils.Create(client2.Connection); + + await manager.OnConnectedAsync(connection1).DefaultTimeout(); + await manager.OnConnectedAsync(connection2).DefaultTimeout(); + + var resultTask = manager.InvokeConnectionAsync(connection1.ConnectionId, "Result", new object[] { "test" }); + var invocation = Assert.IsType(await client1.ReadAsync().DefaultTimeout()); + Assert.NotNull(invocation.InvocationId); + Assert.Equal("test", invocation.Arguments[0]); + + var ex = await Assert.ThrowsAsync(() => + manager.SetConnectionResultAsync(connection2.ConnectionId, CompletionMessage.WithError(invocation.InvocationId, "Error from client"))).DefaultTimeout(); + + Assert.Equal($"Connection ID '{connection2.ConnectionId}' is not valid for invocation ID '{invocation.InvocationId}'.", ex.Message); + + // Internal state for invocation isn't affected by wrong client, check that we can still complete the invocation + await manager.SetConnectionResultAsync(connection1.ConnectionId, CompletionMessage.WithResult(invocation.InvocationId, 10)).DefaultTimeout(); + + var res = await resultTask.DefaultTimeout(); + Assert.Equal(10L, res); + } + } + + /// + /// Specification test for SignalR HubLifetimeManager. + /// + /// A representing the asynchronous completion of the test. + [Fact] + public async Task ConnectionIDNotPresentWhenInvokingClientResult() + { + var manager1 = CreateNewHubLifetimeManager(); + + using (var client1 = new TestClient()) + { + var connection1 = HubConnectionContextUtils.Create(client1.Connection); + + await manager1.OnConnectedAsync(connection1).DefaultTimeout(); + + // No client with this ID + await Assert.ThrowsAsync(() => manager1.InvokeConnectionAsync("none", "Result", new object[] { "test" })).DefaultTimeout(); + } + } + + /// + /// Specification test for SignalR HubLifetimeManager. + /// + /// A representing the asynchronous completion of the test. + [Fact] + public async Task InvokesForMultipleClientsDoNotCollide() + { + var manager1 = CreateNewHubLifetimeManager(); + + using (var client1 = new TestClient()) + using (var client2 = new TestClient()) + { + var connection1 = HubConnectionContextUtils.Create(client1.Connection); + var connection2 = HubConnectionContextUtils.Create(client2.Connection); + + await manager1.OnConnectedAsync(connection1).DefaultTimeout(); + await manager1.OnConnectedAsync(connection2).DefaultTimeout(); + + var invoke1 = manager1.InvokeConnectionAsync(connection1.ConnectionId, "Result", new object[] { "test" }); + var invoke2 = manager1.InvokeConnectionAsync(connection2.ConnectionId, "Result", new object[] { "test" }); + + var invocation1 = Assert.IsType(await client1.ReadAsync().DefaultTimeout()); + var invocation2 = Assert.IsType(await client2.ReadAsync().DefaultTimeout()); + await manager1.SetConnectionResultAsync(connection2.ConnectionId, CompletionMessage.WithError(invocation2.InvocationId, "error")); + + await Assert.ThrowsAnyAsync(() => invoke2).DefaultTimeout(); + Assert.False(invoke1.IsCompleted); + + await manager1.SetConnectionResultAsync(connection1.ConnectionId, CompletionMessage.WithResult(invocation1.InvocationId, 3)); + Assert.Equal(3, await invoke1.DefaultTimeout()); + } + } + + /// + /// Specification test for SignalR HubLifetimeManager. + /// + /// A representing the asynchronous completion of the test. + [Fact] + public async Task ClientDisconnectsWithoutCompletingClientResult() + { + var manager1 = CreateNewHubLifetimeManager(); + + using (var client1 = new TestClient()) + { + var connection1 = HubConnectionContextUtils.Create(client1.Connection); + + await manager1.OnConnectedAsync(connection1).DefaultTimeout(); + + var invoke1 = manager1.InvokeConnectionAsync(connection1.ConnectionId, "Result", new object[] { "test" }); + + connection1.Abort(); + await manager1.OnDisconnectedAsync(connection1).DefaultTimeout(); + + var ex = await Assert.ThrowsAsync(() => invoke1).DefaultTimeout(); + Assert.Equal($"Connection '{connection1.ConnectionId}' disconnected.", ex.Message); + } + } + + /// + /// Specification test for SignalR HubLifetimeManager. + /// + /// A representing the asynchronous completion of the test. + [Fact] + public async Task CanCancelClientResult() + { + var manager1 = CreateNewHubLifetimeManager(); + + using (var client1 = new TestClient()) + { + var connection1 = HubConnectionContextUtils.Create(client1.Connection); + + await manager1.OnConnectedAsync(connection1).DefaultTimeout(); + + var cts = new CancellationTokenSource(); + var invoke1 = manager1.InvokeConnectionAsync(connection1.ConnectionId, "Result", new object[] { "test" }, cts.Token); + var invocation1 = Assert.IsType(await client1.ReadAsync().DefaultTimeout()); + cts.Cancel(); + + await Assert.ThrowsAsync(() => invoke1).DefaultTimeout(); + + // Noop, just checking that it doesn't throw. This could be caused by an inflight response from a client while the server cancels the token + await manager1.SetConnectionResultAsync(connection1.ConnectionId, CompletionMessage.WithResult(invocation1.InvocationId, 1)); + } + } } diff --git a/src/SignalR/server/Specification.Tests/src/ScaleoutHubLifetimeManagerTests.cs b/src/SignalR/server/Specification.Tests/src/ScaleoutHubLifetimeManagerTests.cs index 509d09d477e2..5fc1c9637c76 100644 --- a/src/SignalR/server/Specification.Tests/src/ScaleoutHubLifetimeManagerTests.cs +++ b/src/SignalR/server/Specification.Tests/src/ScaleoutHubLifetimeManagerTests.cs @@ -463,4 +463,170 @@ public async Task StillSubscribedToUserAfterOneOfMultipleConnectionsAssociatedWi await AssertMessageAsync(client2); } } + + /// + /// Specification test for SignalR HubLifetimeManager. + /// + /// A representing the asynchronous completion of the test. + [Fact] + public async Task CanProcessClientReturnResultAcrossServers() + { + var backplane = CreateBackplane(); + var manager1 = CreateNewHubLifetimeManager(backplane); + var manager2 = CreateNewHubLifetimeManager(backplane); + + using (var client1 = new TestClient()) + { + var connection1 = HubConnectionContextUtils.Create(client1.Connection); + + await manager1.OnConnectedAsync(connection1).DefaultTimeout(); + + // Server2 asks for a result from client1 on Server1 + var resultTask = manager2.InvokeConnectionAsync(connection1.ConnectionId, "Result", new object[] { "test" }); + var invocation = Assert.IsType(await client1.ReadAsync().DefaultTimeout()); + Assert.NotNull(invocation.InvocationId); + Assert.Equal("test", invocation.Arguments[0]); + + // Server1 gets the result from client1 and forwards to Server2 + await manager1.SetConnectionResultAsync(connection1.ConnectionId, CompletionMessage.WithResult(invocation.InvocationId, 10)).DefaultTimeout(); + + var res = await resultTask.DefaultTimeout(); + Assert.Equal(10L, res); + } + } + + /// + /// Specification test for SignalR HubLifetimeManager. + /// + /// A representing the asynchronous completion of the test. + [Fact] + public async Task CanProcessClientReturnErrorResultAcrossServers() + { + var backplane = CreateBackplane(); + var manager1 = CreateNewHubLifetimeManager(backplane); + var manager2 = CreateNewHubLifetimeManager(backplane); + + using (var client1 = new TestClient()) + { + var connection1 = HubConnectionContextUtils.Create(client1.Connection); + + await manager1.OnConnectedAsync(connection1).DefaultTimeout(); + + // Server2 asks for a result from client1 on Server1 + var resultTask = manager2.InvokeConnectionAsync(connection1.ConnectionId, "Result", new object[] { "test" }); + var invocation = Assert.IsType(await client1.ReadAsync().DefaultTimeout()); + Assert.NotNull(invocation.InvocationId); + Assert.Equal("test", invocation.Arguments[0]); + + // Server1 gets the result from client1 and forwards to Server2 + await manager1.SetConnectionResultAsync(connection1.ConnectionId, CompletionMessage.WithError(invocation.InvocationId, "Error from client")).DefaultTimeout(); + + var ex = await Assert.ThrowsAsync(() => resultTask).DefaultTimeout(); + Assert.Equal("Error from client", ex.Message); + } + } + + /// + /// Specification test for SignalR HubLifetimeManager. + /// + /// A representing the asynchronous completion of the test. + [Fact] + public async Task ConnectionIDNotPresentMultiServerWhenInvokingClientResult() + { + var backplane = CreateBackplane(); + var manager1 = CreateNewHubLifetimeManager(backplane); + var manager2 = CreateNewHubLifetimeManager(backplane); + + using (var client1 = new TestClient()) + { + var connection1 = HubConnectionContextUtils.Create(client1.Connection); + + await manager1.OnConnectedAsync(connection1).DefaultTimeout(); + + // No client on any backplanes with this ID + await Assert.ThrowsAsync(() => manager1.InvokeConnectionAsync("none", "Result", new object[] { "test" })).DefaultTimeout(); + } + } + + /// + /// Specification test for SignalR HubLifetimeManager. + /// + /// A representing the asynchronous completion of the test. + [Fact] + public async Task ClientDisconnectsWithoutCompletingClientResultOnSecondServer() + { + var backplane = CreateBackplane(); + var manager1 = CreateNewHubLifetimeManager(backplane); + var manager2 = CreateNewHubLifetimeManager(backplane); + + using (var client1 = new TestClient()) + { + var connection1 = HubConnectionContextUtils.Create(client1.Connection); + + await manager2.OnConnectedAsync(connection1).DefaultTimeout(); + + var invoke1 = manager1.InvokeConnectionAsync(connection1.ConnectionId, "Result", new object[] { "test" }); + var invocation = Assert.IsType(await client1.ReadAsync().DefaultTimeout()); + + connection1.Abort(); + await manager2.OnDisconnectedAsync(connection1).DefaultTimeout(); + + // Server should propogate connection closure so task isn't blocked + var ex = await Assert.ThrowsAsync(() => invoke1).DefaultTimeout(); + Assert.Equal("Connection disconnected.", ex.Message); + } + } + + /// + /// Specification test for SignalR HubLifetimeManager. + /// + /// A representing the asynchronous completion of the test. + [Fact] + public async Task InvocationsFromDifferentServersUseUniqueIDs() + { + var backplane = CreateBackplane(); + var manager1 = CreateNewHubLifetimeManager(backplane); + var manager2 = CreateNewHubLifetimeManager(backplane); + + using (var client1 = new TestClient()) + using (var client2 = new TestClient()) + { + var connection1 = HubConnectionContextUtils.Create(client1.Connection); + var connection2 = HubConnectionContextUtils.Create(client2.Connection); + + await manager1.OnConnectedAsync(connection1).DefaultTimeout(); + await manager2.OnConnectedAsync(connection2).DefaultTimeout(); + + var invoke1 = manager1.InvokeConnectionAsync(connection2.ConnectionId, "Result", new object[] { "test" }); + var invocation2 = Assert.IsType(await client2.ReadAsync().DefaultTimeout()); + + var invoke2 = manager2.InvokeConnectionAsync(connection1.ConnectionId, "Result", new object[] { "test" }); + var invocation1 = Assert.IsType(await client1.ReadAsync().DefaultTimeout()); + + Assert.NotEqual(invocation1.InvocationId, invocation2.InvocationId); + + await manager1.SetConnectionResultAsync(connection2.ConnectionId, CompletionMessage.WithResult(invocation2.InvocationId, 2)).DefaultTimeout(); + await manager2.SetConnectionResultAsync(connection1.ConnectionId, CompletionMessage.WithResult(invocation1.InvocationId, 5)).DefaultTimeout(); + + var res = await invoke1.DefaultTimeout(); + Assert.Equal(2, res); + res = await invoke2.DefaultTimeout(); + Assert.Equal(5, res); + } + } + + /// + /// Specification test for SignalR HubLifetimeManager. + /// + /// A representing the asynchronous completion of the test. + [Fact] + public async Task ConnectionDoesNotExist_FailsInvokeConnectionAsync() + { + var backplane = CreateBackplane(); + var manager1 = CreateNewHubLifetimeManager(backplane); + var manager2 = CreateNewHubLifetimeManager(backplane); + + var ex = await Assert.ThrowsAsync(() => manager1.InvokeConnectionAsync("1234", "Result", new object[] { "test" })).DefaultTimeout(); + Assert.Equal("Connection '1234' does not exist.", ex.Message); + } } diff --git a/src/SignalR/server/StackExchangeRedis/src/Internal/RedisChannels.cs b/src/SignalR/server/StackExchangeRedis/src/Internal/RedisChannels.cs index 8f306f3082b6..eea3d9499247 100644 --- a/src/SignalR/server/StackExchangeRedis/src/Internal/RedisChannels.cs +++ b/src/SignalR/server/StackExchangeRedis/src/Internal/RedisChannels.cs @@ -71,4 +71,15 @@ public string Ack(string serverName) { return _prefix + ":internal:ack:" + serverName; } + + /// + /// Gets the name of the client return results channel for the specified server. + /// + /// The name of the server to get the client return results channel for. + /// + [MethodImpl(MethodImplOptions.AggressiveInlining)] + public string ReturnResults(string serverName) + { + return _prefix + ":internal:return:" + serverName; + } } diff --git a/src/SignalR/server/StackExchangeRedis/src/Internal/RedisCompletion.cs b/src/SignalR/server/StackExchangeRedis/src/Internal/RedisCompletion.cs new file mode 100644 index 000000000000..c950d2f2bbfd --- /dev/null +++ b/src/SignalR/server/StackExchangeRedis/src/Internal/RedisCompletion.cs @@ -0,0 +1,19 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. + +using System.Buffers; + +namespace Microsoft.AspNetCore.SignalR.StackExchangeRedis.Internal; + +internal readonly struct RedisCompletion +{ + public ReadOnlySequence CompletionMessage { get; } + + public string ProtocolName { get; } + + public RedisCompletion(string protocolName, ReadOnlySequence completionMessage) + { + ProtocolName = protocolName; + CompletionMessage = completionMessage; + } +} diff --git a/src/SignalR/server/StackExchangeRedis/src/Internal/RedisInvocation.cs b/src/SignalR/server/StackExchangeRedis/src/Internal/RedisInvocation.cs index 8474a8ed780f..1f441873d6f1 100644 --- a/src/SignalR/server/StackExchangeRedis/src/Internal/RedisInvocation.cs +++ b/src/SignalR/server/StackExchangeRedis/src/Internal/RedisInvocation.cs @@ -16,9 +16,16 @@ internal readonly struct RedisInvocation /// public SerializedHubMessage Message { get; } - public RedisInvocation(SerializedHubMessage message, IReadOnlyList? excludedConnectionIds) + public string? ReturnChannel { get; } + + public string? InvocationId { get; } + + public RedisInvocation(SerializedHubMessage message, IReadOnlyList? excludedConnectionIds, + string? invocationId = null, string? returnChannel = null) { Message = message; ExcludedConnectionIds = excludedConnectionIds; + ReturnChannel = returnChannel; + InvocationId = invocationId; } } diff --git a/src/SignalR/server/StackExchangeRedis/src/Internal/RedisLog.cs b/src/SignalR/server/StackExchangeRedis/src/Internal/RedisLog.cs index ca9ead0a5778..ba7217748af2 100644 --- a/src/SignalR/server/StackExchangeRedis/src/Internal/RedisLog.cs +++ b/src/SignalR/server/StackExchangeRedis/src/Internal/RedisLog.cs @@ -52,6 +52,12 @@ public static void ConnectingToEndpoints(ILogger logger, EndPointCollection endp [LoggerMessage(11, LogLevel.Warning, "Error processing message for internal server message.", EventName = "InternalMessageFailed")] public static partial void InternalMessageFailed(ILogger logger, Exception exception); + [LoggerMessage(12, LogLevel.Error, "Received a client result for protocol {HubProtocol} which is not supported by this server. This likely means you have different versions of your server deployed.", EventName = "MismatchedServers")] + public static partial void MismatchedServers(ILogger logger, string hubProtocol); + + [LoggerMessage(13, LogLevel.Error, "Error forwarding client result with ID '{InvocationID}' to server.", EventName = "ErrorForwardingResult")] + public static partial void ErrorForwardingResult(ILogger logger, string invocationId, Exception ex); + // This isn't DefineMessage-based because it's just the simple TextWriter logging from ConnectionMultiplexer public static void ConnectionMultiplexerMessage(ILogger logger, string? message) { diff --git a/src/SignalR/server/StackExchangeRedis/src/Internal/RedisProtocol.cs b/src/SignalR/server/StackExchangeRedis/src/Internal/RedisProtocol.cs index 719413a08136..7a60143c97d4 100644 --- a/src/SignalR/server/StackExchangeRedis/src/Internal/RedisProtocol.cs +++ b/src/SignalR/server/StackExchangeRedis/src/Internal/RedisProtocol.cs @@ -25,20 +25,22 @@ public RedisProtocol(DefaultHubMessageSerializer messageSerializer) // * Invocations are sent to the All, Group, Connection and User channels // * Group Commands are sent to the GroupManagement channel // * Acks are sent to the Acknowledgement channel. + // * Completion messages (client results) are sent to the server specific Result channel // * See the Write[type] methods for a description of the protocol for each in-depth. // * The "Variable length integer" is the length-prefixing format used by BinaryReader/BinaryWriter: // * https://docs.microsoft.com/dotnet/api/system.io.binarywriter.write?view=netcore-2.2 // * The "Length prefixed string" is the string format used by BinaryReader/BinaryWriter: // * A 7-bit variable length integer encodes the length in bytes, followed by the encoded string in UTF-8. - public byte[] WriteInvocation(string methodName, object?[] args) => - WriteInvocation(methodName, args, excludedConnectionIds: null); - - public byte[] WriteInvocation(string methodName, object?[] args, IReadOnlyList? excludedConnectionIds) + public byte[] WriteInvocation(string methodName, object?[] args, string? invocationId = null, + IReadOnlyList? excludedConnectionIds = null, string? returnChannel = null) { // Written as a MessagePack 'arr' containing at least these items: // * A MessagePack 'arr' of 'str's representing the excluded ids // * [The output of WriteSerializedHubMessage, which is an 'arr'] + // For invocations expecting a result + // * InvocationID + // * Redis return channel // Any additional items are discarded. var memoryBufferWriter = MemoryBufferWriter.Get(); @@ -46,7 +48,14 @@ public byte[] WriteInvocation(string methodName, object?[] args, IReadOnlyList 0) { writer.WriteArrayHeader(excludedConnectionIds.Count); @@ -60,7 +69,16 @@ public byte[] WriteInvocation(string methodName, object?[] args, IReadOnlyList data) { // See WriteInvocation for the format var reader = new MessagePackReader(data); - ValidateArraySize(ref reader, 2, "Invocation"); + var length = ValidateArraySize(ref reader, 2, "Invocation"); + + string? returnChannel = null; + string? invocationId = null; // Read excluded Ids IReadOnlyList? excludedConnectionIds = null; @@ -147,7 +192,14 @@ public static RedisInvocation ReadInvocation(ReadOnlyMemory data) // Read payload var message = ReadSerializedHubMessage(ref reader); - return new RedisInvocation(message, excludedConnectionIds); + + if (length > 3) + { + invocationId = reader.ReadString(); + returnChannel = reader.ReadString(); + } + + return new RedisInvocation(message, excludedConnectionIds, invocationId, returnChannel); } public static RedisGroupCommand ReadGroupCommand(ReadOnlyMemory data) @@ -209,7 +261,18 @@ public static SerializedHubMessage ReadSerializedHubMessage(ref MessagePackReade return new SerializedHubMessage(serializations); } - private static void ValidateArraySize(ref MessagePackReader reader, int expectedLength, string messageType) + public static RedisCompletion ReadCompletion(ReadOnlyMemory data) + { + // See WriteCompletionMessage for the format + var reader = new MessagePackReader(data); + ValidateArraySize(ref reader, 2, "CompletionMessage"); + + var protocolName = reader.ReadString(); + var ros = reader.ReadBytes(); + return new RedisCompletion(protocolName, ros ?? new ReadOnlySequence()); + } + + private static int ValidateArraySize(ref MessagePackReader reader, int expectedLength, string messageType) { var length = reader.ReadArrayHeader(); @@ -217,5 +280,6 @@ private static void ValidateArraySize(ref MessagePackReader reader, int expected { throw new InvalidDataException($"Insufficient items in {messageType} array."); } + return length; } } diff --git a/src/SignalR/server/StackExchangeRedis/src/Microsoft.AspNetCore.SignalR.StackExchangeRedis.csproj b/src/SignalR/server/StackExchangeRedis/src/Microsoft.AspNetCore.SignalR.StackExchangeRedis.csproj index ae0ff95ad8cb..e3ed270e72ef 100644 --- a/src/SignalR/server/StackExchangeRedis/src/Microsoft.AspNetCore.SignalR.StackExchangeRedis.csproj +++ b/src/SignalR/server/StackExchangeRedis/src/Microsoft.AspNetCore.SignalR.StackExchangeRedis.csproj @@ -9,6 +9,8 @@ + + diff --git a/src/SignalR/server/StackExchangeRedis/src/PublicAPI.Unshipped.txt b/src/SignalR/server/StackExchangeRedis/src/PublicAPI.Unshipped.txt index 68e96ee62e38..bdac4aa9fd92 100644 --- a/src/SignalR/server/StackExchangeRedis/src/PublicAPI.Unshipped.txt +++ b/src/SignalR/server/StackExchangeRedis/src/PublicAPI.Unshipped.txt @@ -3,3 +3,6 @@ *REMOVED*~Microsoft.AspNetCore.SignalR.StackExchangeRedis.RedisHubLifetimeManager.RedisHubLifetimeManager(Microsoft.Extensions.Logging.ILogger!>! logger, Microsoft.Extensions.Options.IOptions! options, Microsoft.AspNetCore.SignalR.IHubProtocolResolver! hubProtocolResolver, Microsoft.Extensions.Options.IOptions? globalHubOptions, Microsoft.Extensions.Options.IOptions!>? hubOptions) -> void Microsoft.AspNetCore.SignalR.StackExchangeRedis.RedisHubLifetimeManager.RedisHubLifetimeManager(Microsoft.Extensions.Logging.ILogger!>! logger, Microsoft.Extensions.Options.IOptions! options, Microsoft.AspNetCore.SignalR.IHubProtocolResolver! hubProtocolResolver) -> void Microsoft.AspNetCore.SignalR.StackExchangeRedis.RedisHubLifetimeManager.RedisHubLifetimeManager(Microsoft.Extensions.Logging.ILogger!>! logger, Microsoft.Extensions.Options.IOptions! options, Microsoft.AspNetCore.SignalR.IHubProtocolResolver! hubProtocolResolver, Microsoft.Extensions.Options.IOptions? globalHubOptions, Microsoft.Extensions.Options.IOptions!>? hubOptions) -> void +override Microsoft.AspNetCore.SignalR.StackExchangeRedis.RedisHubLifetimeManager.InvokeConnectionAsync(string! connectionId, string! methodName, object?[]! args, System.Threading.CancellationToken cancellationToken = default(System.Threading.CancellationToken)) -> System.Threading.Tasks.Task! +override Microsoft.AspNetCore.SignalR.StackExchangeRedis.RedisHubLifetimeManager.SetConnectionResultAsync(string! connectionId, Microsoft.AspNetCore.SignalR.Protocol.CompletionMessage! result) -> System.Threading.Tasks.Task! +override Microsoft.AspNetCore.SignalR.StackExchangeRedis.RedisHubLifetimeManager.TryGetReturnType(string! invocationId, out System.Type? type) -> bool diff --git a/src/SignalR/server/StackExchangeRedis/src/RedisHubLifetimeManager.cs b/src/SignalR/server/StackExchangeRedis/src/RedisHubLifetimeManager.cs index f3533789a336..6af795083e08 100644 --- a/src/SignalR/server/StackExchangeRedis/src/RedisHubLifetimeManager.cs +++ b/src/SignalR/server/StackExchangeRedis/src/RedisHubLifetimeManager.cs @@ -1,6 +1,9 @@ // Licensed to the .NET Foundation under one or more agreements. // The .NET Foundation licenses this file to you under the MIT license. +using System.Buffers; +using System.Diagnostics; +using System.Diagnostics.CodeAnalysis; using System.Linq; using System.Text; using Microsoft.AspNetCore.Http.Features; @@ -30,9 +33,12 @@ public class RedisHubLifetimeManager : HubLifetimeManager, IDisposab private readonly string _serverName = GenerateServerName(); private readonly RedisProtocol _protocol; private readonly SemaphoreSlim _connectionLock = new SemaphoreSlim(1); + private readonly IHubProtocolResolver _hubProtocolResolver; + private readonly ClientResultsManager _clientResultsManager = new(); private readonly AckHandler _ackHandler; - private int _internalId; + private int _internalAckId; + private ulong _lastInvocationId; /// /// Constructs the with types from Dependency Injection. @@ -61,6 +67,7 @@ public RedisHubLifetimeManager(ILogger> logger, IOptions? globalHubOptions, IOptions>? hubOptions) { + _hubProtocolResolver = hubProtocolResolver; _logger = logger; _options = options.Value; _ackHandler = new AckHandler(); @@ -144,7 +151,7 @@ public override Task SendAllAsync(string methodName, object?[] args, Cancellatio /// public override Task SendAllExceptAsync(string methodName, object?[] args, IReadOnlyList excludedConnectionIds, CancellationToken cancellationToken = default) { - var message = _protocol.WriteInvocation(methodName, args, excludedConnectionIds); + var message = _protocol.WriteInvocation(methodName, args, excludedConnectionIds: excludedConnectionIds); return PublishAsync(_channels.All, message); } @@ -188,7 +195,7 @@ public override Task SendGroupExceptAsync(string groupName, string methodName, o throw new ArgumentNullException(nameof(groupName)); } - var message = _protocol.WriteInvocation(methodName, args, excludedConnectionIds); + var message = _protocol.WriteInvocation(methodName, args, excludedConnectionIds: excludedConnectionIds); return PublishAsync(_channels.Group(groupName), message); } @@ -306,11 +313,11 @@ public override Task SendUsersAsync(IReadOnlyList userIds, string method return Task.CompletedTask; } - private async Task PublishAsync(string channel, byte[] payload) + private async Task PublishAsync(string channel, byte[] payload) { await EnsureRedisServerConnection(); RedisLog.PublishToChannel(_logger, channel); - await _bus!.PublishAsync(channel, payload); + return await _bus!.PublishAsync(channel, payload); } private Task AddGroupAsyncCore(HubConnectionContext connection, string groupName) @@ -358,7 +365,7 @@ await _groups.RemoveSubscriptionAsync(groupChannel, connection, channelName => private async Task SendGroupActionAndWaitForAck(string connectionId, string groupName, GroupAction action) { - var id = Interlocked.Increment(ref _internalId); + var id = Interlocked.Increment(ref _internalAckId); var ack = _ackHandler.CreateAck(id); // Send Add/Remove Group to other servers and wait for an ack or timeout var message = RedisProtocol.WriteGroupCommand(new RedisGroupCommand(id, _serverName, action, groupName, connectionId)); @@ -388,6 +395,79 @@ public void Dispose() _ackHandler.Dispose(); } + /// + public override async Task InvokeConnectionAsync(string connectionId, string methodName, object?[] args, CancellationToken cancellationToken = default) + { + // send thing + if (connectionId == null) + { + throw new ArgumentNullException(nameof(connectionId)); + } + + var connection = _connections[connectionId]; + + // Needs to be unique across servers, easiest way to do that is prefix with connection ID. + var invocationId = $"{connectionId}{Interlocked.Increment(ref _lastInvocationId)}"; + + using var _ = CancellationTokenUtils.CreateLinkedToken(cancellationToken, + connection?.ConnectionAborted ?? default, out var linkedToken); + var task = _clientResultsManager.AddInvocation(connectionId, invocationId, linkedToken); + + try + { + if (connection == null) + { + // TODO: Need to handle other server going away while waiting for connection result + var m = _protocol.WriteInvocation(methodName, args, invocationId, returnChannel: _channels.ReturnResults(_serverName)); + var received = await PublishAsync(_channels.Connection(connectionId), m); + if (received < 1) + { + throw new IOException($"Connection '{connectionId}' does not exist."); + } + } + else + { + // We're sending to a single connection + // Write message directly to connection without caching it in memory + var message = new InvocationMessage(invocationId, methodName, args); + + await connection.WriteAsync(message, cancellationToken); + } + } + catch + { + _clientResultsManager.RemoveInvocation(invocationId); + throw; + } + + try + { + return await task; + } + catch + { + // ConnectionAborted will trigger a generic "Canceled" exception from the task, let's convert it into a more specific message. + if (connection?.ConnectionAborted.IsCancellationRequested == true) + { + throw new IOException($"Connection '{connectionId}' disconnected."); + } + throw; + } + } + + /// + public override Task SetConnectionResultAsync(string connectionId, CompletionMessage result) + { + _clientResultsManager.TryCompleteResult(connectionId, result); + return Task.CompletedTask; + } + + /// + public override bool TryGetReturnType(string invocationId, [NotNullWhen(true)] out Type? type) + { + return _clientResultsManager.TryGetType(invocationId, out type); + } + private async Task SubscribeToAll() { RedisLog.Subscribing(_logger, _channels.All); @@ -476,6 +556,48 @@ private async Task SubscribeToConnection(HubConnectionContext connection) channel.OnMessage(channelMessage => { var invocation = RedisProtocol.ReadInvocation((byte[])channelMessage.Message); + + // This is a Client result we need to setup state for the completion and forward the message to the client + if (!string.IsNullOrEmpty(invocation.InvocationId)) + { + CancellationTokenRegistration? tokenRegistration = null; + _clientResultsManager.AddInvocation(invocation.InvocationId, + (typeof(RawResult), connection.ConnectionId, null!, async (_, completionMessage) => + { + var protocolName = connection.Protocol.Name; + tokenRegistration?.Dispose(); + + var memoryBufferWriter = AspNetCore.Internal.MemoryBufferWriter.Get(); + byte[] message; + try + { + try + { + connection.Protocol.WriteMessage(completionMessage, memoryBufferWriter); + message = RedisProtocol.WriteCompletionMessage(memoryBufferWriter, protocolName); + } + finally + { + memoryBufferWriter.Dispose(); + } + await PublishAsync(invocation.ReturnChannel!, message); + } + catch (Exception ex) + { + RedisLog.ErrorForwardingResult(_logger, completionMessage.InvocationId, ex); + } + })); + + // TODO: this isn't great + tokenRegistration = connection.ConnectionAborted.UnsafeRegister(_ => + { + var invocationInfo = _clientResultsManager.RemoveInvocation(invocation.InvocationId); + invocationInfo?.Completion(null!, CompletionMessage.WithError(invocation.InvocationId, "Connection disconnected.")); + }, null); + } + + // Forward message from other server to client + // Normal client method invokes and client result invokes use the same message return connection.WriteAsync(invocation.Message).AsTask(); }); } @@ -540,6 +662,38 @@ private async Task SubscribeToGroupAsync(string groupChannel, HubConnectionStore }); } + private async Task SubscribeToReturnResultsAsync() + { + var channel = await _bus!.SubscribeAsync(_channels.ReturnResults(_serverName)); + channel.OnMessage((channelMessage) => + { + var completion = RedisProtocol.ReadCompletion(channelMessage.Message); + IHubProtocol? protocol = null; + foreach (var hubProtocol in _hubProtocolResolver.AllProtocols) + { + if (hubProtocol.Name.Equals(completion.ProtocolName)) + { + protocol = hubProtocol; + break; + } + } + + // Should only happen if you have different versions of servers and don't have the same protocols registered on both + if (protocol is null) + { + RedisLog.MismatchedServers(_logger, completion.ProtocolName); + return; + } + + var ros = completion.CompletionMessage; + var parseSuccess = protocol.TryParseMessage(ref ros, _clientResultsManager, out var hubMessage); + Debug.Assert(parseSuccess); + + var invocationInfo = _clientResultsManager.RemoveInvocation(((CompletionMessage)hubMessage!).InvocationId!); + invocationInfo?.Completion(invocationInfo?.Tcs!, (CompletionMessage)hubMessage!); + }); + } + private async Task EnsureRedisServerConnection() { if (_redisServerConnection == null) @@ -589,6 +743,7 @@ private async Task EnsureRedisServerConnection() await SubscribeToAll(); await SubscribeToGroupManagementChannel(); await SubscribeToAckChannel(); + await SubscribeToReturnResultsAsync(); } } finally diff --git a/src/SignalR/server/StackExchangeRedis/test/RedisProtocolTests.cs b/src/SignalR/server/StackExchangeRedis/test/RedisProtocolTests.cs index dac4a79894c3..0a9d56a233f5 100644 --- a/src/SignalR/server/StackExchangeRedis/test/RedisProtocolTests.cs +++ b/src/SignalR/server/StackExchangeRedis/test/RedisProtocolTests.cs @@ -2,8 +2,11 @@ // The .NET Foundation licenses this file to you under the MIT license. using System; +using System.Buffers; using System.Collections.Generic; using System.Linq; +using System.Text; +using Microsoft.AspNetCore.Internal; using Microsoft.AspNetCore.SignalR.Internal; using Microsoft.AspNetCore.SignalR.Protocol; using Microsoft.AspNetCore.SignalR.StackExchangeRedis.Internal; @@ -176,7 +179,7 @@ public void WriteInvocation(string testName) // Actual invocation doesn't matter because we're using a dummy hub protocol. // But the dummy protocol will check that we gave it the test message to make sure everything flows through properly. var expected = testData.Decoded(); - var encoded = protocol.WriteInvocation(_testMessage.Target, _testMessage.Arguments, expected.ExcludedConnectionIds); + var encoded = protocol.WriteInvocation(_testMessage.Target, _testMessage.Arguments, excludedConnectionIds: expected.ExcludedConnectionIds); Assert.Equal(testData.Encoded, encoded); } @@ -192,7 +195,81 @@ public void WriteInvocationWithHubMessageSerializer(string testName) // Actual invocation doesn't matter because we're using a dummy hub protocol. // But the dummy protocol will check that we gave it the test message to make sure everything flows through properly. var expected = testData.Decoded(); - var encoded = protocol.WriteInvocation(_testMessage.Target, _testMessage.Arguments, expected.ExcludedConnectionIds); + var encoded = protocol.WriteInvocation(_testMessage.Target, _testMessage.Arguments, excludedConnectionIds: expected.ExcludedConnectionIds); + + Assert.Equal(testData.Encoded, encoded); + } + + private static readonly Dictionary> _completionMessageTestData = new[] + { + CreateTestData( + "JsonMessageForwarded", + new RedisCompletion("json", new ReadOnlySequence(Encoding.UTF8.GetBytes("{\"type\":3,\"invocationId\":\"1\",\"result\":1}"))), + 0x92, + 0xa4, + (byte)'j', + (byte)'s', + (byte)'o', + (byte)'n', + 0xc4, // 'bin' + 0x28, // length + 0x7b, 0x22, 0x74, 0x79, 0x70, 0x65, 0x22, 0x3a, 0x33, 0x2c, 0x22, 0x69, 0x6e, 0x76, 0x6f, 0x63, 0x61, 0x74, 0x69, + 0x6f, 0x6e, 0x49, 0x64, 0x22, 0x3a, 0x22, 0x31, 0x22, 0x2c, 0x22, 0x72, 0x65, 0x73, 0x75, 0x6c, 0x74, 0x22, 0x3a, + 0x31, 0x7d), + CreateTestData( + "MsgPackMessageForwarded", + new RedisCompletion("messagepack", new ReadOnlySequence(new byte[] { 0x95, 0x03, 0x80, 0xa3, (byte)'x', (byte)'y', (byte)'z', 0x03, 0x2a })), + 0x92, + 0xab, + (byte)'m', + (byte)'e', + (byte)'s', + (byte)'s', + (byte)'a', + (byte)'g', + (byte)'e', + (byte)'p', + (byte)'a', + (byte)'c', + (byte)'k', + 0xc4, // 'bin' + 0x09, // 'bin' length + 0x95, // 5 array elements + 0x03, // type: 3 + 0x80, // empty headers + 0xa3, // 'str' + (byte)'x', + (byte)'y', + (byte)'z', + 0x03, // has result + 0x2a), // 42 + }.ToDictionary(t => t.Name); + + public static IEnumerable CompletionMessageTestData = _completionMessageTestData.Keys.Select(k => new object[] { k }); + + [Theory] + [MemberData(nameof(CompletionMessageTestData))] + public void ParseCompletionMessage(string testName) + { + var testData = _completionMessageTestData[testName]; + + var completionMessage = RedisProtocol.ReadCompletion(testData.Encoded); + + Assert.Equal(testData.Decoded.ProtocolName, completionMessage.ProtocolName); + Assert.Equal(testData.Decoded.CompletionMessage.ToArray(), completionMessage.CompletionMessage.ToArray()); + } + + [Theory] + [MemberData(nameof(CompletionMessageTestData))] + public void WriteCompletionMessage(string testName) + { + var testData = _completionMessageTestData[testName]; + + var writer = MemoryBufferWriter.Get(); + writer.Write(testData.Decoded.CompletionMessage.ToArray()); + + var encoded = RedisProtocol.WriteCompletionMessage(writer, testData.Decoded.ProtocolName); + MemoryBufferWriter.Return(writer); Assert.Equal(testData.Encoded, encoded); }