From 98453e3c4e0af6fd1cc3c1ba6dc959f612ee93f4 Mon Sep 17 00:00:00 2001 From: Jacob Alber Date: Thu, 9 Oct 2025 15:19:26 -0400 Subject: [PATCH 1/2] fix: Make State Persistence APIs work better with PortableValue --- .../Execution/StateManager.cs | 16 +++ .../Execution/StateScope.cs | 7 + .../PortableValue.cs | 48 +++++-- .../JsonSerializationTests.cs | 4 +- .../PortableValueTests.cs | 130 ++++++++++++++++++ .../StateManagerTests.cs | 117 ++++++++++++++++ 6 files changed, 312 insertions(+), 10 deletions(-) create mode 100644 dotnet/tests/Microsoft.Agents.AI.Workflows.UnitTests/PortableValueTests.cs diff --git a/dotnet/src/Microsoft.Agents.AI.Workflows/Execution/StateManager.cs b/dotnet/src/Microsoft.Agents.AI.Workflows/Execution/StateManager.cs index 3e7ef91e0c..0158459ffb 100644 --- a/dotnet/src/Microsoft.Agents.AI.Workflows/Execution/StateManager.cs +++ b/dotnet/src/Microsoft.Agents.AI.Workflows/Execution/StateManager.cs @@ -99,6 +99,12 @@ public async ValueTask> ReadKeysAsync(ScopeId scopeId) public ValueTask ReadStateAsync(ScopeId scopeId, string key) { + if (typeof(T) == typeof(object)) + { + // Reading as object will break across serialize/deserialize boundaries, e.g. checkpointing, distributed runtime, etc. + throw new NotSupportedException("Reading state as 'object' is not supported. Use 'PortableValue' instead for variants."); + } + Throw.IfNullOrEmpty(key); UpdateKey stateKey = new(scopeId, key); @@ -116,6 +122,16 @@ public async ValueTask> ReadKeysAsync(ScopeId scopeId) { return new((T?)result.Value); } + else if (result.Value == null) + { + // Technically should only happen if T is nullable, but we don't have the ability to express that + // so we cannot `return new((T?)null);` directly. + return new((T?)default); + } + else if (typeof(T) == typeof(PortableValue)) + { + return new((T)(object)new PortableValue(result.Value)); + } throw new InvalidOperationException($"State for key '{key}' in scope '{scopeId}' is not of type '{typeof(T).Name}'."); } diff --git a/dotnet/src/Microsoft.Agents.AI.Workflows/Execution/StateScope.cs b/dotnet/src/Microsoft.Agents.AI.Workflows/Execution/StateScope.cs index 607f97c351..e1c50ab1a3 100644 --- a/dotnet/src/Microsoft.Agents.AI.Workflows/Execution/StateScope.cs +++ b/dotnet/src/Microsoft.Agents.AI.Workflows/Execution/StateScope.cs @@ -51,6 +51,13 @@ public bool ContainsKey(string key) Throw.IfNullOrEmpty(key); if (this._stateData.TryGetValue(key, out PortableValue? value)) { + if (typeof(T) == typeof(PortableValue) && !value.TypeId.IsMatch(typeof(PortableValue))) + { + // value is PortableValue, and we do not need to unwrap a PortableValue instance inside of it + // Unfortunately we need to cast through object here. + return new((T)(object)value); + } + return new(value.As()); } diff --git a/dotnet/src/Microsoft.Agents.AI.Workflows/PortableValue.cs b/dotnet/src/Microsoft.Agents.AI.Workflows/PortableValue.cs index 38865da089..3ca0fea0d0 100644 --- a/dotnet/src/Microsoft.Agents.AI.Workflows/PortableValue.cs +++ b/dotnet/src/Microsoft.Agents.AI.Workflows/PortableValue.cs @@ -114,10 +114,7 @@ public override int GetHashCode() /// true if the current value can be represented as type TValue; otherwise, false. public bool Is([NotNullWhen(true)] out TValue? value) { - if (this.Value is IDelayedDeserialization delayedDeserialization) - { - this._deserializedValueCache ??= delayedDeserialization.Deserialize(); - } + this.TryDeserializeAndUpdateCache(typeof(TValue), out _); if (this.Value is TValue typedValue) { @@ -152,11 +149,9 @@ public bool Is([NotNullWhen(true)] out TValue? value) /// true if the current instance can be assigned to targetType; otherwise, false. public bool IsType(Type targetType, [NotNullWhen(true)] out object? value) { + // Unfortunately, there is no way to check that the TypeId specified is assignable to the provided type Throw.IfNull(targetType); - if (this.Value is IDelayedDeserialization delayedDeserialization) - { - this._deserializedValueCache ??= delayedDeserialization.Deserialize(targetType); - } + this.TryDeserializeAndUpdateCache(targetType, out _); if (this.Value is not null && targetType.IsInstanceOfType(this.Value)) { @@ -167,4 +162,41 @@ public bool IsType(Type targetType, [NotNullWhen(true)] out object? value) value = null; return false; } + + private bool TryDeserializeAndUpdateCache(Type targetType, out object? replacedCacheValueOrNull) + { + replacedCacheValueOrNull = null; + + // Explicitly use _value here since we do not want to be overridden by the cache, if any + if (this._value is not IDelayedDeserialization delayedDeserialization) + { + // Not a delayed deserialization; nothing to do + return false; + } + + bool isCompatibleType = false; + if (this._deserializedValueCache == null || !(isCompatibleType = targetType.IsAssignableFrom(this._deserializedValueCache.GetType()))) + { + // Either we have no cache, or the types are incompatible; see if we can deserialize + try + { + object? deserialized = delayedDeserialization.Deserialize(targetType); + + if (deserialized != null && targetType.IsInstanceOfType(deserialized)) + { + replacedCacheValueOrNull = this._deserializedValueCache; + this._deserializedValueCache = deserialized; + + return true; + } + } + catch + { + isCompatibleType = false; + } + } + + // The last possibility is that we already deserialized successfully + return isCompatibleType; + } } diff --git a/dotnet/tests/Microsoft.Agents.AI.Workflows.UnitTests/JsonSerializationTests.cs b/dotnet/tests/Microsoft.Agents.AI.Workflows.UnitTests/JsonSerializationTests.cs index 7447f46128..cd0f910ddb 100644 --- a/dotnet/tests/Microsoft.Agents.AI.Workflows.UnitTests/JsonSerializationTests.cs +++ b/dotnet/tests/Microsoft.Agents.AI.Workflows.UnitTests/JsonSerializationTests.cs @@ -33,7 +33,7 @@ private static JsonSerializerOptions TestCustomSerializedJsonOptions private static EdgeId TakeEdgeId() => new(Interlocked.Increment(ref s_nextEdgeId)); - private static T RunJsonRoundtrip(T value, JsonSerializerOptions? externalOptions = null, Expression>? predicate = null) + internal static T RunJsonRoundtrip(T value, JsonSerializerOptions? externalOptions = null, Expression>? predicate = null) { JsonMarshaller marshaller = new(externalOptions); @@ -172,7 +172,7 @@ private static ValueTask> CreateTestWorkflowAsync() return builder.BuildAsync(); } - private static async ValueTask CreateTestWorkflowInfoAsync() + internal static async ValueTask CreateTestWorkflowInfoAsync() { Workflow testWorkflow = await CreateTestWorkflowAsync().ConfigureAwait(false); return testWorkflow.ToWorkflowInfo(); diff --git a/dotnet/tests/Microsoft.Agents.AI.Workflows.UnitTests/PortableValueTests.cs b/dotnet/tests/Microsoft.Agents.AI.Workflows.UnitTests/PortableValueTests.cs new file mode 100644 index 0000000000..86ffed0ab4 --- /dev/null +++ b/dotnet/tests/Microsoft.Agents.AI.Workflows.UnitTests/PortableValueTests.cs @@ -0,0 +1,130 @@ +// Copyright (c) Microsoft. All rights reserved. + +using System; +using System.Diagnostics.CodeAnalysis; +using System.Threading.Tasks; +using FluentAssertions; +using Microsoft.Agents.AI.Workflows.Checkpointing; +using Microsoft.Extensions.AI; + +namespace Microsoft.Agents.AI.Workflows.UnitTests; + +public class PortableValueTests +{ + [SuppressMessage("Performance", "CA1812", Justification = "This is used as a Never/Bottom type.")] + private sealed class Never + { + private Never() { } + } + + [Theory] + [InlineData("string")] + [InlineData(42)] + [InlineData(true)] + [InlineData(3.14)] + public async Task Test_PortableValueRoundtripAsync(T value) + { + value.Should().NotBeNull(); + + PortableValue portableValue = new(value); + + portableValue.Is(out _).Should().BeFalse(); + portableValue.Is(out T? returnedValue).Should().BeTrue(); + returnedValue.Should().Be(value); + } + + [Fact] + public async Task Test_PortableValueRoundtripObjectAsync() + { + ChatMessage value = new(ChatRole.User, "Hello?"); + + PortableValue portableValue = new(value); + + portableValue.Is(out _).Should().BeFalse(); + portableValue.Is(out ChatMessage? returnedValue).Should().BeTrue(); + returnedValue.Should().Be(value); + } + + [Theory] + [InlineData("string")] + [InlineData(42)] + [InlineData(true)] + [InlineData(3.14)] + public async Task Test_DelayedSerializationRoundtripAsync(T value) + { + value.Should().NotBeNull(); + + TestDelayedDeserialization delayed = new(value); + PortableValue portableValue = new(delayed); + + portableValue.Is(out _).Should().BeFalse(); + portableValue.Is(out object? obj).Should().BeTrue(); + obj.Should().NotBeOfType(); + obj.Should().BeOfType() + .And.Subject.As() + .As().Should().Be(value); + + portableValue.Is(out T? returnedValue).Should().BeTrue(); + returnedValue.Should().Be(value); + } + + [Fact] + public async Task Test_DelayedSerializationRoundtripObjectAsync() + { + ChatMessage value = new(ChatRole.User, "Hello?"); + + TestDelayedDeserialization delayed = new(value); + PortableValue portableValue = new(delayed); + + portableValue.Is(out _).Should().BeFalse(); + portableValue.Is(out object? obj).Should().BeTrue(); + obj.Should().NotBeOfType(); + obj.Should().BeOfType() + .And.Subject.As() + .As().Should().Be(value); + + portableValue.Is(out ChatMessage? returnedValue).Should().BeTrue(); + returnedValue.Should().Be(value); + } + + private sealed class TestDelayedDeserialization : IDelayedDeserialization + { + [NotNull] + public T Value { get; } + + public TestDelayedDeserialization([DisallowNull] T value) + { + this.Value = value; + } + + public TValue Deserialize() + { + if (typeof(TValue) == typeof(object)) + { + return (TValue)(object)new PortableValue(this.Value); + } + + if (this.Value is TValue value) + { + return value; + } + + throw new InvalidOperationException(); + } + + public object? Deserialize(Type targetType) + { + if (targetType == typeof(object)) + { + return new PortableValue(this.Value); + } + + if (targetType.IsInstanceOfType(this.Value)) + { + return this.Value; + } + + return null; + } + } +} diff --git a/dotnet/tests/Microsoft.Agents.AI.Workflows.UnitTests/StateManagerTests.cs b/dotnet/tests/Microsoft.Agents.AI.Workflows.UnitTests/StateManagerTests.cs index 4bb0746996..fc16fd6600 100644 --- a/dotnet/tests/Microsoft.Agents.AI.Workflows.UnitTests/StateManagerTests.cs +++ b/dotnet/tests/Microsoft.Agents.AI.Workflows.UnitTests/StateManagerTests.cs @@ -4,7 +4,9 @@ using System.Collections.Generic; using System.Threading.Tasks; using FluentAssertions; +using Microsoft.Agents.AI.Workflows.Checkpointing; using Microsoft.Agents.AI.Workflows.Execution; +using Microsoft.Extensions.AI; namespace Microsoft.Agents.AI.Workflows.UnitTests; @@ -451,4 +453,119 @@ private static async Task RunConflictingUpdatesTest_WriteVsClearAsync(string? sc await act.Should().NotThrowAsync("writes to private scopes should not be visible across executors"); } } + + private static void VerifyIs(PortableValue? candidatePV, TExpectedType value) + { + candidatePV.Should().NotBeNull(); + candidatePV.Is(out TExpectedType? candidateValue).Should().BeTrue(); + candidateValue.Should().Be(value); + } + + private static void VerifyIsNot(PortableValue? candidatePV) + { + candidatePV.Should().NotBeNull(); + candidatePV.Is(out TExpectedType? _).Should().BeFalse(); + } + + [Theory] + [InlineData(true)] + [InlineData(false)] + public async Task Test_LoadPortableValueStateAsync(bool publishStateUpdates) + { + ScopeId scope = new("executor1"); + const string StringValue = "string"; + const int IntValue = 42; + ScopeKey ScopeKey = new("executor1", "scope", "key"); + PortableValue PortableValueValue = new(StringValue); + + // Arrange + StateManager manager = new(); + await manager.WriteStateAsync(scope, nameof(StringValue), StringValue); + await manager.WriteStateAsync(scope, nameof(IntValue), IntValue); + await manager.WriteStateAsync(scope, nameof(ScopeKey), ScopeKey); + await manager.WriteStateAsync(scope, nameof(PortableValueValue), PortableValueValue); + + if (publishStateUpdates) + { + await manager.PublishUpdatesAsync(tracer: null); + } + + // Act & Assert - Read as the original types + PortableValue? stringAsPV = await manager.ReadStateAsync(scope, nameof(StringValue)); + VerifyIs(stringAsPV, StringValue); + VerifyIsNot(stringAsPV); + VerifyIsNot(stringAsPV); + VerifyIsNot(stringAsPV); + + PortableValue? intAsPV = await manager.ReadStateAsync(scope, nameof(IntValue)); + VerifyIsNot(intAsPV); + VerifyIs(intAsPV, IntValue); + VerifyIsNot(intAsPV); + VerifyIsNot(intAsPV); + + PortableValue? scopeKeyAsPV = await manager.ReadStateAsync(scope, nameof(ScopeKey)); + VerifyIsNot(scopeKeyAsPV); + VerifyIsNot(scopeKeyAsPV); + VerifyIs(scopeKeyAsPV, ScopeKey); + VerifyIsNot(scopeKeyAsPV); + + PortableValue? pvAsPV = await manager.ReadStateAsync(scope, nameof(PortableValueValue)); + VerifyIs(pvAsPV, StringValue); + VerifyIsNot(pvAsPV); + VerifyIsNot(pvAsPV); + + // Check that we don't double-wrap stored PortableValues on the out path + VerifyIsNot(pvAsPV); + } + + [Fact] + public async Task Test_LoadPortableValueState_AfterSerializationAsync() + { + ScopeId scope = new("executor1"); + const string StringValue = "string"; + const int IntValue = 42; + ScopeKey ScopeKey = new("executor1", "scope", "key"); + PortableValue PortableValueValue = new(StringValue); + + // Arrange + StateManager manager = new(); + await manager.WriteStateAsync(scope, nameof(StringValue), StringValue); + await manager.WriteStateAsync(scope, nameof(IntValue), IntValue); + await manager.WriteStateAsync(scope, nameof(ScopeKey), ScopeKey); + await manager.WriteStateAsync(scope, nameof(PortableValueValue), PortableValueValue); + + await manager.PublishUpdatesAsync(tracer: null); + + Dictionary exportedState = await manager.ExportStateAsync(); + Dictionary serializedState = JsonSerializationTests.RunJsonRoundtrip(exportedState); + Checkpoint testCheckpoint = new(0, await JsonSerializationTests.CreateTestWorkflowInfoAsync(), new([], [], []), serializedState, new()); + + manager = new(); + await manager.ImportStateAsync(testCheckpoint); + + // Act & Assert - Read as the original types + PortableValue? stringAsPV = await manager.ReadStateAsync(scope, nameof(StringValue)); + VerifyIs(stringAsPV, StringValue); + VerifyIsNot(stringAsPV); + VerifyIsNot(stringAsPV); + + PortableValue? intAsPV = await manager.ReadStateAsync(scope, nameof(IntValue)); + VerifyIsNot(intAsPV); + VerifyIs(intAsPV, IntValue); + VerifyIsNot(intAsPV); + + PortableValue? scopeKeyAsPV = await manager.ReadStateAsync(scope, nameof(ScopeKey)); + VerifyIsNot(scopeKeyAsPV); + VerifyIsNot(scopeKeyAsPV); + VerifyIs(scopeKeyAsPV, ScopeKey); + VerifyIsNot(scopeKeyAsPV); + + PortableValue? pvAsPV = await manager.ReadStateAsync(scope, nameof(PortableValueValue)); + VerifyIs(pvAsPV, StringValue); + VerifyIsNot(pvAsPV); + VerifyIsNot(pvAsPV); + + // Check that we don't double-wrap stored PortableValues on the out path + VerifyIsNot(pvAsPV); + } } From 479fae332174bc6020f9a8b6c23d889cf533f5cb Mon Sep 17 00:00:00 2001 From: Jacob Alber Date: Thu, 9 Oct 2025 16:14:57 -0400 Subject: [PATCH 2/2] test: Temporarily disable checking for T=object in ReadStateAsync --- .../Microsoft.Agents.AI.Workflows/Execution/StateManager.cs | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/dotnet/src/Microsoft.Agents.AI.Workflows/Execution/StateManager.cs b/dotnet/src/Microsoft.Agents.AI.Workflows/Execution/StateManager.cs index 0158459ffb..ffa289eaf9 100644 --- a/dotnet/src/Microsoft.Agents.AI.Workflows/Execution/StateManager.cs +++ b/dotnet/src/Microsoft.Agents.AI.Workflows/Execution/StateManager.cs @@ -102,7 +102,8 @@ public async ValueTask> ReadKeysAsync(ScopeId scopeId) if (typeof(T) == typeof(object)) { // Reading as object will break across serialize/deserialize boundaries, e.g. checkpointing, distributed runtime, etc. - throw new NotSupportedException("Reading state as 'object' is not supported. Use 'PortableValue' instead for variants."); + // Disabled pending upstream updates for this change; see https://github.com/microsoft/agent-framework/issues/1369 + //throw new NotSupportedException("Reading state as 'object' is not supported. Use 'PortableValue' instead for variants."); } Throw.IfNullOrEmpty(key);