diff --git a/samples/AzureFunctionsApp/AzureFunctionsApp.csproj b/samples/AzureFunctionsApp/AzureFunctionsApp.csproj index 49dc1988d..c2e463efa 100644 --- a/samples/AzureFunctionsApp/AzureFunctionsApp.csproj +++ b/samples/AzureFunctionsApp/AzureFunctionsApp.csproj @@ -8,10 +8,10 @@ - - + + - + diff --git a/src/Abstractions/Entities/TaskEntity.cs b/src/Abstractions/Entities/TaskEntity.cs index 29932f141..ec40154c9 100644 --- a/src/Abstractions/Entities/TaskEntity.cs +++ b/src/Abstractions/Entities/TaskEntity.cs @@ -1,11 +1,22 @@ // Copyright (c) Microsoft Corporation. // Licensed under the MIT License. +using System.Reflection; +using System.Threading.Tasks; + namespace Microsoft.DurableTask.Entities; /// /// The task entity contract. /// +/// +/// Entity State +/// +/// All entity implementations are required to be serializable by the configured . An entity +/// will have its state deserialized before executing an operation, and then the new state will be the serialized value +/// of the implementation instance post-operation. +/// +/// public interface ITaskEntity { /// @@ -15,3 +26,221 @@ public interface ITaskEntity /// The response to the caller, if any. ValueTask RunAsync(TaskEntityOperation operation); } + +/// +/// An which dispatches its operations to public instance methods or properties. +/// +/// +/// Method Binding +/// +/// When using this base class, all public methods will be considered valid entity operations. +/// +/// Only public methods are considered (private, internal, and protected are not.) +/// Properties are not considered. +/// Operation matching is case insensitive. +/// is thrown if no matching public method is found for an operation. +/// is thrown if there are multiple public overloads for an operation name. +/// +/// +/// +/// Parameter Binding +/// +/// Operation methods supports parameter binding as follows: +/// +/// Can bind to the context by adding a parameter of type . +/// Can bind to the raw operation by adding a parameter of type . +/// Can bind to the operation input directly by adding any parameter which does not match a previously described +/// binding candidate. The operation input, if available, will be deserialized to that type. +/// Default parameters can be used for input to allow for an operation to execute (with the default value) without +/// an input being provided. +/// +/// +/// will be thrown if: +/// +/// There is a redundant parameter binding (ie: two context, operation, or input matches) +/// There is an input binding, but no input was provided. +/// There is another unknown type present which does not match context, operation, or input. +/// +/// +/// +/// Return Value +/// +/// Any value returned by the bound method will be returned to the operation caller. Not all callers wait for a return +/// value, such as signal-only callers. The return value is ignored in these cases. +/// +/// +/// Entity State +/// +/// Unchanged from . Entity state is the serialized value of the entity after an operation +/// completes. +/// +/// +public abstract class TaskEntity : ITaskEntity +{ + /** + * TODO: + * 1. Consider caching a compiled delegate for a given operation name. + */ + static readonly BindingFlags InstanceBindingFlags + = BindingFlags.Public | BindingFlags.Instance | BindingFlags.IgnoreCase; + + /// + public ValueTask RunAsync(TaskEntityOperation operation) + { + Check.NotNull(operation); + if (!this.TryDispatchMethod(operation, out object? result, out Type returnType)) + { + throw new NotSupportedException($"No suitable method found for entity operation '{operation}'."); + } + + if (typeof(Task).IsAssignableFrom(returnType)) + { + // Task or Task + return new(AsGeneric((Task)result!, returnType)); // we assume a declared Task return type is never null. + } + + if (returnType == typeof(ValueTask)) + { + // ValueTask + return AsGeneric((ValueTask)result!); // we assume a declared ValueTask return type is never null. + } + + if (returnType.IsGenericType && returnType.GetGenericTypeDefinition() == typeof(ValueTask<>)) + { + // ValueTask + return AsGeneric(result!, returnType); // No inheritance, have to do purely via reflection. + } + + return new(result); + } + + static bool TryGetInput(ParameterInfo parameter, TaskEntityOperation operation, out object? input) + { + if (!operation.HasInput) + { + if (parameter.HasDefaultValue) + { + input = parameter.DefaultValue; + return true; + } + + input = null; + return false; + } + + input = operation.GetInput(parameter.ParameterType); + return true; + } + + static async Task AsGeneric(Task task, Type declared) + { + await task; + if (declared.IsGenericType && declared.GetGenericTypeDefinition() == typeof(Task<>)) + { + return declared.GetProperty("Result", BindingFlags.Public | BindingFlags.Instance).GetValue(task); + } + + return null; + } + + static ValueTask AsGeneric(ValueTask t) + { + static async Task Await(ValueTask t) + { + await t; + return null; + } + + if (t.IsCompletedSuccessfully) + { + return default; + } + + return new(Await(t)); + } + + static ValueTask AsGeneric(object result, Type type) + { + // result and type here must be some form of ValueTask. + if ((bool)type.GetProperty("IsCompletedSuccessfully").GetValue(result)) + { + return new(type.GetProperty("Result").GetValue(result)); + } + else + { + Task t = (Task)type.GetMethod("AsTask", BindingFlags.Instance | BindingFlags.Public) + .Invoke(result, null); + return new(t.ToGeneric()); + } + } + + bool TryDispatchMethod(TaskEntityOperation operation, out object? result, out Type returnType) + { + Type t = this.GetType(); + + // Will throw AmbiguousMatchException if more than 1 overload for the method name exists. + MethodInfo? method = t.GetMethod(operation.Name, InstanceBindingFlags); + if (method is null) + { + result = null; + returnType = typeof(void); + return false; + } + + ParameterInfo[] parameters = method.GetParameters(); + object?[] inputs = new object[parameters.Length]; + + int i = 0; + ParameterInfo? inputResolved = null; + ParameterInfo? contextResolved = null; + ParameterInfo? operationResolved = null; + foreach (ParameterInfo parameter in parameters) + { + if (parameter.ParameterType == typeof(TaskEntityContext)) + { + ThrowIfDuplicateBinding(contextResolved, parameter, "context", operation); + inputs[i] = operation.Context; + contextResolved = parameter; + } + else if (parameter.ParameterType == typeof(TaskEntityOperation)) + { + ThrowIfDuplicateBinding(operationResolved, parameter, "operation", operation); + inputs[i] = operation; + operationResolved = parameter; + } + else + { + ThrowIfDuplicateBinding(inputResolved, parameter, "input", operation); + if (TryGetInput(parameter, operation, out object? input)) + { + inputs[i] = input; + inputResolved = parameter; + } + else + { + throw new InvalidOperationException($"Error dispatching {operation} to '{method}'.\n" + + $"There was an error binding parameter '{parameter}'. The operation expected an input value, " + + "but no input was provided by the caller."); + } + } + + i++; + } + + result = method.Invoke(this, inputs); + returnType = method.ReturnType; + return true; + + static void ThrowIfDuplicateBinding( + ParameterInfo? existing, ParameterInfo parameter, string bindingConcept, TaskEntityOperation operation) + { + if (existing is not null) + { + throw new InvalidOperationException($"Error dispatching {operation} to '{parameter.Member}'.\n" + + $"Unable to bind {bindingConcept} to '{parameter}' because it has " + + $"already been bound to parameter '{existing}'. Please remove the duplicate parameter in method " + + $"'{parameter.Member}'.\nEntity operation: {operation}."); + } + } + } +} diff --git a/src/Abstractions/Entities/TaskEntityOperation.cs b/src/Abstractions/Entities/TaskEntityOperation.cs index 0e83572a5..80f894614 100644 --- a/src/Abstractions/Entities/TaskEntityOperation.cs +++ b/src/Abstractions/Entities/TaskEntityOperation.cs @@ -18,10 +18,21 @@ public abstract class TaskEntityOperation /// public abstract TaskEntityContext Context { get; } + /// + /// Gets a value indicating whether this operation has input or not. + /// + public abstract bool HasInput { get; } + /// /// Gets the input for this operation. /// /// The type to deserialize the input as. /// The deserialized input type. public abstract object? GetInput(Type inputType); + + /// + public override string ToString() + { + return $"{this.Context.Id.Name}/{this.Name}"; + } } diff --git a/src/Shared/Core/TaskExtensions.cs b/src/Shared/Core/TaskExtensions.cs new file mode 100644 index 000000000..116c6623a --- /dev/null +++ b/src/Shared/Core/TaskExtensions.cs @@ -0,0 +1,44 @@ +// Copyright (c) Microsoft Corporation. +// Licensed under the MIT License. + +using System.Reflection; +using Microsoft.DurableTask; + +namespace System; + +/// +/// Extensions for and . +/// +static class TaskExtensions +{ + /// + /// Converts a to a generic . + /// + /// The generic type param to convert to. + /// The task to convert. + /// The converted task. + public static async Task ToGeneric(this Task task) + { + await Check.NotNull(task); + + Type t = task.GetType(); + if (t.IsGenericType) + { + return (T)t.GetProperty("Result", BindingFlags.Public | BindingFlags.Instance).GetValue(task); + } + + return default; + } + + /// + /// Converts a to a . + /// + /// The generic type param to convert to. + /// The value task to convert. + /// The converted value task. + public static async ValueTask ToGeneric(this ValueTask task) + { + await task; + return default; + } +} diff --git a/test/Abstractions.Tests/Entities/TaskEntityTests.cs b/test/Abstractions.Tests/Entities/TaskEntityTests.cs new file mode 100644 index 000000000..ea5129c83 --- /dev/null +++ b/test/Abstractions.Tests/Entities/TaskEntityTests.cs @@ -0,0 +1,299 @@ +// Copyright (c) Microsoft Corporation. +// Licensed under the MIT License. + +using System.Reflection; +using DotNext; + +namespace Microsoft.DurableTask.Entities.Tests; + +public class TaskEntityTests +{ + [Theory] + [InlineData("doesNotExist")] // method does not exist. + [InlineData("add")] // private method, should not work. + [InlineData("staticMethod")] // public static methods are not supported. + public async Task OperationNotSupported_Fails(string name) + { + Operation operation = new(name, Mock.Of(), 10); + TestEntity entity = new(); + + Func> action = () => entity.RunAsync(operation).AsTask(); + + await action.Should().ThrowAsync(); + } + + [Theory] + [CombinatorialData] + public async Task TaskOperation_Success( + [CombinatorialValues("TaskOp", "TaskOfStringOp", "ValueTaskOp", "ValueTaskOfStringOp")] string name, bool sync) + { + object? expected = name.Contains("OfString") ? "success" : null; + Operation operation = new(name, Mock.Of(), sync); + TestEntity entity = new(); + + object? result = await entity.RunAsync(operation); + + result.Should().Be(expected); + } + + [Theory] + [CombinatorialData] + public async Task Add_Success([CombinatorialRange(0, 14)] int method, bool lowercase) + { + int start = Random.Shared.Next(0, 10); + int toAdd = Random.Shared.Next(0, 10); + string opName = lowercase ? "add" : "Add"; + Operation operation = new($"{opName}{method}", Mock.Of(), toAdd); + TestEntity entity = new() { Value = start }; + + object? result = await entity.RunAsync(operation); + + int expected = start + toAdd; + entity.Value.Should().Be(expected); + result.Should().BeOfType().Which.Should().Be(expected); + } + + [Theory] + [CombinatorialData] + public async Task Get_Success([CombinatorialRange(0, 2)] int method, bool lowercase) + { + int expected = Random.Shared.Next(0, 10); + string opName = lowercase ? "get" : "Get"; + Operation operation = new($"{opName}{method}", Mock.Of(), default); + TestEntity entity = new() { Value = expected }; + + object? result = await entity.RunAsync(operation); + + entity.Value.Should().Be(expected); + result.Should().BeOfType().Which.Should().Be(expected); + } + + [Fact] + public async Task Add_NoInput_Fails() + { + Operation operation = new("add0", Mock.Of(), default); + TestEntity entity = new(); + + Func> action = () => entity.RunAsync(operation).AsTask(); + + await action.Should().ThrowAsync(); + } + + [Theory] + [CombinatorialData] + public async Task Dispatch_AmbiguousArgs_Fails([CombinatorialRange(0, 3)] int method) + { + Operation operation = new($"ambiguousArgs{method}", Mock.Of(), 10); + TestEntity entity = new(); + + Func> action = () => entity.RunAsync(operation).AsTask(); + + await action.Should().ThrowAsync(); + } + + [Fact] + public async Task Dispatch_AmbiguousMatch_Fails() + { + Operation operation = new("ambiguousMatch", Mock.Of(), 10); + TestEntity entity = new(); + + Func> action = () => entity.RunAsync(operation).AsTask(); + await action.Should().ThrowAsync(); + } + + [Fact] + public async Task DefaultValue_NoInput_Succeeds() + { + Operation operation = new("defaultValue", Mock.Of(), default); + TestEntity entity = new(); + + object? result = await entity.RunAsync(operation); + + result.Should().BeOfType().Which.Should().Be("default"); + } + + [Fact] + public async Task DefaultValue_Input_Succeeds() + { + Operation operation = new("defaultValue", Mock.Of(), "not-default"); + TestEntity entity = new(); + + object? result = await entity.RunAsync(operation); + + result.Should().BeOfType().Which.Should().Be("not-default"); + } + + class Operation : TaskEntityOperation + { + readonly Optional input; + + public Operation(string name, TaskEntityContext context, Optional input) + { + this.Name = name; + this.Context = context; + this.input = input; + } + + public override string Name { get; } + + public override TaskEntityContext Context { get; } + + public override bool HasInput => this.input.IsPresent; + + public override object? GetInput(Type inputType) + { + if (!this.input.IsPresent) + { + throw new InvalidOperationException("No input available."); + } + + if (this.input.Value is null) + { + return null; + } + + if (!inputType.IsAssignableFrom(this.input.Value.GetType())) + { + throw new InvalidCastException("Cannot convert input type."); + } + + return this.input.Value; + } + } + + class TestEntity : TaskEntity + { + public int Value { get; set; } + + public static string StaticMethod() => throw new NotImplementedException(); + + // All possible permutations of the 3 inputs we support: object, context, operation + // 14 via Add, 2 via Get: 16 total. + public int Add0(int value) => this.Add(value, default, default); + + public int Add1(int value, TaskEntityContext context) => this.Add(value, context, default); + + public int Add2(int value, TaskEntityOperation operation) => this.Add(value, default, operation); + + public int Add3(int value, TaskEntityContext context, TaskEntityOperation operation) + => this.Add(value, context, operation); + + public int Add4(int value, TaskEntityOperation operation, TaskEntityContext context) + => this.Add(value, context, operation); + + public int Add5(TaskEntityOperation operation) => this.Add(default, default, operation); + + public int Add6(TaskEntityOperation operation, int value) => this.Add(value, default, operation); + + public int Add7(TaskEntityOperation operation, TaskEntityContext context) + => this.Add(default, context, operation); + + public int Add8(TaskEntityOperation operation, int value, TaskEntityContext context) + => this.Add(value, context, operation); + + public int Add9(TaskEntityOperation operation, TaskEntityContext context, int value) + => this.Add(value, context, operation); + + public int Add10(TaskEntityContext context, int value) + => this.Add(value, context, default); + + public int Add11(TaskEntityContext context, TaskEntityOperation operation) + => this.Add(default, context, operation); + + public int Add12(TaskEntityContext context, int value, TaskEntityOperation operation) + => this.Add(value, context, operation); + + public int Add13(TaskEntityContext context, TaskEntityOperation operation, int value) + => this.Add(value, context, operation); + + public int Get0() => this.Get(default); + + public int Get1(TaskEntityContext context) => this.Get(context); + + public int AmbiguousMatch(TaskEntityContext context) => this.Value; + + public int AmbiguousMatch(TaskEntityOperation operation) => this.Value; + + public int AmbiguousArgs0(int value, object other) => this.Add0(value); + + public int AmbiguousArgs1(int value, TaskEntityContext context, TaskEntityContext context2) => this.Add0(value); + + public int AmbiguousArgs2(int value, TaskEntityOperation operation, TaskEntityOperation operation2) + => this.Add0(value); + + public string DefaultValue(string toReturn = "default") => toReturn; + + public Task TaskOp(bool sync) + { + static async Task Slow() + { + await Task.Yield(); + } + + return sync ? Task.CompletedTask : Slow(); + } + + public Task TaskOfStringOp(bool sync) + { + static async Task Slow() + { + await Task.Yield(); + return "success"; + } + + return sync ? Task.FromResult("success") : Slow(); + } + + public ValueTask ValueTaskOp(bool sync) + { + static async Task Slow() + { + await Task.Yield(); + } + + return sync ? default : new(Slow()); + } + + public ValueTask ValueTaskOfStringOp(bool sync) + { + static async Task Slow() + { + await Task.Yield(); + return "success"; + } + + return sync ? new("success") : new(Slow()); + } + + int Add(int? value, Optional context, Optional operation) + { + if (context.IsPresent) + { + context.Value.Should().NotBeNull(); + } + + if (operation.IsPresent) + { + operation.Value.Should().NotBeNull(); + } + + if (!value.HasValue && operation.TryGet(out TaskEntityOperation op)) + { + value = (int)op.GetInput(typeof(int))!; + } + + value.HasValue.Should().BeTrue(); + return this.Value += value!.Value; + } + + int Get(Optional context) + { + if (context.IsPresent) + { + context.Value.Should().NotBeNull(); + } + + return this.Value; + } + } +} diff --git a/test/Directory.Build.targets b/test/Directory.Build.targets index d4e59594a..294f47da1 100644 --- a/test/Directory.Build.targets +++ b/test/Directory.Build.targets @@ -4,6 +4,7 @@ Condition=" '$([MSBuild]::GetDirectoryNameOfFileAbove($(MSBuildThisFileDirectory)../, $(_DirectoryBuildTargetsFile)))' != '' " /> + runtime; build; native; contentfiles; analyzers; buildtransitive @@ -13,6 +14,7 @@ + runtime; build; native; contentfiles; analyzers; buildtransitive all