diff --git a/src/Abstractions/Entities/TaskEntity.cs b/src/Abstractions/Entities/TaskEntity.cs
index ec40154c9..837df7ab3 100644
--- a/src/Abstractions/Entities/TaskEntity.cs
+++ b/src/Abstractions/Entities/TaskEntity.cs
@@ -2,7 +2,6 @@
// Licensed under the MIT License.
using System.Reflection;
-using System.Threading.Tasks;
namespace Microsoft.DurableTask.Entities;
@@ -12,9 +11,7 @@ namespace Microsoft.DurableTask.Entities;
///
/// Entity State
///
-/// All entity implementations are required to be serializable by the configured . An entity
-/// will have its state deserialized before executing an operation, and then the new state will be the serialized value
-/// of the implementation instance post-operation.
+/// The state of an entity can be retrieved and updated via .
///
///
public interface ITaskEntity
@@ -30,6 +27,7 @@ public interface ITaskEntity
///
/// An which dispatches its operations to public instance methods or properties.
///
+/// The state type held by this entity.
///
/// Method Binding
///
@@ -71,176 +69,99 @@ public interface ITaskEntity
///
/// Entity State
///
-/// Unchanged from . Entity state is the serialized value of the entity after an operation
-/// completes.
+/// Entity state will be hydrated into the property. See for
+/// more information.
///
///
-public abstract class TaskEntity : ITaskEntity
+public abstract class TaskEntity : ITaskEntity
{
- /**
- * TODO:
- * 1. Consider caching a compiled delegate for a given operation name.
- */
- static readonly BindingFlags InstanceBindingFlags
- = BindingFlags.Public | BindingFlags.Instance | BindingFlags.IgnoreCase;
-
- ///
- public ValueTask RunAsync(TaskEntityOperation operation)
- {
- Check.NotNull(operation);
- if (!this.TryDispatchMethod(operation, out object? result, out Type returnType))
- {
- throw new NotSupportedException($"No suitable method found for entity operation '{operation}'.");
- }
-
- if (typeof(Task).IsAssignableFrom(returnType))
- {
- // Task or Task
- return new(AsGeneric((Task)result!, returnType)); // we assume a declared Task return type is never null.
- }
-
- if (returnType == typeof(ValueTask))
- {
- // ValueTask
- return AsGeneric((ValueTask)result!); // we assume a declared ValueTask return type is never null.
- }
-
- if (returnType.IsGenericType && returnType.GetGenericTypeDefinition() == typeof(ValueTask<>))
- {
- // ValueTask
- return AsGeneric(result!, returnType); // No inheritance, have to do purely via reflection.
- }
-
- return new(result);
- }
+ ///
+ /// Gets a value indicating whether dispatching operations to is allowed. State dispatch
+ /// will only be attempted if entity-level dispatch does not succeed. Default is false . Dispatching to state
+ /// follows the same rules as dispatching to this entity.
+ ///
+ protected virtual bool AllowStateDispatch => false;
- static bool TryGetInput(ParameterInfo parameter, TaskEntityOperation operation, out object? input)
- {
- if (!operation.HasInput)
- {
- if (parameter.HasDefaultValue)
- {
- input = parameter.DefaultValue;
- return true;
- }
+ ///
+ /// Gets or sets the state for this entity.
+ ///
+ ///
+ /// Initialization
+ ///
+ /// This will be hydrated as part of . will
+ /// be called when state is null at the start of an operation only .
+ ///
+ /// Persistence
+ ///
+ /// The contents of this property will be persisted to at the end
+ /// of the operation.
+ ///
+ /// Deletion
+ ///
+ /// Deleting entity state is possible by setting this to null . Setting to default of a value-type will
+ /// not perform a delete. This means deleting entity state is only possible for reference types or using ?
+ /// on a value-type (ie: TaskEntity<int?> ).
+ ///
+ ///
+ protected TState State { get; set; } = default!;
- input = null;
- return false;
- }
+ ///
+ /// Gets the entity operation.
+ ///
+ protected TaskEntityOperation Operation { get; private set; } = null!;
- input = operation.GetInput(parameter.ParameterType);
- return true;
- }
+ ///
+ /// Gets the entity context.
+ ///
+ protected TaskEntityContext Context => this.Operation.Context;
- static async Task AsGeneric(Task task, Type declared)
+ ///
+ public ValueTask RunAsync(TaskEntityOperation operation)
{
- await task;
- if (declared.IsGenericType && declared.GetGenericTypeDefinition() == typeof(Task<>))
+ this.Operation = Check.NotNull(operation);
+ object? state = operation.Context.GetState(typeof(TState));
+ this.State = state is null ? this.InitializeState() : (TState)state;
+ if (!operation.TryDispatch(this, out object? result, out Type returnType)
+ && !this.TryDispatchState(out result, out returnType))
{
- return declared.GetProperty("Result", BindingFlags.Public | BindingFlags.Instance).GetValue(task);
+ throw new NotSupportedException($"No suitable method found for entity operation '{operation}'.");
}
- return null;
+ return TaskEntityHelpers.UnwrapAsync(this.Context, () => this.State, result, returnType);
}
- static ValueTask AsGeneric(ValueTask t)
+ ///
+ /// Initializes the entity state. This is only called when there is no current state for this entity.
+ ///
+ /// The entity state.
+ /// The default implementation uses .
+ protected virtual TState InitializeState()
{
- static async Task Await(ValueTask t)
- {
- await t;
- return null;
- }
-
- if (t.IsCompletedSuccessfully)
+ if (Nullable.GetUnderlyingType(typeof(TState)) is Type t)
{
- return default;
+ // Activator.CreateInstance>() returns null. To avoid this, we will instantiate via underlying
+ // type if it is Nullable. This keeps the experience consistent between value and reference type. If an
+ // implementation wants null, they must override this method and explicitly provide null.
+ return (TState)Activator.CreateInstance(t);
}
- return new(Await(t));
+ return Activator.CreateInstance();
}
- static ValueTask AsGeneric(object result, Type type)
+ bool TryDispatchState(out object? result, out Type returnType)
{
- // result and type here must be some form of ValueTask.
- if ((bool)type.GetProperty("IsCompletedSuccessfully").GetValue(result))
- {
- return new(type.GetProperty("Result").GetValue(result));
- }
- else
- {
- Task t = (Task)type.GetMethod("AsTask", BindingFlags.Instance | BindingFlags.Public)
- .Invoke(result, null);
- return new(t.ToGeneric());
- }
- }
-
- bool TryDispatchMethod(TaskEntityOperation operation, out object? result, out Type returnType)
- {
- Type t = this.GetType();
-
- // Will throw AmbiguousMatchException if more than 1 overload for the method name exists.
- MethodInfo? method = t.GetMethod(operation.Name, InstanceBindingFlags);
- if (method is null)
+ if (!this.AllowStateDispatch)
{
result = null;
returnType = typeof(void);
return false;
}
- ParameterInfo[] parameters = method.GetParameters();
- object?[] inputs = new object[parameters.Length];
-
- int i = 0;
- ParameterInfo? inputResolved = null;
- ParameterInfo? contextResolved = null;
- ParameterInfo? operationResolved = null;
- foreach (ParameterInfo parameter in parameters)
+ if (this.State is null)
{
- if (parameter.ParameterType == typeof(TaskEntityContext))
- {
- ThrowIfDuplicateBinding(contextResolved, parameter, "context", operation);
- inputs[i] = operation.Context;
- contextResolved = parameter;
- }
- else if (parameter.ParameterType == typeof(TaskEntityOperation))
- {
- ThrowIfDuplicateBinding(operationResolved, parameter, "operation", operation);
- inputs[i] = operation;
- operationResolved = parameter;
- }
- else
- {
- ThrowIfDuplicateBinding(inputResolved, parameter, "input", operation);
- if (TryGetInput(parameter, operation, out object? input))
- {
- inputs[i] = input;
- inputResolved = parameter;
- }
- else
- {
- throw new InvalidOperationException($"Error dispatching {operation} to '{method}'.\n" +
- $"There was an error binding parameter '{parameter}'. The operation expected an input value, " +
- "but no input was provided by the caller.");
- }
- }
-
- i++;
+ throw new InvalidOperationException("Attempting to dispatch to state, but entity state is null.");
}
- result = method.Invoke(this, inputs);
- returnType = method.ReturnType;
- return true;
-
- static void ThrowIfDuplicateBinding(
- ParameterInfo? existing, ParameterInfo parameter, string bindingConcept, TaskEntityOperation operation)
- {
- if (existing is not null)
- {
- throw new InvalidOperationException($"Error dispatching {operation} to '{parameter.Member}'.\n" +
- $"Unable to bind {bindingConcept} to '{parameter}' because it has " +
- $"already been bound to parameter '{existing}'. Please remove the duplicate parameter in method " +
- $"'{parameter.Member}'.\nEntity operation: {operation}.");
- }
- }
+ return this.Operation.TryDispatch(this.State, out result, out returnType);
}
}
diff --git a/src/Abstractions/Entities/TaskEntityContext.cs b/src/Abstractions/Entities/TaskEntityContext.cs
index 1b5512567..8ce86fb94 100644
--- a/src/Abstractions/Entities/TaskEntityContext.cs
+++ b/src/Abstractions/Entities/TaskEntityContext.cs
@@ -53,11 +53,16 @@ public abstract void StartOrchestration(
TaskName name, object? input = null, StartOrchestrationOptions? options = null);
///
- /// Deletes the state of this entity after the current operation completes.
+ /// Gets the current state for the entity this context is for. This will return null if no state is present,
+ /// regardless if is a value-type or not.
///
- ///
- /// The state deletion only takes effect after the current operation completes. Any state changes made during the
- /// current operation will be ignored in favor of the deletion.
- ///
- public abstract void DeleteState();
+ /// The type to retrieve the state as.
+ /// The entity state.
+ public abstract object? GetState(Type type);
+
+ ///
+ /// Sets the entity state. Setting of null will delete entity state.
+ ///
+ /// The state to set.
+ public abstract void SetState(object? state);
}
diff --git a/src/Abstractions/Entities/TaskEntityHelpers.cs b/src/Abstractions/Entities/TaskEntityHelpers.cs
new file mode 100644
index 000000000..4437d981c
--- /dev/null
+++ b/src/Abstractions/Entities/TaskEntityHelpers.cs
@@ -0,0 +1,100 @@
+// Copyright (c) Microsoft Corporation.
+// Licensed under the MIT License.
+
+using System.Reflection;
+
+namespace Microsoft.DurableTask.Entities;
+
+///
+/// Helpers for task entities.
+///
+static class TaskEntityHelpers
+{
+ ///
+ /// Unwraps a dispatched result for a into a .
+ ///
+ /// The entity context.
+ /// Delegate to resolve new state for the entity.
+ /// The result of the operation.
+ /// The declared type of the result (may be different that actual type).
+ /// A value task which holds the result of the operation and sets state before it completes.
+ public static ValueTask UnwrapAsync(
+ TaskEntityContext context, Func state, object? result, Type resultType)
+ {
+ // NOTE: Func is used for state so that we can lazily resolve it AFTER the operation has ran.
+ Check.NotNull(context);
+ Check.NotNull(resultType);
+
+ if (typeof(Task).IsAssignableFrom(resultType))
+ {
+ // Task or Task
+ // We assume a declared Task return type is never null.
+ return new(UnwrapTask(context, state, (Task)result!, resultType));
+ }
+
+ if (resultType == typeof(ValueTask))
+ {
+ // ValueTask
+ // We assume a declared ValueTask return type is never null.
+ return UnwrapValueTask(context, state, (ValueTask)result!);
+ }
+
+ if (resultType.IsGenericType && resultType.GetGenericTypeDefinition() == typeof(ValueTask<>))
+ {
+ // ValueTask
+ // No inheritance, have to do purely via reflection.
+ return UnwrapValueTaskOfT(context, state, result!, resultType);
+ }
+
+ context.SetState(state());
+ return new(result);
+ }
+
+ static async Task UnwrapTask(TaskEntityContext context, Func state, Task task, Type declared)
+ {
+ await task;
+ context.SetState(state());
+ if (declared.IsGenericType && declared.GetGenericTypeDefinition() == typeof(Task<>))
+ {
+ return declared.GetProperty("Result", BindingFlags.Public | BindingFlags.Instance).GetValue(task);
+ }
+
+ return null;
+ }
+
+ static ValueTask UnwrapValueTask(TaskEntityContext context, Func state, ValueTask t)
+ {
+ async Task Await(ValueTask t)
+ {
+ await t;
+ context.SetState(state());
+ return null;
+ }
+
+ if (t.IsCompletedSuccessfully)
+ {
+ context.SetState(state());
+ return default;
+ }
+
+ return new(Await(t));
+ }
+
+ static ValueTask UnwrapValueTaskOfT(
+ TaskEntityContext context, Func state, object result, Type type)
+ {
+ // Result and type here must be some form of ValueTask.
+ // TODO: can this amount of reflection be avoided?
+ if ((bool)type.GetProperty("IsCompletedSuccessfully").GetValue(result))
+ {
+ context.SetState(state());
+ return new(type.GetProperty("Result").GetValue(result));
+ }
+ else
+ {
+ Task t = (Task)type.GetMethod("AsTask", BindingFlags.Instance | BindingFlags.Public).Invoke(result, null);
+ Type taskType = typeof(Task<>).MakeGenericType(type.GetGenericArguments()[0]);
+ return new(UnwrapTask(context, state, t, taskType));
+ }
+ }
+}
diff --git a/src/Abstractions/Entities/TaskEntityOperationExtensions.cs b/src/Abstractions/Entities/TaskEntityOperationExtensions.cs
new file mode 100644
index 000000000..303032da8
--- /dev/null
+++ b/src/Abstractions/Entities/TaskEntityOperationExtensions.cs
@@ -0,0 +1,118 @@
+// Copyright (c) Microsoft Corporation.
+// Licensed under the MIT License.
+
+using System.Reflection;
+
+namespace Microsoft.DurableTask.Entities;
+
+///
+/// Extensions for .
+///
+public static class TaskEntityOperationExtensions
+{
+ /**
+ * TODO:
+ * 1. Consider caching a compiled delegate for a given operation name.
+ */
+ static readonly BindingFlags InstanceBindingFlags
+ = BindingFlags.Public | BindingFlags.Instance | BindingFlags.IgnoreCase;
+
+ ///
+ /// Try to dispatch this operation via reflection to a method on .
+ ///
+ /// The operation that is being dispatched.
+ /// The target to dispatch to.
+ /// The result of the dispatch.
+ /// The declared return type of the dispatched method.
+ /// True if dispatch successful, false otherwise.
+ internal static bool TryDispatch(
+ this TaskEntityOperation operation, object target, out object? result, out Type returnType)
+ {
+ Check.NotNull(operation);
+ Check.NotNull(target);
+ Type t = target.GetType();
+
+ // Will throw AmbiguousMatchException if more than 1 overload for the method name exists.
+ MethodInfo? method = t.GetMethod(operation.Name, InstanceBindingFlags);
+ if (method is null)
+ {
+ result = null;
+ returnType = typeof(void);
+ return false;
+ }
+
+ ParameterInfo[] parameters = method.GetParameters();
+ object?[] inputs = new object[parameters.Length];
+
+ int i = 0;
+ ParameterInfo? inputResolved = null;
+ ParameterInfo? contextResolved = null;
+ ParameterInfo? operationResolved = null;
+ foreach (ParameterInfo parameter in parameters)
+ {
+ if (parameter.ParameterType == typeof(TaskEntityContext))
+ {
+ ThrowIfDuplicateBinding(contextResolved, parameter, "context", operation);
+ inputs[i] = operation.Context;
+ contextResolved = parameter;
+ }
+ else if (parameter.ParameterType == typeof(TaskEntityOperation))
+ {
+ ThrowIfDuplicateBinding(operationResolved, parameter, "operation", operation);
+ inputs[i] = operation;
+ operationResolved = parameter;
+ }
+ else
+ {
+ ThrowIfDuplicateBinding(inputResolved, parameter, "input", operation);
+ if (operation.TryGetInput(parameter, out object? input))
+ {
+ inputs[i] = input;
+ inputResolved = parameter;
+ }
+ else
+ {
+ throw new InvalidOperationException($"Error dispatching {operation} to '{method}'.\n" +
+ $"There was an error binding parameter '{parameter}'. The operation expected an input value, " +
+ "but no input was provided by the caller.");
+ }
+ }
+
+ i++;
+ }
+
+ result = method.Invoke(target, inputs);
+ returnType = method.ReturnType;
+ return true;
+
+ static void ThrowIfDuplicateBinding(
+ ParameterInfo? existing, ParameterInfo parameter, string bindingConcept, TaskEntityOperation operation)
+ {
+ if (existing is not null)
+ {
+ throw new InvalidOperationException($"Error dispatching {operation} to '{parameter.Member}'.\n" +
+ $"Unable to bind {bindingConcept} to '{parameter}' because it has " +
+ $"already been bound to parameter '{existing}'. Please remove the duplicate parameter in method " +
+ $"'{parameter.Member}'.\nEntity operation: {operation}.");
+ }
+ }
+ }
+
+ static bool TryGetInput(this TaskEntityOperation operation, ParameterInfo parameter, out object? input)
+ {
+ if (!operation.HasInput)
+ {
+ if (parameter.HasDefaultValue)
+ {
+ input = parameter.DefaultValue;
+ return true;
+ }
+
+ input = null;
+ return false;
+ }
+
+ input = operation.GetInput(parameter.ParameterType);
+ return true;
+ }
+}
diff --git a/src/Shared/Core/TaskExtensions.cs b/src/Shared/Core/TaskExtensions.cs
index 116c6623a..69c9356fb 100644
--- a/src/Shared/Core/TaskExtensions.cs
+++ b/src/Shared/Core/TaskExtensions.cs
@@ -24,7 +24,7 @@ static class TaskExtensions
Type t = task.GetType();
if (t.IsGenericType)
{
- return (T)t.GetProperty("Result", BindingFlags.Public | BindingFlags.Instance).GetValue(task);
+ return (T)t.GetProperty("Result", BindingFlags.Public | BindingFlags.Instance)!.GetValue(task)!;
}
return default;
diff --git a/test/Abstractions.Tests/Entities/TaskEntityTests.cs b/test/Abstractions.Tests/Entities/EntityTaskEntityTests.cs
similarity index 75%
rename from test/Abstractions.Tests/Entities/TaskEntityTests.cs
rename to test/Abstractions.Tests/Entities/EntityTaskEntityTests.cs
index ea5129c83..ecc5541f1 100644
--- a/test/Abstractions.Tests/Entities/TaskEntityTests.cs
+++ b/test/Abstractions.Tests/Entities/EntityTaskEntityTests.cs
@@ -6,7 +6,7 @@
namespace Microsoft.DurableTask.Entities.Tests;
-public class TaskEntityTests
+public class EntityTaskEntityTests
{
[Theory]
[InlineData("doesNotExist")] // method does not exist.
@@ -14,7 +14,7 @@ public class TaskEntityTests
[InlineData("staticMethod")] // public static methods are not supported.
public async Task OperationNotSupported_Fails(string name)
{
- Operation operation = new(name, Mock.Of(), 10);
+ TestEntityOperation operation = new(name, 10);
TestEntity entity = new();
Func> action = () => entity.RunAsync(operation).AsTask();
@@ -28,7 +28,7 @@ public async Task TaskOperation_Success(
[CombinatorialValues("TaskOp", "TaskOfStringOp", "ValueTaskOp", "ValueTaskOfStringOp")] string name, bool sync)
{
object? expected = name.Contains("OfString") ? "success" : null;
- Operation operation = new(name, Mock.Of(), sync);
+ TestEntityOperation operation = new(name, sync);
TestEntity entity = new();
object? result = await entity.RunAsync(operation);
@@ -41,15 +41,16 @@ public async Task TaskOperation_Success(
public async Task Add_Success([CombinatorialRange(0, 14)] int method, bool lowercase)
{
int start = Random.Shared.Next(0, 10);
- int toAdd = Random.Shared.Next(0, 10);
+ int toAdd = Random.Shared.Next(1, 10);
string opName = lowercase ? "add" : "Add";
- Operation operation = new($"{opName}{method}", Mock.Of(), toAdd);
- TestEntity entity = new() { Value = start };
+ TestEntityContext context = new(start);
+ TestEntityOperation operation = new($"{opName}{method}", context, toAdd);
+ TestEntity entity = new();
object? result = await entity.RunAsync(operation);
int expected = start + toAdd;
- entity.Value.Should().Be(expected);
+ context.GetState(typeof(int)).Should().BeOfType().Which.Should().Be(expected);
result.Should().BeOfType().Which.Should().Be(expected);
}
@@ -59,19 +60,20 @@ public async Task Get_Success([CombinatorialRange(0, 2)] int method, bool lowerc
{
int expected = Random.Shared.Next(0, 10);
string opName = lowercase ? "get" : "Get";
- Operation operation = new($"{opName}{method}", Mock.Of(), default);
- TestEntity entity = new() { Value = expected };
+ TestEntityContext context = new(expected);
+ TestEntityOperation operation = new($"{opName}{method}", context, default);
+ TestEntity entity = new();
object? result = await entity.RunAsync(operation);
- entity.Value.Should().Be(expected);
+ context.GetState(typeof(int)).Should().BeOfType().Which.Should().Be(expected);
result.Should().BeOfType().Which.Should().Be(expected);
}
[Fact]
public async Task Add_NoInput_Fails()
{
- Operation operation = new("add0", Mock.Of(), default);
+ TestEntityOperation operation = new("add0", new TestEntityContext(null), default);
TestEntity entity = new();
Func> action = () => entity.RunAsync(operation).AsTask();
@@ -83,7 +85,7 @@ public async Task Add_NoInput_Fails()
[CombinatorialData]
public async Task Dispatch_AmbiguousArgs_Fails([CombinatorialRange(0, 3)] int method)
{
- Operation operation = new($"ambiguousArgs{method}", Mock.Of(), 10);
+ TestEntityOperation operation = new($"ambiguousArgs{method}", new TestEntityContext(null), 10);
TestEntity entity = new();
Func> action = () => entity.RunAsync(operation).AsTask();
@@ -94,7 +96,7 @@ public async Task Dispatch_AmbiguousArgs_Fails([CombinatorialRange(0, 3)] int me
[Fact]
public async Task Dispatch_AmbiguousMatch_Fails()
{
- Operation operation = new("ambiguousMatch", Mock.Of(), 10);
+ TestEntityOperation operation = new("ambiguousMatch", new TestEntityContext(null), 10);
TestEntity entity = new();
Func> action = () => entity.RunAsync(operation).AsTask();
@@ -104,7 +106,7 @@ public async Task Dispatch_AmbiguousMatch_Fails()
[Fact]
public async Task DefaultValue_NoInput_Succeeds()
{
- Operation operation = new("defaultValue", Mock.Of(), default);
+ TestEntityOperation operation = new("defaultValue", new TestEntityContext(null), default);
TestEntity entity = new();
object? result = await entity.RunAsync(operation);
@@ -115,7 +117,7 @@ public async Task DefaultValue_NoInput_Succeeds()
[Fact]
public async Task DefaultValue_Input_Succeeds()
{
- Operation operation = new("defaultValue", Mock.Of(), "not-default");
+ TestEntityOperation operation = new("defaultValue", new TestEntityContext(null), "not-default");
TestEntity entity = new();
object? result = await entity.RunAsync(operation);
@@ -123,48 +125,10 @@ public async Task DefaultValue_Input_Succeeds()
result.Should().BeOfType().Which.Should().Be("not-default");
}
- class Operation : TaskEntityOperation
- {
- readonly Optional input;
-
- public Operation(string name, TaskEntityContext context, Optional input)
- {
- this.Name = name;
- this.Context = context;
- this.input = input;
- }
-
- public override string Name { get; }
-
- public override TaskEntityContext Context { get; }
-
- public override bool HasInput => this.input.IsPresent;
-
- public override object? GetInput(Type inputType)
- {
- if (!this.input.IsPresent)
- {
- throw new InvalidOperationException("No input available.");
- }
-
- if (this.input.Value is null)
- {
- return null;
- }
-
- if (!inputType.IsAssignableFrom(this.input.Value.GetType()))
- {
- throw new InvalidCastException("Cannot convert input type.");
- }
-
- return this.input.Value;
- }
- }
-
- class TestEntity : TaskEntity
+#pragma warning disable CA1822 // Mark members as static
+#pragma warning disable IDE0060 // Remove unused parameter
+ class TestEntity : TaskEntity
{
- public int Value { get; set; }
-
public static string StaticMethod() => throw new NotImplementedException();
// All possible permutations of the 3 inputs we support: object, context, operation
@@ -210,9 +174,9 @@ public int Add13(TaskEntityContext context, TaskEntityOperation operation, int v
public int Get1(TaskEntityContext context) => this.Get(context);
- public int AmbiguousMatch(TaskEntityContext context) => this.Value;
+ public int AmbiguousMatch(TaskEntityContext context) => this.State;
- public int AmbiguousMatch(TaskEntityOperation operation) => this.Value;
+ public int AmbiguousMatch(TaskEntityOperation operation) => this.State;
public int AmbiguousArgs0(int value, object other) => this.Add0(value);
@@ -267,33 +231,35 @@ static async Task Slow()
int Add(int? value, Optional context, Optional operation)
{
- if (context.IsPresent)
+ if (context.HasValue)
{
context.Value.Should().NotBeNull();
}
- if (operation.IsPresent)
+ if (operation.HasValue)
{
operation.Value.Should().NotBeNull();
}
- if (!value.HasValue && operation.TryGet(out TaskEntityOperation op))
+ if (!value.HasValue && operation.TryGet(out TaskEntityOperation? op))
{
value = (int)op.GetInput(typeof(int))!;
}
value.HasValue.Should().BeTrue();
- return this.Value += value!.Value;
+ return this.State += value!.Value;
}
int Get(Optional context)
{
- if (context.IsPresent)
+ if (context.HasValue)
{
context.Value.Should().NotBeNull();
}
- return this.Value;
+ return this.State;
}
}
+#pragma warning restore IDE0060 // Remove unused parameter
+#pragma warning restore CA1822 // Mark members as static
}
diff --git a/test/Abstractions.Tests/Entities/Mocks/TestEntityContext.cs b/test/Abstractions.Tests/Entities/Mocks/TestEntityContext.cs
new file mode 100644
index 000000000..010cc70ec
--- /dev/null
+++ b/test/Abstractions.Tests/Entities/Mocks/TestEntityContext.cs
@@ -0,0 +1,43 @@
+// Copyright (c) Microsoft Corporation.
+// Licensed under the MIT License.
+
+namespace Microsoft.DurableTask.Entities.Tests;
+
+public class TestEntityContext : TaskEntityContext
+{
+ public TestEntityContext(object? state)
+ {
+ this.State = state;
+ }
+
+ public object? State { get; private set; }
+
+ public override EntityInstanceId Id { get; }
+
+ public override object? GetState(Type type)
+ {
+ return this.State switch
+ {
+ null => null,
+ _ when type.IsAssignableFrom(this.State.GetType()) => this.State,
+ _ => throw new InvalidCastException()
+ };
+ }
+
+ public override void SetState(object? state)
+ {
+ this.State = state;
+ }
+
+ public override void SignalEntity(
+ EntityInstanceId id, string operationName, object? input = null, SignalEntityOptions? options = null)
+ {
+ throw new NotImplementedException();
+ }
+
+ public override void StartOrchestration(
+ TaskName name, object? input = null, StartOrchestrationOptions? options = null)
+ {
+ throw new NotImplementedException();
+ }
+}
diff --git a/test/Abstractions.Tests/Entities/Mocks/TestEntityOperation.cs b/test/Abstractions.Tests/Entities/Mocks/TestEntityOperation.cs
new file mode 100644
index 000000000..63c277a4f
--- /dev/null
+++ b/test/Abstractions.Tests/Entities/Mocks/TestEntityOperation.cs
@@ -0,0 +1,54 @@
+// Copyright (c) Microsoft Corporation.
+// Licensed under the MIT License.
+
+using DotNext;
+
+namespace Microsoft.DurableTask.Entities.Tests;
+
+public class TestEntityOperation : TaskEntityOperation
+{
+ readonly Optional input;
+
+ public TestEntityOperation(string name, Optional input)
+ : this(name, new TestEntityContext(null), input)
+ {
+ }
+
+ public TestEntityOperation(string name, object? state, Optional input)
+ : this(name, new TestEntityContext(state), input)
+ {
+ }
+
+ public TestEntityOperation(string name, TaskEntityContext context, Optional input)
+ {
+ this.Name = name;
+ this.Context = context;
+ this.input = input;
+ }
+
+ public override string Name { get; }
+
+ public override TaskEntityContext Context { get; }
+
+ public override bool HasInput => this.input.HasValue;
+
+ public override object? GetInput(Type inputType)
+ {
+ if (this.input.IsUndefined)
+ {
+ throw new InvalidOperationException("No input available.");
+ }
+
+ if (this.input.IsNull)
+ {
+ return null;
+ }
+
+ if (!inputType.IsAssignableFrom(this.input.Value!.GetType()))
+ {
+ throw new InvalidCastException("Cannot convert input type.");
+ }
+
+ return this.input.Value;
+ }
+}
diff --git a/test/Abstractions.Tests/Entities/StateTaskEntityTests.cs b/test/Abstractions.Tests/Entities/StateTaskEntityTests.cs
new file mode 100644
index 000000000..3227a5ae4
--- /dev/null
+++ b/test/Abstractions.Tests/Entities/StateTaskEntityTests.cs
@@ -0,0 +1,323 @@
+// Copyright (c) Microsoft Corporation.
+// Licensed under the MIT License.
+
+using System.Reflection;
+using DotNext;
+
+namespace Microsoft.DurableTask.Entities.Tests;
+
+public class StateTaskEntityTests
+{
+ [Fact]
+ public async Task Precedence_ChoosesEntity()
+ {
+ TestEntityOperation operation = new("Precedence", default);
+ TestEntity entity = new();
+
+ object? result = await entity.RunAsync(operation);
+
+ result.Should().Be(20);
+ }
+
+ [Fact]
+ public async Task StateDispatchDisallowed_Throws()
+ {
+ TestEntityOperation operation = new("add0", 10);
+ TestEntity entity = new(false);
+
+ Func> action = () => entity.RunAsync(operation).AsTask();
+
+ await action.Should().ThrowAsync();
+ }
+
+ [Fact]
+ public async Task StateDispatch_NullState_Throws()
+ {
+ TestEntityOperation operation = new("add0", 10);
+ NullStateEntity entity = new();
+
+ Func> action = () => entity.RunAsync(operation).AsTask();
+
+ await action.Should().ThrowAsync();
+ }
+
+ [Theory]
+ [InlineData("doesNotExist")] // method does not exist.
+ [InlineData("add")] // private method, should not work.
+ [InlineData("staticMethod")] // public static methods are not supported.
+ public async Task OperationNotSupported_Fails(string name)
+ {
+ TestEntityOperation operation = new(name, 10);
+ TestEntity entity = new();
+
+ Func> action = () => entity.RunAsync(operation).AsTask();
+
+ await action.Should().ThrowAsync();
+ }
+
+ [Theory]
+ [CombinatorialData]
+ public async Task TaskOperation_Success(
+ [CombinatorialValues("TaskOp", "TaskOfStringOp", "ValueTaskOp", "ValueTaskOfStringOp")] string name, bool sync)
+ {
+ object? expected = name.Contains("OfString") ? "success" : null;
+ TestEntityOperation operation = new(name, sync);
+ TestEntity entity = new();
+
+ object? result = await entity.RunAsync(operation);
+
+ result.Should().Be(expected);
+ }
+
+ [Theory]
+ [CombinatorialData]
+ public async Task Add_Success([CombinatorialRange(0, 14)] int method, bool lowercase)
+ {
+ int start = Random.Shared.Next(0, 10);
+ int toAdd = Random.Shared.Next(1, 10);
+ string opName = lowercase ? "add" : "Add";
+ TestEntityContext context = new(State(start));
+ TestEntityOperation operation = new($"{opName}{method}", context, toAdd);
+ TestEntity entity = new();
+
+ object? result = await entity.RunAsync(operation);
+
+ int expected = start + toAdd;
+ context.GetState(typeof(TestState)).Should().BeOfType().Which.Value.Should().Be(expected);
+ result.Should().BeOfType().Which.Should().Be(expected);
+ }
+
+ [Theory]
+ [CombinatorialData]
+ public async Task Get_Success([CombinatorialRange(0, 2)] int method, bool lowercase)
+ {
+ int expected = Random.Shared.Next(0, 10);
+ string opName = lowercase ? "get" : "Get";
+ TestEntityContext context = new(State(expected));
+ TestEntityOperation operation = new($"{opName}{method}", context, default);
+ TestEntity entity = new();
+
+ object? result = await entity.RunAsync(operation);
+
+ context.GetState(typeof(TestState)).Should().BeOfType().Which.Value.Should().Be(expected);
+ result.Should().BeOfType().Which.Should().Be(expected);
+ }
+
+ [Fact]
+ public async Task Add_NoInput_Fails()
+ {
+ TestEntityOperation operation = new("add0", new TestEntityContext(null), default);
+ TestEntity entity = new();
+
+ Func> action = () => entity.RunAsync(operation).AsTask();
+
+ await action.Should().ThrowAsync();
+ }
+
+ [Theory]
+ [CombinatorialData]
+ public async Task Dispatch_AmbiguousArgs_Fails([CombinatorialRange(0, 3)] int method)
+ {
+ TestEntityOperation operation = new($"ambiguousArgs{method}", new TestEntityContext(null), 10);
+ TestEntity entity = new();
+
+ Func> action = () => entity.RunAsync(operation).AsTask();
+
+ await action.Should().ThrowAsync();
+ }
+
+ [Fact]
+ public async Task Dispatch_AmbiguousMatch_Fails()
+ {
+ TestEntityOperation operation = new("ambiguousMatch", new TestEntityContext(null), 10);
+ TestEntity entity = new();
+
+ Func> action = () => entity.RunAsync(operation).AsTask();
+ await action.Should().ThrowAsync();
+ }
+
+ [Fact]
+ public async Task DefaultValue_NoInput_Succeeds()
+ {
+ TestEntityOperation operation = new("defaultValue", new TestEntityContext(null), default);
+ TestEntity entity = new();
+
+ object? result = await entity.RunAsync(operation);
+
+ result.Should().BeOfType().Which.Should().Be("default");
+ }
+
+ [Fact]
+ public async Task DefaultValue_Input_Succeeds()
+ {
+ TestEntityOperation operation = new("defaultValue", new TestEntityContext(null), "not-default");
+ TestEntity entity = new();
+
+ object? result = await entity.RunAsync(operation);
+
+ result.Should().BeOfType().Which.Should().Be("not-default");
+ }
+
+ static TestState State(int value) => new() { Value = value };
+
+ class NullStateEntity : TestEntity
+ {
+ protected override TestState InitializeState() => null!;
+ }
+
+ class TestEntity : TaskEntity
+ {
+ readonly bool allowStateDispatch;
+
+ public TestEntity(bool allowStateDispatch = true)
+ {
+ this.allowStateDispatch = allowStateDispatch;
+ }
+
+ protected override bool AllowStateDispatch => this.allowStateDispatch;
+
+ public int Precedence() => this.State!.Precedence() * 2;
+ }
+
+#pragma warning disable CA1822 // Mark members as static
+#pragma warning disable IDE0060 // Remove unused parameter
+ class TestState
+ {
+ public int Value { get; set; }
+
+ public static string StaticMethod() => throw new NotImplementedException();
+
+ public int Precedence() => 10;
+
+ // All possible permutations of the 3 inputs we support: object, context, operation
+ // 14 via Add, 2 via Get: 16 total.
+ public int Add0(int value) => this.Add(value, default, default);
+
+ public int Add1(int value, TaskEntityContext context) => this.Add(value, context, default);
+
+ public int Add2(int value, TaskEntityOperation operation) => this.Add(value, default, operation);
+
+ public int Add3(int value, TaskEntityContext context, TaskEntityOperation operation)
+ => this.Add(value, context, operation);
+
+ public int Add4(int value, TaskEntityOperation operation, TaskEntityContext context)
+ => this.Add(value, context, operation);
+
+ public int Add5(TaskEntityOperation operation) => this.Add(default, default, operation);
+
+ public int Add6(TaskEntityOperation operation, int value) => this.Add(value, default, operation);
+
+ public int Add7(TaskEntityOperation operation, TaskEntityContext context)
+ => this.Add(default, context, operation);
+
+ public int Add8(TaskEntityOperation operation, int value, TaskEntityContext context)
+ => this.Add(value, context, operation);
+
+ public int Add9(TaskEntityOperation operation, TaskEntityContext context, int value)
+ => this.Add(value, context, operation);
+
+ public int Add10(TaskEntityContext context, int value)
+ => this.Add(value, context, default);
+
+ public int Add11(TaskEntityContext context, TaskEntityOperation operation)
+ => this.Add(default, context, operation);
+
+ public int Add12(TaskEntityContext context, int value, TaskEntityOperation operation)
+ => this.Add(value, context, operation);
+
+ public int Add13(TaskEntityContext context, TaskEntityOperation operation, int value)
+ => this.Add(value, context, operation);
+
+ public int Get0() => this.Get(default);
+
+ public int Get1(TaskEntityContext context) => this.Get(context);
+
+ public int AmbiguousMatch(TaskEntityContext context) => this.Value;
+
+ public int AmbiguousMatch(TaskEntityOperation operation) => this.Value;
+
+ public int AmbiguousArgs0(int value, object other) => this.Add0(value);
+
+ public int AmbiguousArgs1(int value, TaskEntityContext context, TaskEntityContext context2) => this.Add0(value);
+
+ public int AmbiguousArgs2(int value, TaskEntityOperation operation, TaskEntityOperation operation2)
+ => this.Add0(value);
+
+ public string DefaultValue(string toReturn = "default") => toReturn;
+
+ public Task TaskOp(bool sync)
+ {
+ static async Task Slow()
+ {
+ await Task.Yield();
+ }
+
+ return sync ? Task.CompletedTask : Slow();
+ }
+
+ public Task TaskOfStringOp(bool sync)
+ {
+ static async Task Slow()
+ {
+ await Task.Yield();
+ return "success";
+ }
+
+ return sync ? Task.FromResult("success") : Slow();
+ }
+
+ public ValueTask ValueTaskOp(bool sync)
+ {
+ static async Task Slow()
+ {
+ await Task.Yield();
+ }
+
+ return sync ? default : new(Slow());
+ }
+
+ public ValueTask ValueTaskOfStringOp(bool sync)
+ {
+ static async Task Slow()
+ {
+ await Task.Yield();
+ return "success";
+ }
+
+ return sync ? new("success") : new(Slow());
+ }
+
+ int Add(int? value, Optional context, Optional operation)
+ {
+ if (context.HasValue)
+ {
+ context.Value.Should().NotBeNull();
+ }
+
+ if (operation.HasValue)
+ {
+ operation.Value.Should().NotBeNull();
+ }
+
+ if (!value.HasValue && operation.TryGet(out TaskEntityOperation? op))
+ {
+ value = (int)op.GetInput(typeof(int))!;
+ }
+
+ value.HasValue.Should().BeTrue();
+ return this.Value += value!.Value;
+ }
+
+ int Get(Optional context)
+ {
+ if (context.HasValue)
+ {
+ context.Value.Should().NotBeNull();
+ }
+
+ return this.Value;
+ }
+ }
+#pragma warning restore IDE0060 // Remove unused parameter
+#pragma warning restore CA1822 // Mark members as static
+}
diff --git a/test/Abstractions.Tests/Entities/TaskEntityHelpersTests.cs b/test/Abstractions.Tests/Entities/TaskEntityHelpersTests.cs
new file mode 100644
index 000000000..2521ddb39
--- /dev/null
+++ b/test/Abstractions.Tests/Entities/TaskEntityHelpersTests.cs
@@ -0,0 +1,238 @@
+// Copyright (c) Microsoft Corporation.
+// Licensed under the MIT License.
+
+namespace Microsoft.DurableTask.Entities.Tests;
+
+public class TaskEntityHelpersTests
+{
+ [Fact]
+ public async Task UnwrapAsync_Void()
+ {
+ int state = Random.Shared.Next(1, 10);
+ TestEntityContext context = new(null);
+
+ object? result = await TaskEntityHelpers.UnwrapAsync(context, () => state, null, typeof(void));
+
+ result.Should().BeNull();
+ context.State.Should().BeOfType().Which.Should().Be(state);
+ }
+
+ [Fact]
+ public async Task UnwrapAsync_Object()
+ {
+ int state = Random.Shared.Next(1, 10);
+ int value = Random.Shared.Next(1, 10);
+ TestEntityContext context = new(null);
+
+ object? result = await TaskEntityHelpers.UnwrapAsync(context, () => state, value, typeof(int));
+
+ result.Should().BeOfType().Which.Should().Be(value);
+ context.State.Should().BeOfType().Which.Should().Be(state);
+ }
+
+ [Theory]
+ [CombinatorialData]
+ public async Task UnwrapAsync_Task(bool async)
+ {
+ TaskCompletionSource tcs = new();
+ int state = Random.Shared.Next(1, 10);
+ TestEntityContext context = new(null);
+
+ if (!async)
+ {
+ tcs.TrySetResult();
+ }
+
+ ValueTask task = TaskEntityHelpers.UnwrapAsync(context, () => state, tcs.Task, typeof(Task));
+
+ if (async)
+ {
+ state++; // Make sure state changes are captured
+ tcs.TrySetResult();
+ }
+
+ object? result = await task;
+
+ result.Should().BeNull();
+ context.State.Should().BeOfType().Which.Should().Be(state);
+ }
+
+ [Theory]
+ [CombinatorialData]
+ public async Task UnwrapAsync_Task_Throws(bool async)
+ {
+ TaskCompletionSource tcs = new();
+ TestEntityContext context = new(null);
+
+ if (!async)
+ {
+ tcs.SetException(new OperationCanceledException());
+ }
+
+ Func throws = async () => await TaskEntityHelpers.UnwrapAsync(context, () => 0, tcs.Task, typeof(Task));
+
+ if (async)
+ {
+ tcs.SetException(new OperationCanceledException());
+ }
+
+ await throws.Should().ThrowExactlyAsync();
+ }
+
+ [Theory]
+ [CombinatorialData]
+ public async Task UnwrapAsync_TaskOfInt(bool async)
+ {
+ TaskCompletionSource tcs = new();
+
+ int state = Random.Shared.Next(1, 10);
+ int value = Random.Shared.Next(1, 10);
+ TestEntityContext context = new(null);
+
+ if (!async)
+ {
+ tcs.TrySetResult(value);
+ }
+
+ ValueTask task = TaskEntityHelpers.UnwrapAsync(context, () => state, tcs.Task, typeof(Task));
+
+ if (async)
+ {
+ state++; // Make sure state changes are captured
+ tcs.TrySetResult(value);
+ }
+
+ object? result = await task;
+
+ result.Should().BeOfType().Which.Should().Be(value);
+ context.State.Should().BeOfType().Which.Should().Be(state);
+ }
+
+ [Theory]
+ [CombinatorialData]
+ public async Task UnwrapAsync_TaskOfInt_Throws(bool async)
+ {
+ TaskCompletionSource tcs = new();
+ TestEntityContext context = new(null);
+
+ if (!async)
+ {
+ tcs.SetException(new OperationCanceledException());
+ }
+
+ Func throws = async () => await TaskEntityHelpers.UnwrapAsync(
+ context, () => 0, tcs.Task, typeof(Task));
+
+ if (async)
+ {
+ tcs.SetException(new OperationCanceledException());
+ }
+
+ await throws.Should().ThrowExactlyAsync();
+ }
+
+
+ [Theory]
+ [CombinatorialData]
+ public async Task UnwrapAsync_ValueTask(bool async)
+ {
+ TaskCompletionSource tcs = new();
+
+ int state = Random.Shared.Next(1, 10);
+ TestEntityContext context = new(null);
+
+ if (!async)
+ {
+ tcs.TrySetResult();
+ }
+
+ ValueTask task = TaskEntityHelpers.UnwrapAsync(
+ context, () => state, new ValueTask(tcs.Task), typeof(ValueTask));
+
+ if (async)
+ {
+ state++; // Make sure state changes are captured
+ tcs.TrySetResult();
+ }
+
+ object? result = await task;
+ result.Should().BeNull();
+ context.State.Should().BeOfType().Which.Should().Be(state);
+ }
+
+ [Theory]
+ [CombinatorialData]
+ public async Task UnwrapAsync_ValueTask_Throws(bool async)
+ {
+ TaskCompletionSource tcs = new();
+ TestEntityContext context = new(null);
+
+ if (!async)
+ {
+ tcs.SetException(new OperationCanceledException());
+ }
+
+ Func throws = async () => await TaskEntityHelpers.UnwrapAsync(
+ context, () => 0, new ValueTask(tcs.Task), typeof(ValueTask));
+
+ if (async)
+ {
+ tcs.SetException(new OperationCanceledException());
+ }
+
+ await throws.Should().ThrowExactlyAsync();
+ }
+
+
+ [Theory]
+ [CombinatorialData]
+ public async Task UnwrapAsync_ValueTaskOfInt(bool async)
+ {
+ TaskCompletionSource tcs = new();
+ int state = Random.Shared.Next(1, 10);
+ int value = Random.Shared.Next(1, 10);
+ TestEntityContext context = new(null);
+
+ if (!async)
+ {
+ tcs.TrySetResult(value);
+ }
+
+ ValueTask task = TaskEntityHelpers.UnwrapAsync(
+ context, () => state, new ValueTask(tcs.Task), typeof(ValueTask));
+
+ if (async)
+ {
+ state++; // Make sure state changes are captured
+ tcs.TrySetResult(value);
+ }
+
+ object? result = await task;
+
+ result.Should().BeOfType().Which.Should().Be(value);
+ context.State.Should().BeOfType().Which.Should().Be(state);
+ }
+
+ [Theory]
+ [CombinatorialData]
+ public async Task UnwrapAsync_ValueTaskOfInt_Throws(bool async)
+ {
+ TaskCompletionSource tcs = new();
+ TestEntityContext context = new(null);
+
+ if (!async)
+ {
+ tcs.SetException(new OperationCanceledException());
+ }
+
+ Func throws = async () => await TaskEntityHelpers.UnwrapAsync(
+ context, () => 0, new ValueTask(tcs.Task), typeof(ValueTask));
+
+ if (async)
+ {
+ tcs.SetException(new OperationCanceledException());
+ }
+
+ await throws.Should().ThrowExactlyAsync();
+ }
+}
diff --git a/test/Directory.Build.targets b/test/Directory.Build.targets
index 294f47da1..c0558551f 100644
--- a/test/Directory.Build.targets
+++ b/test/Directory.Build.targets
@@ -4,7 +4,7 @@
Condition=" '$([MSBuild]::GetDirectoryNameOfFileAbove($(MSBuildThisFileDirectory)../, $(_DirectoryBuildTargetsFile)))' != '' " />
-
+
runtime; build; native; contentfiles; analyzers; buildtransitive