diff --git a/src/Abstractions/Entities/EntityInstanceId.cs b/src/Abstractions/Entities/EntityInstanceId.cs index d0f18a41a..52712a24a 100644 --- a/src/Abstractions/Entities/EntityInstanceId.cs +++ b/src/Abstractions/Entities/EntityInstanceId.cs @@ -1,15 +1,80 @@ // Copyright (c) Microsoft Corporation. // Licensed under the MIT License. +using System.Text.Json; +using System.Text.Json.Serialization; + namespace Microsoft.DurableTask.Entities; /// /// Represents the ID of an entity. /// -/// The name of the entity. -/// The key for this entity. -public readonly record struct EntityInstanceId(string Name, string Key) +[JsonConverter(typeof(EntityInstanceId.JsonConverter))] +public readonly record struct EntityInstanceId { + /// + /// Initializes a new instance of the class. + /// + /// The entity name. + /// The entity key. + public EntityInstanceId(string name, string key) + { + Check.NotNullOrEmpty(name); + if (name.Contains('@')) + { + throw new ArgumentException("entity names may not contain `@` characters.", nameof(name)); + } + + Check.NotNull(key); + this.Name = name.ToLowerInvariant(); + this.Key = key; + } + + /// + /// Gets the entity name. Entity names are normalized to lower case. + /// + public string Name { get; } + + /// + /// Gets the entity key. + /// + public string Key { get; } + + /// + /// Constructs a from a string containing the instance ID. + /// + /// The string representation of the entity ID. + /// the constructed entity instance ID. + public static EntityInstanceId FromString(string instanceId) + { + Check.NotNullOrEmpty(instanceId); + var pos = instanceId.IndexOf('@', 1); + if (pos <= 0 || instanceId[0] != '@') + { + throw new ArgumentException($"Instance ID '{instanceId}' is not a valid entity ID.", nameof(instanceId)); + } + + var entityName = instanceId.Substring(1, pos - 1); + var entityKey = instanceId.Substring(pos + 1); + return new EntityInstanceId(entityName, entityKey); + } + /// public override string ToString() => $"@{this.Name}@{this.Key}"; + + /// + /// We override the default json conversion so we can use a more compact string representation for entity instance ids. + /// + class JsonConverter : JsonConverter + { + public override EntityInstanceId Read(ref Utf8JsonReader reader, Type typeToConvert, JsonSerializerOptions options) + { + return EntityInstanceId.FromString(reader.GetString()!); + } + + public override void Write(Utf8JsonWriter writer, EntityInstanceId value, JsonSerializerOptions options) + { + writer.WriteStringValue(value.ToString()!); + } + } } diff --git a/src/Client/Core/DurableTaskClientOptions.cs b/src/Client/Core/DurableTaskClientOptions.cs index 1c84eb949..142be7e52 100644 --- a/src/Client/Core/DurableTaskClientOptions.cs +++ b/src/Client/Core/DurableTaskClientOptions.cs @@ -46,6 +46,12 @@ public DataConverter DataConverter } } + /// + /// Gets or sets a value indicating whether this client should support entities. If true, all instance ids starting with '@' are reserved for entities, + /// and validation checks are performed where appropriate. + /// + public bool EnableEntitySupport { get; set; } + /// /// Gets a value indicating whether was explicitly set or not. /// @@ -67,6 +73,7 @@ internal void ApplyTo(DurableTaskClientOptions other) { // Make sure to keep this up to date as values are added. other.DataConverter = this.DataConverter; + other.EnableEntitySupport = this.EnableEntitySupport; } } } diff --git a/src/Client/Core/Entities/EntityQuery.cs b/src/Client/Core/Entities/EntityQuery.cs index 15ef7a644..a356ebdef 100644 --- a/src/Client/Core/Entities/EntityQuery.cs +++ b/src/Client/Core/Entities/EntityQuery.cs @@ -39,9 +39,24 @@ public string? InstanceIdStartsWith get => this.instanceIdStartsWith; init { - // prefix '@' if filter value provided and not already prefixed with '@'. - this.instanceIdStartsWith = value?.Length > 0 && value[0] != '@' - ? $"@{value}" : value; + if (value != null) + { + // prefix '@' if filter value provided and not already prefixed with '@'. + string prefix = value.Length == 0 || value[0] != '@' ? $"@{value}" : value; + + // check if there is a name-key separator in the string + int pos = prefix.IndexOf('@', 1); + if (pos != -1) + { + // selectively normalize only the part up until that separator + this.instanceIdStartsWith = prefix.Substring(0, pos).ToLowerInvariant() + prefix.Substring(pos); + } + else + { + // normalize the entire prefix + this.instanceIdStartsWith = prefix.ToLowerInvariant(); + } + } } } diff --git a/src/Client/Grpc/GrpcDurableTaskClient.cs b/src/Client/Grpc/GrpcDurableTaskClient.cs index a6ad1b12e..b0efbb933 100644 --- a/src/Client/Grpc/GrpcDurableTaskClient.cs +++ b/src/Client/Grpc/GrpcDurableTaskClient.cs @@ -20,7 +20,7 @@ public sealed class GrpcDurableTaskClient : DurableTaskClient readonly ILogger logger; readonly TaskHubSidecarServiceClient sidecarClient; readonly GrpcDurableTaskClientOptions options; - readonly DurableEntityClient entityClient; + readonly DurableEntityClient? entityClient; AsyncDisposable asyncDisposable; /// @@ -49,11 +49,16 @@ public GrpcDurableTaskClient(string name, GrpcDurableTaskClientOptions options, this.options = Check.NotNull(options); this.asyncDisposable = GetCallInvoker(options, out CallInvoker callInvoker); this.sidecarClient = new TaskHubSidecarServiceClient(callInvoker); - this.entityClient = new GrpcDurableEntityClient(this.Name, this.DataConverter, this.sidecarClient, logger); + + if (this.options.EnableEntitySupport) + { + this.entityClient = new GrpcDurableEntityClient(this.Name, this.DataConverter, this.sidecarClient, logger); + } } /// - public override DurableEntityClient Entities => this.entityClient; + public override DurableEntityClient Entities => this.entityClient + ?? throw new NotSupportedException($"Durable entities are disabled because {nameof(DurableTaskClientOptions)}.{nameof(DurableTaskClientOptions.EnableEntitySupport)}=false"); DataConverter DataConverter => this.options.DataConverter; @@ -70,6 +75,8 @@ public override async Task ScheduleNewOrchestrationInstanceAsync( StartOrchestrationOptions? options = null, CancellationToken cancellation = default) { + Check.NotEntity(this.options.EnableEntitySupport, options?.InstanceId); + var request = new P.CreateInstanceRequest { Name = orchestratorName.Name, @@ -103,6 +110,8 @@ public override async Task RaiseEventAsync( Check.NotNullOrEmpty(instanceId); Check.NotNullOrEmpty(eventName); + Check.NotEntity(this.options.EnableEntitySupport, instanceId); + P.RaiseEventRequest request = new() { InstanceId = instanceId, @@ -118,6 +127,8 @@ public override async Task TerminateInstanceAsync( string instanceId, object? output = null, CancellationToken cancellation = default) { Check.NotNullOrEmpty(instanceId); + Check.NotEntity(this.options.EnableEntitySupport, instanceId); + this.logger.TerminatingInstance(instanceId); string? serializedOutput = this.DataConverter.Serialize(output); @@ -134,6 +145,8 @@ await this.sidecarClient.TerminateInstanceAsync( public override async Task SuspendInstanceAsync( string instanceId, string? reason = null, CancellationToken cancellation = default) { + Check.NotEntity(this.options.EnableEntitySupport, instanceId); + P.SuspendRequest request = new() { InstanceId = instanceId, @@ -155,6 +168,8 @@ public override async Task SuspendInstanceAsync( public override async Task ResumeInstanceAsync( string instanceId, string? reason = null, CancellationToken cancellation = default) { + Check.NotEntity(this.options.EnableEntitySupport, instanceId); + P.ResumeRequest request = new() { InstanceId = instanceId, @@ -176,6 +191,8 @@ public override async Task ResumeInstanceAsync( public override async Task GetInstancesAsync( string instanceId, bool getInputsAndOutputs = false, CancellationToken cancellation = default) { + Check.NotEntity(this.options.EnableEntitySupport, instanceId); + if (string.IsNullOrEmpty(instanceId)) { throw new ArgumentNullException(nameof(instanceId)); @@ -201,6 +218,8 @@ public override async Task ResumeInstanceAsync( /// public override AsyncPageable GetAllInstancesAsync(OrchestrationQuery? filter = null) { + Check.NotEntity(this.options.EnableEntitySupport, filter?.InstanceIdPrefix); + return Pageable.Create(async (continuation, pageSize, cancellation) => { P.QueryInstancesRequest request = new() @@ -250,6 +269,8 @@ public override AsyncPageable GetAllInstancesAsync(Orches public override async Task WaitForInstanceStartAsync( string instanceId, bool getInputsAndOutputs = false, CancellationToken cancellation = default) { + Check.NotEntity(this.options.EnableEntitySupport, instanceId); + this.logger.WaitingForInstanceStart(instanceId, getInputsAndOutputs); P.GetInstanceRequest request = new() @@ -275,6 +296,8 @@ public override async Task WaitForInstanceStartAsync( public override async Task WaitForInstanceCompletionAsync( string instanceId, bool getInputsAndOutputs = false, CancellationToken cancellation = default) { + Check.NotEntity(this.options.EnableEntitySupport, instanceId); + this.logger.WaitingForInstanceCompletion(instanceId, getInputsAndOutputs); P.GetInstanceRequest request = new() diff --git a/src/Shared/Core/Validation/Check.cs b/src/Shared/Core/Validation/Check.cs index f7acd310f..76c154749 100644 --- a/src/Shared/Core/Validation/Check.cs +++ b/src/Shared/Core/Validation/Check.cs @@ -94,6 +94,20 @@ public static string NotNullOrEmpty( return argument; } + /// + /// Checks that, if entity support is enabled, the given string is not an entity instance id, and throws an otherwise. + /// + /// Whether entity support is enabled. + /// The instance id. + /// The name of the argument. + public static void NotEntity(bool entitySupportEnabled, string? instanceId, [CallerArgumentExpression("instanceId")] string? argument = default) + { + if (entitySupportEnabled && instanceId?.Length > 0 && instanceId[0] == '@') + { + throw new ArgumentException("Instance IDs starting with '@' are reserved for entities, and must not be used for orchestrations, when entity support is enabled.", argument); + } + } + /// /// Checks if the supplied type is a concrete non-abstract type and implements the provided generic type. /// Throws if the conditions are not met. diff --git a/src/Worker/Core/DurableTaskWorkerOptions.cs b/src/Worker/Core/DurableTaskWorkerOptions.cs index 4a5b7c98e..bc1e7fc4f 100644 --- a/src/Worker/Core/DurableTaskWorkerOptions.cs +++ b/src/Worker/Core/DurableTaskWorkerOptions.cs @@ -44,6 +44,12 @@ public DataConverter DataConverter } } + /// + /// Gets or sets a value indicating whether this client should support entities. If true, all instance ids starting with '@' are reserved for entities, + /// and validation checks are performed where appropriate. + /// + public bool EnableEntitySupport { get; set; } + /// /// Gets or sets the maximum timer interval for the /// method. @@ -99,6 +105,7 @@ internal void ApplyTo(DurableTaskWorkerOptions other) // Make sure to keep this up to date as values are added. other.DataConverter = this.DataConverter; other.MaximumTimerInterval = this.MaximumTimerInterval; + other.EnableEntitySupport = this.EnableEntitySupport; } } } diff --git a/src/Worker/Core/Shims/TaskEntityShim.cs b/src/Worker/Core/Shims/TaskEntityShim.cs index 23d76c1af..3be50b56c 100644 --- a/src/Worker/Core/Shims/TaskEntityShim.cs +++ b/src/Worker/Core/Shims/TaskEntityShim.cs @@ -195,8 +195,7 @@ public void Reset() public override void SignalEntity(EntityInstanceId id, string operationName, object? input = null, SignalEntityOptions? options = null) { - Check.NotNullOrEmpty(id.Name); - Check.NotNull(id.Key); + Check.NotDefault(id); this.operationActions.Add(new SendSignalOperationAction() { @@ -209,6 +208,8 @@ public override void SignalEntity(EntityInstanceId id, string operationName, obj public override void StartOrchestration(TaskName name, object? input = null, StartOrchestrationOptions? options = null) { + Check.NotEntity(true, options?.InstanceId); + this.operationActions.Add(new StartNewOrchestrationOperationAction() { Name = name.Name, diff --git a/src/Worker/Core/Shims/TaskOrchestrationContextWrapper.cs b/src/Worker/Core/Shims/TaskOrchestrationContextWrapper.cs index b9e1484fb..08a7ae824 100644 --- a/src/Worker/Core/Shims/TaskOrchestrationContextWrapper.cs +++ b/src/Worker/Core/Shims/TaskOrchestrationContextWrapper.cs @@ -61,7 +61,24 @@ public TaskOrchestrationContextWrapper( /// public override TaskOrchestrationEntityFeature Entities - => this.entityFeature ??= new TaskOrchestrationEntityContext(this); + { + get + { + if (this.entityFeature == null) + { + if (this.invocationContext.Options.EnableEntitySupport) + { + this.entityFeature = new TaskOrchestrationEntityContext(this); + } + else + { + throw new NotSupportedException($"Durable entities are disabled because {nameof(DurableTaskWorkerOptions)}.{nameof(DurableTaskWorkerOptions.EnableEntitySupport)}=false"); + } + } + + return this.entityFeature; + } + } /// /// Gets the DataConverter to use for inputs, outputs, and entity states. @@ -134,6 +151,8 @@ public override async Task CallSubOrchestratorAsync( => options is SubOrchestrationOptions derived ? derived.InstanceId : null; string instanceId = GetInstanceId(options) ?? this.NewGuid().ToString("N"); + Check.NotEntity(this.invocationContext.Options.EnableEntitySupport, instanceId); + // if this orchestration uses entities, first validate that the suborchsestration call is allowed in the current context if (this.entityFeature != null && !this.entityFeature.EntityContext.ValidateSuborchestrationTransition(out string? errorMsg)) { @@ -233,6 +252,8 @@ public override Task WaitForExternalEvent(string eventName, CancellationTo /// public override void SendEvent(string instanceId, string eventName, object eventData) { + Check.NotEntity(this.invocationContext.Options.EnableEntitySupport, instanceId); + this.innerContext.SendEvent(new OrchestrationInstance { InstanceId = instanceId }, eventName, eventData); } diff --git a/src/Worker/Core/Shims/TaskOrchestrationEntityContext.cs b/src/Worker/Core/Shims/TaskOrchestrationEntityContext.cs index 5f47b43ae..514abbe5e 100644 --- a/src/Worker/Core/Shims/TaskOrchestrationEntityContext.cs +++ b/src/Worker/Core/Shims/TaskOrchestrationEntityContext.cs @@ -89,9 +89,7 @@ public override async Task LockEntitiesAsync(IEnumerable public override async Task CallEntityAsync(EntityInstanceId id, string operationName, object? input = null, CallEntityOptions? options = null) { - Check.NotNullOrEmpty(id.Name); - Check.NotNull(id.Key); - + Check.NotDefault(id); OperationResult operationResult = await this.CallEntityInternalAsync(id, operationName, input); if (operationResult.IsError) @@ -107,9 +105,7 @@ public override async Task CallEntityAsync(EntityInstanceId id /// public override async Task CallEntityAsync(EntityInstanceId id, string operationName, object? input = null, CallEntityOptions? options = null) { - Check.NotNullOrEmpty(id.Name); - Check.NotNull(id.Key); - + Check.NotDefault(id); OperationResult operationResult = await this.CallEntityInternalAsync(id, operationName, input); if (operationResult.IsError) @@ -121,9 +117,7 @@ public override async Task CallEntityAsync(EntityInstanceId id, string operation /// public override Task SignalEntityAsync(EntityInstanceId id, string operationName, object? input = null, SignalEntityOptions? options = null) { - Check.NotNullOrEmpty(id.Name); - Check.NotNull(id.Key); - + Check.NotDefault(id); this.SendOperationMessage(id.ToString(), operationName, input, oneWay: true, scheduledTime: options?.SignalTime); return Task.CompletedTask; }