diff --git a/src/Abstractions/Entities/TaskEntity.cs b/src/Abstractions/Entities/TaskEntity.cs index ec40154c9..837df7ab3 100644 --- a/src/Abstractions/Entities/TaskEntity.cs +++ b/src/Abstractions/Entities/TaskEntity.cs @@ -2,7 +2,6 @@ // Licensed under the MIT License. using System.Reflection; -using System.Threading.Tasks; namespace Microsoft.DurableTask.Entities; @@ -12,9 +11,7 @@ namespace Microsoft.DurableTask.Entities; /// /// Entity State /// -/// All entity implementations are required to be serializable by the configured . An entity -/// will have its state deserialized before executing an operation, and then the new state will be the serialized value -/// of the implementation instance post-operation. +/// The state of an entity can be retrieved and updated via . /// /// public interface ITaskEntity @@ -30,6 +27,7 @@ public interface ITaskEntity /// /// An which dispatches its operations to public instance methods or properties. /// +/// The state type held by this entity. /// /// Method Binding /// @@ -71,176 +69,99 @@ public interface ITaskEntity /// /// Entity State /// -/// Unchanged from . Entity state is the serialized value of the entity after an operation -/// completes. +/// Entity state will be hydrated into the property. See for +/// more information. /// /// -public abstract class TaskEntity : ITaskEntity +public abstract class TaskEntity : ITaskEntity { - /** - * TODO: - * 1. Consider caching a compiled delegate for a given operation name. - */ - static readonly BindingFlags InstanceBindingFlags - = BindingFlags.Public | BindingFlags.Instance | BindingFlags.IgnoreCase; - - /// - public ValueTask RunAsync(TaskEntityOperation operation) - { - Check.NotNull(operation); - if (!this.TryDispatchMethod(operation, out object? result, out Type returnType)) - { - throw new NotSupportedException($"No suitable method found for entity operation '{operation}'."); - } - - if (typeof(Task).IsAssignableFrom(returnType)) - { - // Task or Task - return new(AsGeneric((Task)result!, returnType)); // we assume a declared Task return type is never null. - } - - if (returnType == typeof(ValueTask)) - { - // ValueTask - return AsGeneric((ValueTask)result!); // we assume a declared ValueTask return type is never null. - } - - if (returnType.IsGenericType && returnType.GetGenericTypeDefinition() == typeof(ValueTask<>)) - { - // ValueTask - return AsGeneric(result!, returnType); // No inheritance, have to do purely via reflection. - } - - return new(result); - } + /// + /// Gets a value indicating whether dispatching operations to is allowed. State dispatch + /// will only be attempted if entity-level dispatch does not succeed. Default is false. Dispatching to state + /// follows the same rules as dispatching to this entity. + /// + protected virtual bool AllowStateDispatch => false; - static bool TryGetInput(ParameterInfo parameter, TaskEntityOperation operation, out object? input) - { - if (!operation.HasInput) - { - if (parameter.HasDefaultValue) - { - input = parameter.DefaultValue; - return true; - } + /// + /// Gets or sets the state for this entity. + /// + /// + /// Initialization + /// + /// This will be hydrated as part of . will + /// be called when state is null at the start of an operation only. + /// + /// Persistence + /// + /// The contents of this property will be persisted to at the end + /// of the operation. + /// + /// Deletion + /// + /// Deleting entity state is possible by setting this to null. Setting to default of a value-type will + /// not perform a delete. This means deleting entity state is only possible for reference types or using ? + /// on a value-type (ie: TaskEntity<int?>). + /// + /// + protected TState State { get; set; } = default!; - input = null; - return false; - } + /// + /// Gets the entity operation. + /// + protected TaskEntityOperation Operation { get; private set; } = null!; - input = operation.GetInput(parameter.ParameterType); - return true; - } + /// + /// Gets the entity context. + /// + protected TaskEntityContext Context => this.Operation.Context; - static async Task AsGeneric(Task task, Type declared) + /// + public ValueTask RunAsync(TaskEntityOperation operation) { - await task; - if (declared.IsGenericType && declared.GetGenericTypeDefinition() == typeof(Task<>)) + this.Operation = Check.NotNull(operation); + object? state = operation.Context.GetState(typeof(TState)); + this.State = state is null ? this.InitializeState() : (TState)state; + if (!operation.TryDispatch(this, out object? result, out Type returnType) + && !this.TryDispatchState(out result, out returnType)) { - return declared.GetProperty("Result", BindingFlags.Public | BindingFlags.Instance).GetValue(task); + throw new NotSupportedException($"No suitable method found for entity operation '{operation}'."); } - return null; + return TaskEntityHelpers.UnwrapAsync(this.Context, () => this.State, result, returnType); } - static ValueTask AsGeneric(ValueTask t) + /// + /// Initializes the entity state. This is only called when there is no current state for this entity. + /// + /// The entity state. + /// The default implementation uses . + protected virtual TState InitializeState() { - static async Task Await(ValueTask t) - { - await t; - return null; - } - - if (t.IsCompletedSuccessfully) + if (Nullable.GetUnderlyingType(typeof(TState)) is Type t) { - return default; + // Activator.CreateInstance>() returns null. To avoid this, we will instantiate via underlying + // type if it is Nullable. This keeps the experience consistent between value and reference type. If an + // implementation wants null, they must override this method and explicitly provide null. + return (TState)Activator.CreateInstance(t); } - return new(Await(t)); + return Activator.CreateInstance(); } - static ValueTask AsGeneric(object result, Type type) + bool TryDispatchState(out object? result, out Type returnType) { - // result and type here must be some form of ValueTask. - if ((bool)type.GetProperty("IsCompletedSuccessfully").GetValue(result)) - { - return new(type.GetProperty("Result").GetValue(result)); - } - else - { - Task t = (Task)type.GetMethod("AsTask", BindingFlags.Instance | BindingFlags.Public) - .Invoke(result, null); - return new(t.ToGeneric()); - } - } - - bool TryDispatchMethod(TaskEntityOperation operation, out object? result, out Type returnType) - { - Type t = this.GetType(); - - // Will throw AmbiguousMatchException if more than 1 overload for the method name exists. - MethodInfo? method = t.GetMethod(operation.Name, InstanceBindingFlags); - if (method is null) + if (!this.AllowStateDispatch) { result = null; returnType = typeof(void); return false; } - ParameterInfo[] parameters = method.GetParameters(); - object?[] inputs = new object[parameters.Length]; - - int i = 0; - ParameterInfo? inputResolved = null; - ParameterInfo? contextResolved = null; - ParameterInfo? operationResolved = null; - foreach (ParameterInfo parameter in parameters) + if (this.State is null) { - if (parameter.ParameterType == typeof(TaskEntityContext)) - { - ThrowIfDuplicateBinding(contextResolved, parameter, "context", operation); - inputs[i] = operation.Context; - contextResolved = parameter; - } - else if (parameter.ParameterType == typeof(TaskEntityOperation)) - { - ThrowIfDuplicateBinding(operationResolved, parameter, "operation", operation); - inputs[i] = operation; - operationResolved = parameter; - } - else - { - ThrowIfDuplicateBinding(inputResolved, parameter, "input", operation); - if (TryGetInput(parameter, operation, out object? input)) - { - inputs[i] = input; - inputResolved = parameter; - } - else - { - throw new InvalidOperationException($"Error dispatching {operation} to '{method}'.\n" + - $"There was an error binding parameter '{parameter}'. The operation expected an input value, " + - "but no input was provided by the caller."); - } - } - - i++; + throw new InvalidOperationException("Attempting to dispatch to state, but entity state is null."); } - result = method.Invoke(this, inputs); - returnType = method.ReturnType; - return true; - - static void ThrowIfDuplicateBinding( - ParameterInfo? existing, ParameterInfo parameter, string bindingConcept, TaskEntityOperation operation) - { - if (existing is not null) - { - throw new InvalidOperationException($"Error dispatching {operation} to '{parameter.Member}'.\n" + - $"Unable to bind {bindingConcept} to '{parameter}' because it has " + - $"already been bound to parameter '{existing}'. Please remove the duplicate parameter in method " + - $"'{parameter.Member}'.\nEntity operation: {operation}."); - } - } + return this.Operation.TryDispatch(this.State, out result, out returnType); } } diff --git a/src/Abstractions/Entities/TaskEntityContext.cs b/src/Abstractions/Entities/TaskEntityContext.cs index 1b5512567..8ce86fb94 100644 --- a/src/Abstractions/Entities/TaskEntityContext.cs +++ b/src/Abstractions/Entities/TaskEntityContext.cs @@ -53,11 +53,16 @@ public abstract void StartOrchestration( TaskName name, object? input = null, StartOrchestrationOptions? options = null); /// - /// Deletes the state of this entity after the current operation completes. + /// Gets the current state for the entity this context is for. This will return null if no state is present, + /// regardless if is a value-type or not. /// - /// - /// The state deletion only takes effect after the current operation completes. Any state changes made during the - /// current operation will be ignored in favor of the deletion. - /// - public abstract void DeleteState(); + /// The type to retrieve the state as. + /// The entity state. + public abstract object? GetState(Type type); + + /// + /// Sets the entity state. Setting of null will delete entity state. + /// + /// The state to set. + public abstract void SetState(object? state); } diff --git a/src/Abstractions/Entities/TaskEntityHelpers.cs b/src/Abstractions/Entities/TaskEntityHelpers.cs new file mode 100644 index 000000000..4437d981c --- /dev/null +++ b/src/Abstractions/Entities/TaskEntityHelpers.cs @@ -0,0 +1,100 @@ +// Copyright (c) Microsoft Corporation. +// Licensed under the MIT License. + +using System.Reflection; + +namespace Microsoft.DurableTask.Entities; + +/// +/// Helpers for task entities. +/// +static class TaskEntityHelpers +{ + /// + /// Unwraps a dispatched result for a into a . + /// + /// The entity context. + /// Delegate to resolve new state for the entity. + /// The result of the operation. + /// The declared type of the result (may be different that actual type). + /// A value task which holds the result of the operation and sets state before it completes. + public static ValueTask UnwrapAsync( + TaskEntityContext context, Func state, object? result, Type resultType) + { + // NOTE: Func is used for state so that we can lazily resolve it AFTER the operation has ran. + Check.NotNull(context); + Check.NotNull(resultType); + + if (typeof(Task).IsAssignableFrom(resultType)) + { + // Task or Task + // We assume a declared Task return type is never null. + return new(UnwrapTask(context, state, (Task)result!, resultType)); + } + + if (resultType == typeof(ValueTask)) + { + // ValueTask + // We assume a declared ValueTask return type is never null. + return UnwrapValueTask(context, state, (ValueTask)result!); + } + + if (resultType.IsGenericType && resultType.GetGenericTypeDefinition() == typeof(ValueTask<>)) + { + // ValueTask + // No inheritance, have to do purely via reflection. + return UnwrapValueTaskOfT(context, state, result!, resultType); + } + + context.SetState(state()); + return new(result); + } + + static async Task UnwrapTask(TaskEntityContext context, Func state, Task task, Type declared) + { + await task; + context.SetState(state()); + if (declared.IsGenericType && declared.GetGenericTypeDefinition() == typeof(Task<>)) + { + return declared.GetProperty("Result", BindingFlags.Public | BindingFlags.Instance).GetValue(task); + } + + return null; + } + + static ValueTask UnwrapValueTask(TaskEntityContext context, Func state, ValueTask t) + { + async Task Await(ValueTask t) + { + await t; + context.SetState(state()); + return null; + } + + if (t.IsCompletedSuccessfully) + { + context.SetState(state()); + return default; + } + + return new(Await(t)); + } + + static ValueTask UnwrapValueTaskOfT( + TaskEntityContext context, Func state, object result, Type type) + { + // Result and type here must be some form of ValueTask. + // TODO: can this amount of reflection be avoided? + if ((bool)type.GetProperty("IsCompletedSuccessfully").GetValue(result)) + { + context.SetState(state()); + return new(type.GetProperty("Result").GetValue(result)); + } + else + { + Task t = (Task)type.GetMethod("AsTask", BindingFlags.Instance | BindingFlags.Public).Invoke(result, null); + Type taskType = typeof(Task<>).MakeGenericType(type.GetGenericArguments()[0]); + return new(UnwrapTask(context, state, t, taskType)); + } + } +} diff --git a/src/Abstractions/Entities/TaskEntityOperationExtensions.cs b/src/Abstractions/Entities/TaskEntityOperationExtensions.cs new file mode 100644 index 000000000..303032da8 --- /dev/null +++ b/src/Abstractions/Entities/TaskEntityOperationExtensions.cs @@ -0,0 +1,118 @@ +// Copyright (c) Microsoft Corporation. +// Licensed under the MIT License. + +using System.Reflection; + +namespace Microsoft.DurableTask.Entities; + +/// +/// Extensions for . +/// +public static class TaskEntityOperationExtensions +{ + /** + * TODO: + * 1. Consider caching a compiled delegate for a given operation name. + */ + static readonly BindingFlags InstanceBindingFlags + = BindingFlags.Public | BindingFlags.Instance | BindingFlags.IgnoreCase; + + /// + /// Try to dispatch this operation via reflection to a method on . + /// + /// The operation that is being dispatched. + /// The target to dispatch to. + /// The result of the dispatch. + /// The declared return type of the dispatched method. + /// True if dispatch successful, false otherwise. + internal static bool TryDispatch( + this TaskEntityOperation operation, object target, out object? result, out Type returnType) + { + Check.NotNull(operation); + Check.NotNull(target); + Type t = target.GetType(); + + // Will throw AmbiguousMatchException if more than 1 overload for the method name exists. + MethodInfo? method = t.GetMethod(operation.Name, InstanceBindingFlags); + if (method is null) + { + result = null; + returnType = typeof(void); + return false; + } + + ParameterInfo[] parameters = method.GetParameters(); + object?[] inputs = new object[parameters.Length]; + + int i = 0; + ParameterInfo? inputResolved = null; + ParameterInfo? contextResolved = null; + ParameterInfo? operationResolved = null; + foreach (ParameterInfo parameter in parameters) + { + if (parameter.ParameterType == typeof(TaskEntityContext)) + { + ThrowIfDuplicateBinding(contextResolved, parameter, "context", operation); + inputs[i] = operation.Context; + contextResolved = parameter; + } + else if (parameter.ParameterType == typeof(TaskEntityOperation)) + { + ThrowIfDuplicateBinding(operationResolved, parameter, "operation", operation); + inputs[i] = operation; + operationResolved = parameter; + } + else + { + ThrowIfDuplicateBinding(inputResolved, parameter, "input", operation); + if (operation.TryGetInput(parameter, out object? input)) + { + inputs[i] = input; + inputResolved = parameter; + } + else + { + throw new InvalidOperationException($"Error dispatching {operation} to '{method}'.\n" + + $"There was an error binding parameter '{parameter}'. The operation expected an input value, " + + "but no input was provided by the caller."); + } + } + + i++; + } + + result = method.Invoke(target, inputs); + returnType = method.ReturnType; + return true; + + static void ThrowIfDuplicateBinding( + ParameterInfo? existing, ParameterInfo parameter, string bindingConcept, TaskEntityOperation operation) + { + if (existing is not null) + { + throw new InvalidOperationException($"Error dispatching {operation} to '{parameter.Member}'.\n" + + $"Unable to bind {bindingConcept} to '{parameter}' because it has " + + $"already been bound to parameter '{existing}'. Please remove the duplicate parameter in method " + + $"'{parameter.Member}'.\nEntity operation: {operation}."); + } + } + } + + static bool TryGetInput(this TaskEntityOperation operation, ParameterInfo parameter, out object? input) + { + if (!operation.HasInput) + { + if (parameter.HasDefaultValue) + { + input = parameter.DefaultValue; + return true; + } + + input = null; + return false; + } + + input = operation.GetInput(parameter.ParameterType); + return true; + } +} diff --git a/src/Shared/Core/TaskExtensions.cs b/src/Shared/Core/TaskExtensions.cs index 116c6623a..69c9356fb 100644 --- a/src/Shared/Core/TaskExtensions.cs +++ b/src/Shared/Core/TaskExtensions.cs @@ -24,7 +24,7 @@ static class TaskExtensions Type t = task.GetType(); if (t.IsGenericType) { - return (T)t.GetProperty("Result", BindingFlags.Public | BindingFlags.Instance).GetValue(task); + return (T)t.GetProperty("Result", BindingFlags.Public | BindingFlags.Instance)!.GetValue(task)!; } return default; diff --git a/test/Abstractions.Tests/Entities/TaskEntityTests.cs b/test/Abstractions.Tests/Entities/EntityTaskEntityTests.cs similarity index 75% rename from test/Abstractions.Tests/Entities/TaskEntityTests.cs rename to test/Abstractions.Tests/Entities/EntityTaskEntityTests.cs index ea5129c83..ecc5541f1 100644 --- a/test/Abstractions.Tests/Entities/TaskEntityTests.cs +++ b/test/Abstractions.Tests/Entities/EntityTaskEntityTests.cs @@ -6,7 +6,7 @@ namespace Microsoft.DurableTask.Entities.Tests; -public class TaskEntityTests +public class EntityTaskEntityTests { [Theory] [InlineData("doesNotExist")] // method does not exist. @@ -14,7 +14,7 @@ public class TaskEntityTests [InlineData("staticMethod")] // public static methods are not supported. public async Task OperationNotSupported_Fails(string name) { - Operation operation = new(name, Mock.Of(), 10); + TestEntityOperation operation = new(name, 10); TestEntity entity = new(); Func> action = () => entity.RunAsync(operation).AsTask(); @@ -28,7 +28,7 @@ public async Task TaskOperation_Success( [CombinatorialValues("TaskOp", "TaskOfStringOp", "ValueTaskOp", "ValueTaskOfStringOp")] string name, bool sync) { object? expected = name.Contains("OfString") ? "success" : null; - Operation operation = new(name, Mock.Of(), sync); + TestEntityOperation operation = new(name, sync); TestEntity entity = new(); object? result = await entity.RunAsync(operation); @@ -41,15 +41,16 @@ public async Task TaskOperation_Success( public async Task Add_Success([CombinatorialRange(0, 14)] int method, bool lowercase) { int start = Random.Shared.Next(0, 10); - int toAdd = Random.Shared.Next(0, 10); + int toAdd = Random.Shared.Next(1, 10); string opName = lowercase ? "add" : "Add"; - Operation operation = new($"{opName}{method}", Mock.Of(), toAdd); - TestEntity entity = new() { Value = start }; + TestEntityContext context = new(start); + TestEntityOperation operation = new($"{opName}{method}", context, toAdd); + TestEntity entity = new(); object? result = await entity.RunAsync(operation); int expected = start + toAdd; - entity.Value.Should().Be(expected); + context.GetState(typeof(int)).Should().BeOfType().Which.Should().Be(expected); result.Should().BeOfType().Which.Should().Be(expected); } @@ -59,19 +60,20 @@ public async Task Get_Success([CombinatorialRange(0, 2)] int method, bool lowerc { int expected = Random.Shared.Next(0, 10); string opName = lowercase ? "get" : "Get"; - Operation operation = new($"{opName}{method}", Mock.Of(), default); - TestEntity entity = new() { Value = expected }; + TestEntityContext context = new(expected); + TestEntityOperation operation = new($"{opName}{method}", context, default); + TestEntity entity = new(); object? result = await entity.RunAsync(operation); - entity.Value.Should().Be(expected); + context.GetState(typeof(int)).Should().BeOfType().Which.Should().Be(expected); result.Should().BeOfType().Which.Should().Be(expected); } [Fact] public async Task Add_NoInput_Fails() { - Operation operation = new("add0", Mock.Of(), default); + TestEntityOperation operation = new("add0", new TestEntityContext(null), default); TestEntity entity = new(); Func> action = () => entity.RunAsync(operation).AsTask(); @@ -83,7 +85,7 @@ public async Task Add_NoInput_Fails() [CombinatorialData] public async Task Dispatch_AmbiguousArgs_Fails([CombinatorialRange(0, 3)] int method) { - Operation operation = new($"ambiguousArgs{method}", Mock.Of(), 10); + TestEntityOperation operation = new($"ambiguousArgs{method}", new TestEntityContext(null), 10); TestEntity entity = new(); Func> action = () => entity.RunAsync(operation).AsTask(); @@ -94,7 +96,7 @@ public async Task Dispatch_AmbiguousArgs_Fails([CombinatorialRange(0, 3)] int me [Fact] public async Task Dispatch_AmbiguousMatch_Fails() { - Operation operation = new("ambiguousMatch", Mock.Of(), 10); + TestEntityOperation operation = new("ambiguousMatch", new TestEntityContext(null), 10); TestEntity entity = new(); Func> action = () => entity.RunAsync(operation).AsTask(); @@ -104,7 +106,7 @@ public async Task Dispatch_AmbiguousMatch_Fails() [Fact] public async Task DefaultValue_NoInput_Succeeds() { - Operation operation = new("defaultValue", Mock.Of(), default); + TestEntityOperation operation = new("defaultValue", new TestEntityContext(null), default); TestEntity entity = new(); object? result = await entity.RunAsync(operation); @@ -115,7 +117,7 @@ public async Task DefaultValue_NoInput_Succeeds() [Fact] public async Task DefaultValue_Input_Succeeds() { - Operation operation = new("defaultValue", Mock.Of(), "not-default"); + TestEntityOperation operation = new("defaultValue", new TestEntityContext(null), "not-default"); TestEntity entity = new(); object? result = await entity.RunAsync(operation); @@ -123,48 +125,10 @@ public async Task DefaultValue_Input_Succeeds() result.Should().BeOfType().Which.Should().Be("not-default"); } - class Operation : TaskEntityOperation - { - readonly Optional input; - - public Operation(string name, TaskEntityContext context, Optional input) - { - this.Name = name; - this.Context = context; - this.input = input; - } - - public override string Name { get; } - - public override TaskEntityContext Context { get; } - - public override bool HasInput => this.input.IsPresent; - - public override object? GetInput(Type inputType) - { - if (!this.input.IsPresent) - { - throw new InvalidOperationException("No input available."); - } - - if (this.input.Value is null) - { - return null; - } - - if (!inputType.IsAssignableFrom(this.input.Value.GetType())) - { - throw new InvalidCastException("Cannot convert input type."); - } - - return this.input.Value; - } - } - - class TestEntity : TaskEntity +#pragma warning disable CA1822 // Mark members as static +#pragma warning disable IDE0060 // Remove unused parameter + class TestEntity : TaskEntity { - public int Value { get; set; } - public static string StaticMethod() => throw new NotImplementedException(); // All possible permutations of the 3 inputs we support: object, context, operation @@ -210,9 +174,9 @@ public int Add13(TaskEntityContext context, TaskEntityOperation operation, int v public int Get1(TaskEntityContext context) => this.Get(context); - public int AmbiguousMatch(TaskEntityContext context) => this.Value; + public int AmbiguousMatch(TaskEntityContext context) => this.State; - public int AmbiguousMatch(TaskEntityOperation operation) => this.Value; + public int AmbiguousMatch(TaskEntityOperation operation) => this.State; public int AmbiguousArgs0(int value, object other) => this.Add0(value); @@ -267,33 +231,35 @@ static async Task Slow() int Add(int? value, Optional context, Optional operation) { - if (context.IsPresent) + if (context.HasValue) { context.Value.Should().NotBeNull(); } - if (operation.IsPresent) + if (operation.HasValue) { operation.Value.Should().NotBeNull(); } - if (!value.HasValue && operation.TryGet(out TaskEntityOperation op)) + if (!value.HasValue && operation.TryGet(out TaskEntityOperation? op)) { value = (int)op.GetInput(typeof(int))!; } value.HasValue.Should().BeTrue(); - return this.Value += value!.Value; + return this.State += value!.Value; } int Get(Optional context) { - if (context.IsPresent) + if (context.HasValue) { context.Value.Should().NotBeNull(); } - return this.Value; + return this.State; } } +#pragma warning restore IDE0060 // Remove unused parameter +#pragma warning restore CA1822 // Mark members as static } diff --git a/test/Abstractions.Tests/Entities/Mocks/TestEntityContext.cs b/test/Abstractions.Tests/Entities/Mocks/TestEntityContext.cs new file mode 100644 index 000000000..010cc70ec --- /dev/null +++ b/test/Abstractions.Tests/Entities/Mocks/TestEntityContext.cs @@ -0,0 +1,43 @@ +// Copyright (c) Microsoft Corporation. +// Licensed under the MIT License. + +namespace Microsoft.DurableTask.Entities.Tests; + +public class TestEntityContext : TaskEntityContext +{ + public TestEntityContext(object? state) + { + this.State = state; + } + + public object? State { get; private set; } + + public override EntityInstanceId Id { get; } + + public override object? GetState(Type type) + { + return this.State switch + { + null => null, + _ when type.IsAssignableFrom(this.State.GetType()) => this.State, + _ => throw new InvalidCastException() + }; + } + + public override void SetState(object? state) + { + this.State = state; + } + + public override void SignalEntity( + EntityInstanceId id, string operationName, object? input = null, SignalEntityOptions? options = null) + { + throw new NotImplementedException(); + } + + public override void StartOrchestration( + TaskName name, object? input = null, StartOrchestrationOptions? options = null) + { + throw new NotImplementedException(); + } +} diff --git a/test/Abstractions.Tests/Entities/Mocks/TestEntityOperation.cs b/test/Abstractions.Tests/Entities/Mocks/TestEntityOperation.cs new file mode 100644 index 000000000..63c277a4f --- /dev/null +++ b/test/Abstractions.Tests/Entities/Mocks/TestEntityOperation.cs @@ -0,0 +1,54 @@ +// Copyright (c) Microsoft Corporation. +// Licensed under the MIT License. + +using DotNext; + +namespace Microsoft.DurableTask.Entities.Tests; + +public class TestEntityOperation : TaskEntityOperation +{ + readonly Optional input; + + public TestEntityOperation(string name, Optional input) + : this(name, new TestEntityContext(null), input) + { + } + + public TestEntityOperation(string name, object? state, Optional input) + : this(name, new TestEntityContext(state), input) + { + } + + public TestEntityOperation(string name, TaskEntityContext context, Optional input) + { + this.Name = name; + this.Context = context; + this.input = input; + } + + public override string Name { get; } + + public override TaskEntityContext Context { get; } + + public override bool HasInput => this.input.HasValue; + + public override object? GetInput(Type inputType) + { + if (this.input.IsUndefined) + { + throw new InvalidOperationException("No input available."); + } + + if (this.input.IsNull) + { + return null; + } + + if (!inputType.IsAssignableFrom(this.input.Value!.GetType())) + { + throw new InvalidCastException("Cannot convert input type."); + } + + return this.input.Value; + } +} diff --git a/test/Abstractions.Tests/Entities/StateTaskEntityTests.cs b/test/Abstractions.Tests/Entities/StateTaskEntityTests.cs new file mode 100644 index 000000000..3227a5ae4 --- /dev/null +++ b/test/Abstractions.Tests/Entities/StateTaskEntityTests.cs @@ -0,0 +1,323 @@ +// Copyright (c) Microsoft Corporation. +// Licensed under the MIT License. + +using System.Reflection; +using DotNext; + +namespace Microsoft.DurableTask.Entities.Tests; + +public class StateTaskEntityTests +{ + [Fact] + public async Task Precedence_ChoosesEntity() + { + TestEntityOperation operation = new("Precedence", default); + TestEntity entity = new(); + + object? result = await entity.RunAsync(operation); + + result.Should().Be(20); + } + + [Fact] + public async Task StateDispatchDisallowed_Throws() + { + TestEntityOperation operation = new("add0", 10); + TestEntity entity = new(false); + + Func> action = () => entity.RunAsync(operation).AsTask(); + + await action.Should().ThrowAsync(); + } + + [Fact] + public async Task StateDispatch_NullState_Throws() + { + TestEntityOperation operation = new("add0", 10); + NullStateEntity entity = new(); + + Func> action = () => entity.RunAsync(operation).AsTask(); + + await action.Should().ThrowAsync(); + } + + [Theory] + [InlineData("doesNotExist")] // method does not exist. + [InlineData("add")] // private method, should not work. + [InlineData("staticMethod")] // public static methods are not supported. + public async Task OperationNotSupported_Fails(string name) + { + TestEntityOperation operation = new(name, 10); + TestEntity entity = new(); + + Func> action = () => entity.RunAsync(operation).AsTask(); + + await action.Should().ThrowAsync(); + } + + [Theory] + [CombinatorialData] + public async Task TaskOperation_Success( + [CombinatorialValues("TaskOp", "TaskOfStringOp", "ValueTaskOp", "ValueTaskOfStringOp")] string name, bool sync) + { + object? expected = name.Contains("OfString") ? "success" : null; + TestEntityOperation operation = new(name, sync); + TestEntity entity = new(); + + object? result = await entity.RunAsync(operation); + + result.Should().Be(expected); + } + + [Theory] + [CombinatorialData] + public async Task Add_Success([CombinatorialRange(0, 14)] int method, bool lowercase) + { + int start = Random.Shared.Next(0, 10); + int toAdd = Random.Shared.Next(1, 10); + string opName = lowercase ? "add" : "Add"; + TestEntityContext context = new(State(start)); + TestEntityOperation operation = new($"{opName}{method}", context, toAdd); + TestEntity entity = new(); + + object? result = await entity.RunAsync(operation); + + int expected = start + toAdd; + context.GetState(typeof(TestState)).Should().BeOfType().Which.Value.Should().Be(expected); + result.Should().BeOfType().Which.Should().Be(expected); + } + + [Theory] + [CombinatorialData] + public async Task Get_Success([CombinatorialRange(0, 2)] int method, bool lowercase) + { + int expected = Random.Shared.Next(0, 10); + string opName = lowercase ? "get" : "Get"; + TestEntityContext context = new(State(expected)); + TestEntityOperation operation = new($"{opName}{method}", context, default); + TestEntity entity = new(); + + object? result = await entity.RunAsync(operation); + + context.GetState(typeof(TestState)).Should().BeOfType().Which.Value.Should().Be(expected); + result.Should().BeOfType().Which.Should().Be(expected); + } + + [Fact] + public async Task Add_NoInput_Fails() + { + TestEntityOperation operation = new("add0", new TestEntityContext(null), default); + TestEntity entity = new(); + + Func> action = () => entity.RunAsync(operation).AsTask(); + + await action.Should().ThrowAsync(); + } + + [Theory] + [CombinatorialData] + public async Task Dispatch_AmbiguousArgs_Fails([CombinatorialRange(0, 3)] int method) + { + TestEntityOperation operation = new($"ambiguousArgs{method}", new TestEntityContext(null), 10); + TestEntity entity = new(); + + Func> action = () => entity.RunAsync(operation).AsTask(); + + await action.Should().ThrowAsync(); + } + + [Fact] + public async Task Dispatch_AmbiguousMatch_Fails() + { + TestEntityOperation operation = new("ambiguousMatch", new TestEntityContext(null), 10); + TestEntity entity = new(); + + Func> action = () => entity.RunAsync(operation).AsTask(); + await action.Should().ThrowAsync(); + } + + [Fact] + public async Task DefaultValue_NoInput_Succeeds() + { + TestEntityOperation operation = new("defaultValue", new TestEntityContext(null), default); + TestEntity entity = new(); + + object? result = await entity.RunAsync(operation); + + result.Should().BeOfType().Which.Should().Be("default"); + } + + [Fact] + public async Task DefaultValue_Input_Succeeds() + { + TestEntityOperation operation = new("defaultValue", new TestEntityContext(null), "not-default"); + TestEntity entity = new(); + + object? result = await entity.RunAsync(operation); + + result.Should().BeOfType().Which.Should().Be("not-default"); + } + + static TestState State(int value) => new() { Value = value }; + + class NullStateEntity : TestEntity + { + protected override TestState InitializeState() => null!; + } + + class TestEntity : TaskEntity + { + readonly bool allowStateDispatch; + + public TestEntity(bool allowStateDispatch = true) + { + this.allowStateDispatch = allowStateDispatch; + } + + protected override bool AllowStateDispatch => this.allowStateDispatch; + + public int Precedence() => this.State!.Precedence() * 2; + } + +#pragma warning disable CA1822 // Mark members as static +#pragma warning disable IDE0060 // Remove unused parameter + class TestState + { + public int Value { get; set; } + + public static string StaticMethod() => throw new NotImplementedException(); + + public int Precedence() => 10; + + // All possible permutations of the 3 inputs we support: object, context, operation + // 14 via Add, 2 via Get: 16 total. + public int Add0(int value) => this.Add(value, default, default); + + public int Add1(int value, TaskEntityContext context) => this.Add(value, context, default); + + public int Add2(int value, TaskEntityOperation operation) => this.Add(value, default, operation); + + public int Add3(int value, TaskEntityContext context, TaskEntityOperation operation) + => this.Add(value, context, operation); + + public int Add4(int value, TaskEntityOperation operation, TaskEntityContext context) + => this.Add(value, context, operation); + + public int Add5(TaskEntityOperation operation) => this.Add(default, default, operation); + + public int Add6(TaskEntityOperation operation, int value) => this.Add(value, default, operation); + + public int Add7(TaskEntityOperation operation, TaskEntityContext context) + => this.Add(default, context, operation); + + public int Add8(TaskEntityOperation operation, int value, TaskEntityContext context) + => this.Add(value, context, operation); + + public int Add9(TaskEntityOperation operation, TaskEntityContext context, int value) + => this.Add(value, context, operation); + + public int Add10(TaskEntityContext context, int value) + => this.Add(value, context, default); + + public int Add11(TaskEntityContext context, TaskEntityOperation operation) + => this.Add(default, context, operation); + + public int Add12(TaskEntityContext context, int value, TaskEntityOperation operation) + => this.Add(value, context, operation); + + public int Add13(TaskEntityContext context, TaskEntityOperation operation, int value) + => this.Add(value, context, operation); + + public int Get0() => this.Get(default); + + public int Get1(TaskEntityContext context) => this.Get(context); + + public int AmbiguousMatch(TaskEntityContext context) => this.Value; + + public int AmbiguousMatch(TaskEntityOperation operation) => this.Value; + + public int AmbiguousArgs0(int value, object other) => this.Add0(value); + + public int AmbiguousArgs1(int value, TaskEntityContext context, TaskEntityContext context2) => this.Add0(value); + + public int AmbiguousArgs2(int value, TaskEntityOperation operation, TaskEntityOperation operation2) + => this.Add0(value); + + public string DefaultValue(string toReturn = "default") => toReturn; + + public Task TaskOp(bool sync) + { + static async Task Slow() + { + await Task.Yield(); + } + + return sync ? Task.CompletedTask : Slow(); + } + + public Task TaskOfStringOp(bool sync) + { + static async Task Slow() + { + await Task.Yield(); + return "success"; + } + + return sync ? Task.FromResult("success") : Slow(); + } + + public ValueTask ValueTaskOp(bool sync) + { + static async Task Slow() + { + await Task.Yield(); + } + + return sync ? default : new(Slow()); + } + + public ValueTask ValueTaskOfStringOp(bool sync) + { + static async Task Slow() + { + await Task.Yield(); + return "success"; + } + + return sync ? new("success") : new(Slow()); + } + + int Add(int? value, Optional context, Optional operation) + { + if (context.HasValue) + { + context.Value.Should().NotBeNull(); + } + + if (operation.HasValue) + { + operation.Value.Should().NotBeNull(); + } + + if (!value.HasValue && operation.TryGet(out TaskEntityOperation? op)) + { + value = (int)op.GetInput(typeof(int))!; + } + + value.HasValue.Should().BeTrue(); + return this.Value += value!.Value; + } + + int Get(Optional context) + { + if (context.HasValue) + { + context.Value.Should().NotBeNull(); + } + + return this.Value; + } + } +#pragma warning restore IDE0060 // Remove unused parameter +#pragma warning restore CA1822 // Mark members as static +} diff --git a/test/Abstractions.Tests/Entities/TaskEntityHelpersTests.cs b/test/Abstractions.Tests/Entities/TaskEntityHelpersTests.cs new file mode 100644 index 000000000..2521ddb39 --- /dev/null +++ b/test/Abstractions.Tests/Entities/TaskEntityHelpersTests.cs @@ -0,0 +1,238 @@ +// Copyright (c) Microsoft Corporation. +// Licensed under the MIT License. + +namespace Microsoft.DurableTask.Entities.Tests; + +public class TaskEntityHelpersTests +{ + [Fact] + public async Task UnwrapAsync_Void() + { + int state = Random.Shared.Next(1, 10); + TestEntityContext context = new(null); + + object? result = await TaskEntityHelpers.UnwrapAsync(context, () => state, null, typeof(void)); + + result.Should().BeNull(); + context.State.Should().BeOfType().Which.Should().Be(state); + } + + [Fact] + public async Task UnwrapAsync_Object() + { + int state = Random.Shared.Next(1, 10); + int value = Random.Shared.Next(1, 10); + TestEntityContext context = new(null); + + object? result = await TaskEntityHelpers.UnwrapAsync(context, () => state, value, typeof(int)); + + result.Should().BeOfType().Which.Should().Be(value); + context.State.Should().BeOfType().Which.Should().Be(state); + } + + [Theory] + [CombinatorialData] + public async Task UnwrapAsync_Task(bool async) + { + TaskCompletionSource tcs = new(); + int state = Random.Shared.Next(1, 10); + TestEntityContext context = new(null); + + if (!async) + { + tcs.TrySetResult(); + } + + ValueTask task = TaskEntityHelpers.UnwrapAsync(context, () => state, tcs.Task, typeof(Task)); + + if (async) + { + state++; // Make sure state changes are captured + tcs.TrySetResult(); + } + + object? result = await task; + + result.Should().BeNull(); + context.State.Should().BeOfType().Which.Should().Be(state); + } + + [Theory] + [CombinatorialData] + public async Task UnwrapAsync_Task_Throws(bool async) + { + TaskCompletionSource tcs = new(); + TestEntityContext context = new(null); + + if (!async) + { + tcs.SetException(new OperationCanceledException()); + } + + Func throws = async () => await TaskEntityHelpers.UnwrapAsync(context, () => 0, tcs.Task, typeof(Task)); + + if (async) + { + tcs.SetException(new OperationCanceledException()); + } + + await throws.Should().ThrowExactlyAsync(); + } + + [Theory] + [CombinatorialData] + public async Task UnwrapAsync_TaskOfInt(bool async) + { + TaskCompletionSource tcs = new(); + + int state = Random.Shared.Next(1, 10); + int value = Random.Shared.Next(1, 10); + TestEntityContext context = new(null); + + if (!async) + { + tcs.TrySetResult(value); + } + + ValueTask task = TaskEntityHelpers.UnwrapAsync(context, () => state, tcs.Task, typeof(Task)); + + if (async) + { + state++; // Make sure state changes are captured + tcs.TrySetResult(value); + } + + object? result = await task; + + result.Should().BeOfType().Which.Should().Be(value); + context.State.Should().BeOfType().Which.Should().Be(state); + } + + [Theory] + [CombinatorialData] + public async Task UnwrapAsync_TaskOfInt_Throws(bool async) + { + TaskCompletionSource tcs = new(); + TestEntityContext context = new(null); + + if (!async) + { + tcs.SetException(new OperationCanceledException()); + } + + Func throws = async () => await TaskEntityHelpers.UnwrapAsync( + context, () => 0, tcs.Task, typeof(Task)); + + if (async) + { + tcs.SetException(new OperationCanceledException()); + } + + await throws.Should().ThrowExactlyAsync(); + } + + + [Theory] + [CombinatorialData] + public async Task UnwrapAsync_ValueTask(bool async) + { + TaskCompletionSource tcs = new(); + + int state = Random.Shared.Next(1, 10); + TestEntityContext context = new(null); + + if (!async) + { + tcs.TrySetResult(); + } + + ValueTask task = TaskEntityHelpers.UnwrapAsync( + context, () => state, new ValueTask(tcs.Task), typeof(ValueTask)); + + if (async) + { + state++; // Make sure state changes are captured + tcs.TrySetResult(); + } + + object? result = await task; + result.Should().BeNull(); + context.State.Should().BeOfType().Which.Should().Be(state); + } + + [Theory] + [CombinatorialData] + public async Task UnwrapAsync_ValueTask_Throws(bool async) + { + TaskCompletionSource tcs = new(); + TestEntityContext context = new(null); + + if (!async) + { + tcs.SetException(new OperationCanceledException()); + } + + Func throws = async () => await TaskEntityHelpers.UnwrapAsync( + context, () => 0, new ValueTask(tcs.Task), typeof(ValueTask)); + + if (async) + { + tcs.SetException(new OperationCanceledException()); + } + + await throws.Should().ThrowExactlyAsync(); + } + + + [Theory] + [CombinatorialData] + public async Task UnwrapAsync_ValueTaskOfInt(bool async) + { + TaskCompletionSource tcs = new(); + int state = Random.Shared.Next(1, 10); + int value = Random.Shared.Next(1, 10); + TestEntityContext context = new(null); + + if (!async) + { + tcs.TrySetResult(value); + } + + ValueTask task = TaskEntityHelpers.UnwrapAsync( + context, () => state, new ValueTask(tcs.Task), typeof(ValueTask)); + + if (async) + { + state++; // Make sure state changes are captured + tcs.TrySetResult(value); + } + + object? result = await task; + + result.Should().BeOfType().Which.Should().Be(value); + context.State.Should().BeOfType().Which.Should().Be(state); + } + + [Theory] + [CombinatorialData] + public async Task UnwrapAsync_ValueTaskOfInt_Throws(bool async) + { + TaskCompletionSource tcs = new(); + TestEntityContext context = new(null); + + if (!async) + { + tcs.SetException(new OperationCanceledException()); + } + + Func throws = async () => await TaskEntityHelpers.UnwrapAsync( + context, () => 0, new ValueTask(tcs.Task), typeof(ValueTask)); + + if (async) + { + tcs.SetException(new OperationCanceledException()); + } + + await throws.Should().ThrowExactlyAsync(); + } +} diff --git a/test/Directory.Build.targets b/test/Directory.Build.targets index 294f47da1..c0558551f 100644 --- a/test/Directory.Build.targets +++ b/test/Directory.Build.targets @@ -4,7 +4,7 @@ Condition=" '$([MSBuild]::GetDirectoryNameOfFileAbove($(MSBuildThisFileDirectory)../, $(_DirectoryBuildTargetsFile)))' != '' " /> - + runtime; build; native; contentfiles; analyzers; buildtransitive