Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions flo_ai/state/__init__.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from flo_ai.state.flo_json_output_collector import FloJsonOutputCollector
from flo_ai.state.flo_output_collector import FloOutputCollector
from flo_ai.state.flo_output_collector import FloOutputCollector, CollectionStatus

__all__ = ['FloJsonOutputCollector', 'FloOutputCollector']
__all__ = ['FloJsonOutputCollector', 'FloOutputCollector', 'CollectionStatus']
9 changes: 6 additions & 3 deletions flo_ai/state/flo_json_output_collector.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,13 +3,14 @@
from flo_ai.error.flo_exception import FloException
from typing import Dict, List, Any
from flo_ai.common.flo_logger import get_logger
from flo_ai.state.flo_output_collector import FloOutputCollector
from flo_ai.state.flo_output_collector import FloOutputCollector, CollectionStatus


class FloJsonOutputCollector(FloOutputCollector):
def __init__(self, strict: bool = False):
super().__init__()
self.strict = strict
self.status = CollectionStatus.success
self.data: List[Dict[str, Any]] = []

def append(self, agent_output):
Expand Down Expand Up @@ -70,9 +71,11 @@ def __extract_jsons(self, llm_response):
json_obj = json.loads(self.__strip_comments(json_str))
json_object.update(json_obj)
except json.JSONDecodeError as e:
get_logger().error(f'Invalid JSON in response: {json_str}')
raise e
self.status = CollectionStatus.partial
get_logger().error(f'Invalid JSON in response: {json_str}, {e}')
if self.strict and len(json_matches) == 0:
self.status = CollectionStatus.error
get_logger().error(f'Error while finding json in -- {llm_response}')
raise FloException(
'JSON response expected in collector model: strict', error_code=1099
)
Expand Down
7 changes: 7 additions & 0 deletions flo_ai/state/flo_output_collector.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,13 @@
from enum import Enum
from abc import ABC, abstractmethod


class CollectionStatus(Enum):
success = 'success'
partial = 'partial'
error = 'error'


class FloOutputCollector(ABC):
@abstractmethod
def append():
Expand Down
6 changes: 0 additions & 6 deletions tests/test_json_output_collection.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@
import pytest
import json
from flo_ai.error.flo_exception import FloException
from flo_ai.state.flo_output_collector import FloOutputCollector
from flo_ai.state.flo_json_output_collector import FloJsonOutputCollector
Expand Down Expand Up @@ -103,11 +102,6 @@ def test_fetch_with_overlapping_keys(self, collector: FloJsonOutputCollector):
result = collector.fetch()
assert result == {'key': 'value2'} # Later values should override earlier ones

def test_invalid_json(self, collector: FloJsonOutputCollector):
test_input = '{"key": "value",}' # Invalid JSON with trailing comma
with pytest.raises(json.JSONDecodeError):
collector.append(test_input)

def test_complex_nested_structure(self, collector: FloJsonOutputCollector):
test_input = """
{
Expand Down