Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 7 additions & 0 deletions examples/python/output_parser_yaml.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,13 @@
- type: str
description: The first name of the person
name: last_name
- name: location
type: object
description: The details about birth location
fields:
- name: state
type: str
description: The Indian State in whihc the person was born
data_collector: kv
"""

Expand Down
83 changes: 64 additions & 19 deletions flo_ai/parsers/flo_json_parser.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@ class ParseContract:
class FloJsonParser(FloParser):
def __init__(self, parse_contract: ParseContract):
self.contract = parse_contract
self._cached_models = {}
super().__init__()

def __dict_list_to_csv_string(self, data):
Expand All @@ -35,30 +36,76 @@ def __dict_list_to_csv_string(self, data):

return f'```\n{csv_string}```'

def __create_contract_from_json(self) -> BaseModel:
def __create_nested_model(
self, field_def: Dict[str, Any], model_name: str
) -> BaseModel:
"""Creates a nested Pydantic model for object types"""
if model_name in self._cached_models:
return self._cached_models[model_name]

nested_fields = {}
for nested_field in field_def['fields']:
nested_type = self.__get_field_type_annotation(
nested_field, f"{model_name}_{nested_field['name']}"
)
field_description = nested_field['description']
nested_fields[nested_field['name']] = (
nested_type,
Field(..., description=field_description),
)

NestedModel = create_model(model_name, **nested_fields)
self._cached_models[model_name] = NestedModel
return NestedModel

def __get_field_type_annotation(
self, field: Dict[str, Any], model_name: str
) -> Any:
"""Determines the type annotation for a field, handling nested objects"""
type_mapping = {
'str': str,
'int': int,
'bool': bool,
'float': float,
'literal': Literal,
'literal': self.__create_literal_type,
'object': lambda f: self.__create_nested_model(f, model_name),
'array': lambda f: List[
self.__get_field_type_annotation(f['items'], f'{model_name}_item')
],
}

field_type = field['type']
type_handler = type_mapping.get(field_type)

if type_handler is None:
raise ValueError(f'Unsupported type: {field_type}')

return (
type_handler(field)
if field_type in ['literal', 'object', 'array']
else type_handler
)

def __create_literal_type(self, field: Dict[str, Any]) -> Any:
"""Creates a Literal type from field definition"""
literal_values = field.get('values', [])
if not literal_values:
raise ValueError(
f"Field '{field['name']}' of type 'literal' must specify 'values'."
)
literals = [literal_value['value'] for literal_value in literal_values]
return Literal[tuple(literals)]

def __create_contract_from_json(self) -> BaseModel:
pydantic_fields = {}
for field in self.contract.fields:
field_type = field['type']
if field_type == 'literal':
field_type = self.__get_field_type_annotation(
field, f"{self.contract.name}_{field['name']}"
)

if field['type'] == 'literal':
literal_values = field.get('values', [])
if not literal_values:
raise ValueError(
f"Field '{field['name']}' of type 'literal' must specify 'values'."
)
literals = [literal_value['value'] for literal_value in literal_values]
field_type_annotation = Literal[tuple(literals)]
default_prompt = (
field['default_value_prompt']
if 'default_value_prompt' in field
else ''
)
default_prompt = field.get('default_value_prompt', '')
field_description = f"""
{field['description']}
Following are the list of possibles values and its correponding description:
Expand All @@ -68,13 +115,10 @@ def __create_contract_from_json(self) -> BaseModel:
{default_prompt}
"""
else:
field_type_annotation = type_mapping.get(field_type)
if field_type_annotation is None:
raise ValueError(f'Unsupported type: {field_type}')
field_description = field['description']

pydantic_fields[field['name']] = (
field_type_annotation,
field_type,
Field(..., description=field_description),
)

Expand All @@ -86,6 +130,7 @@ def get_format_instructions(self):
pydantic_object=self.__create_contract_from_json()
).get_format_instructions()

@staticmethod
def create(json_dict: Optional[Dict] = None, json_path: Optional[str] = None):
return FloJsonParser.Builder(json_dict=json_dict, json_path=json_path).build()

Expand Down