Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
20 changes: 9 additions & 11 deletions workflowai/core/utils/_tools.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,8 @@
import contextlib
import inspect
from enum import Enum
from typing import Any, Callable, NamedTuple, Optional, cast, get_type_hints
from typing import Any, Callable, NamedTuple, Optional, get_type_hints

from pydantic import BaseModel
from pydantic import TypeAdapter

from workflowai.core.utils._schema_generator import JsonSchemaGenerator

Expand All @@ -24,10 +24,6 @@ def _get_type_schema(param_type: type):
Returns:
A dictionary containing the JSON schema type definition
"""
if issubclass(param_type, Enum):
if not issubclass(param_type, str):
raise ValueError(f"Non string enums are not supported: {param_type}")
return SchemaDeserializer({"type": "string", "enum": [e.value for e in param_type]})

if param_type is str:
return SchemaDeserializer({"type": "string"})
Expand All @@ -41,11 +37,13 @@ def _get_type_schema(param_type: type):
if param_type is bool:
return SchemaDeserializer({"type": "boolean"})

if issubclass(param_type, BaseModel):
# Attempting to build a type adapter with pydantic
with contextlib.suppress(Exception):
adapter = TypeAdapter[Any](param_type)
return SchemaDeserializer(
schema=param_type.model_json_schema(by_alias=True, schema_generator=JsonSchemaGenerator),
serializer=lambda x: cast(BaseModel, x).model_dump(mode="json"), # pyright: ignore [reportUnknownLambdaType]
deserializer=param_type.model_validate,
schema=adapter.json_schema(schema_generator=JsonSchemaGenerator),
deserializer=adapter.validate_python, # pyright: ignore [reportUnknownLambdaType]
serializer=lambda x: adapter.dump_python(x, mode="json"), # pyright: ignore [reportUnknownLambdaType]
)

raise ValueError(f"Unsupported type: {param_type}")
Expand Down
76 changes: 72 additions & 4 deletions workflowai/core/utils/_tools_test.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,54 @@
import json
from datetime import datetime
from enum import Enum
from typing import Annotated
from typing import Annotated, Any

import pytest
from pydantic import BaseModel

from workflowai.core.utils._tools import tool_schema
from zoneinfo import ZoneInfo

from workflowai.core.utils._tools import _get_type_schema, tool_schema # pyright: ignore [reportPrivateUsage]


class TestGetTypeSchema:
class _BasicEnum(str, Enum):
A = "a"
B = "b"

class _BasicModel(BaseModel):
a: int
b: str

@pytest.mark.parametrize(
("param_type", "value"),
[
(int, 1),
(float, 1.0),
(bool, True),
(str, "test"),
(datetime, datetime.now(tz=ZoneInfo("UTC"))),
(ZoneInfo, ZoneInfo("UTC")),
(list[int], [1, 2, 3]),
(dict[str, int], {"a": 1, "b": 2}),
(_BasicEnum, _BasicEnum.A),
(_BasicModel, _BasicModel(a=1, b="test")),
(list[_BasicModel], [_BasicModel(a=1, b="test"), _BasicModel(a=2, b="test2")]),
(tuple[int, str], (1, "test")),
],
)
def test_get_type_schema(self, param_type: Any, value: Any):
schema = _get_type_schema(param_type)
if schema.serializer is None or schema.deserializer is None:
assert schema.serializer is None
assert schema.deserializer is None

# Check that the value is serializable and deserializable with plain json
assert json.loads(json.dumps(value)) == value
return

serialized = schema.serializer(value)
deserialized = schema.deserializer(serialized)
assert deserialized == value


class TestToolSchema:
Expand All @@ -17,6 +62,7 @@ def sample_func(
age: int,
height: float,
is_active: bool,
date: datetime,
mode: TestMode = TestMode.FAST,
) -> bool:
"""Sample function for testing"""
Expand Down Expand Up @@ -44,8 +90,12 @@ def sample_func(
"type": "string",
"enum": ["fast", "slow"],
},
"date": {
"type": "string",
"format": "date-time",
},
},
"required": ["name", "age", "height", "is_active"], # 'mode' is not required
"required": ["name", "age", "height", "is_active", "date"], # 'mode' is not required
}
assert output_schema.schema == {
"type": "boolean",
Expand Down Expand Up @@ -123,3 +173,21 @@ def sample_func() -> TestModel: ...
}
assert output_schema.serializer is not None
assert output_schema.serializer(TestModel(val=10)) == {"val": 10}

def test_with_datetime_in_input(self):
def sample_func(time: datetime) -> str: ...

input_schema, _ = tool_schema(sample_func)

assert input_schema.deserializer is not None
assert input_schema.deserializer({"time": "2024-01-01T12:00:00+00:00"}) == {
"time": datetime(
2024,
1,
1,
12,
0,
0,
tzinfo=ZoneInfo("UTC"),
),
}