From 37a21b2cadd7bcb8975ca00d4ff64b17bc4193bf Mon Sep 17 00:00:00 2001 From: Guillaume Aquilina Date: Wed, 5 Feb 2025 12:28:59 -0500 Subject: [PATCH 1/2] fix: handle more types in tool defs --- workflowai/core/utils/_tools.py | 46 +++++++++++++++++----- workflowai/core/utils/_tools_test.py | 57 ++++++++++++++++++++++++++-- 2 files changed, 91 insertions(+), 12 deletions(-) diff --git a/workflowai/core/utils/_tools.py b/workflowai/core/utils/_tools.py index f16632e..0dda9bf 100644 --- a/workflowai/core/utils/_tools.py +++ b/workflowai/core/utils/_tools.py @@ -1,8 +1,10 @@ +import contextlib +import datetime import inspect from enum import Enum from typing import Any, Callable, NamedTuple, Optional, cast, get_type_hints -from pydantic import BaseModel +from pydantic import BaseModel, TypeAdapter from workflowai.core.utils._schema_generator import JsonSchemaGenerator @@ -15,6 +17,14 @@ class SchemaDeserializer(NamedTuple): deserializer: Optional[Callable[[Any], Any]] = None +def _serialize_datetime(x: datetime.datetime) -> str: + return x.isoformat() + + +def _deserialize_datetime(x: str) -> datetime.datetime: + return datetime.datetime.fromisoformat(x) + + def _get_type_schema(param_type: type): """Convert a Python type to its corresponding JSON schema type. @@ -24,10 +34,6 @@ def _get_type_schema(param_type: type): Returns: A dictionary containing the JSON schema type definition """ - if issubclass(param_type, Enum): - if not issubclass(param_type, str): - raise ValueError(f"Non string enums are not supported: {param_type}") - return SchemaDeserializer({"type": "string", "enum": [e.value for e in param_type]}) if param_type is str: return SchemaDeserializer({"type": "string"}) @@ -41,11 +47,33 @@ def _get_type_schema(param_type: type): if param_type is bool: return SchemaDeserializer({"type": "boolean"}) - if issubclass(param_type, BaseModel): + if param_type is datetime.datetime: + return SchemaDeserializer( + {"type": "string", "format": "date-time"}, + serializer=_serialize_datetime, + deserializer=_deserialize_datetime, + ) + + if inspect.isclass(param_type): + if issubclass(param_type, BaseModel): + return SchemaDeserializer( + schema=param_type.model_json_schema(by_alias=True, schema_generator=JsonSchemaGenerator), + serializer=lambda x: cast(BaseModel, x).model_dump(mode="json"), # pyright: ignore [reportUnknownLambdaType] + deserializer=param_type.model_validate, + ) + + if issubclass(param_type, Enum): + if not issubclass(param_type, str): + raise ValueError(f"Non string enums are not supported: {param_type}") + return SchemaDeserializer({"type": "string", "enum": [e.value for e in param_type]}) + + # Attempting to build a type adapter with pydantic + with contextlib.suppress(Exception): + adapter = TypeAdapter[Any](param_type) return SchemaDeserializer( - schema=param_type.model_json_schema(by_alias=True, schema_generator=JsonSchemaGenerator), - serializer=lambda x: cast(BaseModel, x).model_dump(mode="json"), # pyright: ignore [reportUnknownLambdaType] - deserializer=param_type.model_validate, + schema=adapter.json_schema(), + deserializer=adapter.validate_python, # pyright: ignore [reportUnknownLambdaType] + serializer=lambda x: adapter.dump_python(x, mode="json"), # pyright: ignore [reportUnknownLambdaType] ) raise ValueError(f"Unsupported type: {param_type}") diff --git a/workflowai/core/utils/_tools_test.py b/workflowai/core/utils/_tools_test.py index d6c04da..46d0567 100644 --- a/workflowai/core/utils/_tools_test.py +++ b/workflowai/core/utils/_tools_test.py @@ -1,9 +1,42 @@ +import json +from datetime import datetime from enum import Enum -from typing import Annotated +from typing import Annotated, Any +import pytest from pydantic import BaseModel - -from workflowai.core.utils._tools import tool_schema +from zoneinfo import ZoneInfo + +from workflowai.core.utils._tools import _get_type_schema, tool_schema # pyright: ignore [reportPrivateUsage] + + +class TestGetTypeSchema: + @pytest.mark.parametrize( + ("param_type", "value"), + [ + (int, 1), + (float, 1.0), + (bool, True), + (str, "test"), + (datetime, datetime.now(tz=ZoneInfo("UTC"))), + (ZoneInfo, ZoneInfo("UTC")), + (list[int], [1, 2, 3]), + (dict[str, int], {"a": 1, "b": 2}), + ], + ) + def test_get_type_schema(self, param_type: Any, value: Any): + schema = _get_type_schema(param_type) + if schema.serializer is None or schema.deserializer is None: + assert schema.serializer is None + assert schema.deserializer is None + + # Check that the value is serializable and deserializable with plain json + assert json.loads(json.dumps(value)) == value + return + + serialized = schema.serializer(value) + deserialized = schema.deserializer(serialized) + assert deserialized == value class TestToolSchema: @@ -123,3 +156,21 @@ def sample_func() -> TestModel: ... } assert output_schema.serializer is not None assert output_schema.serializer(TestModel(val=10)) == {"val": 10} + + def test_with_datetime_in_input(self): + def sample_func(time: datetime) -> str: ... + + input_schema, _ = tool_schema(sample_func) + + assert input_schema.deserializer is not None + assert input_schema.deserializer({"time": "2024-01-01T12:00:00+00:00"}) == { + "time": datetime( + 2024, + 1, + 1, + 12, + 0, + 0, + tzinfo=ZoneInfo("UTC"), + ), + } From 6ec44c391820622f439ae21163e50304e9dc5994 Mon Sep 17 00:00:00 2001 From: Guillaume Aquilina Date: Wed, 5 Feb 2025 14:21:04 -0500 Subject: [PATCH 2/2] fix: tests for python 3.9 --- workflowai/core/utils/_tools.py | 36 +++------------------------- workflowai/core/utils/_tools_test.py | 19 ++++++++++++++- 2 files changed, 21 insertions(+), 34 deletions(-) diff --git a/workflowai/core/utils/_tools.py b/workflowai/core/utils/_tools.py index 0dda9bf..d0d4184 100644 --- a/workflowai/core/utils/_tools.py +++ b/workflowai/core/utils/_tools.py @@ -1,10 +1,8 @@ import contextlib -import datetime import inspect -from enum import Enum -from typing import Any, Callable, NamedTuple, Optional, cast, get_type_hints +from typing import Any, Callable, NamedTuple, Optional, get_type_hints -from pydantic import BaseModel, TypeAdapter +from pydantic import TypeAdapter from workflowai.core.utils._schema_generator import JsonSchemaGenerator @@ -17,14 +15,6 @@ class SchemaDeserializer(NamedTuple): deserializer: Optional[Callable[[Any], Any]] = None -def _serialize_datetime(x: datetime.datetime) -> str: - return x.isoformat() - - -def _deserialize_datetime(x: str) -> datetime.datetime: - return datetime.datetime.fromisoformat(x) - - def _get_type_schema(param_type: type): """Convert a Python type to its corresponding JSON schema type. @@ -47,31 +37,11 @@ def _get_type_schema(param_type: type): if param_type is bool: return SchemaDeserializer({"type": "boolean"}) - if param_type is datetime.datetime: - return SchemaDeserializer( - {"type": "string", "format": "date-time"}, - serializer=_serialize_datetime, - deserializer=_deserialize_datetime, - ) - - if inspect.isclass(param_type): - if issubclass(param_type, BaseModel): - return SchemaDeserializer( - schema=param_type.model_json_schema(by_alias=True, schema_generator=JsonSchemaGenerator), - serializer=lambda x: cast(BaseModel, x).model_dump(mode="json"), # pyright: ignore [reportUnknownLambdaType] - deserializer=param_type.model_validate, - ) - - if issubclass(param_type, Enum): - if not issubclass(param_type, str): - raise ValueError(f"Non string enums are not supported: {param_type}") - return SchemaDeserializer({"type": "string", "enum": [e.value for e in param_type]}) - # Attempting to build a type adapter with pydantic with contextlib.suppress(Exception): adapter = TypeAdapter[Any](param_type) return SchemaDeserializer( - schema=adapter.json_schema(), + schema=adapter.json_schema(schema_generator=JsonSchemaGenerator), deserializer=adapter.validate_python, # pyright: ignore [reportUnknownLambdaType] serializer=lambda x: adapter.dump_python(x, mode="json"), # pyright: ignore [reportUnknownLambdaType] ) diff --git a/workflowai/core/utils/_tools_test.py b/workflowai/core/utils/_tools_test.py index 46d0567..2825374 100644 --- a/workflowai/core/utils/_tools_test.py +++ b/workflowai/core/utils/_tools_test.py @@ -11,6 +11,14 @@ class TestGetTypeSchema: + class _BasicEnum(str, Enum): + A = "a" + B = "b" + + class _BasicModel(BaseModel): + a: int + b: str + @pytest.mark.parametrize( ("param_type", "value"), [ @@ -22,6 +30,10 @@ class TestGetTypeSchema: (ZoneInfo, ZoneInfo("UTC")), (list[int], [1, 2, 3]), (dict[str, int], {"a": 1, "b": 2}), + (_BasicEnum, _BasicEnum.A), + (_BasicModel, _BasicModel(a=1, b="test")), + (list[_BasicModel], [_BasicModel(a=1, b="test"), _BasicModel(a=2, b="test2")]), + (tuple[int, str], (1, "test")), ], ) def test_get_type_schema(self, param_type: Any, value: Any): @@ -50,6 +62,7 @@ def sample_func( age: int, height: float, is_active: bool, + date: datetime, mode: TestMode = TestMode.FAST, ) -> bool: """Sample function for testing""" @@ -77,8 +90,12 @@ def sample_func( "type": "string", "enum": ["fast", "slow"], }, + "date": { + "type": "string", + "format": "date-time", + }, }, - "required": ["name", "age", "height", "is_active"], # 'mode' is not required + "required": ["name", "age", "height", "is_active", "date"], # 'mode' is not required } assert output_schema.schema == { "type": "boolean",