diff --git a/samples/AzureFunctionsApp/AzureFunctionsApp.csproj b/samples/AzureFunctionsApp/AzureFunctionsApp.csproj
index 49dc1988d..c2e463efa 100644
--- a/samples/AzureFunctionsApp/AzureFunctionsApp.csproj
+++ b/samples/AzureFunctionsApp/AzureFunctionsApp.csproj
@@ -8,10 +8,10 @@
-
-
+
+
-
+
diff --git a/src/Abstractions/Entities/TaskEntity.cs b/src/Abstractions/Entities/TaskEntity.cs
index 29932f141..ec40154c9 100644
--- a/src/Abstractions/Entities/TaskEntity.cs
+++ b/src/Abstractions/Entities/TaskEntity.cs
@@ -1,11 +1,22 @@
// Copyright (c) Microsoft Corporation.
// Licensed under the MIT License.
+using System.Reflection;
+using System.Threading.Tasks;
+
namespace Microsoft.DurableTask.Entities;
///
/// The task entity contract.
///
+///
+/// Entity State
+///
+/// All entity implementations are required to be serializable by the configured . An entity
+/// will have its state deserialized before executing an operation, and then the new state will be the serialized value
+/// of the implementation instance post-operation.
+///
+///
public interface ITaskEntity
{
///
@@ -15,3 +26,221 @@ public interface ITaskEntity
/// The response to the caller, if any.
ValueTask
public abstract TaskEntityContext Context { get; }
+ ///
+ /// Gets a value indicating whether this operation has input or not.
+ ///
+ public abstract bool HasInput { get; }
+
///
/// Gets the input for this operation.
///
/// The type to deserialize the input as.
/// The deserialized input type.
public abstract object? GetInput(Type inputType);
+
+ ///
+ public override string ToString()
+ {
+ return $"{this.Context.Id.Name}/{this.Name}";
+ }
}
diff --git a/src/Shared/Core/TaskExtensions.cs b/src/Shared/Core/TaskExtensions.cs
new file mode 100644
index 000000000..116c6623a
--- /dev/null
+++ b/src/Shared/Core/TaskExtensions.cs
@@ -0,0 +1,44 @@
+// Copyright (c) Microsoft Corporation.
+// Licensed under the MIT License.
+
+using System.Reflection;
+using Microsoft.DurableTask;
+
+namespace System;
+
+///
+/// Extensions for and .
+///
+static class TaskExtensions
+{
+ ///
+ /// Converts a to a generic .
+ ///
+ /// The generic type param to convert to.
+ /// The task to convert.
+ /// The converted task.
+ public static async Task ToGeneric(this Task task)
+ {
+ await Check.NotNull(task);
+
+ Type t = task.GetType();
+ if (t.IsGenericType)
+ {
+ return (T)t.GetProperty("Result", BindingFlags.Public | BindingFlags.Instance).GetValue(task);
+ }
+
+ return default;
+ }
+
+ ///
+ /// Converts a to a .
+ ///
+ /// The generic type param to convert to.
+ /// The value task to convert.
+ /// The converted value task.
+ public static async ValueTask ToGeneric(this ValueTask task)
+ {
+ await task;
+ return default;
+ }
+}
diff --git a/test/Abstractions.Tests/Entities/TaskEntityTests.cs b/test/Abstractions.Tests/Entities/TaskEntityTests.cs
new file mode 100644
index 000000000..ea5129c83
--- /dev/null
+++ b/test/Abstractions.Tests/Entities/TaskEntityTests.cs
@@ -0,0 +1,299 @@
+// Copyright (c) Microsoft Corporation.
+// Licensed under the MIT License.
+
+using System.Reflection;
+using DotNext;
+
+namespace Microsoft.DurableTask.Entities.Tests;
+
+public class TaskEntityTests
+{
+ [Theory]
+ [InlineData("doesNotExist")] // method does not exist.
+ [InlineData("add")] // private method, should not work.
+ [InlineData("staticMethod")] // public static methods are not supported.
+ public async Task OperationNotSupported_Fails(string name)
+ {
+ Operation operation = new(name, Mock.Of(), 10);
+ TestEntity entity = new();
+
+ Func> action = () => entity.RunAsync(operation).AsTask();
+
+ await action.Should().ThrowAsync();
+ }
+
+ [Theory]
+ [CombinatorialData]
+ public async Task TaskOperation_Success(
+ [CombinatorialValues("TaskOp", "TaskOfStringOp", "ValueTaskOp", "ValueTaskOfStringOp")] string name, bool sync)
+ {
+ object? expected = name.Contains("OfString") ? "success" : null;
+ Operation operation = new(name, Mock.Of(), sync);
+ TestEntity entity = new();
+
+ object? result = await entity.RunAsync(operation);
+
+ result.Should().Be(expected);
+ }
+
+ [Theory]
+ [CombinatorialData]
+ public async Task Add_Success([CombinatorialRange(0, 14)] int method, bool lowercase)
+ {
+ int start = Random.Shared.Next(0, 10);
+ int toAdd = Random.Shared.Next(0, 10);
+ string opName = lowercase ? "add" : "Add";
+ Operation operation = new($"{opName}{method}", Mock.Of(), toAdd);
+ TestEntity entity = new() { Value = start };
+
+ object? result = await entity.RunAsync(operation);
+
+ int expected = start + toAdd;
+ entity.Value.Should().Be(expected);
+ result.Should().BeOfType().Which.Should().Be(expected);
+ }
+
+ [Theory]
+ [CombinatorialData]
+ public async Task Get_Success([CombinatorialRange(0, 2)] int method, bool lowercase)
+ {
+ int expected = Random.Shared.Next(0, 10);
+ string opName = lowercase ? "get" : "Get";
+ Operation operation = new($"{opName}{method}", Mock.Of(), default);
+ TestEntity entity = new() { Value = expected };
+
+ object? result = await entity.RunAsync(operation);
+
+ entity.Value.Should().Be(expected);
+ result.Should().BeOfType().Which.Should().Be(expected);
+ }
+
+ [Fact]
+ public async Task Add_NoInput_Fails()
+ {
+ Operation operation = new("add0", Mock.Of(), default);
+ TestEntity entity = new();
+
+ Func> action = () => entity.RunAsync(operation).AsTask();
+
+ await action.Should().ThrowAsync();
+ }
+
+ [Theory]
+ [CombinatorialData]
+ public async Task Dispatch_AmbiguousArgs_Fails([CombinatorialRange(0, 3)] int method)
+ {
+ Operation operation = new($"ambiguousArgs{method}", Mock.Of(), 10);
+ TestEntity entity = new();
+
+ Func> action = () => entity.RunAsync(operation).AsTask();
+
+ await action.Should().ThrowAsync();
+ }
+
+ [Fact]
+ public async Task Dispatch_AmbiguousMatch_Fails()
+ {
+ Operation operation = new("ambiguousMatch", Mock.Of(), 10);
+ TestEntity entity = new();
+
+ Func> action = () => entity.RunAsync(operation).AsTask();
+ await action.Should().ThrowAsync();
+ }
+
+ [Fact]
+ public async Task DefaultValue_NoInput_Succeeds()
+ {
+ Operation operation = new("defaultValue", Mock.Of(), default);
+ TestEntity entity = new();
+
+ object? result = await entity.RunAsync(operation);
+
+ result.Should().BeOfType().Which.Should().Be("default");
+ }
+
+ [Fact]
+ public async Task DefaultValue_Input_Succeeds()
+ {
+ Operation operation = new("defaultValue", Mock.Of(), "not-default");
+ TestEntity entity = new();
+
+ object? result = await entity.RunAsync(operation);
+
+ result.Should().BeOfType().Which.Should().Be("not-default");
+ }
+
+ class Operation : TaskEntityOperation
+ {
+ readonly Optional input;
+
+ public Operation(string name, TaskEntityContext context, Optional input)
+ {
+ this.Name = name;
+ this.Context = context;
+ this.input = input;
+ }
+
+ public override string Name { get; }
+
+ public override TaskEntityContext Context { get; }
+
+ public override bool HasInput => this.input.IsPresent;
+
+ public override object? GetInput(Type inputType)
+ {
+ if (!this.input.IsPresent)
+ {
+ throw new InvalidOperationException("No input available.");
+ }
+
+ if (this.input.Value is null)
+ {
+ return null;
+ }
+
+ if (!inputType.IsAssignableFrom(this.input.Value.GetType()))
+ {
+ throw new InvalidCastException("Cannot convert input type.");
+ }
+
+ return this.input.Value;
+ }
+ }
+
+ class TestEntity : TaskEntity
+ {
+ public int Value { get; set; }
+
+ public static string StaticMethod() => throw new NotImplementedException();
+
+ // All possible permutations of the 3 inputs we support: object, context, operation
+ // 14 via Add, 2 via Get: 16 total.
+ public int Add0(int value) => this.Add(value, default, default);
+
+ public int Add1(int value, TaskEntityContext context) => this.Add(value, context, default);
+
+ public int Add2(int value, TaskEntityOperation operation) => this.Add(value, default, operation);
+
+ public int Add3(int value, TaskEntityContext context, TaskEntityOperation operation)
+ => this.Add(value, context, operation);
+
+ public int Add4(int value, TaskEntityOperation operation, TaskEntityContext context)
+ => this.Add(value, context, operation);
+
+ public int Add5(TaskEntityOperation operation) => this.Add(default, default, operation);
+
+ public int Add6(TaskEntityOperation operation, int value) => this.Add(value, default, operation);
+
+ public int Add7(TaskEntityOperation operation, TaskEntityContext context)
+ => this.Add(default, context, operation);
+
+ public int Add8(TaskEntityOperation operation, int value, TaskEntityContext context)
+ => this.Add(value, context, operation);
+
+ public int Add9(TaskEntityOperation operation, TaskEntityContext context, int value)
+ => this.Add(value, context, operation);
+
+ public int Add10(TaskEntityContext context, int value)
+ => this.Add(value, context, default);
+
+ public int Add11(TaskEntityContext context, TaskEntityOperation operation)
+ => this.Add(default, context, operation);
+
+ public int Add12(TaskEntityContext context, int value, TaskEntityOperation operation)
+ => this.Add(value, context, operation);
+
+ public int Add13(TaskEntityContext context, TaskEntityOperation operation, int value)
+ => this.Add(value, context, operation);
+
+ public int Get0() => this.Get(default);
+
+ public int Get1(TaskEntityContext context) => this.Get(context);
+
+ public int AmbiguousMatch(TaskEntityContext context) => this.Value;
+
+ public int AmbiguousMatch(TaskEntityOperation operation) => this.Value;
+
+ public int AmbiguousArgs0(int value, object other) => this.Add0(value);
+
+ public int AmbiguousArgs1(int value, TaskEntityContext context, TaskEntityContext context2) => this.Add0(value);
+
+ public int AmbiguousArgs2(int value, TaskEntityOperation operation, TaskEntityOperation operation2)
+ => this.Add0(value);
+
+ public string DefaultValue(string toReturn = "default") => toReturn;
+
+ public Task TaskOp(bool sync)
+ {
+ static async Task Slow()
+ {
+ await Task.Yield();
+ }
+
+ return sync ? Task.CompletedTask : Slow();
+ }
+
+ public Task TaskOfStringOp(bool sync)
+ {
+ static async Task Slow()
+ {
+ await Task.Yield();
+ return "success";
+ }
+
+ return sync ? Task.FromResult("success") : Slow();
+ }
+
+ public ValueTask ValueTaskOp(bool sync)
+ {
+ static async Task Slow()
+ {
+ await Task.Yield();
+ }
+
+ return sync ? default : new(Slow());
+ }
+
+ public ValueTask ValueTaskOfStringOp(bool sync)
+ {
+ static async Task Slow()
+ {
+ await Task.Yield();
+ return "success";
+ }
+
+ return sync ? new("success") : new(Slow());
+ }
+
+ int Add(int? value, Optional context, Optional operation)
+ {
+ if (context.IsPresent)
+ {
+ context.Value.Should().NotBeNull();
+ }
+
+ if (operation.IsPresent)
+ {
+ operation.Value.Should().NotBeNull();
+ }
+
+ if (!value.HasValue && operation.TryGet(out TaskEntityOperation op))
+ {
+ value = (int)op.GetInput(typeof(int))!;
+ }
+
+ value.HasValue.Should().BeTrue();
+ return this.Value += value!.Value;
+ }
+
+ int Get(Optional context)
+ {
+ if (context.IsPresent)
+ {
+ context.Value.Should().NotBeNull();
+ }
+
+ return this.Value;
+ }
+ }
+}
diff --git a/test/Directory.Build.targets b/test/Directory.Build.targets
index d4e59594a..294f47da1 100644
--- a/test/Directory.Build.targets
+++ b/test/Directory.Build.targets
@@ -4,6 +4,7 @@
Condition=" '$([MSBuild]::GetDirectoryNameOfFileAbove($(MSBuildThisFileDirectory)../, $(_DirectoryBuildTargetsFile)))' != '' " />
+
runtime; build; native; contentfiles; analyzers; buildtransitive
@@ -13,6 +14,7 @@
+
runtime; build; native; contentfiles; analyzers; buildtransitive
all