From a8a652834f889b8a6cd2ac78cded13d465fe2dd7 Mon Sep 17 00:00:00 2001 From: vizsatiz Date: Sun, 22 Dec 2024 14:57:04 +0530 Subject: [PATCH] feat: improved parser capabilities --- examples/python/output_parser_yaml.py | 7 +++ flo_ai/parsers/flo_json_parser.py | 83 +++++++++++++++++++++------ 2 files changed, 71 insertions(+), 19 deletions(-) diff --git a/examples/python/output_parser_yaml.py b/examples/python/output_parser_yaml.py index b350d330..c02a6660 100644 --- a/examples/python/output_parser_yaml.py +++ b/examples/python/output_parser_yaml.py @@ -43,6 +43,13 @@ - type: str description: The first name of the person name: last_name + - name: location + type: object + description: The details about birth location + fields: + - name: state + type: str + description: The Indian State in whihc the person was born data_collector: kv """ diff --git a/flo_ai/parsers/flo_json_parser.py b/flo_ai/parsers/flo_json_parser.py index adf6db40..4d196e5f 100644 --- a/flo_ai/parsers/flo_json_parser.py +++ b/flo_ai/parsers/flo_json_parser.py @@ -18,6 +18,7 @@ class ParseContract: class FloJsonParser(FloParser): def __init__(self, parse_contract: ParseContract): self.contract = parse_contract + self._cached_models = {} super().__init__() def __dict_list_to_csv_string(self, data): @@ -35,30 +36,76 @@ def __dict_list_to_csv_string(self, data): return f'```\n{csv_string}```' - def __create_contract_from_json(self) -> BaseModel: + def __create_nested_model( + self, field_def: Dict[str, Any], model_name: str + ) -> BaseModel: + """Creates a nested Pydantic model for object types""" + if model_name in self._cached_models: + return self._cached_models[model_name] + + nested_fields = {} + for nested_field in field_def['fields']: + nested_type = self.__get_field_type_annotation( + nested_field, f"{model_name}_{nested_field['name']}" + ) + field_description = nested_field['description'] + nested_fields[nested_field['name']] = ( + nested_type, + Field(..., description=field_description), + ) + + NestedModel = create_model(model_name, **nested_fields) + self._cached_models[model_name] = NestedModel + return NestedModel + + def __get_field_type_annotation( + self, field: Dict[str, Any], model_name: str + ) -> Any: + """Determines the type annotation for a field, handling nested objects""" type_mapping = { 'str': str, 'int': int, 'bool': bool, 'float': float, - 'literal': Literal, + 'literal': self.__create_literal_type, + 'object': lambda f: self.__create_nested_model(f, model_name), + 'array': lambda f: List[ + self.__get_field_type_annotation(f['items'], f'{model_name}_item') + ], } + + field_type = field['type'] + type_handler = type_mapping.get(field_type) + + if type_handler is None: + raise ValueError(f'Unsupported type: {field_type}') + + return ( + type_handler(field) + if field_type in ['literal', 'object', 'array'] + else type_handler + ) + + def __create_literal_type(self, field: Dict[str, Any]) -> Any: + """Creates a Literal type from field definition""" + literal_values = field.get('values', []) + if not literal_values: + raise ValueError( + f"Field '{field['name']}' of type 'literal' must specify 'values'." + ) + literals = [literal_value['value'] for literal_value in literal_values] + return Literal[tuple(literals)] + + def __create_contract_from_json(self) -> BaseModel: pydantic_fields = {} for field in self.contract.fields: - field_type = field['type'] - if field_type == 'literal': + field_type = self.__get_field_type_annotation( + field, f"{self.contract.name}_{field['name']}" + ) + + if field['type'] == 'literal': literal_values = field.get('values', []) - if not literal_values: - raise ValueError( - f"Field '{field['name']}' of type 'literal' must specify 'values'." - ) - literals = [literal_value['value'] for literal_value in literal_values] - field_type_annotation = Literal[tuple(literals)] - default_prompt = ( - field['default_value_prompt'] - if 'default_value_prompt' in field - else '' - ) + default_prompt = field.get('default_value_prompt', '') field_description = f""" {field['description']} Following are the list of possibles values and its correponding description: @@ -68,13 +115,10 @@ def __create_contract_from_json(self) -> BaseModel: {default_prompt} """ else: - field_type_annotation = type_mapping.get(field_type) - if field_type_annotation is None: - raise ValueError(f'Unsupported type: {field_type}') field_description = field['description'] pydantic_fields[field['name']] = ( - field_type_annotation, + field_type, Field(..., description=field_description), ) @@ -86,6 +130,7 @@ def get_format_instructions(self): pydantic_object=self.__create_contract_from_json() ).get_format_instructions() + @staticmethod def create(json_dict: Optional[Dict] = None, json_path: Optional[str] = None): return FloJsonParser.Builder(json_dict=json_dict, json_path=json_path).build()