Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
17 changes: 17 additions & 0 deletions dotnet/src/Microsoft.Agents.AI.Workflows/Execution/StateManager.cs
Original file line number Diff line number Diff line change
Expand Up @@ -99,6 +99,13 @@ public async ValueTask<HashSet<string>> ReadKeysAsync(ScopeId scopeId)

public ValueTask<T?> ReadStateAsync<T>(ScopeId scopeId, string key)
{
if (typeof(T) == typeof(object))
{
// Reading as object will break across serialize/deserialize boundaries, e.g. checkpointing, distributed runtime, etc.
// Disabled pending upstream updates for this change; see https://github.com/microsoft/agent-framework/issues/1369
//throw new NotSupportedException("Reading state as 'object' is not supported. Use 'PortableValue' instead for variants.");
}

Throw.IfNullOrEmpty(key);

UpdateKey stateKey = new(scopeId, key);
Expand All @@ -116,6 +123,16 @@ public async ValueTask<HashSet<string>> ReadKeysAsync(ScopeId scopeId)
{
return new((T?)result.Value);
}
else if (result.Value == null)
{
// Technically should only happen if T is nullable, but we don't have the ability to express that
// so we cannot `return new((T?)null);` directly.
return new((T?)default);
}
else if (typeof(T) == typeof(PortableValue))
{
return new((T)(object)new PortableValue(result.Value));
}

throw new InvalidOperationException($"State for key '{key}' in scope '{scopeId}' is not of type '{typeof(T).Name}'.");
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,13 @@ public bool ContainsKey(string key)
Throw.IfNullOrEmpty(key);
if (this._stateData.TryGetValue(key, out PortableValue? value))
{
if (typeof(T) == typeof(PortableValue) && !value.TypeId.IsMatch(typeof(PortableValue)))
{
// value is PortableValue, and we do not need to unwrap a PortableValue instance inside of it
// Unfortunately we need to cast through object here.
return new((T)(object)value);
}

return new(value.As<T>());
}

Expand Down
48 changes: 40 additions & 8 deletions dotnet/src/Microsoft.Agents.AI.Workflows/PortableValue.cs
Original file line number Diff line number Diff line change
Expand Up @@ -114,10 +114,7 @@ public override int GetHashCode()
/// <returns>true if the current value can be represented as type TValue; otherwise, false.</returns>
public bool Is<TValue>([NotNullWhen(true)] out TValue? value)
{
if (this.Value is IDelayedDeserialization delayedDeserialization)
{
this._deserializedValueCache ??= delayedDeserialization.Deserialize<TValue>();
}
this.TryDeserializeAndUpdateCache(typeof(TValue), out _);

if (this.Value is TValue typedValue)
{
Expand Down Expand Up @@ -152,11 +149,9 @@ public bool Is<TValue>([NotNullWhen(true)] out TValue? value)
/// <returns>true if the current instance can be assigned to targetType; otherwise, false.</returns>
public bool IsType(Type targetType, [NotNullWhen(true)] out object? value)
{
// Unfortunately, there is no way to check that the TypeId specified is assignable to the provided type
Throw.IfNull(targetType);
if (this.Value is IDelayedDeserialization delayedDeserialization)
{
this._deserializedValueCache ??= delayedDeserialization.Deserialize(targetType);
}
this.TryDeserializeAndUpdateCache(targetType, out _);

if (this.Value is not null && targetType.IsInstanceOfType(this.Value))
{
Expand All @@ -167,4 +162,41 @@ public bool IsType(Type targetType, [NotNullWhen(true)] out object? value)
value = null;
return false;
}

private bool TryDeserializeAndUpdateCache(Type targetType, out object? replacedCacheValueOrNull)
{
replacedCacheValueOrNull = null;

// Explicitly use _value here since we do not want to be overridden by the cache, if any
if (this._value is not IDelayedDeserialization delayedDeserialization)
{
// Not a delayed deserialization; nothing to do
return false;
}

bool isCompatibleType = false;
if (this._deserializedValueCache == null || !(isCompatibleType = targetType.IsAssignableFrom(this._deserializedValueCache.GetType())))
{
// Either we have no cache, or the types are incompatible; see if we can deserialize
try
{
object? deserialized = delayedDeserialization.Deserialize(targetType);

if (deserialized != null && targetType.IsInstanceOfType(deserialized))
{
replacedCacheValueOrNull = this._deserializedValueCache;
this._deserializedValueCache = deserialized;

return true;
}
}
catch
{
isCompatibleType = false;
}
}

// The last possibility is that we already deserialized successfully
return isCompatibleType;
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@ private static JsonSerializerOptions TestCustomSerializedJsonOptions

private static EdgeId TakeEdgeId() => new(Interlocked.Increment(ref s_nextEdgeId));

private static T RunJsonRoundtrip<T>(T value, JsonSerializerOptions? externalOptions = null, Expression<Func<T, bool>>? predicate = null)
internal static T RunJsonRoundtrip<T>(T value, JsonSerializerOptions? externalOptions = null, Expression<Func<T, bool>>? predicate = null)
{
JsonMarshaller marshaller = new(externalOptions);

Expand Down Expand Up @@ -172,7 +172,7 @@ private static ValueTask<Workflow<string>> CreateTestWorkflowAsync()
return builder.BuildAsync<string>();
}

private static async ValueTask<WorkflowInfo> CreateTestWorkflowInfoAsync()
internal static async ValueTask<WorkflowInfo> CreateTestWorkflowInfoAsync()
{
Workflow<string> testWorkflow = await CreateTestWorkflowAsync().ConfigureAwait(false);
return testWorkflow.ToWorkflowInfo();
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,130 @@
// Copyright (c) Microsoft. All rights reserved.

using System;
using System.Diagnostics.CodeAnalysis;
using System.Threading.Tasks;
using FluentAssertions;
using Microsoft.Agents.AI.Workflows.Checkpointing;
using Microsoft.Extensions.AI;

namespace Microsoft.Agents.AI.Workflows.UnitTests;

public class PortableValueTests
{
[SuppressMessage("Performance", "CA1812", Justification = "This is used as a Never/Bottom type.")]
private sealed class Never
{
private Never() { }
}

[Theory]
[InlineData("string")]
[InlineData(42)]
[InlineData(true)]
[InlineData(3.14)]
public async Task Test_PortableValueRoundtripAsync<T>(T value)
{
value.Should().NotBeNull();

PortableValue portableValue = new(value);

portableValue.Is<Never>(out _).Should().BeFalse();
portableValue.Is(out T? returnedValue).Should().BeTrue();
returnedValue.Should().Be(value);
}

[Fact]
public async Task Test_PortableValueRoundtripObjectAsync()
{
ChatMessage value = new(ChatRole.User, "Hello?");

PortableValue portableValue = new(value);

portableValue.Is<Never>(out _).Should().BeFalse();
portableValue.Is(out ChatMessage? returnedValue).Should().BeTrue();
returnedValue.Should().Be(value);
}

[Theory]
[InlineData("string")]
[InlineData(42)]
[InlineData(true)]
[InlineData(3.14)]
public async Task Test_DelayedSerializationRoundtripAsync<T>(T value)
{
value.Should().NotBeNull();

TestDelayedDeserialization<T> delayed = new(value);
PortableValue portableValue = new(delayed);

portableValue.Is<Never>(out _).Should().BeFalse();
portableValue.Is(out object? obj).Should().BeTrue();
obj.Should().NotBeOfType<T>();
obj.Should().BeOfType<PortableValue>()
.And.Subject.As<PortableValue>()
.As<T>().Should().Be(value);

portableValue.Is(out T? returnedValue).Should().BeTrue();
returnedValue.Should().Be(value);
}

[Fact]
public async Task Test_DelayedSerializationRoundtripObjectAsync()
{
ChatMessage value = new(ChatRole.User, "Hello?");

TestDelayedDeserialization<ChatMessage> delayed = new(value);
PortableValue portableValue = new(delayed);

portableValue.Is<Never>(out _).Should().BeFalse();
portableValue.Is(out object? obj).Should().BeTrue();
obj.Should().NotBeOfType<ChatMessage>();
obj.Should().BeOfType<PortableValue>()
.And.Subject.As<PortableValue>()
.As<ChatMessage>().Should().Be(value);

portableValue.Is(out ChatMessage? returnedValue).Should().BeTrue();
returnedValue.Should().Be(value);
}

private sealed class TestDelayedDeserialization<T> : IDelayedDeserialization
{
[NotNull]
public T Value { get; }

public TestDelayedDeserialization([DisallowNull] T value)
{
this.Value = value;
}

public TValue Deserialize<TValue>()
{
if (typeof(TValue) == typeof(object))
{
return (TValue)(object)new PortableValue(this.Value);
}

if (this.Value is TValue value)
{
return value;
}

throw new InvalidOperationException();
}

public object? Deserialize(Type targetType)
{
if (targetType == typeof(object))
{
return new PortableValue(this.Value);
}

if (targetType.IsInstanceOfType(this.Value))
{
return this.Value;
}

return null;
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,9 @@
using System.Collections.Generic;
using System.Threading.Tasks;
using FluentAssertions;
using Microsoft.Agents.AI.Workflows.Checkpointing;
using Microsoft.Agents.AI.Workflows.Execution;
using Microsoft.Extensions.AI;

namespace Microsoft.Agents.AI.Workflows.UnitTests;

Expand Down Expand Up @@ -451,4 +453,119 @@ private static async Task RunConflictingUpdatesTest_WriteVsClearAsync(string? sc
await act.Should().NotThrowAsync("writes to private scopes should not be visible across executors");
}
}

private static void VerifyIs<TExpectedType>(PortableValue? candidatePV, TExpectedType value)
{
candidatePV.Should().NotBeNull();
candidatePV.Is(out TExpectedType? candidateValue).Should().BeTrue();
candidateValue.Should().Be(value);
}

private static void VerifyIsNot<TExpectedType>(PortableValue? candidatePV)
{
candidatePV.Should().NotBeNull();
candidatePV.Is(out TExpectedType? _).Should().BeFalse();
}

[Theory]
[InlineData(true)]
[InlineData(false)]
public async Task Test_LoadPortableValueStateAsync(bool publishStateUpdates)
{
ScopeId scope = new("executor1");
const string StringValue = "string";
const int IntValue = 42;
ScopeKey ScopeKey = new("executor1", "scope", "key");
PortableValue PortableValueValue = new(StringValue);

// Arrange
StateManager manager = new();
await manager.WriteStateAsync(scope, nameof(StringValue), StringValue);
await manager.WriteStateAsync(scope, nameof(IntValue), IntValue);
await manager.WriteStateAsync(scope, nameof(ScopeKey), ScopeKey);
await manager.WriteStateAsync(scope, nameof(PortableValueValue), PortableValueValue);

if (publishStateUpdates)
{
await manager.PublishUpdatesAsync(tracer: null);
}

// Act & Assert - Read as the original types
PortableValue? stringAsPV = await manager.ReadStateAsync<PortableValue>(scope, nameof(StringValue));
VerifyIs(stringAsPV, StringValue);
VerifyIsNot<int>(stringAsPV);
VerifyIsNot<ChatMessage>(stringAsPV);
VerifyIsNot<PortableValue>(stringAsPV);

PortableValue? intAsPV = await manager.ReadStateAsync<PortableValue>(scope, nameof(IntValue));
VerifyIsNot<string>(intAsPV);
VerifyIs(intAsPV, IntValue);
VerifyIsNot<ChatMessage>(intAsPV);
VerifyIsNot<PortableValue>(intAsPV);

PortableValue? scopeKeyAsPV = await manager.ReadStateAsync<PortableValue>(scope, nameof(ScopeKey));
VerifyIsNot<string>(scopeKeyAsPV);
VerifyIsNot<int>(scopeKeyAsPV);
VerifyIs(scopeKeyAsPV, ScopeKey);
VerifyIsNot<PortableValue>(scopeKeyAsPV);

PortableValue? pvAsPV = await manager.ReadStateAsync<PortableValue>(scope, nameof(PortableValueValue));
VerifyIs(pvAsPV, StringValue);
VerifyIsNot<int>(pvAsPV);
VerifyIsNot<ChatMessage>(pvAsPV);

// Check that we don't double-wrap stored PortableValues on the out path
VerifyIsNot<PortableValue>(pvAsPV);
}

[Fact]
public async Task Test_LoadPortableValueState_AfterSerializationAsync()
{
ScopeId scope = new("executor1");
const string StringValue = "string";
const int IntValue = 42;
ScopeKey ScopeKey = new("executor1", "scope", "key");
PortableValue PortableValueValue = new(StringValue);

// Arrange
StateManager manager = new();
await manager.WriteStateAsync(scope, nameof(StringValue), StringValue);
await manager.WriteStateAsync(scope, nameof(IntValue), IntValue);
await manager.WriteStateAsync(scope, nameof(ScopeKey), ScopeKey);
await manager.WriteStateAsync(scope, nameof(PortableValueValue), PortableValueValue);

await manager.PublishUpdatesAsync(tracer: null);

Dictionary<ScopeKey, PortableValue> exportedState = await manager.ExportStateAsync();
Dictionary<ScopeKey, PortableValue> serializedState = JsonSerializationTests.RunJsonRoundtrip(exportedState);
Checkpoint testCheckpoint = new(0, await JsonSerializationTests.CreateTestWorkflowInfoAsync(), new([], [], []), serializedState, new());

manager = new();
await manager.ImportStateAsync(testCheckpoint);

// Act & Assert - Read as the original types
PortableValue? stringAsPV = await manager.ReadStateAsync<PortableValue>(scope, nameof(StringValue));
VerifyIs(stringAsPV, StringValue);
VerifyIsNot<int>(stringAsPV);
VerifyIsNot<ChatMessage>(stringAsPV);

PortableValue? intAsPV = await manager.ReadStateAsync<PortableValue>(scope, nameof(IntValue));
VerifyIsNot<string>(intAsPV);
VerifyIs(intAsPV, IntValue);
VerifyIsNot<ChatMessage>(intAsPV);

PortableValue? scopeKeyAsPV = await manager.ReadStateAsync<PortableValue>(scope, nameof(ScopeKey));
VerifyIsNot<string>(scopeKeyAsPV);
VerifyIsNot<int>(scopeKeyAsPV);
VerifyIs(scopeKeyAsPV, ScopeKey);
VerifyIsNot<PortableValue>(scopeKeyAsPV);

PortableValue? pvAsPV = await manager.ReadStateAsync<PortableValue>(scope, nameof(PortableValueValue));
VerifyIs(pvAsPV, StringValue);
VerifyIsNot<int>(pvAsPV);
VerifyIsNot<ChatMessage>(pvAsPV);

// Check that we don't double-wrap stored PortableValues on the out path
VerifyIsNot<PortableValue>(pvAsPV);
}
}
Loading