diff --git a/eng/proto b/eng/proto index 7d6826889..b9f0e97b1 160000 --- a/eng/proto +++ b/eng/proto @@ -1 +1 @@ -Subproject commit 7d6826889eb9b104592ab1020c648517a155ba79 +Subproject commit b9f0e97b1db298f5f3ae777a7b6092b6b1f86807 diff --git a/src/Abstractions/DurableTaskRegistry.cs b/src/Abstractions/DurableTaskRegistry.cs index 28422dfab..eefd96680 100644 --- a/src/Abstractions/DurableTaskRegistry.cs +++ b/src/Abstractions/DurableTaskRegistry.cs @@ -1,6 +1,8 @@ // Copyright (c) Microsoft Corporation. // Licensed under the MIT License. +using Microsoft.DurableTask.Entities; + namespace Microsoft.DurableTask; /// @@ -22,6 +24,12 @@ public sealed partial class DurableTaskRegistry internal IDictionary> Orchestrators { get; } = new Dictionary>(); + /// + /// Gets the currently registered entities. + /// + internal IDictionary> Entities { get; } + = new Dictionary>(); + /// /// Registers an activity factory. /// @@ -76,4 +84,31 @@ public DurableTaskRegistry AddOrchestrator(TaskName name, Func factory()); return this; } + + /// + /// Registers an entity factory. + /// + /// The name of the entity. + /// The entity factory. + /// This registry instance, for call chaining. + /// + /// Thrown if any of the following are true: + /// + /// If is default. + /// If is already registered. + /// If is null. + /// + /// + public DurableTaskRegistry AddEntity(TaskName name, Func factory) + { + Check.NotDefault(name); + Check.NotNull(factory); + if (this.Entities.ContainsKey(name)) + { + throw new ArgumentException($"An {nameof(ITaskEntity)} named '{name}' is already added.", nameof(name)); + } + + this.Entities.Add(name, factory); + return this; + } } diff --git a/src/Shared/Grpc/ProtoUtils.cs b/src/Shared/Grpc/ProtoUtils.cs index 40a604778..63c0e1307 100644 --- a/src/Shared/Grpc/ProtoUtils.cs +++ b/src/Shared/Grpc/ProtoUtils.cs @@ -7,6 +7,7 @@ using System.Text; using DurableTask.Core; using DurableTask.Core.Command; +using DurableTask.Core.Entities.OperationFormat; using DurableTask.Core.History; using Google.Protobuf; using Google.Protobuf.WellKnownTypes; @@ -385,6 +386,251 @@ internal static OrchestrationStatus ToCore(this P.OrchestrationStatus status) }; } + /// + /// Converts a to a . + /// + /// The entity batch request to convert. + /// The converted entity batch request. + [return: NotNullIfNotNull("entityBatchRequest")] + internal static EntityBatchRequest? ToEntityBatchRequest(this P.EntityBatchRequest? entityBatchRequest) + { + if (entityBatchRequest == null) + { + return null; + } + + return new EntityBatchRequest() + { + EntityState = entityBatchRequest.EntityState, + InstanceId = entityBatchRequest.InstanceId, + Operations = entityBatchRequest.Operations.Select(r => r.ToOperationRequest()).ToList(), + }; + } + + /// + /// Converts a to a . + /// + /// The operation request to convert. + /// The converted operation request. + [return: NotNullIfNotNull("operationRequest")] + internal static OperationRequest? ToOperationRequest(this P.OperationRequest? operationRequest) + { + if (operationRequest == null) + { + return null; + } + + return new OperationRequest() + { + Operation = operationRequest.Operation, + Input = operationRequest.Input, + Id = Guid.Parse(operationRequest.RequestId), + }; + } + + /// + /// Converts a to a . + /// + /// The operation result to convert. + /// The converted operation result. + [return: NotNullIfNotNull("operationResult")] + internal static OperationResult? ToOperationResult(this P.OperationResult? operationResult) + { + if (operationResult == null) + { + return null; + } + + switch (operationResult.ResultTypeCase) + { + case P.OperationResult.ResultTypeOneofCase.Success: + return new OperationResult() + { + Result = operationResult.Success.Result, + }; + + case P.OperationResult.ResultTypeOneofCase.Failure: + return new OperationResult() + { + FailureDetails = operationResult.Failure.FailureDetails.ToCore(), + }; + + default: + throw new NotSupportedException($"Deserialization of {operationResult.ResultTypeCase} is not supported."); + } + } + + /// + /// Converts a to . + /// + /// The operation result to convert. + /// The converted operation result. + [return: NotNullIfNotNull("operationResult")] + internal static P.OperationResult? ToOperationResult(this OperationResult? operationResult) + { + if (operationResult == null) + { + return null; + } + + if (operationResult.FailureDetails == null) + { + return new P.OperationResult() + { + Success = new P.OperationResultSuccess() + { + Result = operationResult.Result, + }, + }; + } + else + { + return new P.OperationResult() + { + Failure = new P.OperationResultFailure() + { + FailureDetails = ToProtobuf(operationResult.FailureDetails), + }, + }; + } + } + + /// + /// Converts a to a . + /// + /// The operation action to convert. + /// The converted operation action. + [return: NotNullIfNotNull("operationAction")] + internal static OperationAction? ToOperationAction(this P.OperationAction? operationAction) + { + if (operationAction == null) + { + return null; + } + + switch (operationAction.OperationActionTypeCase) + { + case P.OperationAction.OperationActionTypeOneofCase.SendSignal: + + return new SendSignalOperationAction() + { + Name = operationAction.SendSignal.Name, + Input = operationAction.SendSignal.Input, + InstanceId = operationAction.SendSignal.InstanceId, + ScheduledTime = operationAction.SendSignal.ScheduledTime?.ToDateTime(), + }; + + case P.OperationAction.OperationActionTypeOneofCase.StartNewOrchestration: + + return new StartNewOrchestrationOperationAction() + { + Name = operationAction.StartNewOrchestration.Name, + Input = operationAction.StartNewOrchestration.Input, + InstanceId = operationAction.StartNewOrchestration.InstanceId, + Version = operationAction.StartNewOrchestration.Version, + }; + default: + throw new NotSupportedException($"Deserialization of {operationAction.OperationActionTypeCase} is not supported."); + } + } + + /// + /// Converts a to . + /// + /// The operation action to convert. + /// The converted operation action. + [return: NotNullIfNotNull("operationAction")] + internal static P.OperationAction? ToOperationAction(this OperationAction? operationAction) + { + if (operationAction == null) + { + return null; + } + + var action = new P.OperationAction(); + + switch (operationAction) + { + case SendSignalOperationAction sendSignalAction: + + action.SendSignal = new P.SendSignalAction() + { + Name = sendSignalAction.Name, + Input = sendSignalAction.Input, + InstanceId = sendSignalAction.InstanceId, + ScheduledTime = sendSignalAction.ScheduledTime?.ToTimestamp(), + }; + break; + + case StartNewOrchestrationOperationAction startNewOrchestrationAction: + + action.StartNewOrchestration = new P.StartNewOrchestrationAction() + { + Name = startNewOrchestrationAction.Name, + Input = startNewOrchestrationAction.Input, + Version = startNewOrchestrationAction.Version, + InstanceId = startNewOrchestrationAction.InstanceId, + }; + break; + } + + return action; + } + + /// + /// Converts a to a . + /// + /// The operation result to convert. + /// The converted operation result. + [return: NotNullIfNotNull("entityBatchResult")] + internal static EntityBatchResult? ToEntityBatchResult(this P.EntityBatchResult? entityBatchResult) + { + if (entityBatchResult == null) + { + return null; + } + + return new EntityBatchResult() + { + Actions = entityBatchResult.Actions.Select(operationAction => operationAction!.ToOperationAction()).ToList(), + EntityState = entityBatchResult.EntityState, + Results = entityBatchResult.Results.Select(operationResult => operationResult!.ToOperationResult()).ToList(), + FailureDetails = entityBatchResult.FailureDetails.ToCore(), + }; + } + + /// + /// Converts a to . + /// + /// The operation result to convert. + /// The converted operation result. + [return: NotNullIfNotNull("entityBatchResult")] + internal static P.EntityBatchResult? ToEntityBatchResult(this EntityBatchResult? entityBatchResult) + { + if (entityBatchResult == null) + { + return null; + } + + var batchResult = new P.EntityBatchResult() + { + EntityState = entityBatchResult.EntityState, + FailureDetails = entityBatchResult.FailureDetails.ToProtobuf(), + }; + + foreach (OperationAction action in entityBatchResult.Actions!) + { + batchResult.Actions.Add(action.ToOperationAction()); + } + + foreach (OperationResult result in entityBatchResult.Results!) + { + batchResult.Results.Add(result.ToOperationResult()); + } + + return batchResult; + } + /// /// Gets the approximate byte count for a . /// diff --git a/src/Worker/Core/DurableTaskFactory.cs b/src/Worker/Core/DurableTaskFactory.cs index 8d0fc368a..0e77a584a 100644 --- a/src/Worker/Core/DurableTaskFactory.cs +++ b/src/Worker/Core/DurableTaskFactory.cs @@ -2,28 +2,33 @@ // Licensed under the MIT License. using System.Diagnostics.CodeAnalysis; +using Microsoft.DurableTask.Entities; namespace Microsoft.DurableTask.Worker; /// /// A factory for creating orchestrators and activities. /// -sealed class DurableTaskFactory : IDurableTaskFactory +sealed class DurableTaskFactory : IDurableTaskFactory2 { readonly IDictionary> activities; readonly IDictionary> orchestrators; + readonly IDictionary> entities; /// /// Initializes a new instance of the class. /// /// The activity factories. /// The orchestrator factories. + /// The entity factories. internal DurableTaskFactory( IDictionary> activities, - IDictionary> orchestrators) + IDictionary> orchestrators, + IDictionary> entities) { this.activities = Check.NotNull(activities); this.orchestrators = Check.NotNull(orchestrators); + this.entities = Check.NotNull(entities); } /// @@ -54,4 +59,18 @@ public bool TryCreateOrchestrator( orchestrator = null; return false; } + + /// + public bool TryCreateEntity( + TaskName name, IServiceProvider serviceProvider, [NotNullWhen(true)] out ITaskEntity? entity) + { + if (this.entities.TryGetValue(name, out Func? factory)) + { + entity = factory.Invoke(serviceProvider); + return true; + } + + entity = null; + return false; + } } diff --git a/src/Worker/Core/DurableTaskRegistryExtensions.cs b/src/Worker/Core/DurableTaskRegistryExtensions.cs index f46f9e71e..f89288315 100644 --- a/src/Worker/Core/DurableTaskRegistryExtensions.cs +++ b/src/Worker/Core/DurableTaskRegistryExtensions.cs @@ -16,6 +16,6 @@ static class DurableTaskRegistryExtensions public static IDurableTaskFactory BuildFactory(this DurableTaskRegistry registry) { Check.NotNull(registry); - return new DurableTaskFactory(registry.Activities, registry.Orchestrators); + return new DurableTaskFactory(registry.Activities, registry.Orchestrators, registry.Entities); } } diff --git a/src/Worker/Core/IDurableTaskFactory.cs b/src/Worker/Core/IDurableTaskFactory.cs index 56257beb9..43d96f4da 100644 --- a/src/Worker/Core/IDurableTaskFactory.cs +++ b/src/Worker/Core/IDurableTaskFactory.cs @@ -2,6 +2,7 @@ // Licensed under the MIT License. using System.Diagnostics.CodeAnalysis; +using Microsoft.DurableTask.Entities; namespace Microsoft.DurableTask.Worker; @@ -35,3 +36,24 @@ bool TryCreateActivity( bool TryCreateOrchestrator( TaskName name, IServiceProvider serviceProvider, [NotNullWhen(true)] out ITaskOrchestrator? orchestrator); } + +/// +/// A newer version of that adds support for entities. +/// +public interface IDurableTaskFactory2 : IDurableTaskFactory +{ + /// + /// Tries to create an entity given a name. + /// + /// The name of the orchestrator. + /// The service provider. + /// The entity or null if it does not exist. + /// True if entity was created, false otherwise. + /// + /// While is provided here, it is not required to be used to construct + /// orchestrators. Individual implementations of this contract may use it in different ways. The default + /// implementation does not use it. + /// + bool TryCreateEntity( + TaskName name, IServiceProvider serviceProvider, [NotNullWhen(true)] out ITaskEntity? entity); +} diff --git a/src/Worker/Core/Shims/DurableTaskShimFactory.cs b/src/Worker/Core/Shims/DurableTaskShimFactory.cs index 8ed8e6b1f..26e33cb73 100644 --- a/src/Worker/Core/Shims/DurableTaskShimFactory.cs +++ b/src/Worker/Core/Shims/DurableTaskShimFactory.cs @@ -2,6 +2,8 @@ // Licensed under the MIT License. using DurableTask.Core; +using DurableTask.Core.Entities; +using Microsoft.DurableTask.Entities; using Microsoft.Extensions.Logging; using Microsoft.Extensions.Logging.Abstractions; @@ -107,4 +109,24 @@ public TaskOrchestration CreateOrchestration( Check.NotNull(implementation); return this.CreateOrchestration(name, FuncTaskOrchestrator.Create(implementation), parent); } + + /// + /// Creates a from a . + /// + /// + /// The name of the entity. This should be the name the entity was invoked with. + /// + /// The entity to wrap. + /// The entity id for the shim. + /// A new . + public TaskEntity CreateEntity(TaskName name, ITaskEntity entity, EntityId entityId) + { + Check.NotDefault(name); + Check.NotNull(entity); + + // For now, we simply create a new shim for each entity batch operation. + // In the future we may consider caching those shims and reusing them, which can reduce + // deserialization and allocation overheads. + return new TaskEntityShim(this.options.DataConverter, entity, entityId); + } } diff --git a/src/Worker/Core/Shims/TaskEntityShim.cs b/src/Worker/Core/Shims/TaskEntityShim.cs new file mode 100644 index 000000000..a87fd9e19 --- /dev/null +++ b/src/Worker/Core/Shims/TaskEntityShim.cs @@ -0,0 +1,250 @@ +// Copyright (c) Microsoft Corporation. +// Licensed under the MIT License. + +using DurableTask.Core; +using DurableTask.Core.Entities; +using DurableTask.Core.Entities.OperationFormat; +using Microsoft.DurableTask.Entities; +using Microsoft.Extensions.Logging; +using DTCore = DurableTask.Core; + +namespace Microsoft.DurableTask.Worker.Shims; + +/// +/// Shim that provides the entity context and implements batched execution. +/// +class TaskEntityShim : DTCore.Entities.TaskEntity +{ + readonly DataConverter dataConverter; + readonly ITaskEntity taskEntity; + readonly EntityInstanceId entityId; + + readonly StateShim state; + readonly ContextShim context; + readonly OperationShim operation; + + /// + /// Initializes a new instance of the class. + /// + /// The data converter. + /// The task entity. + /// The entity ID. + public TaskEntityShim(DataConverter dataConverter, ITaskEntity taskEntity, EntityId entityId) + { + this.dataConverter = Check.NotNull(dataConverter); + this.taskEntity = Check.NotNull(taskEntity); + this.entityId = new EntityInstanceId(entityId.Name, entityId.Key); + this.state = new StateShim(dataConverter); + this.context = new ContextShim(this.entityId, dataConverter); + this.operation = new OperationShim(this); + } + + /// + public override async Task ExecuteOperationBatchAsync(EntityBatchRequest operations) + { + // set the current state, and commit it so we can roll back to it later. + // The commit/rollback mechanism is needed since we treat entity operations transactionally. + // This means that if an operation throws an unhandled exception, all its effects are rolled back. + // In particular, (1) the entity state is reverted to what it was prior to the operation, and + // (2) all of the messages sent by the operation (e.g. if it started a new orchestrations) are discarded. + this.state.CurrentState = operations.EntityState; + this.state.Commit(); + + List results = new(); + + foreach (OperationRequest current in operations.Operations!) + { + this.operation.SetNameAndInput(current.Operation!, current.Input); + + try + { + object? result = await this.taskEntity.RunAsync(this.operation); + string? serializedResult = this.dataConverter.Serialize(result); + results.Add(new OperationResult() { Result = serializedResult }); + + // the user code completed without exception, so we commit the current state and actions. + this.state.Commit(); + this.context.Commit(); + } + catch (Exception applicationException) + { + results.Add(new OperationResult() + { + FailureDetails = new FailureDetails(applicationException), + }); + + // the user code threw an unhandled exception, so we roll back the state and the actions. + this.state.Rollback(); + this.context.Rollback(); + } + } + + var batchResult = new EntityBatchResult() + { + Results = results, + Actions = this.context.Actions, + EntityState = this.state.CurrentState, + FailureDetails = null, + }; + + // we reset only the context, but keep the current state. + // this makes it possible to reuse the cached state object if the TaskEntityShim is reused. + this.context.Reset(); + + return batchResult; + } + + class StateShim : TaskEntityState + { + readonly DataConverter dataConverter; + + string? value; + object? cachedValue; + bool cacheValid; + string? checkpointValue; + + public StateShim(DataConverter dataConverter) + { + this.dataConverter = dataConverter; + } + + public string? CurrentState + { + get => this.value; + set + { + if (this.value != value) + { + this.value = value; + this.cachedValue = null; + this.cacheValid = false; + } + } + } + + public void Commit() + { + this.checkpointValue = this.value; + } + + public void Rollback() + { + this.CurrentState = this.checkpointValue; + } + + public void Reset() + { + this.CurrentState = default; + } + + public override object? GetState(Type type) + { + if (!this.cacheValid) + { + this.cachedValue = this.dataConverter.Deserialize(this.value, type); + this.cacheValid = true; + } + + return this.cachedValue; + } + + public override void SetState(object? state) + { + this.value = this.dataConverter.Serialize(state); + this.cachedValue = state; + this.cacheValid = true; + } + } + + class ContextShim : TaskEntityContext + { + readonly EntityInstanceId entityInstanceId; + readonly DataConverter dataConverter; + + List operationActions; + int checkpointPosition; + + public ContextShim(EntityInstanceId entityInstanceId, DataConverter dataConverter) + { + this.entityInstanceId = entityInstanceId; + this.dataConverter = dataConverter; + this.operationActions = new List(); + } + + public List Actions => this.operationActions; + + public int CurrentPosition => this.operationActions.Count; + + public override EntityInstanceId Id => this.entityInstanceId; + + public void Commit() + { + this.checkpointPosition = this.CurrentPosition; + } + + public void Rollback() + { + this.operationActions.RemoveRange(this.checkpointPosition, this.operationActions.Count - this.checkpointPosition); + } + + public void Reset() + { + this.operationActions = new List(); + this.checkpointPosition = 0; + } + + public override void SignalEntity(EntityInstanceId id, string operationName, object? input = null, SignalEntityOptions? options = null) + { + this.operationActions.Add(new SendSignalOperationAction() + { + InstanceId = id.ToString(), + Name = operationName, + Input = this.dataConverter.Serialize(input), + ScheduledTime = options?.SignalTime?.UtcDateTime, + }); + } + + public override void StartOrchestration(TaskName name, object? input = null, StartOrchestrationOptions? options = null) + { + this.operationActions.Add(new StartNewOrchestrationOperationAction() + { + Name = name.Name, + Version = name.Version, + InstanceId = Guid.NewGuid().ToString("N"), + Input = this.dataConverter.Serialize(input), + }); + } + } + + class OperationShim : TaskEntityOperation + { + readonly TaskEntityShim taskEntityShim; + + string? name; + string? input; + + public OperationShim(TaskEntityShim taskEntityShim) + { + this.taskEntityShim = taskEntityShim; + } + + public override string Name => this.name!; // name is always set before user code can access this property + + public override TaskEntityContext Context => this.taskEntityShim.context; + + public override TaskEntityState State => this.taskEntityShim.state; + + public override bool HasInput => this.input != null; + + public override object? GetInput(Type inputType) + { + return this.taskEntityShim.dataConverter.Deserialize(this.input, inputType); + } + + public void SetNameAndInput(string name, string? input) + { + this.name = name; + this.input = input; + } + } +} diff --git a/src/Worker/Grpc/GrpcDurableTaskWorker.Processor.cs b/src/Worker/Grpc/GrpcDurableTaskWorker.Processor.cs index 8a027bf0e..bd5daa635 100644 --- a/src/Worker/Grpc/GrpcDurableTaskWorker.Processor.cs +++ b/src/Worker/Grpc/GrpcDurableTaskWorker.Processor.cs @@ -3,12 +3,16 @@ using System.Text; using DurableTask.Core; +using DurableTask.Core.Entities; +using DurableTask.Core.Entities.OperationFormat; using DurableTask.Core.History; using Grpc.Core; +using Microsoft.DurableTask.Entities; using Microsoft.DurableTask.Worker.Shims; using Microsoft.Extensions.DependencyInjection; using Microsoft.Extensions.Logging; using static Microsoft.DurableTask.Protobuf.TaskHubSidecarService; +using DTCore = DurableTask.Core; using P = Microsoft.DurableTask.Protobuf; namespace Microsoft.DurableTask.Worker.Grpc; @@ -148,6 +152,10 @@ async Task ProcessWorkItemsAsync(AsyncServerStreamingCall stream, Ca { this.RunBackgroundTask(workItem, () => this.OnRunActivityAsync(workItem.ActivityRequest)); } + else if (workItem.RequestCase == P.WorkItem.RequestOneofCase.EntityRequest) + { + this.RunBackgroundTask(workItem, () => this.OnRunEntityBatchAsync(workItem.EntityRequest)); + } else { this.Logger.UnexpectedWorkItemType(workItem.RequestCase.ToString()); @@ -337,5 +345,68 @@ async Task OnRunActivityAsync(P.ActivityRequest request) await this.sidecar.CompleteActivityTaskAsync(response); } + + async Task OnRunEntityBatchAsync(P.EntityBatchRequest request) + { + var coreEntityId = DTCore.Entities.EntityId.FromString(request.InstanceId); + EntityId entityId = new(coreEntityId.Name, coreEntityId.Key); + + TaskName name = new(entityId.Name); + + EntityBatchRequest batchRequest = request.ToEntityBatchRequest(); + EntityBatchResult? batchResult; + + try + { + await using AsyncServiceScope scope = this.worker.services.CreateAsyncScope(); + IDurableTaskFactory2 factory = (IDurableTaskFactory2)this.worker.Factory; + + if (factory.TryCreateEntity(name, scope.ServiceProvider, out ITaskEntity? entity)) + { + // Both the factory invocation and the RunAsync could involve user code and need to be handled as + // part of try/catch. + TaskEntity shim = this.shimFactory.CreateEntity(name, entity, entityId); + batchResult = await shim.ExecuteOperationBatchAsync(batchRequest); + } + else + { + // we could not find the entity. This is considered an application error, + // so we return a non-retriable error-OperationResult for each operation in the batch. + batchResult = new EntityBatchResult() + { + Actions = new List(), // no actions + EntityState = batchRequest.EntityState, // state is unmodified + Results = Enumerable.Repeat( + new OperationResult() + { + FailureDetails = new FailureDetails( + errorType: "EntityTaskNotFound", + errorMessage: $"No entity task named '{name}' was found.", + stackTrace: null, + innerFailure: null, + isNonRetriable: true), + }, + batchRequest.Operations!.Count).ToList(), + FailureDetails = null, + }; + } + } + catch (Exception frameworkException) + { + // return a result with no results, same state, + // and which contains failure details + batchResult = new EntityBatchResult() + { + Actions = new List(), + EntityState = batchRequest.EntityState, + Results = new List(), + FailureDetails = new FailureDetails(frameworkException), + }; + } + + // convert the result to protobuf format and send it back + P.EntityBatchResult response = batchResult.ToEntityBatchResult(); + await this.sidecar.CompleteEntityTaskAsync(response); + } } }