Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 6 additions & 2 deletions livekit-agents/livekit/agents/llm/_strict.py
Original file line number Diff line number Diff line change
Expand Up @@ -79,6 +79,9 @@ def _ensure_strict_json_schema(
# Strip empty schema objects ({}) — they are JSON Schema's identity element
# for anyOf (match anything) and cause OpenAI strict mode to reject the schema.
# Common when Union[..., Any] or ForwardRef patterns produce bare {} entries.
# Also convert oneOf → anyOf because OpenAI strict mode does not permit oneOf.
# Pydantic emits oneOf for discriminated unions, but anyOf is semantically equivalent
# for the LLM's purposes and is accepted by the API.
for union_key in ("anyOf", "oneOf"):
variants = json_schema.get(union_key)
if is_list(variants):
Expand All @@ -89,8 +92,9 @@ def _ensure_strict_json_schema(
)
json_schema.pop(union_key, None)
elif len(variants) >= 2:
json_schema[union_key] = [
_ensure_strict_json_schema(variant, path=(*path, union_key, str(i)), root=root)
json_schema.pop(union_key, None)
json_schema["anyOf"] = [
_ensure_strict_json_schema(variant, path=(*path, "anyOf", str(i)), root=root)
for i, variant in enumerate(variants)
]
else:
Expand Down
72 changes: 71 additions & 1 deletion tests/test_tools.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import enum
import json
from typing import Any, Literal
from typing import Annotated, Any, Literal

import pytest
from pydantic import BaseModel, Field
Expand Down Expand Up @@ -450,6 +450,76 @@ def test_non_nullable_enum_excludes_null(self):
assert "null" not in status.get("type", []), f"type should not contain 'null': {status}"


class _CarModel(BaseModel):
vehicle: Literal["Car"]
brand: str
color: str


class _BikeModel(BaseModel):
vehicle: Literal["Bike"]
brand: str
color: str


class _DiscriminatedUnionModel(BaseModel):
item: Annotated[_CarModel | _BikeModel, Field(discriminator="vehicle")]


class _NestedDiscriminatedUnionModel(BaseModel):
items: list[Annotated[_CarModel | _BikeModel, Field(discriminator="vehicle")]]


def _has_one_of(schema: object) -> bool:
"""Recursively check if any dict in the schema tree contains 'oneOf'."""
if isinstance(schema, dict):
if "oneOf" in schema:
return True
return any(_has_one_of(v) for v in schema.values())
if isinstance(schema, list):
return any(_has_one_of(v) for v in schema)
return False


class TestDiscriminatedUnionSchema:
"""Test that discriminated unions use anyOf instead of oneOf in strict schema."""

def test_discriminated_union_uses_anyof_not_oneof(self):
"""Pydantic emits oneOf for discriminated unions, but OpenAI strict mode
rejects oneOf. Ensure to_strict_json_schema converts oneOf to anyOf."""
schema = to_strict_json_schema(_DiscriminatedUnionModel)
assert not _has_one_of(schema), (
f"schema should not contain oneOf: {json.dumps(schema, indent=2)}"
)
item = schema["properties"]["item"]
assert "anyOf" in item, f"item should have anyOf: {json.dumps(item, indent=2)}"
assert len(item["anyOf"]) == 2, f"item should have 2 variants: {json.dumps(item, indent=2)}"

def test_nested_discriminated_union_uses_anyof_not_oneof(self):
"""Nested discriminated unions should also convert oneOf to anyOf."""
schema = to_strict_json_schema(_NestedDiscriminatedUnionModel)
assert not _has_one_of(schema), (
f"nested schema should not contain oneOf: {json.dumps(schema, indent=2)}"
)

def test_discriminated_union_build_strict_openai_schema(self):
"""End-to-end: build_strict_openai_schema should not produce oneOf for
a function tool with a discriminated union parameter."""

@function_tool
async def lookup_vehicle(
item: Annotated[_CarModel | _BikeModel, Field(discriminator="vehicle")],
) -> str:
"""Look up a vehicle."""
return str(item)

schema = build_strict_openai_schema(lookup_vehicle)
schema_str = json.dumps(schema)
assert '"oneOf"' not in schema_str, (
f"strict openai schema should not contain oneOf: {json.dumps(schema, indent=2)}"
)


class _OpenEnumModel(BaseModel):
"""Simulates a codegen'd "open enum" pattern (e.g. Fern Python SDK).
Union[Literal["a", "b"], Any] produces an anyOf with a bare {} entry."""
Expand Down
Loading