From 47fd758426499b4655e4701355de5d3d5f890321 Mon Sep 17 00:00:00 2001 From: Nate Kelley Date: Thu, 11 Dec 2025 09:24:32 -0700 Subject: [PATCH] fix: serialization depth limit --- src/buster/resources/airflow/utils.py | 104 +++++++++++++++++------- tests/manual/test_before_after.py | 1 - tests/manual/test_payload_inspection.py | 1 - 3 files changed, 74 insertions(+), 32 deletions(-) diff --git a/src/buster/resources/airflow/utils.py b/src/buster/resources/airflow/utils.py index 3b01265..848a889 100644 --- a/src/buster/resources/airflow/utils.py +++ b/src/buster/resources/airflow/utils.py @@ -1,13 +1,13 @@ import traceback from datetime import datetime, timedelta -from typing import Any, Dict +from typing import Any, Dict, Optional from buster.types import ApiVersion, Environment from buster.types.api import AirflowFlowVersion from buster.utils import get_buster_url -def serialize_airflow_context(context: Dict[str, Any]) -> Dict[str, Any]: +def serialize_airflow_context(context: Dict[str, Any], max_depth: int = 5) -> Dict[str, Any]: """ Serialize an Airflow context dictionary to a JSON-safe format. @@ -17,36 +17,50 @@ def serialize_airflow_context(context: Dict[str, Any]) -> Dict[str, Any]: - Exception objects (converts to dict with type, message, and traceback) - Sets (converts to lists) - Other non-serializable objects (converts to string representation) + - Circular references (tracks visited objects to prevent infinite loops) + - Depth limiting (stops recursion at max_depth to prevent stack overflow) Args: context: The Airflow context dictionary to serialize + max_depth: Maximum recursion depth (default: 5) Returns: A JSON-serializable dictionary """ serialized: Dict[str, Any] = {} + visited: set[int] = set() # Track visited objects by id to detect circular references for key, value in context.items(): - serialized[key] = _serialize_value(value) + serialized[key] = _serialize_value(value, depth=0, max_depth=max_depth, visited=visited) return serialized -def _serialize_value(value: Any) -> Any: +def _serialize_value(value: Any, depth: int = 0, max_depth: int = 5, visited: Optional[set[int]] = None) -> Any: """ - Recursively serialize a value to a JSON-safe format. + Recursively serialize a value to a JSON-safe format with depth limiting and circular reference detection. Args: value: The value to serialize + depth: Current recursion depth + max_depth: Maximum allowed recursion depth + visited: Set of visited object ids to detect circular references Returns: A JSON-serializable version of the value """ + if visited is None: + visited = set() + + # Check depth limit + if depth > max_depth: + return f"" + # Handle None if value is None: return None - # Handle primitives (str, int, float, bool) + # Handle primitives (str, int, float, bool) - these can't have circular refs if isinstance(value, (str, int, float, bool)): return value @@ -58,48 +72,78 @@ def _serialize_value(value: Any) -> Any: if isinstance(value, timedelta): return str(value) + # Check for circular reference using object id + # Only check for mutable objects (lists, dicts, objects) + obj_id = id(value) + if obj_id in visited: + return f"" + # Handle Exception objects - serialize with type, message, and traceback if isinstance(value, BaseException): - return { - "_type": "exception", - "exception_type": type(value).__name__, - "exception_message": str(value), - "traceback": traceback.format_exception(type(value), value, value.__traceback__), - } + visited.add(obj_id) + try: + return { + "_type": "exception", + "exception_type": type(value).__name__, + "exception_message": str(value), + "traceback": traceback.format_exception(type(value), value, value.__traceback__), + } + finally: + visited.discard(obj_id) # Handle lists if isinstance(value, list): - return [_serialize_value(item) for item in value] + visited.add(obj_id) + try: + return [_serialize_value(item, depth + 1, max_depth, visited) for item in value] + finally: + visited.discard(obj_id) # Handle tuples (convert to list) if isinstance(value, tuple): - return [_serialize_value(item) for item in value] + visited.add(obj_id) + try: + return [_serialize_value(item, depth + 1, max_depth, visited) for item in value] + finally: + visited.discard(obj_id) # Handle sets (convert to list) if isinstance(value, set): - return [_serialize_value(item) for item in value] + visited.add(obj_id) + try: + return [_serialize_value(item, depth + 1, max_depth, visited) for item in value] + finally: + visited.discard(obj_id) # Handle dictionaries if isinstance(value, dict): - return {k: _serialize_value(v) for k, v in value.items()} + visited.add(obj_id) + try: + return {k: _serialize_value(v, depth + 1, max_depth, visited) for k, v in value.items()} + finally: + visited.discard(obj_id) # Handle objects with __dict__ attribute (serialize their attributes) if hasattr(value, "__dict__"): - obj_dict = {"_type": "object", "_class": type(value).__name__} + visited.add(obj_id) try: - # Try to serialize the object's attributes - for attr_key, attr_value in value.__dict__.items(): - # Skip private/protected attributes - if not attr_key.startswith("_"): - try: - obj_dict[attr_key] = _serialize_value(attr_value) - except Exception: - # If serialization fails for an attribute, convert to string - obj_dict[attr_key] = str(attr_value) - return obj_dict - except Exception: - # If we can't serialize the object's dict, fall back to string - return str(value) + obj_dict = {"_type": "object", "_class": type(value).__name__} + try: + # Try to serialize the object's attributes + for attr_key, attr_value in value.__dict__.items(): + # Skip private/protected attributes + if not attr_key.startswith("_"): + try: + obj_dict[attr_key] = _serialize_value(attr_value, depth + 1, max_depth, visited) + except Exception: + # If serialization fails for an attribute, convert to string + obj_dict[attr_key] = str(attr_value) + return obj_dict + except Exception: + # If we can't serialize the object's dict, fall back to string + return str(value) + finally: + visited.discard(obj_id) # Handle callable objects (functions, methods) if callable(value): diff --git a/tests/manual/test_before_after.py b/tests/manual/test_before_after.py index 530c392..56784b3 100644 --- a/tests/manual/test_before_after.py +++ b/tests/manual/test_before_after.py @@ -125,4 +125,3 @@ def setattr(self, module, name, value): mp = SimpleMonkeypatch() demo_before_after(mp) - diff --git a/tests/manual/test_payload_inspection.py b/tests/manual/test_payload_inspection.py index 47c8aa8..7972151 100644 --- a/tests/manual/test_payload_inspection.py +++ b/tests/manual/test_payload_inspection.py @@ -99,4 +99,3 @@ def setattr(self, module, name, value): mp = SimpleMonkeypatch() inspect_payload(mp) -