Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
104 changes: 74 additions & 30 deletions src/buster/resources/airflow/utils.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,13 @@
import traceback
from datetime import datetime, timedelta
from typing import Any, Dict
from typing import Any, Dict, Optional

from buster.types import ApiVersion, Environment
from buster.types.api import AirflowFlowVersion
from buster.utils import get_buster_url


def serialize_airflow_context(context: Dict[str, Any]) -> Dict[str, Any]:
def serialize_airflow_context(context: Dict[str, Any], max_depth: int = 5) -> Dict[str, Any]:
"""
Serialize an Airflow context dictionary to a JSON-safe format.

Expand All @@ -17,36 +17,50 @@ def serialize_airflow_context(context: Dict[str, Any]) -> Dict[str, Any]:
- Exception objects (converts to dict with type, message, and traceback)
- Sets (converts to lists)
- Other non-serializable objects (converts to string representation)
- Circular references (tracks visited objects to prevent infinite loops)
- Depth limiting (stops recursion at max_depth to prevent stack overflow)

Args:
context: The Airflow context dictionary to serialize
max_depth: Maximum recursion depth (default: 5)

Returns:
A JSON-serializable dictionary
"""
serialized: Dict[str, Any] = {}
visited: set[int] = set() # Track visited objects by id to detect circular references

for key, value in context.items():
serialized[key] = _serialize_value(value)
serialized[key] = _serialize_value(value, depth=0, max_depth=max_depth, visited=visited)

return serialized


def _serialize_value(value: Any) -> Any:
def _serialize_value(value: Any, depth: int = 0, max_depth: int = 5, visited: Optional[set[int]] = None) -> Any:
"""
Recursively serialize a value to a JSON-safe format.
Recursively serialize a value to a JSON-safe format with depth limiting and circular reference detection.

Args:
value: The value to serialize
depth: Current recursion depth
max_depth: Maximum allowed recursion depth
visited: Set of visited object ids to detect circular references

Returns:
A JSON-serializable version of the value
"""
if visited is None:
visited = set()

# Check depth limit
if depth > max_depth:
return f"<max depth {max_depth} reached>"

# Handle None
if value is None:
return None

# Handle primitives (str, int, float, bool)
# Handle primitives (str, int, float, bool) - these can't have circular refs
if isinstance(value, (str, int, float, bool)):
return value

Expand All @@ -58,48 +72,78 @@ def _serialize_value(value: Any) -> Any:
if isinstance(value, timedelta):
return str(value)

# Check for circular reference using object id
# Only check for mutable objects (lists, dicts, objects)
obj_id = id(value)
if obj_id in visited:
return f"<circular reference to {type(value).__name__}>"

# Handle Exception objects - serialize with type, message, and traceback
if isinstance(value, BaseException):
return {
"_type": "exception",
"exception_type": type(value).__name__,
"exception_message": str(value),
"traceback": traceback.format_exception(type(value), value, value.__traceback__),
}
visited.add(obj_id)
try:
return {
"_type": "exception",
"exception_type": type(value).__name__,
"exception_message": str(value),
"traceback": traceback.format_exception(type(value), value, value.__traceback__),
}
finally:
visited.discard(obj_id)

# Handle lists
if isinstance(value, list):
return [_serialize_value(item) for item in value]
visited.add(obj_id)
try:
return [_serialize_value(item, depth + 1, max_depth, visited) for item in value]
finally:
visited.discard(obj_id)

# Handle tuples (convert to list)
if isinstance(value, tuple):
return [_serialize_value(item) for item in value]
visited.add(obj_id)
try:
return [_serialize_value(item, depth + 1, max_depth, visited) for item in value]
finally:
visited.discard(obj_id)

# Handle sets (convert to list)
if isinstance(value, set):
return [_serialize_value(item) for item in value]
visited.add(obj_id)
try:
return [_serialize_value(item, depth + 1, max_depth, visited) for item in value]
finally:
visited.discard(obj_id)

# Handle dictionaries
if isinstance(value, dict):
return {k: _serialize_value(v) for k, v in value.items()}
visited.add(obj_id)
try:
return {k: _serialize_value(v, depth + 1, max_depth, visited) for k, v in value.items()}
finally:
visited.discard(obj_id)

# Handle objects with __dict__ attribute (serialize their attributes)
if hasattr(value, "__dict__"):
obj_dict = {"_type": "object", "_class": type(value).__name__}
visited.add(obj_id)
try:
# Try to serialize the object's attributes
for attr_key, attr_value in value.__dict__.items():
# Skip private/protected attributes
if not attr_key.startswith("_"):
try:
obj_dict[attr_key] = _serialize_value(attr_value)
except Exception:
# If serialization fails for an attribute, convert to string
obj_dict[attr_key] = str(attr_value)
return obj_dict
except Exception:
# If we can't serialize the object's dict, fall back to string
return str(value)
obj_dict = {"_type": "object", "_class": type(value).__name__}
try:
# Try to serialize the object's attributes
for attr_key, attr_value in value.__dict__.items():
# Skip private/protected attributes
if not attr_key.startswith("_"):
try:
obj_dict[attr_key] = _serialize_value(attr_value, depth + 1, max_depth, visited)
except Exception:
# If serialization fails for an attribute, convert to string
obj_dict[attr_key] = str(attr_value)
return obj_dict
except Exception:
# If we can't serialize the object's dict, fall back to string
return str(value)
finally:
visited.discard(obj_id)

# Handle callable objects (functions, methods)
if callable(value):
Expand Down
1 change: 0 additions & 1 deletion tests/manual/test_before_after.py
Original file line number Diff line number Diff line change
Expand Up @@ -125,4 +125,3 @@ def setattr(self, module, name, value):

mp = SimpleMonkeypatch()
demo_before_after(mp)

1 change: 0 additions & 1 deletion tests/manual/test_payload_inspection.py
Original file line number Diff line number Diff line change
Expand Up @@ -99,4 +99,3 @@ def setattr(self, module, name, value):

mp = SimpleMonkeypatch()
inspect_payload(mp)

Loading