diff --git a/.gitignore b/.gitignore index 69cc4315..0787c186 100644 --- a/.gitignore +++ b/.gitignore @@ -13,4 +13,5 @@ bin *.sql *.log *.yaml +scratch_pad.py examples/local/* \ No newline at end of file diff --git a/flo_ai/state/flo_json_output_collector.py b/flo_ai/state/flo_json_output_collector.py index 6eaccb62..71eeaf18 100644 --- a/flo_ai/state/flo_json_output_collector.py +++ b/flo_ai/state/flo_json_output_collector.py @@ -1,24 +1,36 @@ import json +import regex +from flo_ai.error.flo_exception import FloException from typing import Dict, List, Any +from flo_ai.common.flo_logger import get_logger from flo_ai.state.flo_output_collector import FloOutputCollector class FloJsonOutputCollector(FloOutputCollector): - def __init__(self): + def __init__(self, strict: bool = False): super().__init__() + self.strict = strict self.data: List[Dict[str, Any]] = [] def append(self, agent_output): - output_dict = json.loads(self.__remove_after_braces(agent_output)) - self.data.append(output_dict) - - def __remove_after_braces(self, s: str) -> str: - first_brace = s.find('{') - last_brace = s.rfind('}') - - if first_brace != -1 and last_brace != -1 and first_brace < last_brace: - return s[first_brace : last_brace + 1] - return s + self.data.append(self.__extract_jsons(agent_output)) + + def __extract_jsons(self, llm_response): + json_pattern = r'\{(?:[^{}]|(?R))*\}' + json_matches = regex.findall(json_pattern, llm_response) + json_object = {} + for json_str in json_matches: + try: + json_obj = json.loads(json_str) + json_object.update(json_obj) + except json.JSONDecodeError as e: + get_logger().error(f'Invalid JSON in response: {json_str}') + raise e + if self.strict and len(json_matches) == 0: + raise FloException( + 'JSON response expected in collector model: strict', error_code=1099 + ) + return json_object def pop(self): return self.data.pop()