diff --git a/src/Abstractions/Entities/EntityInstanceId.cs b/src/Abstractions/Entities/EntityInstanceId.cs
index d0f18a41a..52712a24a 100644
--- a/src/Abstractions/Entities/EntityInstanceId.cs
+++ b/src/Abstractions/Entities/EntityInstanceId.cs
@@ -1,15 +1,80 @@
// Copyright (c) Microsoft Corporation.
// Licensed under the MIT License.
+using System.Text.Json;
+using System.Text.Json.Serialization;
+
namespace Microsoft.DurableTask.Entities;
///
/// Represents the ID of an entity.
///
-/// The name of the entity.
-/// The key for this entity.
-public readonly record struct EntityInstanceId(string Name, string Key)
+[JsonConverter(typeof(EntityInstanceId.JsonConverter))]
+public readonly record struct EntityInstanceId
{
+ ///
+ /// Initializes a new instance of the class.
+ ///
+ /// The entity name.
+ /// The entity key.
+ public EntityInstanceId(string name, string key)
+ {
+ Check.NotNullOrEmpty(name);
+ if (name.Contains('@'))
+ {
+ throw new ArgumentException("entity names may not contain `@` characters.", nameof(name));
+ }
+
+ Check.NotNull(key);
+ this.Name = name.ToLowerInvariant();
+ this.Key = key;
+ }
+
+ ///
+ /// Gets the entity name. Entity names are normalized to lower case.
+ ///
+ public string Name { get; }
+
+ ///
+ /// Gets the entity key.
+ ///
+ public string Key { get; }
+
+ ///
+ /// Constructs a from a string containing the instance ID.
+ ///
+ /// The string representation of the entity ID.
+ /// the constructed entity instance ID.
+ public static EntityInstanceId FromString(string instanceId)
+ {
+ Check.NotNullOrEmpty(instanceId);
+ var pos = instanceId.IndexOf('@', 1);
+ if (pos <= 0 || instanceId[0] != '@')
+ {
+ throw new ArgumentException($"Instance ID '{instanceId}' is not a valid entity ID.", nameof(instanceId));
+ }
+
+ var entityName = instanceId.Substring(1, pos - 1);
+ var entityKey = instanceId.Substring(pos + 1);
+ return new EntityInstanceId(entityName, entityKey);
+ }
+
///
public override string ToString() => $"@{this.Name}@{this.Key}";
+
+ ///
+ /// We override the default json conversion so we can use a more compact string representation for entity instance ids.
+ ///
+ class JsonConverter : JsonConverter
+ {
+ public override EntityInstanceId Read(ref Utf8JsonReader reader, Type typeToConvert, JsonSerializerOptions options)
+ {
+ return EntityInstanceId.FromString(reader.GetString()!);
+ }
+
+ public override void Write(Utf8JsonWriter writer, EntityInstanceId value, JsonSerializerOptions options)
+ {
+ writer.WriteStringValue(value.ToString()!);
+ }
+ }
}
diff --git a/src/Client/Core/DurableTaskClientOptions.cs b/src/Client/Core/DurableTaskClientOptions.cs
index 1c84eb949..142be7e52 100644
--- a/src/Client/Core/DurableTaskClientOptions.cs
+++ b/src/Client/Core/DurableTaskClientOptions.cs
@@ -46,6 +46,12 @@ public DataConverter DataConverter
}
}
+ ///
+ /// Gets or sets a value indicating whether this client should support entities. If true, all instance ids starting with '@' are reserved for entities,
+ /// and validation checks are performed where appropriate.
+ ///
+ public bool EnableEntitySupport { get; set; }
+
///
/// Gets a value indicating whether was explicitly set or not.
///
@@ -67,6 +73,7 @@ internal void ApplyTo(DurableTaskClientOptions other)
{
// Make sure to keep this up to date as values are added.
other.DataConverter = this.DataConverter;
+ other.EnableEntitySupport = this.EnableEntitySupport;
}
}
}
diff --git a/src/Client/Core/Entities/EntityQuery.cs b/src/Client/Core/Entities/EntityQuery.cs
index 15ef7a644..a356ebdef 100644
--- a/src/Client/Core/Entities/EntityQuery.cs
+++ b/src/Client/Core/Entities/EntityQuery.cs
@@ -39,9 +39,24 @@ public string? InstanceIdStartsWith
get => this.instanceIdStartsWith;
init
{
- // prefix '@' if filter value provided and not already prefixed with '@'.
- this.instanceIdStartsWith = value?.Length > 0 && value[0] != '@'
- ? $"@{value}" : value;
+ if (value != null)
+ {
+ // prefix '@' if filter value provided and not already prefixed with '@'.
+ string prefix = value.Length == 0 || value[0] != '@' ? $"@{value}" : value;
+
+ // check if there is a name-key separator in the string
+ int pos = prefix.IndexOf('@', 1);
+ if (pos != -1)
+ {
+ // selectively normalize only the part up until that separator
+ this.instanceIdStartsWith = prefix.Substring(0, pos).ToLowerInvariant() + prefix.Substring(pos);
+ }
+ else
+ {
+ // normalize the entire prefix
+ this.instanceIdStartsWith = prefix.ToLowerInvariant();
+ }
+ }
}
}
diff --git a/src/Client/Grpc/GrpcDurableTaskClient.cs b/src/Client/Grpc/GrpcDurableTaskClient.cs
index a6ad1b12e..b0efbb933 100644
--- a/src/Client/Grpc/GrpcDurableTaskClient.cs
+++ b/src/Client/Grpc/GrpcDurableTaskClient.cs
@@ -20,7 +20,7 @@ public sealed class GrpcDurableTaskClient : DurableTaskClient
readonly ILogger logger;
readonly TaskHubSidecarServiceClient sidecarClient;
readonly GrpcDurableTaskClientOptions options;
- readonly DurableEntityClient entityClient;
+ readonly DurableEntityClient? entityClient;
AsyncDisposable asyncDisposable;
///
@@ -49,11 +49,16 @@ public GrpcDurableTaskClient(string name, GrpcDurableTaskClientOptions options,
this.options = Check.NotNull(options);
this.asyncDisposable = GetCallInvoker(options, out CallInvoker callInvoker);
this.sidecarClient = new TaskHubSidecarServiceClient(callInvoker);
- this.entityClient = new GrpcDurableEntityClient(this.Name, this.DataConverter, this.sidecarClient, logger);
+
+ if (this.options.EnableEntitySupport)
+ {
+ this.entityClient = new GrpcDurableEntityClient(this.Name, this.DataConverter, this.sidecarClient, logger);
+ }
}
///
- public override DurableEntityClient Entities => this.entityClient;
+ public override DurableEntityClient Entities => this.entityClient
+ ?? throw new NotSupportedException($"Durable entities are disabled because {nameof(DurableTaskClientOptions)}.{nameof(DurableTaskClientOptions.EnableEntitySupport)}=false");
DataConverter DataConverter => this.options.DataConverter;
@@ -70,6 +75,8 @@ public override async Task ScheduleNewOrchestrationInstanceAsync(
StartOrchestrationOptions? options = null,
CancellationToken cancellation = default)
{
+ Check.NotEntity(this.options.EnableEntitySupport, options?.InstanceId);
+
var request = new P.CreateInstanceRequest
{
Name = orchestratorName.Name,
@@ -103,6 +110,8 @@ public override async Task RaiseEventAsync(
Check.NotNullOrEmpty(instanceId);
Check.NotNullOrEmpty(eventName);
+ Check.NotEntity(this.options.EnableEntitySupport, instanceId);
+
P.RaiseEventRequest request = new()
{
InstanceId = instanceId,
@@ -118,6 +127,8 @@ public override async Task TerminateInstanceAsync(
string instanceId, object? output = null, CancellationToken cancellation = default)
{
Check.NotNullOrEmpty(instanceId);
+ Check.NotEntity(this.options.EnableEntitySupport, instanceId);
+
this.logger.TerminatingInstance(instanceId);
string? serializedOutput = this.DataConverter.Serialize(output);
@@ -134,6 +145,8 @@ await this.sidecarClient.TerminateInstanceAsync(
public override async Task SuspendInstanceAsync(
string instanceId, string? reason = null, CancellationToken cancellation = default)
{
+ Check.NotEntity(this.options.EnableEntitySupport, instanceId);
+
P.SuspendRequest request = new()
{
InstanceId = instanceId,
@@ -155,6 +168,8 @@ public override async Task SuspendInstanceAsync(
public override async Task ResumeInstanceAsync(
string instanceId, string? reason = null, CancellationToken cancellation = default)
{
+ Check.NotEntity(this.options.EnableEntitySupport, instanceId);
+
P.ResumeRequest request = new()
{
InstanceId = instanceId,
@@ -176,6 +191,8 @@ public override async Task ResumeInstanceAsync(
public override async Task GetInstancesAsync(
string instanceId, bool getInputsAndOutputs = false, CancellationToken cancellation = default)
{
+ Check.NotEntity(this.options.EnableEntitySupport, instanceId);
+
if (string.IsNullOrEmpty(instanceId))
{
throw new ArgumentNullException(nameof(instanceId));
@@ -201,6 +218,8 @@ public override async Task ResumeInstanceAsync(
///
public override AsyncPageable GetAllInstancesAsync(OrchestrationQuery? filter = null)
{
+ Check.NotEntity(this.options.EnableEntitySupport, filter?.InstanceIdPrefix);
+
return Pageable.Create(async (continuation, pageSize, cancellation) =>
{
P.QueryInstancesRequest request = new()
@@ -250,6 +269,8 @@ public override AsyncPageable GetAllInstancesAsync(Orches
public override async Task WaitForInstanceStartAsync(
string instanceId, bool getInputsAndOutputs = false, CancellationToken cancellation = default)
{
+ Check.NotEntity(this.options.EnableEntitySupport, instanceId);
+
this.logger.WaitingForInstanceStart(instanceId, getInputsAndOutputs);
P.GetInstanceRequest request = new()
@@ -275,6 +296,8 @@ public override async Task WaitForInstanceStartAsync(
public override async Task WaitForInstanceCompletionAsync(
string instanceId, bool getInputsAndOutputs = false, CancellationToken cancellation = default)
{
+ Check.NotEntity(this.options.EnableEntitySupport, instanceId);
+
this.logger.WaitingForInstanceCompletion(instanceId, getInputsAndOutputs);
P.GetInstanceRequest request = new()
diff --git a/src/Shared/Core/Validation/Check.cs b/src/Shared/Core/Validation/Check.cs
index f7acd310f..76c154749 100644
--- a/src/Shared/Core/Validation/Check.cs
+++ b/src/Shared/Core/Validation/Check.cs
@@ -94,6 +94,20 @@ public static string NotNullOrEmpty(
return argument;
}
+ ///
+ /// Checks that, if entity support is enabled, the given string is not an entity instance id, and throws an otherwise.
+ ///
+ /// Whether entity support is enabled.
+ /// The instance id.
+ /// The name of the argument.
+ public static void NotEntity(bool entitySupportEnabled, string? instanceId, [CallerArgumentExpression("instanceId")] string? argument = default)
+ {
+ if (entitySupportEnabled && instanceId?.Length > 0 && instanceId[0] == '@')
+ {
+ throw new ArgumentException("Instance IDs starting with '@' are reserved for entities, and must not be used for orchestrations, when entity support is enabled.", argument);
+ }
+ }
+
///
/// Checks if the supplied type is a concrete non-abstract type and implements the provided generic type.
/// Throws if the conditions are not met.
diff --git a/src/Worker/Core/DurableTaskWorkerOptions.cs b/src/Worker/Core/DurableTaskWorkerOptions.cs
index 4a5b7c98e..bc1e7fc4f 100644
--- a/src/Worker/Core/DurableTaskWorkerOptions.cs
+++ b/src/Worker/Core/DurableTaskWorkerOptions.cs
@@ -44,6 +44,12 @@ public DataConverter DataConverter
}
}
+ ///
+ /// Gets or sets a value indicating whether this client should support entities. If true, all instance ids starting with '@' are reserved for entities,
+ /// and validation checks are performed where appropriate.
+ ///
+ public bool EnableEntitySupport { get; set; }
+
///
/// Gets or sets the maximum timer interval for the
/// method.
@@ -99,6 +105,7 @@ internal void ApplyTo(DurableTaskWorkerOptions other)
// Make sure to keep this up to date as values are added.
other.DataConverter = this.DataConverter;
other.MaximumTimerInterval = this.MaximumTimerInterval;
+ other.EnableEntitySupport = this.EnableEntitySupport;
}
}
}
diff --git a/src/Worker/Core/Shims/TaskEntityShim.cs b/src/Worker/Core/Shims/TaskEntityShim.cs
index 23d76c1af..3be50b56c 100644
--- a/src/Worker/Core/Shims/TaskEntityShim.cs
+++ b/src/Worker/Core/Shims/TaskEntityShim.cs
@@ -195,8 +195,7 @@ public void Reset()
public override void SignalEntity(EntityInstanceId id, string operationName, object? input = null, SignalEntityOptions? options = null)
{
- Check.NotNullOrEmpty(id.Name);
- Check.NotNull(id.Key);
+ Check.NotDefault(id);
this.operationActions.Add(new SendSignalOperationAction()
{
@@ -209,6 +208,8 @@ public override void SignalEntity(EntityInstanceId id, string operationName, obj
public override void StartOrchestration(TaskName name, object? input = null, StartOrchestrationOptions? options = null)
{
+ Check.NotEntity(true, options?.InstanceId);
+
this.operationActions.Add(new StartNewOrchestrationOperationAction()
{
Name = name.Name,
diff --git a/src/Worker/Core/Shims/TaskOrchestrationContextWrapper.cs b/src/Worker/Core/Shims/TaskOrchestrationContextWrapper.cs
index b9e1484fb..08a7ae824 100644
--- a/src/Worker/Core/Shims/TaskOrchestrationContextWrapper.cs
+++ b/src/Worker/Core/Shims/TaskOrchestrationContextWrapper.cs
@@ -61,7 +61,24 @@ public TaskOrchestrationContextWrapper(
///
public override TaskOrchestrationEntityFeature Entities
- => this.entityFeature ??= new TaskOrchestrationEntityContext(this);
+ {
+ get
+ {
+ if (this.entityFeature == null)
+ {
+ if (this.invocationContext.Options.EnableEntitySupport)
+ {
+ this.entityFeature = new TaskOrchestrationEntityContext(this);
+ }
+ else
+ {
+ throw new NotSupportedException($"Durable entities are disabled because {nameof(DurableTaskWorkerOptions)}.{nameof(DurableTaskWorkerOptions.EnableEntitySupport)}=false");
+ }
+ }
+
+ return this.entityFeature;
+ }
+ }
///
/// Gets the DataConverter to use for inputs, outputs, and entity states.
@@ -134,6 +151,8 @@ public override async Task CallSubOrchestratorAsync(
=> options is SubOrchestrationOptions derived ? derived.InstanceId : null;
string instanceId = GetInstanceId(options) ?? this.NewGuid().ToString("N");
+ Check.NotEntity(this.invocationContext.Options.EnableEntitySupport, instanceId);
+
// if this orchestration uses entities, first validate that the suborchsestration call is allowed in the current context
if (this.entityFeature != null && !this.entityFeature.EntityContext.ValidateSuborchestrationTransition(out string? errorMsg))
{
@@ -233,6 +252,8 @@ public override Task WaitForExternalEvent(string eventName, CancellationTo
///
public override void SendEvent(string instanceId, string eventName, object eventData)
{
+ Check.NotEntity(this.invocationContext.Options.EnableEntitySupport, instanceId);
+
this.innerContext.SendEvent(new OrchestrationInstance { InstanceId = instanceId }, eventName, eventData);
}
diff --git a/src/Worker/Core/Shims/TaskOrchestrationEntityContext.cs b/src/Worker/Core/Shims/TaskOrchestrationEntityContext.cs
index 5f47b43ae..514abbe5e 100644
--- a/src/Worker/Core/Shims/TaskOrchestrationEntityContext.cs
+++ b/src/Worker/Core/Shims/TaskOrchestrationEntityContext.cs
@@ -89,9 +89,7 @@ public override async Task LockEntitiesAsync(IEnumerable
public override async Task CallEntityAsync(EntityInstanceId id, string operationName, object? input = null, CallEntityOptions? options = null)
{
- Check.NotNullOrEmpty(id.Name);
- Check.NotNull(id.Key);
-
+ Check.NotDefault(id);
OperationResult operationResult = await this.CallEntityInternalAsync(id, operationName, input);
if (operationResult.IsError)
@@ -107,9 +105,7 @@ public override async Task CallEntityAsync(EntityInstanceId id
///
public override async Task CallEntityAsync(EntityInstanceId id, string operationName, object? input = null, CallEntityOptions? options = null)
{
- Check.NotNullOrEmpty(id.Name);
- Check.NotNull(id.Key);
-
+ Check.NotDefault(id);
OperationResult operationResult = await this.CallEntityInternalAsync(id, operationName, input);
if (operationResult.IsError)
@@ -121,9 +117,7 @@ public override async Task CallEntityAsync(EntityInstanceId id, string operation
///
public override Task SignalEntityAsync(EntityInstanceId id, string operationName, object? input = null, SignalEntityOptions? options = null)
{
- Check.NotNullOrEmpty(id.Name);
- Check.NotNull(id.Key);
-
+ Check.NotDefault(id);
this.SendOperationMessage(id.ToString(), operationName, input, oneWay: true, scheduledTime: options?.SignalTime);
return Task.CompletedTask;
}