diff --git a/workflowai/core/utils/_tools.py b/workflowai/core/utils/_tools.py index f16632e..d0d4184 100644 --- a/workflowai/core/utils/_tools.py +++ b/workflowai/core/utils/_tools.py @@ -1,8 +1,8 @@ +import contextlib import inspect -from enum import Enum -from typing import Any, Callable, NamedTuple, Optional, cast, get_type_hints +from typing import Any, Callable, NamedTuple, Optional, get_type_hints -from pydantic import BaseModel +from pydantic import TypeAdapter from workflowai.core.utils._schema_generator import JsonSchemaGenerator @@ -24,10 +24,6 @@ def _get_type_schema(param_type: type): Returns: A dictionary containing the JSON schema type definition """ - if issubclass(param_type, Enum): - if not issubclass(param_type, str): - raise ValueError(f"Non string enums are not supported: {param_type}") - return SchemaDeserializer({"type": "string", "enum": [e.value for e in param_type]}) if param_type is str: return SchemaDeserializer({"type": "string"}) @@ -41,11 +37,13 @@ def _get_type_schema(param_type: type): if param_type is bool: return SchemaDeserializer({"type": "boolean"}) - if issubclass(param_type, BaseModel): + # Attempting to build a type adapter with pydantic + with contextlib.suppress(Exception): + adapter = TypeAdapter[Any](param_type) return SchemaDeserializer( - schema=param_type.model_json_schema(by_alias=True, schema_generator=JsonSchemaGenerator), - serializer=lambda x: cast(BaseModel, x).model_dump(mode="json"), # pyright: ignore [reportUnknownLambdaType] - deserializer=param_type.model_validate, + schema=adapter.json_schema(schema_generator=JsonSchemaGenerator), + deserializer=adapter.validate_python, # pyright: ignore [reportUnknownLambdaType] + serializer=lambda x: adapter.dump_python(x, mode="json"), # pyright: ignore [reportUnknownLambdaType] ) raise ValueError(f"Unsupported type: {param_type}") diff --git a/workflowai/core/utils/_tools_test.py b/workflowai/core/utils/_tools_test.py index d6c04da..2825374 100644 --- a/workflowai/core/utils/_tools_test.py +++ b/workflowai/core/utils/_tools_test.py @@ -1,9 +1,54 @@ +import json +from datetime import datetime from enum import Enum -from typing import Annotated +from typing import Annotated, Any +import pytest from pydantic import BaseModel - -from workflowai.core.utils._tools import tool_schema +from zoneinfo import ZoneInfo + +from workflowai.core.utils._tools import _get_type_schema, tool_schema # pyright: ignore [reportPrivateUsage] + + +class TestGetTypeSchema: + class _BasicEnum(str, Enum): + A = "a" + B = "b" + + class _BasicModel(BaseModel): + a: int + b: str + + @pytest.mark.parametrize( + ("param_type", "value"), + [ + (int, 1), + (float, 1.0), + (bool, True), + (str, "test"), + (datetime, datetime.now(tz=ZoneInfo("UTC"))), + (ZoneInfo, ZoneInfo("UTC")), + (list[int], [1, 2, 3]), + (dict[str, int], {"a": 1, "b": 2}), + (_BasicEnum, _BasicEnum.A), + (_BasicModel, _BasicModel(a=1, b="test")), + (list[_BasicModel], [_BasicModel(a=1, b="test"), _BasicModel(a=2, b="test2")]), + (tuple[int, str], (1, "test")), + ], + ) + def test_get_type_schema(self, param_type: Any, value: Any): + schema = _get_type_schema(param_type) + if schema.serializer is None or schema.deserializer is None: + assert schema.serializer is None + assert schema.deserializer is None + + # Check that the value is serializable and deserializable with plain json + assert json.loads(json.dumps(value)) == value + return + + serialized = schema.serializer(value) + deserialized = schema.deserializer(serialized) + assert deserialized == value class TestToolSchema: @@ -17,6 +62,7 @@ def sample_func( age: int, height: float, is_active: bool, + date: datetime, mode: TestMode = TestMode.FAST, ) -> bool: """Sample function for testing""" @@ -44,8 +90,12 @@ def sample_func( "type": "string", "enum": ["fast", "slow"], }, + "date": { + "type": "string", + "format": "date-time", + }, }, - "required": ["name", "age", "height", "is_active"], # 'mode' is not required + "required": ["name", "age", "height", "is_active", "date"], # 'mode' is not required } assert output_schema.schema == { "type": "boolean", @@ -123,3 +173,21 @@ def sample_func() -> TestModel: ... } assert output_schema.serializer is not None assert output_schema.serializer(TestModel(val=10)) == {"val": 10} + + def test_with_datetime_in_input(self): + def sample_func(time: datetime) -> str: ... + + input_schema, _ = tool_schema(sample_func) + + assert input_schema.deserializer is not None + assert input_schema.deserializer({"time": "2024-01-01T12:00:00+00:00"}) == { + "time": datetime( + 2024, + 1, + 1, + 12, + 0, + 0, + tzinfo=ZoneInfo("UTC"), + ), + }