Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion eng/proto
35 changes: 35 additions & 0 deletions src/Abstractions/DurableTaskRegistry.cs
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
// Copyright (c) Microsoft Corporation.
// Licensed under the MIT License.

using Microsoft.DurableTask.Entities;

namespace Microsoft.DurableTask;

/// <summary>
Expand All @@ -22,6 +24,12 @@ public sealed partial class DurableTaskRegistry
internal IDictionary<TaskName, Func<IServiceProvider, ITaskOrchestrator>> Orchestrators { get; }
= new Dictionary<TaskName, Func<IServiceProvider, ITaskOrchestrator>>();

/// <summary>
/// Gets the currently registered entities.
/// </summary>
internal IDictionary<TaskName, Func<IServiceProvider, ITaskEntity>> Entities { get; }
= new Dictionary<TaskName, Func<IServiceProvider, ITaskEntity>>();

/// <summary>
/// Registers an activity factory.
/// </summary>
Expand Down Expand Up @@ -76,4 +84,31 @@ public DurableTaskRegistry AddOrchestrator(TaskName name, Func<ITaskOrchestrator
this.Orchestrators.Add(name, _ => factory());
return this;
}

/// <summary>
/// Registers an entity factory.
/// </summary>
/// <param name="name">The name of the entity.</param>
/// <param name="factory">The entity factory.</param>
/// <returns>This registry instance, for call chaining.</returns>
/// <exception cref="ArgumentException">
/// Thrown if any of the following are true:
/// <list type="bullet">
/// <item>If <paramref name="name"/> is <c>default</c>.</item>
/// <item>If <paramref name="name" /> is already registered.</item>
/// <item>If <paramref name="factory"/> is <c>null</c>.</item>
/// </list>
/// </exception>
public DurableTaskRegistry AddEntity(TaskName name, Func<IServiceProvider, ITaskEntity> factory)
{
Check.NotDefault(name);
Check.NotNull(factory);
if (this.Entities.ContainsKey(name))
{
throw new ArgumentException($"An {nameof(ITaskEntity)} named '{name}' is already added.", nameof(name));
}

this.Entities.Add(name, factory);
return this;
}
}
246 changes: 246 additions & 0 deletions src/Shared/Grpc/ProtoUtils.cs
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
using System.Text;
using DurableTask.Core;
using DurableTask.Core.Command;
using DurableTask.Core.Entities.OperationFormat;
using DurableTask.Core.History;
using Google.Protobuf;
using Google.Protobuf.WellKnownTypes;
Expand Down Expand Up @@ -385,6 +386,251 @@ internal static OrchestrationStatus ToCore(this P.OrchestrationStatus status)
};
}

/// <summary>
/// Converts a <see cref="P.EntityBatchRequest" /> to a <see cref="EntityBatchRequest" />.
/// </summary>
/// <param name="entityBatchRequest">The entity batch request to convert.</param>
/// <returns>The converted entity batch request.</returns>
[return: NotNullIfNotNull("entityBatchRequest")]
internal static EntityBatchRequest? ToEntityBatchRequest(this P.EntityBatchRequest? entityBatchRequest)
{
if (entityBatchRequest == null)
{
return null;
}

return new EntityBatchRequest()
{
EntityState = entityBatchRequest.EntityState,
InstanceId = entityBatchRequest.InstanceId,
Operations = entityBatchRequest.Operations.Select(r => r.ToOperationRequest()).ToList(),
};
}

/// <summary>
/// Converts a <see cref="P.OperationRequest" /> to a <see cref="OperationRequest" />.
/// </summary>
/// <param name="operationRequest">The operation request to convert.</param>
/// <returns>The converted operation request.</returns>
[return: NotNullIfNotNull("operationRequest")]
internal static OperationRequest? ToOperationRequest(this P.OperationRequest? operationRequest)
{
if (operationRequest == null)
{
return null;
}

return new OperationRequest()
{
Operation = operationRequest.Operation,
Input = operationRequest.Input,
Id = Guid.Parse(operationRequest.RequestId),
};
}

/// <summary>
/// Converts a <see cref="P.OperationResult" /> to a <see cref="OperationResult" />.
/// </summary>
/// <param name="operationResult">The operation result to convert.</param>
/// <returns>The converted operation result.</returns>
[return: NotNullIfNotNull("operationResult")]
internal static OperationResult? ToOperationResult(this P.OperationResult? operationResult)
{
if (operationResult == null)
{
return null;
}

switch (operationResult.ResultTypeCase)
{
case P.OperationResult.ResultTypeOneofCase.Success:
return new OperationResult()
{
Result = operationResult.Success.Result,
};

case P.OperationResult.ResultTypeOneofCase.Failure:
return new OperationResult()
{
FailureDetails = operationResult.Failure.FailureDetails.ToCore(),
};

default:
throw new NotSupportedException($"Deserialization of {operationResult.ResultTypeCase} is not supported.");
}
}

/// <summary>
/// Converts a <see cref="OperationResult" /> to <see cref="P.OperationResult" />.
/// </summary>
/// <param name="operationResult">The operation result to convert.</param>
/// <returns>The converted operation result.</returns>
[return: NotNullIfNotNull("operationResult")]
internal static P.OperationResult? ToOperationResult(this OperationResult? operationResult)
{
if (operationResult == null)
{
return null;
}

if (operationResult.FailureDetails == null)
{
return new P.OperationResult()
{
Success = new P.OperationResultSuccess()
{
Result = operationResult.Result,
},
};
}
else
{
return new P.OperationResult()
{
Failure = new P.OperationResultFailure()
{
FailureDetails = ToProtobuf(operationResult.FailureDetails),
},
};
}
}

/// <summary>
/// Converts a <see cref="P.OperationAction" /> to a <see cref="OperationAction" />.
/// </summary>
/// <param name="operationAction">The operation action to convert.</param>
/// <returns>The converted operation action.</returns>
[return: NotNullIfNotNull("operationAction")]
internal static OperationAction? ToOperationAction(this P.OperationAction? operationAction)
{
if (operationAction == null)
{
return null;
}

switch (operationAction.OperationActionTypeCase)
{
case P.OperationAction.OperationActionTypeOneofCase.SendSignal:

return new SendSignalOperationAction()
{
Name = operationAction.SendSignal.Name,
Input = operationAction.SendSignal.Input,
InstanceId = operationAction.SendSignal.InstanceId,
ScheduledTime = operationAction.SendSignal.ScheduledTime?.ToDateTime(),
};

case P.OperationAction.OperationActionTypeOneofCase.StartNewOrchestration:

return new StartNewOrchestrationOperationAction()
{
Name = operationAction.StartNewOrchestration.Name,
Input = operationAction.StartNewOrchestration.Input,
InstanceId = operationAction.StartNewOrchestration.InstanceId,
Version = operationAction.StartNewOrchestration.Version,
};
default:
throw new NotSupportedException($"Deserialization of {operationAction.OperationActionTypeCase} is not supported.");
}
}

/// <summary>
/// Converts a <see cref="OperationAction" /> to <see cref="P.OperationAction" />.
/// </summary>
/// <param name="operationAction">The operation action to convert.</param>
/// <returns>The converted operation action.</returns>
[return: NotNullIfNotNull("operationAction")]
internal static P.OperationAction? ToOperationAction(this OperationAction? operationAction)
{
if (operationAction == null)
{
return null;
}

var action = new P.OperationAction();

switch (operationAction)
{
case SendSignalOperationAction sendSignalAction:

action.SendSignal = new P.SendSignalAction()
{
Name = sendSignalAction.Name,
Input = sendSignalAction.Input,
InstanceId = sendSignalAction.InstanceId,
ScheduledTime = sendSignalAction.ScheduledTime?.ToTimestamp(),
};
break;

case StartNewOrchestrationOperationAction startNewOrchestrationAction:

action.StartNewOrchestration = new P.StartNewOrchestrationAction()
{
Name = startNewOrchestrationAction.Name,
Input = startNewOrchestrationAction.Input,
Version = startNewOrchestrationAction.Version,
InstanceId = startNewOrchestrationAction.InstanceId,
};
break;
}

return action;
}

/// <summary>
/// Converts a <see cref="P.EntityBatchResult" /> to a <see cref="EntityBatchResult" />.
/// </summary>
/// <param name="entityBatchResult">The operation result to convert.</param>
/// <returns>The converted operation result.</returns>
[return: NotNullIfNotNull("entityBatchResult")]
internal static EntityBatchResult? ToEntityBatchResult(this P.EntityBatchResult? entityBatchResult)
{
if (entityBatchResult == null)
{
return null;
}

return new EntityBatchResult()
{
Actions = entityBatchResult.Actions.Select(operationAction => operationAction!.ToOperationAction()).ToList(),
EntityState = entityBatchResult.EntityState,
Results = entityBatchResult.Results.Select(operationResult => operationResult!.ToOperationResult()).ToList(),
FailureDetails = entityBatchResult.FailureDetails.ToCore(),
};
}

/// <summary>
/// Converts a <see cref="EntityBatchResult" /> to <see cref="P.EntityBatchResult" />.
/// </summary>
/// <param name="entityBatchResult">The operation result to convert.</param>
/// <returns>The converted operation result.</returns>
[return: NotNullIfNotNull("entityBatchResult")]
internal static P.EntityBatchResult? ToEntityBatchResult(this EntityBatchResult? entityBatchResult)
{
if (entityBatchResult == null)
{
return null;
}

var batchResult = new P.EntityBatchResult()
{
EntityState = entityBatchResult.EntityState,
FailureDetails = entityBatchResult.FailureDetails.ToProtobuf(),
};

foreach (OperationAction action in entityBatchResult.Actions!)
{
batchResult.Actions.Add(action.ToOperationAction());
}

foreach (OperationResult result in entityBatchResult.Results!)
{
batchResult.Results.Add(result.ToOperationResult());
}

return batchResult;
}

/// <summary>
/// Gets the approximate byte count for a <see cref="P.TaskFailureDetails" />.
/// </summary>
Expand Down
23 changes: 21 additions & 2 deletions src/Worker/Core/DurableTaskFactory.cs
Original file line number Diff line number Diff line change
Expand Up @@ -2,28 +2,33 @@
// Licensed under the MIT License.

using System.Diagnostics.CodeAnalysis;
using Microsoft.DurableTask.Entities;

namespace Microsoft.DurableTask.Worker;

/// <summary>
/// A factory for creating orchestrators and activities.
/// </summary>
sealed class DurableTaskFactory : IDurableTaskFactory
sealed class DurableTaskFactory : IDurableTaskFactory2
{
readonly IDictionary<TaskName, Func<IServiceProvider, ITaskActivity>> activities;
readonly IDictionary<TaskName, Func<IServiceProvider, ITaskOrchestrator>> orchestrators;
readonly IDictionary<TaskName, Func<IServiceProvider, ITaskEntity>> entities;

/// <summary>
/// Initializes a new instance of the <see cref="DurableTaskFactory" /> class.
/// </summary>
/// <param name="activities">The activity factories.</param>
/// <param name="orchestrators">The orchestrator factories.</param>
/// <param name="entities">The entity factories.</param>
internal DurableTaskFactory(
IDictionary<TaskName, Func<IServiceProvider, ITaskActivity>> activities,
IDictionary<TaskName, Func<IServiceProvider, ITaskOrchestrator>> orchestrators)
IDictionary<TaskName, Func<IServiceProvider, ITaskOrchestrator>> orchestrators,
IDictionary<TaskName, Func<IServiceProvider, ITaskEntity>> entities)
{
this.activities = Check.NotNull(activities);
this.orchestrators = Check.NotNull(orchestrators);
this.entities = Check.NotNull(entities);
}

/// <inheritdoc/>
Expand Down Expand Up @@ -54,4 +59,18 @@ public bool TryCreateOrchestrator(
orchestrator = null;
return false;
}

/// <inheritdoc/>
public bool TryCreateEntity(
TaskName name, IServiceProvider serviceProvider, [NotNullWhen(true)] out ITaskEntity? entity)
{
if (this.entities.TryGetValue(name, out Func<IServiceProvider, ITaskEntity>? factory))
{
entity = factory.Invoke(serviceProvider);
return true;
}

entity = null;
return false;
}
}
2 changes: 1 addition & 1 deletion src/Worker/Core/DurableTaskRegistryExtensions.cs
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,6 @@ static class DurableTaskRegistryExtensions
public static IDurableTaskFactory BuildFactory(this DurableTaskRegistry registry)
{
Check.NotNull(registry);
return new DurableTaskFactory(registry.Activities, registry.Orchestrators);
return new DurableTaskFactory(registry.Activities, registry.Orchestrators, registry.Entities);
}
}
Loading