Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
20 commits
Select commit Hold shift + click to select a range
1c68a0f
Harden Python checkpoint persistence defaults
Mar 27, 2026
747b1e0
fix: resolve mypy no-any-return error in checkpoint encoding
Mar 27, 2026
204b117
Simplify find_class return in _RestrictedUnpickler (#4894)
Mar 27, 2026
3e40ec4
Address review feedback for #4894: Python: Harden Python checkpoint p…
Mar 27, 2026
0a6a169
Restore # noqa: S301 on line 102 of _checkpoint_encoding.py (#4894)
Mar 27, 2026
6f9bb58
Address review feedback for #4894: Python: Harden Python checkpoint p…
Mar 27, 2026
8e4da79
Address PR review comments on checkpoint encoding (#4894)
Mar 27, 2026
dfc8579
Address PR #4941 review comments: fix docstring position and return type
Mar 27, 2026
96afb37
Address review feedback for #4894: review comment fixes
Mar 27, 2026
089fc3d
Merge remote-tracking branch 'upstream/main' into agent/fix-4894-1
Apr 8, 2026
ad2a9f6
fix: use pickle.UnpicklingError in RestrictedUnpickler and improve do…
Apr 8, 2026
f0a990b
fix: address PR #4941 review comments for checkpoint encoding
Apr 8, 2026
363a096
fix: replace deprecated 'builtin' repo with pre-commit-hooks in pre-c…
Apr 8, 2026
889ad10
style: apply pyupgrade formatting to docstring example
Apr 8, 2026
5689638
fix: resolve pre-commit hook paths for monorepo git root
Apr 8, 2026
0263504
Merge remote-tracking branch 'upstream/main' into agent/fix-4894-1
Apr 9, 2026
151ab6b
Fix pre-commit config paths for prek --cd python execution
Apr 9, 2026
160fc2b
fix: add builtins:getattr to checkpoint deserialization allowlist
Apr 9, 2026
83bca52
Address review feedback for #4894: review comment fixes
Apr 9, 2026
1576a46
Merge branch 'main' into agent/fix-4894-1
moonbox3 Apr 9, 2026
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 2 additions & 4 deletions python/.pre-commit-config.yaml
Original file line number Diff line number Diff line change
@@ -1,7 +1,8 @@
fail_fast: true
exclude: ^scripts/
repos:
- repo: builtin
- repo: https://github.com/pre-commit/pre-commit-hooks
rev: v6.0.0
hooks:
- id: check-toml
name: Check TOML files
Expand Down Expand Up @@ -34,9 +35,6 @@ repos:
- id: no-commit-to-branch
name: Protect main branch
args: [--branch, main]
- repo: https://github.com/pre-commit/pre-commit-hooks
rev: v6.0.0
hooks:
- id: check-ast
name: Check Valid Python Samples
types: ["python"]
Expand Down
39 changes: 33 additions & 6 deletions python/packages/core/agent_framework/_workflows/_checkpoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -244,14 +244,39 @@ class FileCheckpointStorage:
is serialized using pickle and embedded as base64-encoded strings within the JSON. This allows
for human-readable checkpoint files while preserving the ability to store complex Python objects.

SECURITY WARNING: Checkpoints use pickle for data serialization. Only load checkpoints
from trusted sources. Loading a malicious checkpoint file can execute arbitrary code.
By default, checkpoint deserialization is restricted to a built-in set of safe
Python types (primitives, datetime, uuid, ...) and all ``agent_framework``
internal types. To allow additional application-specific types, pass them via
the ``allowed_checkpoint_types`` parameter using ``"module:qualname"`` format.

Example::

storage = FileCheckpointStorage(
"/tmp/checkpoints",
allowed_checkpoint_types=[
"my_app.models:MyState",
],
)
"""

def __init__(self, storage_path: str | Path):
"""Initialize the file storage."""
def __init__(
self,
storage_path: str | Path,
*,
allowed_checkpoint_types: list[str] | None = None,
) -> None:
"""Initialize the file storage.

Args:
storage_path: Directory path where checkpoint files will be stored.
allowed_checkpoint_types: Additional types (beyond the built-in safe set
and framework types) that are permitted during checkpoint
deserialization. Each entry should be a ``"module:qualname"``
string (e.g., ``"my_app.models:MyState"``).
"""
self.storage_path = Path(storage_path)
self.storage_path.mkdir(parents=True, exist_ok=True)
self._allowed_types: frozenset[str] = frozenset(allowed_checkpoint_types or [])
logger.info(f"Initialized file checkpoint storage at {self.storage_path}")

def _validate_file_path(self, checkpoint_id: CheckpointID) -> Path:
Expand Down Expand Up @@ -327,7 +352,7 @@ def _read() -> dict[str, Any]:
from ._checkpoint_encoding import decode_checkpoint_value

try:
decoded_checkpoint_dict = decode_checkpoint_value(encoded_checkpoint)
decoded_checkpoint_dict = decode_checkpoint_value(encoded_checkpoint, allowed_types=self._allowed_types)
except WorkflowCheckpointException:
raise
checkpoint = WorkflowCheckpoint.from_dict(decoded_checkpoint_dict)
Expand All @@ -352,7 +377,9 @@ def _list_checkpoints() -> list[WorkflowCheckpoint]:
encoded_checkpoint = json.load(f)
from ._checkpoint_encoding import decode_checkpoint_value

decoded_checkpoint_dict = decode_checkpoint_value(encoded_checkpoint)
decoded_checkpoint_dict = decode_checkpoint_value(
encoded_checkpoint, allowed_types=self._allowed_types
)
checkpoint = WorkflowCheckpoint.from_dict(decoded_checkpoint_dict)
if checkpoint.workflow_name == workflow_name:
checkpoints.append(checkpoint)
Expand Down
164 changes: 137 additions & 27 deletions python/packages/core/agent_framework/_workflows/_checkpoint_encoding.py
Original file line number Diff line number Diff line change
@@ -1,25 +1,29 @@
# Copyright (c) Microsoft. All rights reserved.

from __future__ import annotations

import base64
import logging
import pickle # nosec # noqa: S403
from typing import Any

from ..exceptions import WorkflowCheckpointException

"""Checkpoint encoding using JSON structure with pickle+base64 for arbitrary data.

This hybrid approach provides:
- Human-readable JSON structure for debugging and inspection of primitives and collections
- Full Python object fidelity via pickle for data values (non-JSON-native types)
- Base64 encoding to embed binary pickle data in JSON strings

SECURITY WARNING: Checkpoints use pickle for data serialization. Only load checkpoints
from trusted sources. Loading a malicious checkpoint file can execute arbitrary code.
When ``allowed_types`` is supplied to :func:`decode_checkpoint_value`, a
``RestrictedUnpickler`` is used that limits which classes may be instantiated
during deserialization. The default built-in safe set covers common Python
value types (primitives, datetime, uuid, ...) and all ``agent_framework``
internal types. Callers can extend the set by passing additional
``"module:qualname"`` strings.
"""

from __future__ import annotations

import base64
import io
import logging
import pickle # nosec # noqa: S403
from typing import Any

from ..exceptions import WorkflowCheckpointException

logger = logging.getLogger("agent_framework")

Expand All @@ -30,6 +34,82 @@
# Types that are natively JSON-serializable and don't need pickling
_JSON_NATIVE_TYPES = (str, int, float, bool, type(None))

# Module prefix for framework-internal types that are always allowed
_FRAMEWORK_MODULE_PREFIX = "agent_framework."

# Built-in types considered safe for checkpoint deserialization.
# Each entry is a ``module:qualname`` string matching the format produced by
# :func:`_type_to_key`. These are the classes for which pickle's
# ``find_class`` will be called when unpickling common Python value types.
_BUILTIN_ALLOWED_TYPE_KEYS: frozenset[str] = frozenset({
# builtins
"builtins:object",
"builtins:complex",
"builtins:range",
"builtins:slice",
"builtins:int",
"builtins:float",
"builtins:str",
"builtins:bytes",
"builtins:bytearray",
"builtins:bool",
"builtins:set",
"builtins:frozenset",
"builtins:list",
"builtins:dict",
"builtins:tuple",
"builtins:type",
# getattr is used by pickle to reconstruct enum members
"builtins:getattr",
# copyreg helpers used by pickle for object reconstruction
"copyreg:_reconstructor",
# datetime
"datetime:datetime",
"datetime:date",
"datetime:time",
"datetime:timedelta",
"datetime:timezone",
# uuid
"uuid:UUID",
# decimal
"decimal:Decimal",
# collections
"collections:OrderedDict",
"collections:defaultdict",
"collections:deque",
})


class _RestrictedUnpickler(pickle.Unpickler): # noqa: S301
"""Unpickler that restricts which classes may be instantiated.

Only classes whose ``module:qualname`` key appears in the combined allow
set (built-in safe types + framework types + caller-specified extras) are
permitted. All other classes raise :class:`pickle.UnpicklingError`.
"""

def __init__(self, data: bytes, allowed_types: frozenset[str]) -> None:
super().__init__(io.BytesIO(data))
self._allowed_types = allowed_types

def find_class(self, module: str, name: str) -> type:
type_key = f"{module}:{name}"

if (
type_key in _BUILTIN_ALLOWED_TYPE_KEYS
or type_key in self._allowed_types
or module.startswith(_FRAMEWORK_MODULE_PREFIX)
):
return super().find_class(module, name) # type: ignore[no-any-return] # nosec
Comment thread
moonbox3 marked this conversation as resolved.

Comment thread
moonbox3 marked this conversation as resolved.
raise pickle.UnpicklingError(
f"Checkpoint deserialization blocked for type '{type_key}'. "
f"To allow this type, either include its 'module:qualname' key in the "
f"'allowed_types' set passed to 'decode_checkpoint_value', or add it to "
f"'allowed_checkpoint_types' on your checkpoint storage "
f"(for example, 'FileCheckpointStorage.allowed_checkpoint_types')."
)


def encode_checkpoint_value(value: Any) -> Any:
"""Encode a Python value for checkpoint storage.
Expand All @@ -48,29 +128,51 @@ def encode_checkpoint_value(value: Any) -> Any:
return _encode(value)


def decode_checkpoint_value(value: Any) -> Any:
def decode_checkpoint_value(value: Any, *, allowed_types: frozenset[str] | None = None) -> Any:
"""Decode a value from checkpoint storage.

Reverses the encoding performed by encode_checkpoint_value.
Pickled values (identified by _PICKLE_MARKER) are decoded and unpickled.

WARNING: Only call this with trusted data. Pickle can execute
arbitrary code during deserialization. The post-unpickle type verification
detects accidental corruption or type mismatches, but cannot prevent
arbitrary code execution from malicious pickle payloads.

Args:
value: A JSON-deserialized value from checkpoint storage.
allowed_types: If not ``None``, restrict pickle deserialization to the
built-in safe set, framework types, and the types listed here.
Each entry should use ``"module:qualname"`` format — that is, the
dotted module path followed by a colon and the class
``__qualname__``. For example, given a user-defined class::

# my_app/models.py
class MyState: ...

the corresponding entry would be ``"my_app.models:MyState"``::

decode_checkpoint_value(
data,
allowed_types=frozenset({"my_app.models:MyState"}),
)

When using :class:`FileCheckpointStorage`, pass the same strings
via ``allowed_checkpoint_types``::

storage = FileCheckpointStorage(
"/tmp/checkpoints",
allowed_checkpoint_types=["my_app.models:MyState"],
)

If ``None``, no restriction is applied (backward-compatible
behavior).

Returns:
The original Python value.

Raises:
WorkflowCheckpointException: If the unpickled object's type doesn't match
the recorded type, indicating corruption, or if the base64/pickle
data is malformed.
the recorded type, indicating corruption, if the base64/pickle
data is malformed, or if a disallowed type is encountered during
restricted deserialization.
"""
return _decode(value)
return _decode(value, allowed_types=allowed_types)


def _encode(value: Any) -> Any:
Expand All @@ -94,7 +196,7 @@ def _encode(value: Any) -> Any:
}


def _decode(value: Any) -> Any:
def _decode(value: Any, *, allowed_types: frozenset[str] | None = None) -> Any:
"""Recursively decode a value from JSON storage."""
# JSON-native types pass through
if isinstance(value, _JSON_NATIVE_TYPES):
Expand All @@ -104,16 +206,16 @@ def _decode(value: Any) -> Any:
if isinstance(value, dict):
# Pickled value: decode, unpickle, and verify type
if _PICKLE_MARKER in value and _TYPE_MARKER in value:
obj = _base64_to_unpickle(value[_PICKLE_MARKER]) # type: ignore
obj = _base64_to_unpickle(value[_PICKLE_MARKER], allowed_types=allowed_types) # type: ignore
_verify_type(obj, value.get(_TYPE_MARKER)) # type: ignore
return obj

# Regular dict: decode values recursively
return {k: _decode(v) for k, v in value.items()} # type: ignore
return {k: _decode(v, allowed_types=allowed_types) for k, v in value.items()} # type: ignore

# Handle encoded lists
if isinstance(value, list):
return [_decode(item) for item in value] # type: ignore
return [_decode(item, allowed_types=allowed_types) for item in value] # type: ignore

return value

Expand Down Expand Up @@ -148,15 +250,23 @@ def _pickle_to_base64(value: Any) -> str:
return base64.b64encode(pickled).decode("ascii")


def _base64_to_unpickle(encoded: str) -> Any:
def _base64_to_unpickle(encoded: str, *, allowed_types: frozenset[str] | None = None) -> Any:
"""Decode base64 string and unpickle.

Args:
encoded: Base64-encoded pickle data.
allowed_types: If not ``None``, use restricted unpickling that only
permits built-in safe types, framework types, and the specified
extra types.

Raises:
WorkflowCheckpointException: If the base64 data is corrupted or the pickle
format is incompatible.
WorkflowCheckpointException: If the base64 data is corrupted, the pickle
format is incompatible, or a disallowed type is encountered.
"""
try:
pickled = base64.b64decode(encoded.encode("ascii"))
if allowed_types is not None:
return _RestrictedUnpickler(pickled, allowed_types).load()
return pickle.loads(pickled) # nosec # noqa: S301
except Exception as exc:
raise WorkflowCheckpointException(f"Failed to decode pickled checkpoint data: {exc}") from exc
Expand Down
18 changes: 15 additions & 3 deletions python/packages/core/tests/workflow/test_checkpoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -1048,7 +1048,10 @@ async def test_file_checkpoint_storage_roundtrip_datetime():
async def test_file_checkpoint_storage_roundtrip_dataclass():
"""Test that dataclass objects roundtrip correctly via pickle encoding."""
with tempfile.TemporaryDirectory() as temp_dir:
storage = FileCheckpointStorage(temp_dir)
storage = FileCheckpointStorage(
temp_dir,
allowed_checkpoint_types=["tests.workflow.test_checkpoint:_TestCustomData"],
)

custom_obj = _TestCustomData(name="test", value=42, tags=["a", "b", "c"])

Expand Down Expand Up @@ -1238,7 +1241,10 @@ async def test_file_checkpoint_storage_roundtrip_messages_with_complex_data():
async def test_file_checkpoint_storage_roundtrip_pending_request_info_events():
"""Test that pending_request_info_events with WorkflowEvent objects roundtrip correctly."""
with tempfile.TemporaryDirectory() as temp_dir:
storage = FileCheckpointStorage(temp_dir)
storage = FileCheckpointStorage(
temp_dir,
allowed_checkpoint_types=["tests.workflow.test_checkpoint:_TestToolApprovalRequest"],
)

# Create request_info events using the proper WorkflowEvent factory
event1 = WorkflowEvent.request_info(
Expand Down Expand Up @@ -1300,7 +1306,13 @@ async def test_file_checkpoint_storage_roundtrip_pending_request_info_events():
async def test_file_checkpoint_storage_roundtrip_full_checkpoint():
"""Test complete WorkflowCheckpoint roundtrip with all fields populated using proper types."""
with tempfile.TemporaryDirectory() as temp_dir:
storage = FileCheckpointStorage(temp_dir)
storage = FileCheckpointStorage(
temp_dir,
allowed_checkpoint_types=[
"tests.workflow.test_checkpoint:_TestApprovalRequest",
"tests.workflow.test_checkpoint:_TestExecutorState",
],
)

# Create proper WorkflowMessage objects
msg1 = WorkflowMessage(data="msg1", source_id="s", target_id="t")
Expand Down
Loading
Loading